ref: Apply ruff rule TC006 (#5088)

Apply ruff rule TC006
This commit is contained in:
Christophe Bornet 2024-12-08 12:35:00 +01:00 committed by GitHub
commit 60b1927cc5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 34 additions and 30 deletions

View file

@ -17,7 +17,6 @@ from langflow.io import BoolInput, HandleInput, IntInput, MessageTextInput
from langflow.memory import delete_message
from langflow.schema import Data
from langflow.schema.content_block import ContentBlock
from langflow.schema.log import SendMessageFunctionType
from langflow.schema.message import Message
from langflow.template import Output
from langflow.utils.constants import MESSAGE_SENDER_AI
@ -25,6 +24,8 @@ from langflow.utils.constants import MESSAGE_SENDER_AI
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
from langflow.schema.log import SendMessageFunctionType
DEFAULT_TOOLS_DESCRIPTION = "A helpful assistant with access to the following tools:"
DEFAULT_AGENT_NAME = "Agent ({tools_names})"
@ -163,7 +164,7 @@ class LCAgentComponent(Component):
version="v2",
),
agent_message,
cast(SendMessageFunctionType, self.send_message),
cast("SendMessageFunctionType", self.send_message),
)
except ExceptionWithMessageError as e:
msg_id = e.agent_message.id

View file

@ -177,7 +177,7 @@ class BaseCrewComponent(Component):
_id = self._vertex.id if self._vertex else self.display_name
if isinstance(agent_output, AgentFinish):
messages = agent_output.messages
self.log(cast(dict, messages[0].to_json()), name=f"Finish (Agent: {_id})")
self.log(cast("dict", messages[0].to_json()), name=f"Finish (Agent: {_id})")
elif isinstance(agent_output, list):
_messages_dict = {f"Action {i}": action.messages for i, (action, _) in enumerate(agent_output)}
# Serialize the messages with to_json() to avoid issues with circular references

View file

@ -31,7 +31,7 @@ class ChatComponent(Component):
self.status = messages
self._send_messages_events(messages)
return cast(str | Message, message_text)
return cast("str | Message", message_text)
def _create_message(self, input_value, sender, sender_name, files, session_id) -> Message:
if isinstance(input_value, Data):

View file

@ -75,7 +75,7 @@ class CohereRerankComponent(LCVectorStoreComponent):
user_agent=self.user_agent,
)
retriever = ContextualCompressionRetriever(base_compressor=cohere_reranker, base_retriever=self.retriever)
return cast(Retriever, retriever)
return cast("Retriever", retriever)
async def search_documents(self) -> list[Data]: # type: ignore[override]
retriever = self.build_base_retriever()

View file

@ -1,14 +1,16 @@
from typing import cast
from typing import TYPE_CHECKING, cast
from pydantic import BaseModel, Field, create_model
from langflow.base.models.chat_result import get_chat_result
from langflow.custom import Component
from langflow.field_typing.constants import LanguageModel
from langflow.helpers.base_model import build_model_from_schema
from langflow.io import BoolInput, HandleInput, MessageTextInput, Output, StrInput, TableInput
from langflow.schema.data import Data
if TYPE_CHECKING:
from langflow.field_typing.constants import LanguageModel
class StructuredOutputComponent(Component):
display_name = "Structured Output"
@ -94,7 +96,7 @@ class StructuredOutputComponent(Component):
else:
output_model = _output_model
try:
llm_with_structured_output = cast(LanguageModel, self.llm).with_structured_output(schema=output_model) # type: ignore[valid-type, attr-defined]
llm_with_structured_output = cast("LanguageModel", self.llm).with_structured_output(schema=output_model) # type: ignore[valid-type, attr-defined]
except NotImplementedError as exc:
msg = f"{self.llm.__class__.__name__} does not support structured output."

View file

