* mypy github action

* fix endpoints.py mypy lint

* directly run poetry run mypy

* line based mypy error suppression

* switch to use make lint

* fix ruff issues

* fix EmbedComponent lint

* fix prompt.py's lint
This commit is contained in:
ming 2024-06-18 13:52:54 -04:00 committed by GitHub
commit 8ccb9e7597
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 40 additions and 59 deletions

View file

@ -32,21 +32,6 @@ jobs:
run: |
poetry env use ${{ matrix.python-version }}
poetry install
- name: Get .mypy_cache to speed up mypy
uses: actions/cache@v4
make lint
env:
SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2"
with:
path: |
./.mypy_cache
key: ${{ runner.os }}-mypy-${{ hashFiles('**/pyproject.toml') }}
- name: Run linters
uses: wearerequired/lint-action@v2
with:
github_token: ${{ secrets.github_token }}
# Enable linters
git_email: "gabriel@langflow.org"
mypy: true
mypy_args: '--namespace-packages -p "langflow"'
mypy_command_prefix: "poetry run"
GITHUB_TOKEN: ${{ secrets.github_token }}

View file

@ -82,7 +82,7 @@ async def simple_run_flow(
if input_request.output_type == "debug"
or (
vertex.is_output
and (input_request.output_type == "any" or input_request.output_type in vertex.id.lower())
and (input_request.output_type == "any" or input_request.output_type in vertex.id.lower()) # type: ignore
)
]
task_result, session_id = await run_graph_internal(
@ -230,7 +230,7 @@ async def webhook_run_flow(
session_id=data_dict.get("session_id"),
)
logger.debug("Starting background task")
background_tasks.add_task(
background_tasks.add_task( # type: ignore
simple_run_flow,
db=db,
flow=flow,
@ -325,8 +325,6 @@ async def experimental_run_flow(
session_id=session_id,
inputs=inputs,
outputs=outputs,
artifacts=artifacts,
session_service=session_service,
stream=stream,
)
@ -508,9 +506,8 @@ def get_config():
try:
from langflow.services.deps import get_settings_service
settings_service: "SettingsService" = get_settings_service()
settings_service: "SettingsService" = get_settings_service() # type: ignore
return settings_service.settings.model_dump()
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
raise HTTPException(status_code=500, detail=str(exc)) from exc

View file

@ -109,10 +109,10 @@ async def download_profile_picture(
try:
extension = file_name.split(".")[-1]
config_dir = get_storage_service().settings_service.settings.config_dir
config_path = Path(config_dir)
config_path = Path(config_dir) # type: ignore
folder_path = config_path / "profile_pictures" / folder_name
content_type = build_content_type_from_extension(extension)
file_content = await storage_service.get_file(flow_id=folder_path, file_name=file_name)
file_content = await storage_service.get_file(flow_id=folder_path, file_name=file_name) # type: ignore
return StreamingResponse(BytesIO(file_content), media_type=content_type)
except Exception as e:
@ -123,13 +123,13 @@ async def download_profile_picture(
async def list_profile_pictures(storage_service: StorageService = Depends(get_storage_service)):
try:
config_dir = get_storage_service().settings_service.settings.config_dir
config_path = Path(config_dir)
config_path = Path(config_dir) # type: ignore
people_path = config_path / "profile_pictures/People"
space_path = config_path / "profile_pictures/Space"
people = await storage_service.list_files(flow_id=people_path)
space = await storage_service.list_files(flow_id=space_path)
people = await storage_service.list_files(flow_id=people_path) # type: ignore
space = await storage_service.list_files(flow_id=space_path) # type: ignore
files = [Path("People") / i for i in people]
files += [Path("Space") / i for i in space]

View file

@ -43,7 +43,7 @@ def create_flow(
# based on the highest number found
if session.exec(select(Flow).where(Flow.name == flow.name).where(Flow.user_id == current_user.id)).first():
flows = session.exec(
select(Flow).where(Flow.name.like(f"{flow.name} (%")).where(Flow.user_id == current_user.id)
select(Flow).where(Flow.name.like(f"{flow.name} (%")).where(Flow.user_id == current_user.id) # type: ignore
).all()
if flows:
numbers = [int(flow.name.split("(")[1].split(")")[0]) for flow in flows]

View file

@ -87,7 +87,7 @@ async def update_message(
try:
message_dict = message.model_dump(exclude_none=True)
message_dict.pop("index", None)
monitor_service.update_message(message_id=message_id, **message_dict)
monitor_service.update_message(message_id=message_id, **message_dict) # type: ignore
return MessageModelResponse(index=message_id, **message_dict)
except Exception as e:

View file

@ -62,7 +62,7 @@ def build_records_from_result_data(result_data: ResultData, get_final_results_on
else:
return []
for message in messages:
for message in messages: # type: ignore
message_dict = message if isinstance(message, dict) else message.model_dump()
if get_final_results_only:
result_data_dict = result_data.model_dump()

View file

@ -82,9 +82,9 @@ class ChatComponent(CustomComponent):
if not return_message:
message_text = message.text
else:
message_text = message
message_text = message # type: ignore
self.status = message_text
if session_id and isinstance(message, Message) and isinstance(message.text, str):
self.store_message(message)
return message_text
return message_text # type: ignore

View file

@ -78,9 +78,9 @@ class LCModelComponent(CustomComponent):
}
}
else:
status_message = f"Response: {content}"
status_message = f"Response: {content}" # type: ignore
else:
status_message = f"Response: {message.content}"
status_message = f"Response: {message.content}" # type: ignore
return status_message
def get_chat_result(
@ -102,11 +102,11 @@ class LCModelComponent(CustomComponent):
messages.append(input_value.to_lc_message())
else:
messages.append(HumanMessage(content=input_value))
inputs = messages or {}
inputs = messages or {} # type: ignore
if stream:
return runnable.stream(inputs)
return runnable.stream(inputs) # type: ignore
else:
message = runnable.invoke(inputs)
message = runnable.invoke(inputs) # type: ignore
result = message.content
if isinstance(message, AIMessage):
status_message = self.build_status_message(message)

View file

@ -9,7 +9,7 @@ class EmbedComponent(CustomComponent):
def build_config(self):
return {"texts": {"display_name": "Texts"}, "embbedings": {"display_name": "Embeddings"}}
def build(self, texts: list[str], embbedings: Embeddings) -> Embeddings:
def build(self, texts: list[str], embbedings: Embeddings) -> Record:
vectors = Record(vector=embbedings.embed_documents(texts))
self.status = vectors
return vectors

View file

@ -43,7 +43,7 @@ class MemoryComponent(BaseMemoryComponent):
},
}
def get_messages(self, **kwargs) -> list[Message]:
def get_messages(self, **kwargs) -> list[Message]: # type: ignore
# Validate kwargs by checking if it contains the correct keys
if "sender" not in kwargs:
kwargs["sender"] = None

View file

@ -16,7 +16,7 @@ class PromptComponent(CustomComponent):
async def build(
self,
template: Prompt,
template: str,
**kwargs,
) -> Prompt:
prompt = await Prompt.from_template_and_variables(template, kwargs)

View file

@ -62,7 +62,7 @@ class SelfQueryRetrieverComponent(CustomComponent):
input_text = query
else:
raise ValueError(f"Query type {type(query)} not supported.")
documents = self_query_retriever.invoke(input=input_text)
documents = self_query_retriever.invoke(input=input_text) # type: ignore
records = [Record.from_document(document) for document in documents]
self.status = records
return records
self.status = records # type: ignore
return records # type: ignore

View file

@ -37,6 +37,6 @@ class Prompt(Record):
if isinstance(value, Message):
content_dicts = await value.get_file_content_dicts()
contents.extend(content_dicts)
prompt_template = ChatPromptTemplate.from_messages([HumanMessage(content=contents)])
prompt_template = ChatPromptTemplate.from_messages([HumanMessage(content=contents)]) # type: ignore
instance.prompt = prompt_template.to_json()
return instance

View file

@ -153,7 +153,6 @@ class ContractEdge(Edge):
sender_name=target.params.get("sender_name", ""),
message=target.params.get(INPUT_FIELD_NAME, {}),
session_id=target.params.get("session_id", ""),
artifacts=target.artifacts,
flow_id=target.graph.flow_id,
)
return self.result

View file

@ -389,9 +389,9 @@ class Vertex:
if any(isinstance(self._raw_params.get(key), Vertex) for key in new_params):
return
if not overwrite:
for key in new_params.copy():
for key in new_params.copy(): # type: ignore
if key not in self._raw_params:
new_params.pop(key)
new_params.pop(key) # type: ignore
self._raw_params.update(new_params)
self.params = self._raw_params.copy()
self.updated_raw_params = True

View file

@ -34,8 +34,8 @@ class Message(BaseModel):
if is_image_file(file):
new_files.append(Image(path=file))
else:
new_files.append(file)
self.files = new_files
new_files.append(file) # type: ignore
self.files = new_files # type: ignore
def to_lc_message(
self,
@ -58,7 +58,7 @@ class Message(BaseModel):
if self.files:
contents = [{"type": "text", "text": self.text}]
contents.extend(self.get_file_content_dicts())
human_message = HumanMessage(content=contents)
human_message = HumanMessage(content=contents) # type: ignore
else:
human_message = HumanMessage(
content=[{"type": "text", "text": self.text}],
@ -66,7 +66,7 @@ class Message(BaseModel):
return human_message
return AIMessage(content=self.text)
return AIMessage(content=self.text) # type: ignore
@classmethod
def from_record(cls, record: "Record") -> "Message":

View file

@ -133,9 +133,9 @@ class Record(BaseModel):
contents = [{"type": "text", "text": text}]
for file_path in files:
image_template = ImagePromptTemplate()
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file_path})
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file_path}) # type: ignore
contents.append({"type": "image_url", "image_url": image_prompt_value.image_url})
human_message = HumanMessage(content=contents)
human_message = HumanMessage(content=contents) # type: ignore
else:
human_message = HumanMessage(
content=[{"type": "text", "text": text}],
@ -143,7 +143,7 @@ class Record(BaseModel):
return human_message
return AIMessage(content=text)
return AIMessage(content=text) # type: ignore
def __getattr__(self, key):
"""

View file

@ -107,8 +107,8 @@ async def get_current_user_by_jwt(
with warnings.catch_warnings():
warnings.simplefilter("ignore")
payload = jwt.decode(token, secret_key, algorithms=[settings_service.auth_settings.ALGORITHM])
user_id: UUID = payload.get("sub")
token_type: str = payload.get("type")
user_id: UUID = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore
if expires := payload.get("exp", None):
expires_datetime = datetime.fromtimestamp(expires, timezone.utc)
if datetime.now(timezone.utc) > expires_datetime:

View file

@ -61,7 +61,7 @@ class VariableService(Service):
# credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first()
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
if variable.type == "Credential" and field == "session_id":
if variable.type == "Credential" and field == "session_id": # type: ignore
raise TypeError(
f"variable {name} of type 'Credential' cannot be used in a Session ID field "
"because its purpose is to prevent the exposure of values."

View file

@ -76,7 +76,7 @@ class ChatOutputResponse(BaseModel):
):
"""Build chat output response from message."""
content = message.content
return cls(message=content, sender=sender, sender_name=sender_name)
return cls(message=content, sender=sender, sender_name=sender_name) # type: ignore
@model_validator(mode="after")
def validate_message(self):