Merge remote-tracking branch 'origin/dev' into two_edges

This commit is contained in:
ogabrielluiz 2024-06-19 01:11:38 -03:00
commit bac6a8cdff
14 changed files with 27 additions and 46 deletions

View file

@ -32,21 +32,6 @@ jobs:
run: |
poetry env use ${{ matrix.python-version }}
poetry install
- name: Get .mypy_cache to speed up mypy
uses: actions/cache@v4
make lint
env:
SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2"
with:
path: |
./.mypy_cache
key: ${{ runner.os }}-mypy-${{ hashFiles('**/pyproject.toml') }}
- name: Run linters
uses: wearerequired/lint-action@v2
with:
github_token: ${{ secrets.github_token }}
# Enable linters
git_email: "gabriel@langflow.org"
mypy: true
mypy_args: '--namespace-packages -p "langflow"'
mypy_command_prefix: "poetry run"
GITHUB_TOKEN: ${{ secrets.github_token }}

View file

@ -96,7 +96,7 @@ async def simple_run_flow(
if input_request.output_type == "debug"
or (
vertex.is_output
and (input_request.output_type == "any" or input_request.output_type in vertex.id.lower())
and (input_request.output_type == "any" or input_request.output_type in vertex.id.lower()) # type: ignore
)
]
task_result, session_id = await run_graph_internal(
@ -242,7 +242,7 @@ async def webhook_run_flow(
session_id=data_dict.get("session_id"),
)
logger.debug("Starting background task")
background_tasks.add_task(
background_tasks.add_task( # type: ignore
simple_run_flow,
flow=flow,
input_request=input_request,
@ -336,8 +336,6 @@ async def experimental_run_flow(
session_id=session_id,
inputs=inputs,
outputs=outputs,
artifacts=artifacts,
session_service=session_service,
stream=stream,
)
@ -519,9 +517,8 @@ def get_config():
try:
from langflow.services.deps import get_settings_service
settings_service: "SettingsService" = get_settings_service()
settings_service: "SettingsService" = get_settings_service() # type: ignore
return settings_service.settings.model_dump()
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
raise HTTPException(status_code=500, detail=str(exc)) from exc

View file

@ -109,10 +109,10 @@ async def download_profile_picture(
try:
extension = file_name.split(".")[-1]
config_dir = get_storage_service().settings_service.settings.config_dir
config_path = Path(config_dir)
config_path = Path(config_dir) # type: ignore
folder_path = config_path / "profile_pictures" / folder_name
content_type = build_content_type_from_extension(extension)
file_content = await storage_service.get_file(flow_id=folder_path, file_name=file_name)
file_content = await storage_service.get_file(flow_id=folder_path, file_name=file_name) # type: ignore
return StreamingResponse(BytesIO(file_content), media_type=content_type)
except Exception as e:
@ -123,13 +123,13 @@ async def download_profile_picture(
async def list_profile_pictures(storage_service: StorageService = Depends(get_storage_service)):
try:
config_dir = get_storage_service().settings_service.settings.config_dir
config_path = Path(config_dir)
config_path = Path(config_dir) # type: ignore
people_path = config_path / "profile_pictures/People"
space_path = config_path / "profile_pictures/Space"
people = await storage_service.list_files(flow_id=people_path)
space = await storage_service.list_files(flow_id=space_path)
people = await storage_service.list_files(flow_id=people_path) # type: ignore
space = await storage_service.list_files(flow_id=space_path) # type: ignore
files = [Path("People") / i for i in people]
files += [Path("Space") / i for i in space]

View file

@ -43,7 +43,7 @@ def create_flow(
# based on the highest number found
if session.exec(select(Flow).where(Flow.name == flow.name).where(Flow.user_id == current_user.id)).first():
flows = session.exec(
select(Flow).where(Flow.name.like(f"{flow.name} (%")).where(Flow.user_id == current_user.id)
select(Flow).where(Flow.name.like(f"{flow.name} (%")).where(Flow.user_id == current_user.id) # type: ignore
).all()
if flows:
numbers = [int(flow.name.split("(")[1].split(")")[0]) for flow in flows]

View file

@ -88,7 +88,7 @@ async def update_message(
try:
message_dict = message.model_dump(exclude_none=True)
message_dict.pop("index", None)
monitor_service.update_message(message_id=message_id, **message_dict)
monitor_service.update_message(message_id=message_id, **message_dict) # type: ignore
return MessageModelResponse(index=message_id, **message_dict)
except Exception as e:

View file

@ -65,7 +65,7 @@ def build_data_from_result_data(result_data: ResultData, get_final_results_only:
else:
return []
for message in messages:
for message in messages: # type: ignore
message_dict = message if isinstance(message, dict) else message.model_dump()
if get_final_results_only:
result_data_dict = result_data.model_dump()

View file

@ -82,9 +82,9 @@ class ChatComponent(Component):
if not return_message:
message_text = message.text
else:
message_text = message
message_text = message # type: ignore
self.status = message_text
if session_id and isinstance(message, Message) and isinstance(message.text, str):
self.store_message(message)
return message_text
return message_text # type: ignore

View file

@ -85,9 +85,9 @@ class LCModelComponent(Component):
}
}
else:
status_message = f"Response: {content}"
status_message = f"Response: {content}" # type: ignore
else:
status_message = f"Response: {message.content}"
status_message = f"Response: {message.content}" # type: ignore
return status_message
def get_chat_result(

View file

@ -229,7 +229,6 @@ class ContractEdge(Edge):
sender_name=target.params.get("sender_name", ""),
message=target.params.get(INPUT_FIELD_NAME, {}),
session_id=target.params.get("session_id", ""),
artifacts=target.artifacts,
flow_id=target.graph.flow_id,
)
return self.result

View file

@ -393,9 +393,9 @@ class Vertex:
if any(isinstance(self._raw_params.get(key), Vertex) for key in new_params):
return
if not overwrite:
for key in new_params.copy():
for key in new_params.copy(): # type: ignore
if key not in self._raw_params:
new_params.pop(key)
new_params.pop(key) # type: ignore
self._raw_params.update(new_params)
self.params = self._raw_params.copy()
self.updated_raw_params = True

View file

@ -133,9 +133,9 @@ class Data(BaseModel):
contents = [{"type": "text", "text": text}]
for file_path in files:
image_template = ImagePromptTemplate()
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file_path})
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file_path}) # type: ignore
contents.append({"type": "image_url", "image_url": image_prompt_value.image_url})
human_message = HumanMessage(content=contents)
human_message = HumanMessage(content=contents) # type: ignore
else:
human_message = HumanMessage(
content=[{"type": "text", "text": text}],
@ -143,7 +143,7 @@ class Data(BaseModel):
return human_message
return AIMessage(content=text)
return AIMessage(content=text) # type: ignore
def __getattr__(self, key):
"""

View file

@ -68,7 +68,7 @@ class Message(Data):
if self.files:
contents = [{"type": "text", "text": self.text}]
contents.extend(self.get_file_content_dicts())
human_message = HumanMessage(content=contents)
human_message = HumanMessage(content=contents) # type: ignore
else:
human_message = HumanMessage(
content=[{"type": "text", "text": self.text}],
@ -76,7 +76,7 @@ class Message(Data):
return human_message
return AIMessage(content=self.text)
return AIMessage(content=self.text) # type: ignore
@classmethod
def from_data(cls, data: "Data") -> "Message":

View file

@ -107,8 +107,8 @@ async def get_current_user_by_jwt(
with warnings.catch_warnings():
warnings.simplefilter("ignore")
payload = jwt.decode(token, secret_key, algorithms=[settings_service.auth_settings.ALGORITHM])
user_id: UUID = payload.get("sub")
token_type: str = payload.get("type")
user_id: UUID = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore
if expires := payload.get("exp", None):
expires_datetime = datetime.fromtimestamp(expires, timezone.utc)
if datetime.now(timezone.utc) > expires_datetime:

View file

@ -76,7 +76,7 @@ class ChatOutputResponse(BaseModel):
):
"""Build chat output response from message."""
content = message.content
return cls(message=content, sender=sender, sender_name=sender_name)
return cls(message=content, sender=sender, sender_name=sender_name) # type: ignore
@model_validator(mode="after")
def validate_message(self):