@ -63,7 +63,7 @@ class ChatVertexAIComponent(LCModelComponent):
credentials = None
return cast(
LanguageModel,
"LanguageModel",
ChatVertexAI(
credentials=credentials,
location=location,

View file

@ -83,7 +83,7 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
nvidia_reranker = self.build_model()
retriever = ContextualCompressionRetriever(base_compressor=nvidia_reranker, base_retriever=self.retriever)
return cast(Retriever, retriever)
return cast("Retriever", retriever)
async def search_documents(self) -> list[Data]: # type: ignore[override]
retriever = self.build_base_retriever()

View file

@ -51,4 +51,4 @@ class AmazonKendraRetrieverComponent(CustomComponent):
except Exception as e:
msg = "Could not connect to AmazonKendra API."
raise ValueError(msg) from e
return cast(Retriever, output)
return cast("Retriever", output)

View file

@ -28,4 +28,4 @@ class MetalRetrieverComponent(CustomComponent):
except Exception as e:
msg = "Could not connect to Metal API."
raise ValueError(msg) from e
return cast(Retriever, MetalRetriever(client=metal, params=params or {}))
return cast("Retriever", MetalRetriever(client=metal, params=params or {}))

View file

@ -44,4 +44,4 @@ class BingSearchAPIComponent(LCToolComponent):
)
else:
wrapper = BingSearchAPIWrapper(bing_subscription_key=self.bing_subscription_key)
return cast(Tool, BingSearchResults(api_wrapper=wrapper, num_results=self.k))
return cast("Tool", BingSearchResults(api_wrapper=wrapper, num_results=self.k))

View file

@ -37,7 +37,7 @@ class WikipediaAPIComponent(LCToolComponent):
def build_tool(self) -> Tool:
wrapper = self._build_wrapper()
return cast(Tool, WikipediaQueryRun(api_wrapper=wrapper))
return cast("Tool", WikipediaQueryRun(api_wrapper=wrapper))
def _build_wrapper(self) -> WikipediaAPIWrapper:
return WikipediaAPIWrapper(

View file

@ -60,7 +60,7 @@ class VectaraSelfQueryRetriverComponent(CustomComponent):
metadata_field_obj.append(attribute_info)
return cast(
Retriever,
"Retriever",
SelfQueryRetriever.from_llm(
llm, vectorstore, document_content_description, metadata_field_obj, verbose=True
),

View file

@ -22,7 +22,7 @@ class Edge:
self.is_cycle = False
if data := edge.get("data", {}):
self._source_handle = data.get("sourceHandle", {})
self._target_handle = cast(TargetHandleDict, data.get("targetHandle", {}))
self._target_handle = cast("TargetHandleDict", data.get("targetHandle", {}))
self.source_handle: SourceHandle = SourceHandle(**self._source_handle)
if isinstance(self._target_handle, dict):
try:

View file

@ -8,7 +8,6 @@ import queue
import threading
import uuid
from collections import defaultdict, deque
from collections.abc import Generator, Iterable
from datetime import datetime, timezone
from functools import partial
from itertools import chain
@ -43,6 +42,8 @@ from langflow.services.deps import get_chat_service, get_tracing_service
from langflow.utils.async_helpers import run_until_complete
if TYPE_CHECKING:
from collections.abc import Generator, Iterable
from langflow.api.v1.schemas import InputValueRequest
from langflow.custom.custom_component.component import Component
from langflow.events.event_manager import EventManager
@ -1728,7 +1729,7 @@ class Graph:
edges.add(new_edge)
if self.vertices and not edges:
logger.warning("Graph has vertices but no edges")
return list(cast(Iterable[CycleEdge], edges))
return list(cast("Iterable[CycleEdge]", edges))
def build_edge(self, edge: EdgeData) -> CycleEdge | Edge:
source = self.get_vertex(edge["source"])

View file

@ -129,7 +129,7 @@ class ComponentVertex(Vertex):
if output.value is UNDEFINED:
result = self.results[edge.source_handle.name]
else:
result = cast(Any, output.value)
result = cast("Any", output.value)
except NoComponentInstanceError:
result = self.results[edge.source_handle.name]
break

View file

@ -106,7 +106,7 @@ async def run_flow(
inputs_components = []
types = []
for input_dict in inputs:
inputs_list.append({INPUT_FIELD_NAME: cast(str, input_dict.get("input_value"))})
inputs_list.append({INPUT_FIELD_NAME: cast("str", input_dict.get("input_value"))})
inputs_components.append(input_dict.get("components", []))
types.append(input_dict.get("type", "chat"))

View file

@ -175,10 +175,10 @@ def process_tweaks(
:return: The modified graph_data dictionary.
:raises ValueError: If the input is not in the expected format.
"""
tweaks_dict = cast(dict[str, Any], tweaks.model_dump()) if not isinstance(tweaks, dict) else tweaks
tweaks_dict = cast("dict[str, Any]", tweaks.model_dump()) if not isinstance(tweaks, dict) else tweaks
if "stream" not in tweaks_dict:
tweaks_dict |= {"stream": stream}
nodes = validate_input(graph_data, cast(dict[str, str | dict[str, Any]], tweaks_dict))
nodes = validate_input(graph_data, cast("dict[str, str | dict[str, Any]]", tweaks_dict))
nodes_map = {node.get("id"): node for node in nodes}
nodes_display_name_map = {node.get("data", {}).get("node", {}).get("display_name"): node for node in nodes}

View file

@ -97,7 +97,7 @@ class Data(BaseModel):
Data: The converted Data.
"""
data: dict = {"text": message.content}
data["metadata"] = cast(dict, message.to_json())
data["metadata"] = cast("dict", message.to_json())
return cls(data=data, text_key="text")
def __add__(self, other: "Data") -> "Data":

View file

@ -70,7 +70,7 @@ class DataFrame(pandas_DataFrame):
if isinstance(data, Data):
data = data.data
new_df = self._constructor([data])
return cast(DataFrame, pd.concat([self, new_df], ignore_index=True))
return cast("DataFrame", pd.concat([self, new_df], ignore_index=True))
def add_rows(self, data: list[dict | Data]) -> "DataFrame":
"""Adds multiple rows to the dataset.
@ -88,7 +88,7 @@ class DataFrame(pandas_DataFrame):
else:
processed_data.append(item)
new_df = self._constructor(processed_data)
return cast(DataFrame, pd.concat([self, new_df], ignore_index=True))
return cast("DataFrame", pd.concat([self, new_df], ignore_index=True))
@property
def _constructor(self):

View file

@ -162,15 +162,15 @@ class LangWatchTracer(BaseTracer):
if "prompt" in value:
prompt = value.load_lc_prompt()
if len(prompt.input_variables) == 0 and all(isinstance(m, BaseMessage) for m in prompt.messages):
value = langchain_messages_to_chat_messages([cast(list[BaseMessage], prompt.messages)])
value = langchain_messages_to_chat_messages([cast("list[BaseMessage]", prompt.messages)])
else:
value = cast(dict, value.load_lc_prompt())
value = cast("dict", value.load_lc_prompt())
elif value.sender:
value = langchain_message_to_chat_message(value.to_lc_message())
else:
value = cast(dict, value.to_lc_document())
value = cast("dict", value.to_lc_document())
elif isinstance(value, Data):
value = cast(dict, value.to_lc_document())
value = cast("dict", value.to_lc_document())
return value
def get_langchain_callback(self) -> BaseCallbackHandler | None:

View file

@ -80,7 +80,7 @@ class Template(BaseModel):
if field is None:
msg = f"Field {field_name} not found in template {self.type_name}"
raise ValueError(msg)
return cast(Input, field)
return cast("Input", field)
def update_field(self, field_name: str, field: Input) -> None:
"""Updates the field with the given name."""