From 8ccb9e75977b985c201e98b177706cd05dda5f4e Mon Sep 17 00:00:00 2001 From: ming Date: Tue, 18 Jun 2024 13:52:54 -0400 Subject: [PATCH] Fix mypy (#2204) * mypy github action * fix endpoints.py mypy lint * directly run poetry run mypy * line based mypy error suppression * switch to use make lint * fix ruff issues * fix EmbedComponent lint * fix prompt.py's lint --- .github/workflows/lint-py.yml | 19 ++----------------- src/backend/base/langflow/api/v1/endpoints.py | 9 +++------ src/backend/base/langflow/api/v1/files.py | 10 +++++----- src/backend/base/langflow/api/v1/flows.py | 2 +- src/backend/base/langflow/api/v1/monitor.py | 2 +- .../langflow/base/flow_processing/utils.py | 2 +- src/backend/base/langflow/base/io/chat.py | 4 ++-- .../base/langflow/base/models/model.py | 10 +++++----- .../langflow/components/experimental/Embed.py | 2 +- .../components/helpers/MemoryComponent.py | 2 +- .../base/langflow/components/inputs/Prompt.py | 2 +- .../retrievers/SelfQueryRetriever.py | 6 +++--- .../base/langflow/field_typing/prompt.py | 2 +- src/backend/base/langflow/graph/edge/base.py | 1 - .../base/langflow/graph/vertex/base.py | 4 ++-- src/backend/base/langflow/schema/message.py | 8 ++++---- src/backend/base/langflow/schema/record.py | 6 +++--- .../base/langflow/services/auth/utils.py | 4 ++-- .../langflow/services/variable/service.py | 2 +- src/backend/base/langflow/utils/schemas.py | 2 +- 20 files changed, 40 insertions(+), 59 deletions(-) diff --git a/.github/workflows/lint-py.yml b/.github/workflows/lint-py.yml index eb28b3101..95fbb15e6 100644 --- a/.github/workflows/lint-py.yml +++ b/.github/workflows/lint-py.yml @@ -32,21 +32,6 @@ jobs: run: | poetry env use ${{ matrix.python-version }} poetry install - - name: Get .mypy_cache to speed up mypy - uses: actions/cache@v4 + make lint env: - SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" - with: - path: | - ./.mypy_cache - key: ${{ runner.os }}-mypy-${{ hashFiles('**/pyproject.toml') }} - - name: Run linters - uses: wearerequired/lint-action@v2 - with: - github_token: ${{ secrets.github_token }} - # Enable linters - git_email: "gabriel@langflow.org" - mypy: true - mypy_args: '--namespace-packages -p "langflow"' - mypy_command_prefix: "poetry run" - + GITHUB_TOKEN: ${{ secrets.github_token }} diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 4b31661fa..9124810c7 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -82,7 +82,7 @@ async def simple_run_flow( if input_request.output_type == "debug" or ( vertex.is_output - and (input_request.output_type == "any" or input_request.output_type in vertex.id.lower()) + and (input_request.output_type == "any" or input_request.output_type in vertex.id.lower()) # type: ignore ) ] task_result, session_id = await run_graph_internal( @@ -230,7 +230,7 @@ async def webhook_run_flow( session_id=data_dict.get("session_id"), ) logger.debug("Starting background task") - background_tasks.add_task( + background_tasks.add_task( # type: ignore simple_run_flow, db=db, flow=flow, @@ -325,8 +325,6 @@ async def experimental_run_flow( session_id=session_id, inputs=inputs, outputs=outputs, - artifacts=artifacts, - session_service=session_service, stream=stream, ) @@ -508,9 +506,8 @@ def get_config(): try: from langflow.services.deps import get_settings_service - settings_service: "SettingsService" = get_settings_service() + settings_service: "SettingsService" = get_settings_service() # type: ignore return settings_service.settings.model_dump() except Exception as exc: logger.exception(exc) raise HTTPException(status_code=500, detail=str(exc)) from exc - raise HTTPException(status_code=500, detail=str(exc)) from exc diff --git a/src/backend/base/langflow/api/v1/files.py b/src/backend/base/langflow/api/v1/files.py index a72ee7711..a96397e08 100644 --- a/src/backend/base/langflow/api/v1/files.py +++ b/src/backend/base/langflow/api/v1/files.py @@ -109,10 +109,10 @@ async def download_profile_picture( try: extension = file_name.split(".")[-1] config_dir = get_storage_service().settings_service.settings.config_dir - config_path = Path(config_dir) + config_path = Path(config_dir) # type: ignore folder_path = config_path / "profile_pictures" / folder_name content_type = build_content_type_from_extension(extension) - file_content = await storage_service.get_file(flow_id=folder_path, file_name=file_name) + file_content = await storage_service.get_file(flow_id=folder_path, file_name=file_name) # type: ignore return StreamingResponse(BytesIO(file_content), media_type=content_type) except Exception as e: @@ -123,13 +123,13 @@ async def download_profile_picture( async def list_profile_pictures(storage_service: StorageService = Depends(get_storage_service)): try: config_dir = get_storage_service().settings_service.settings.config_dir - config_path = Path(config_dir) + config_path = Path(config_dir) # type: ignore people_path = config_path / "profile_pictures/People" space_path = config_path / "profile_pictures/Space" - people = await storage_service.list_files(flow_id=people_path) - space = await storage_service.list_files(flow_id=space_path) + people = await storage_service.list_files(flow_id=people_path) # type: ignore + space = await storage_service.list_files(flow_id=space_path) # type: ignore files = [Path("People") / i for i in people] files += [Path("Space") / i for i in space] diff --git a/src/backend/base/langflow/api/v1/flows.py b/src/backend/base/langflow/api/v1/flows.py index 08957e502..299a433d6 100644 --- a/src/backend/base/langflow/api/v1/flows.py +++ b/src/backend/base/langflow/api/v1/flows.py @@ -43,7 +43,7 @@ def create_flow( # based on the highest number found if session.exec(select(Flow).where(Flow.name == flow.name).where(Flow.user_id == current_user.id)).first(): flows = session.exec( - select(Flow).where(Flow.name.like(f"{flow.name} (%")).where(Flow.user_id == current_user.id) + select(Flow).where(Flow.name.like(f"{flow.name} (%")).where(Flow.user_id == current_user.id) # type: ignore ).all() if flows: numbers = [int(flow.name.split("(")[1].split(")")[0]) for flow in flows] diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index 9714e4592..e0608f38a 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -87,7 +87,7 @@ async def update_message( try: message_dict = message.model_dump(exclude_none=True) message_dict.pop("index", None) - monitor_service.update_message(message_id=message_id, **message_dict) + monitor_service.update_message(message_id=message_id, **message_dict) # type: ignore return MessageModelResponse(index=message_id, **message_dict) except Exception as e: diff --git a/src/backend/base/langflow/base/flow_processing/utils.py b/src/backend/base/langflow/base/flow_processing/utils.py index 3a4eb393c..34b3ed1be 100644 --- a/src/backend/base/langflow/base/flow_processing/utils.py +++ b/src/backend/base/langflow/base/flow_processing/utils.py @@ -62,7 +62,7 @@ def build_records_from_result_data(result_data: ResultData, get_final_results_on else: return [] - for message in messages: + for message in messages: # type: ignore message_dict = message if isinstance(message, dict) else message.model_dump() if get_final_results_only: result_data_dict = result_data.model_dump() diff --git a/src/backend/base/langflow/base/io/chat.py b/src/backend/base/langflow/base/io/chat.py index 7d59e8ef3..d855b523a 100644 --- a/src/backend/base/langflow/base/io/chat.py +++ b/src/backend/base/langflow/base/io/chat.py @@ -82,9 +82,9 @@ class ChatComponent(CustomComponent): if not return_message: message_text = message.text else: - message_text = message + message_text = message # type: ignore self.status = message_text if session_id and isinstance(message, Message) and isinstance(message.text, str): self.store_message(message) - return message_text + return message_text # type: ignore diff --git a/src/backend/base/langflow/base/models/model.py b/src/backend/base/langflow/base/models/model.py index 74d350446..5cc4f1066 100644 --- a/src/backend/base/langflow/base/models/model.py +++ b/src/backend/base/langflow/base/models/model.py @@ -78,9 +78,9 @@ class LCModelComponent(CustomComponent): } } else: - status_message = f"Response: {content}" + status_message = f"Response: {content}" # type: ignore else: - status_message = f"Response: {message.content}" + status_message = f"Response: {message.content}" # type: ignore return status_message def get_chat_result( @@ -102,11 +102,11 @@ class LCModelComponent(CustomComponent): messages.append(input_value.to_lc_message()) else: messages.append(HumanMessage(content=input_value)) - inputs = messages or {} + inputs = messages or {} # type: ignore if stream: - return runnable.stream(inputs) + return runnable.stream(inputs) # type: ignore else: - message = runnable.invoke(inputs) + message = runnable.invoke(inputs) # type: ignore result = message.content if isinstance(message, AIMessage): status_message = self.build_status_message(message) diff --git a/src/backend/base/langflow/components/experimental/Embed.py b/src/backend/base/langflow/components/experimental/Embed.py index 88de23486..177eb135c 100644 --- a/src/backend/base/langflow/components/experimental/Embed.py +++ b/src/backend/base/langflow/components/experimental/Embed.py @@ -9,7 +9,7 @@ class EmbedComponent(CustomComponent): def build_config(self): return {"texts": {"display_name": "Texts"}, "embbedings": {"display_name": "Embeddings"}} - def build(self, texts: list[str], embbedings: Embeddings) -> Embeddings: + def build(self, texts: list[str], embbedings: Embeddings) -> Record: vectors = Record(vector=embbedings.embed_documents(texts)) self.status = vectors return vectors diff --git a/src/backend/base/langflow/components/helpers/MemoryComponent.py b/src/backend/base/langflow/components/helpers/MemoryComponent.py index 96e82da1e..235370bee 100644 --- a/src/backend/base/langflow/components/helpers/MemoryComponent.py +++ b/src/backend/base/langflow/components/helpers/MemoryComponent.py @@ -43,7 +43,7 @@ class MemoryComponent(BaseMemoryComponent): }, } - def get_messages(self, **kwargs) -> list[Message]: + def get_messages(self, **kwargs) -> list[Message]: # type: ignore # Validate kwargs by checking if it contains the correct keys if "sender" not in kwargs: kwargs["sender"] = None diff --git a/src/backend/base/langflow/components/inputs/Prompt.py b/src/backend/base/langflow/components/inputs/Prompt.py index e65d27576..a6140deee 100644 --- a/src/backend/base/langflow/components/inputs/Prompt.py +++ b/src/backend/base/langflow/components/inputs/Prompt.py @@ -16,7 +16,7 @@ class PromptComponent(CustomComponent): async def build( self, - template: Prompt, + template: str, **kwargs, ) -> Prompt: prompt = await Prompt.from_template_and_variables(template, kwargs) diff --git a/src/backend/base/langflow/components/retrievers/SelfQueryRetriever.py b/src/backend/base/langflow/components/retrievers/SelfQueryRetriever.py index 3e6d6f696..7dc53caff 100644 --- a/src/backend/base/langflow/components/retrievers/SelfQueryRetriever.py +++ b/src/backend/base/langflow/components/retrievers/SelfQueryRetriever.py @@ -62,7 +62,7 @@ class SelfQueryRetrieverComponent(CustomComponent): input_text = query else: raise ValueError(f"Query type {type(query)} not supported.") - documents = self_query_retriever.invoke(input=input_text) + documents = self_query_retriever.invoke(input=input_text) # type: ignore records = [Record.from_document(document) for document in documents] - self.status = records - return records + self.status = records # type: ignore + return records # type: ignore diff --git a/src/backend/base/langflow/field_typing/prompt.py b/src/backend/base/langflow/field_typing/prompt.py index 029261ac7..f7cdece35 100644 --- a/src/backend/base/langflow/field_typing/prompt.py +++ b/src/backend/base/langflow/field_typing/prompt.py @@ -37,6 +37,6 @@ class Prompt(Record): if isinstance(value, Message): content_dicts = await value.get_file_content_dicts() contents.extend(content_dicts) - prompt_template = ChatPromptTemplate.from_messages([HumanMessage(content=contents)]) + prompt_template = ChatPromptTemplate.from_messages([HumanMessage(content=contents)]) # type: ignore instance.prompt = prompt_template.to_json() return instance diff --git a/src/backend/base/langflow/graph/edge/base.py b/src/backend/base/langflow/graph/edge/base.py index 1fc2ef344..b99785ea6 100644 --- a/src/backend/base/langflow/graph/edge/base.py +++ b/src/backend/base/langflow/graph/edge/base.py @@ -153,7 +153,6 @@ class ContractEdge(Edge): sender_name=target.params.get("sender_name", ""), message=target.params.get(INPUT_FIELD_NAME, {}), session_id=target.params.get("session_id", ""), - artifacts=target.artifacts, flow_id=target.graph.flow_id, ) return self.result diff --git a/src/backend/base/langflow/graph/vertex/base.py b/src/backend/base/langflow/graph/vertex/base.py index 427512e5f..64c951ce3 100644 --- a/src/backend/base/langflow/graph/vertex/base.py +++ b/src/backend/base/langflow/graph/vertex/base.py @@ -389,9 +389,9 @@ class Vertex: if any(isinstance(self._raw_params.get(key), Vertex) for key in new_params): return if not overwrite: - for key in new_params.copy(): + for key in new_params.copy(): # type: ignore if key not in self._raw_params: - new_params.pop(key) + new_params.pop(key) # type: ignore self._raw_params.update(new_params) self.params = self._raw_params.copy() self.updated_raw_params = True diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index 30a1d6a68..c8888fb24 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -34,8 +34,8 @@ class Message(BaseModel): if is_image_file(file): new_files.append(Image(path=file)) else: - new_files.append(file) - self.files = new_files + new_files.append(file) # type: ignore + self.files = new_files # type: ignore def to_lc_message( self, @@ -58,7 +58,7 @@ class Message(BaseModel): if self.files: contents = [{"type": "text", "text": self.text}] contents.extend(self.get_file_content_dicts()) - human_message = HumanMessage(content=contents) + human_message = HumanMessage(content=contents) # type: ignore else: human_message = HumanMessage( content=[{"type": "text", "text": self.text}], @@ -66,7 +66,7 @@ class Message(BaseModel): return human_message - return AIMessage(content=self.text) + return AIMessage(content=self.text) # type: ignore @classmethod def from_record(cls, record: "Record") -> "Message": diff --git a/src/backend/base/langflow/schema/record.py b/src/backend/base/langflow/schema/record.py index 67d9b5da8..c822d5ce5 100644 --- a/src/backend/base/langflow/schema/record.py +++ b/src/backend/base/langflow/schema/record.py @@ -133,9 +133,9 @@ class Record(BaseModel): contents = [{"type": "text", "text": text}] for file_path in files: image_template = ImagePromptTemplate() - image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file_path}) + image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file_path}) # type: ignore contents.append({"type": "image_url", "image_url": image_prompt_value.image_url}) - human_message = HumanMessage(content=contents) + human_message = HumanMessage(content=contents) # type: ignore else: human_message = HumanMessage( content=[{"type": "text", "text": text}], @@ -143,7 +143,7 @@ class Record(BaseModel): return human_message - return AIMessage(content=text) + return AIMessage(content=text) # type: ignore def __getattr__(self, key): """ diff --git a/src/backend/base/langflow/services/auth/utils.py b/src/backend/base/langflow/services/auth/utils.py index 49a1e4720..06836d21a 100644 --- a/src/backend/base/langflow/services/auth/utils.py +++ b/src/backend/base/langflow/services/auth/utils.py @@ -107,8 +107,8 @@ async def get_current_user_by_jwt( with warnings.catch_warnings(): warnings.simplefilter("ignore") payload = jwt.decode(token, secret_key, algorithms=[settings_service.auth_settings.ALGORITHM]) - user_id: UUID = payload.get("sub") - token_type: str = payload.get("type") + user_id: UUID = payload.get("sub") # type: ignore + token_type: str = payload.get("type") # type: ignore if expires := payload.get("exp", None): expires_datetime = datetime.fromtimestamp(expires, timezone.utc) if datetime.now(timezone.utc) > expires_datetime: diff --git a/src/backend/base/langflow/services/variable/service.py b/src/backend/base/langflow/services/variable/service.py index b2389e890..bab6801b1 100644 --- a/src/backend/base/langflow/services/variable/service.py +++ b/src/backend/base/langflow/services/variable/service.py @@ -61,7 +61,7 @@ class VariableService(Service): # credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first() variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first() - if variable.type == "Credential" and field == "session_id": + if variable.type == "Credential" and field == "session_id": # type: ignore raise TypeError( f"variable {name} of type 'Credential' cannot be used in a Session ID field " "because its purpose is to prevent the exposure of values." diff --git a/src/backend/base/langflow/utils/schemas.py b/src/backend/base/langflow/utils/schemas.py index 647941f59..a17fc8aa6 100644 --- a/src/backend/base/langflow/utils/schemas.py +++ b/src/backend/base/langflow/utils/schemas.py @@ -76,7 +76,7 @@ class ChatOutputResponse(BaseModel): ): """Build chat output response from message.""" content = message.content - return cls(message=content, sender=sender, sender_name=sender_name) + return cls(message=content, sender=sender, sender_name=sender_name) # type: ignore @model_validator(mode="after") def validate_message(self):