Fix import errors and type annotations

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-07 21:45:52 -03:00
commit a6b7b9d5a8
13 changed files with 87 additions and 43 deletions

View file

@ -57,7 +57,7 @@ def read_flows(
try:
auth_settings = settings_service.auth_settings
if auth_settings.AUTO_LOGIN:
flows: list[Flow] = session.exec(
flows = session.exec(
select(Flow).where(
(Flow.user_id == None) | (Flow.user_id == current_user.id) # noqa
)
@ -65,7 +65,7 @@ def read_flows(
else:
flows = current_user.flows
flows = validate_is_component(flows)
flows = validate_is_component(flows) # type: ignore
flow_ids = [flow.id for flow in flows]
# with the session get the flows that DO NOT have a user_id
try:
@ -77,7 +77,7 @@ def read_flows(
).all()
for example_flow in example_flows:
if example_flow.id not in flow_ids:
flows.append(example_flow)
flows.append(example_flow) # type: ignore
except Exception as e:
logger.error(e)
except Exception as e:

View file

@ -33,7 +33,7 @@ class ConversationChainComponent(CustomComponent):
chain = ConversationChain(llm=llm, memory=memory)
result = chain.invoke({"input": input_value})
if isinstance(result, dict):
result = result.get(chain.output_key)
result = result.get(chain.output_key, "") # type: ignore
elif isinstance(result, str):
result = result

View file

@ -32,8 +32,10 @@ class RunFlowComponent(CustomComponent):
},
}
def build_records_from_result_data(self, result_data: ResultData) -> Record:
def build_records_from_result_data(self, result_data: ResultData) -> List[Record]:
messages = result_data.messages
if not messages:
return []
records = []
for message in messages:
message_dict = (
@ -47,7 +49,7 @@ class RunFlowComponent(CustomComponent):
async def build(
self, input_value: Text, flow_name: str, tweaks: NestedDict
) -> Record:
) -> List[Record]:
results: List[Optional[ResultData]] = await self.run_flow(
input_value=input_value, flow_name=flow_name, tweaks=tweaks
)

View file

@ -1,18 +1,21 @@
from typing import TYPE_CHECKING, Any, List, Optional
from loguru import logger
from pydantic import BaseModel, Field
from langflow.graph.edge.utils import build_clean_params
from langflow.graph.schema import INPUT_FIELD_NAME
from langflow.services.deps import get_monitor_service
from langflow.services.monitor.utils import log_message
from loguru import logger
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
class SourceHandle(BaseModel):
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
baseClasses: List[str] = Field(
..., description="List of base classes for the source handle."
)
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
@ -20,7 +23,9 @@ class SourceHandle(BaseModel):
class TargetHandle(BaseModel):
fieldName: str = Field(..., description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
inputTypes: Optional[List[str]] = Field(
None, description="List of input types for the target handle."
)
type: str = Field(..., description="Type of the target handle.")
@ -49,16 +54,24 @@ class Edge:
def validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
self.valid_handles = (
self.target_handle.type in self.source_handle.baseClasses
)
else:
self.valid_handles = (
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
any(
baseClass in self.target_handle.inputTypes
for baseClass in self.source_handle.baseClasses
)
or self.target_handle.type in self.source_handle.baseClasses
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has invalid handles"
)
def __setstate__(self, state):
self.source_id = state["source_id"]
@ -75,7 +88,11 @@ class Edge:
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
self.valid = any(
output in target_req
for output in self.source_types
for target_req in self.target_reqs
)
# Get what type of input the target node is expecting
self.matched_type = next(
@ -86,7 +103,10 @@ class Edge:
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has no matched type"
)
def __repr__(self) -> str:
return (
@ -98,8 +118,13 @@ class Edge:
return hash(self.__repr__())
def __eq__(self, __o: object) -> bool:
# Create a better way to compare edges
return self._source_handle == __o._source_handle and self._target_handle == __o._target_handle
if not isinstance(__o, Edge):
return False
return (
self._source_handle == __o._source_handle
and self._target_handle == __o._target_handle
)
class ContractEdge(Edge):
@ -156,7 +181,9 @@ class ContractEdge(Edge):
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None):
def log_transaction(
edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None
):
try:
monitor_service = get_monitor_service()
clean_params = build_clean_params(target)

View file

@ -60,8 +60,8 @@ class Graph:
self._edges = self._graph_data["edges"]
self.inactivated_vertices: set = set()
self.activated_vertices: List[str] = []
self.vertices_layers = []
self.vertices_to_run = set()
self.vertices_layers: List[List[str]] = []
self.vertices_to_run: set[str] = set()
self.stop_vertex = None
self.inactive_vertices: set = set()
@ -197,9 +197,9 @@ class Graph:
self,
inputs: list[Dict[str, Union[str, list[str]]]],
outputs: list[str],
session_id: str,
stream: Optional[bool] = False,
) -> List[Optional["ResultData"]]:
session_id: Optional[str] = None,
stream: bool = False,
) -> List[List[Optional["ResultData"]]]:
"""Runs the graph with the given inputs."""
# inputs is {"message": "Hello, world!"}
@ -207,15 +207,16 @@ class Graph:
# of the vertices that are inputs
# if the value is a list, we need to run multiple times
vertex_outputs = []
if not isinstance(inputs_values, list):
inputs_values = [inputs_values]
for input_dict in inputs_values:
if not isinstance(inputs, list):
inputs = [inputs]
for input_dict in inputs:
components: list[str] = input_dict.get("components", [])
run_outputs = await self._run(
inputs={INPUT_FIELD_NAME: input_dict.get(INPUT_FIELD_NAME)},
input_components=input_dict.get("components", []),
inputs={INPUT_FIELD_NAME: input_dict.get(INPUT_FIELD_NAME, "")},
input_components=components,
outputs=outputs,
stream=stream,
session_id=session_id,
session_id=session_id or "",
)
logger.debug(f"Run outputs: {run_outputs}")
vertex_outputs.append(run_outputs)

View file

@ -396,7 +396,7 @@ class Vertex:
self._built = True
def extract_messages_from_artifacts(self, artifacts: Dict[str, Any]) -> List[str]:
def extract_messages_from_artifacts(self, artifacts: Dict[str, Any]) -> List[dict]:
"""
Extracts messages from the artifacts.

View file

@ -16,7 +16,7 @@ def docs_to_records(documents: list[Document]) -> list[Record]:
return [Record.from_document(document) for document in documents]
def records_to_text(template: str, records: list[Record]) -> list[str]:
def records_to_text(template: str, records: list[Record]) -> str:
"""
Converts a list of Records to a list of texts.

View file

@ -2,7 +2,7 @@ from datetime import datetime
from pathlib import Path
import orjson
from emoji import demojize, purely_emoji
from emoji import demojize, purely_emoji # type: ignore
from loguru import logger
from sqlmodel import select

View file

@ -77,6 +77,8 @@ class CustomComponent(Component):
_flows_records: Optional[List[Record]] = None
def update_state(self, name: str, value: Any):
if not self.vertex:
raise ValueError("Vertex is not set")
try:
self.vertex.graph.update_state(
name=name, record=value, caller=self.vertex.id
@ -85,6 +87,8 @@ class CustomComponent(Component):
raise ValueError(f"Error updating state: {e}")
def append_state(self, name: str, value: Any):
if not self.vertex:
raise ValueError("Vertex is not set")
try:
self.vertex.graph.append_state(
name=name, record=value, caller=self.vertex.id
@ -93,6 +97,8 @@ class CustomComponent(Component):
raise ValueError(f"Error appending state: {e}")
def get_state(self, name: str):
if not self.vertex:
raise ValueError("Vertex is not set")
try:
return self.vertex.graph.get_state(name=name)
except Exception as e:
@ -142,7 +148,7 @@ class CustomComponent(Component):
def update_build_config(
self,
build_config: dotdict,
field_name: str,
field_name: Optional[str],
field_value: Any,
):
build_config[field_name] = field_value
@ -173,6 +179,8 @@ class CustomComponent(Component):
ValueError: If the input data is not of a valid type or if the specified keys are not found in the data.
"""
if not keys:
keys = []
records = []
if not isinstance(data, Sequence):
data = [data]

View file

@ -429,7 +429,7 @@ def build_custom_components(components_paths: List[str]):
return {}
logger.info(f"Building custom components from {components_paths}")
custom_components_from_file = {}
custom_components_from_file: dict = {}
processed_paths = set()
for path in components_paths:
path_str = str(path)
@ -485,7 +485,9 @@ def update_field_dict(
def sanitize_field_config(field_config: Union[Dict, TemplateField]):
# If any of the already existing keys are in field_config, remove them
if isinstance(field_config, TemplateField):
field_config = field_config.to_dict()
field_dict = field_config.to_dict()
else:
field_dict = field_config
for key in [
"name",
"field_type",
@ -496,8 +498,8 @@ def sanitize_field_config(field_config: Union[Dict, TemplateField]):
"advanced",
"show",
]:
field_config.pop(key, None)
return field_config
field_dict.pop(key, None)
return field_dict
def build_component(component):

View file

@ -217,7 +217,7 @@ async def run_graph(
graph = Graph.from_payload(graph, flow_id=flow_id)
else:
graph_data = graph._graph_data
if not session_id and session_service is not None:
if session_id is None and session_service is not None:
session_id = session_service.generate_key(
session_id=flow_id, data_graph=graph_data
)
@ -226,9 +226,9 @@ async def run_graph(
outputs = await graph.run(
inputs,
outputs,
outputs or [],
stream=stream,
session_id=session_id,
session_id=session_id or "",
)
if session_id and session_service:
session_service.update_session(session_id, (graph, artifacts))
@ -236,7 +236,7 @@ async def run_graph(
def validate_input(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
graph_data: Dict[str, Any], tweaks: Union["Tweaks", Dict[str, Dict[str, Any]]]
) -> List[Dict[str, Any]]:
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
raise ValueError("graph_data and tweaks should be dictionaries")

View file

@ -4,7 +4,7 @@ from datetime import datetime
from typing import TYPE_CHECKING, Dict, Optional
from uuid import UUID, uuid4
from emoji import purely_emoji
from emoji import purely_emoji # type: ignore
from pydantic import field_serializer, field_validator
from sqlmodel import JSON, Column, Field, Relationship, SQLModel
@ -22,7 +22,9 @@ class FlowBase(SQLModel):
icon_bg_color: Optional[str] = Field(default=None, nullable=True)
data: Optional[Dict] = Field(default=None, nullable=True)
is_component: Optional[bool] = Field(default=False, nullable=True)
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, nullable=True)
updated_at: Optional[datetime] = Field(
default_factory=datetime.utcnow, nullable=True
)
folder: Optional[str] = Field(default=None, nullable=True)
@field_validator("icon_bg_color")

View file

@ -3,6 +3,8 @@ from typing import Any, Callable
class TaskBackend(ABC):
name: str
@abstractmethod
def launch_task(self, task_func: Callable[..., Any], *args: Any, **kwargs: Any):
pass