Fix import errors and type annotations
This commit is contained in:
parent
45d4610812
commit
a6b7b9d5a8
13 changed files with 87 additions and 43 deletions
|
|
@ -57,7 +57,7 @@ def read_flows(
|
|||
try:
|
||||
auth_settings = settings_service.auth_settings
|
||||
if auth_settings.AUTO_LOGIN:
|
||||
flows: list[Flow] = session.exec(
|
||||
flows = session.exec(
|
||||
select(Flow).where(
|
||||
(Flow.user_id == None) | (Flow.user_id == current_user.id) # noqa
|
||||
)
|
||||
|
|
@ -65,7 +65,7 @@ def read_flows(
|
|||
else:
|
||||
flows = current_user.flows
|
||||
|
||||
flows = validate_is_component(flows)
|
||||
flows = validate_is_component(flows) # type: ignore
|
||||
flow_ids = [flow.id for flow in flows]
|
||||
# with the session get the flows that DO NOT have a user_id
|
||||
try:
|
||||
|
|
@ -77,7 +77,7 @@ def read_flows(
|
|||
).all()
|
||||
for example_flow in example_flows:
|
||||
if example_flow.id not in flow_ids:
|
||||
flows.append(example_flow)
|
||||
flows.append(example_flow) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class ConversationChainComponent(CustomComponent):
|
|||
chain = ConversationChain(llm=llm, memory=memory)
|
||||
result = chain.invoke({"input": input_value})
|
||||
if isinstance(result, dict):
|
||||
result = result.get(chain.output_key)
|
||||
result = result.get(chain.output_key, "") # type: ignore
|
||||
|
||||
elif isinstance(result, str):
|
||||
result = result
|
||||
|
|
|
|||
|
|
@ -32,8 +32,10 @@ class RunFlowComponent(CustomComponent):
|
|||
},
|
||||
}
|
||||
|
||||
def build_records_from_result_data(self, result_data: ResultData) -> Record:
|
||||
def build_records_from_result_data(self, result_data: ResultData) -> List[Record]:
|
||||
messages = result_data.messages
|
||||
if not messages:
|
||||
return []
|
||||
records = []
|
||||
for message in messages:
|
||||
message_dict = (
|
||||
|
|
@ -47,7 +49,7 @@ class RunFlowComponent(CustomComponent):
|
|||
|
||||
async def build(
|
||||
self, input_value: Text, flow_name: str, tweaks: NestedDict
|
||||
) -> Record:
|
||||
) -> List[Record]:
|
||||
results: List[Optional[ResultData]] = await self.run_flow(
|
||||
input_value=input_value, flow_name=flow_name, tweaks=tweaks
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,21 @@
|
|||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langflow.graph.edge.utils import build_clean_params
|
||||
from langflow.graph.schema import INPUT_FIELD_NAME
|
||||
from langflow.services.deps import get_monitor_service
|
||||
from langflow.services.monitor.utils import log_message
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
|
||||
class SourceHandle(BaseModel):
|
||||
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
|
||||
baseClasses: List[str] = Field(
|
||||
..., description="List of base classes for the source handle."
|
||||
)
|
||||
dataType: str = Field(..., description="Data type for the source handle.")
|
||||
id: str = Field(..., description="Unique identifier for the source handle.")
|
||||
|
||||
|
|
@ -20,7 +23,9 @@ class SourceHandle(BaseModel):
|
|||
class TargetHandle(BaseModel):
|
||||
fieldName: str = Field(..., description="Field name for the target handle.")
|
||||
id: str = Field(..., description="Unique identifier for the target handle.")
|
||||
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
|
||||
inputTypes: Optional[List[str]] = Field(
|
||||
None, description="List of input types for the target handle."
|
||||
)
|
||||
type: str = Field(..., description="Type of the target handle.")
|
||||
|
||||
|
||||
|
|
@ -49,16 +54,24 @@ class Edge:
|
|||
|
||||
def validate_handles(self, source, target) -> None:
|
||||
if self.target_handle.inputTypes is None:
|
||||
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
|
||||
self.valid_handles = (
|
||||
self.target_handle.type in self.source_handle.baseClasses
|
||||
)
|
||||
else:
|
||||
self.valid_handles = (
|
||||
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
|
||||
any(
|
||||
baseClass in self.target_handle.inputTypes
|
||||
for baseClass in self.source_handle.baseClasses
|
||||
)
|
||||
or self.target_handle.type in self.source_handle.baseClasses
|
||||
)
|
||||
if not self.valid_handles:
|
||||
logger.debug(self.source_handle)
|
||||
logger.debug(self.target_handle)
|
||||
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
|
||||
raise ValueError(
|
||||
f"Edge between {source.vertex_type} and {target.vertex_type} "
|
||||
f"has invalid handles"
|
||||
)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.source_id = state["source_id"]
|
||||
|
|
@ -75,7 +88,11 @@ class Edge:
|
|||
# Both lists contain strings and sometimes a string contains the value we are
|
||||
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
|
||||
# so we need to check if any of the strings in source_types is in target_reqs
|
||||
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
|
||||
self.valid = any(
|
||||
output in target_req
|
||||
for output in self.source_types
|
||||
for target_req in self.target_reqs
|
||||
)
|
||||
# Get what type of input the target node is expecting
|
||||
|
||||
self.matched_type = next(
|
||||
|
|
@ -86,7 +103,10 @@ class Edge:
|
|||
if no_matched_type:
|
||||
logger.debug(self.source_types)
|
||||
logger.debug(self.target_reqs)
|
||||
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
|
||||
raise ValueError(
|
||||
f"Edge between {source.vertex_type} and {target.vertex_type} "
|
||||
f"has no matched type"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
|
@ -98,8 +118,13 @@ class Edge:
|
|||
return hash(self.__repr__())
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
# Create a better way to compare edges
|
||||
return self._source_handle == __o._source_handle and self._target_handle == __o._target_handle
|
||||
|
||||
if not isinstance(__o, Edge):
|
||||
return False
|
||||
return (
|
||||
self._source_handle == __o._source_handle
|
||||
and self._target_handle == __o._target_handle
|
||||
)
|
||||
|
||||
|
||||
class ContractEdge(Edge):
|
||||
|
|
@ -156,7 +181,9 @@ class ContractEdge(Edge):
|
|||
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
|
||||
|
||||
|
||||
def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None):
|
||||
def log_transaction(
|
||||
edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None
|
||||
):
|
||||
try:
|
||||
monitor_service = get_monitor_service()
|
||||
clean_params = build_clean_params(target)
|
||||
|
|
|
|||
|
|
@ -60,8 +60,8 @@ class Graph:
|
|||
self._edges = self._graph_data["edges"]
|
||||
self.inactivated_vertices: set = set()
|
||||
self.activated_vertices: List[str] = []
|
||||
self.vertices_layers = []
|
||||
self.vertices_to_run = set()
|
||||
self.vertices_layers: List[List[str]] = []
|
||||
self.vertices_to_run: set[str] = set()
|
||||
self.stop_vertex = None
|
||||
|
||||
self.inactive_vertices: set = set()
|
||||
|
|
@ -197,9 +197,9 @@ class Graph:
|
|||
self,
|
||||
inputs: list[Dict[str, Union[str, list[str]]]],
|
||||
outputs: list[str],
|
||||
session_id: str,
|
||||
stream: Optional[bool] = False,
|
||||
) -> List[Optional["ResultData"]]:
|
||||
session_id: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> List[List[Optional["ResultData"]]]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
|
||||
# inputs is {"message": "Hello, world!"}
|
||||
|
|
@ -207,15 +207,16 @@ class Graph:
|
|||
# of the vertices that are inputs
|
||||
# if the value is a list, we need to run multiple times
|
||||
vertex_outputs = []
|
||||
if not isinstance(inputs_values, list):
|
||||
inputs_values = [inputs_values]
|
||||
for input_dict in inputs_values:
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
for input_dict in inputs:
|
||||
components: list[str] = input_dict.get("components", [])
|
||||
run_outputs = await self._run(
|
||||
inputs={INPUT_FIELD_NAME: input_dict.get(INPUT_FIELD_NAME)},
|
||||
input_components=input_dict.get("components", []),
|
||||
inputs={INPUT_FIELD_NAME: input_dict.get(INPUT_FIELD_NAME, "")},
|
||||
input_components=components,
|
||||
outputs=outputs,
|
||||
stream=stream,
|
||||
session_id=session_id,
|
||||
session_id=session_id or "",
|
||||
)
|
||||
logger.debug(f"Run outputs: {run_outputs}")
|
||||
vertex_outputs.append(run_outputs)
|
||||
|
|
|
|||
|
|
@ -396,7 +396,7 @@ class Vertex:
|
|||
|
||||
self._built = True
|
||||
|
||||
def extract_messages_from_artifacts(self, artifacts: Dict[str, Any]) -> List[str]:
|
||||
def extract_messages_from_artifacts(self, artifacts: Dict[str, Any]) -> List[dict]:
|
||||
"""
|
||||
Extracts messages from the artifacts.
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ def docs_to_records(documents: list[Document]) -> list[Record]:
|
|||
return [Record.from_document(document) for document in documents]
|
||||
|
||||
|
||||
def records_to_text(template: str, records: list[Record]) -> list[str]:
|
||||
def records_to_text(template: str, records: list[Record]) -> str:
|
||||
"""
|
||||
Converts a list of Records to a list of texts.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
from emoji import demojize, purely_emoji
|
||||
from emoji import demojize, purely_emoji # type: ignore
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ class CustomComponent(Component):
|
|||
_flows_records: Optional[List[Record]] = None
|
||||
|
||||
def update_state(self, name: str, value: Any):
|
||||
if not self.vertex:
|
||||
raise ValueError("Vertex is not set")
|
||||
try:
|
||||
self.vertex.graph.update_state(
|
||||
name=name, record=value, caller=self.vertex.id
|
||||
|
|
@ -85,6 +87,8 @@ class CustomComponent(Component):
|
|||
raise ValueError(f"Error updating state: {e}")
|
||||
|
||||
def append_state(self, name: str, value: Any):
|
||||
if not self.vertex:
|
||||
raise ValueError("Vertex is not set")
|
||||
try:
|
||||
self.vertex.graph.append_state(
|
||||
name=name, record=value, caller=self.vertex.id
|
||||
|
|
@ -93,6 +97,8 @@ class CustomComponent(Component):
|
|||
raise ValueError(f"Error appending state: {e}")
|
||||
|
||||
def get_state(self, name: str):
|
||||
if not self.vertex:
|
||||
raise ValueError("Vertex is not set")
|
||||
try:
|
||||
return self.vertex.graph.get_state(name=name)
|
||||
except Exception as e:
|
||||
|
|
@ -142,7 +148,7 @@ class CustomComponent(Component):
|
|||
def update_build_config(
|
||||
self,
|
||||
build_config: dotdict,
|
||||
field_name: str,
|
||||
field_name: Optional[str],
|
||||
field_value: Any,
|
||||
):
|
||||
build_config[field_name] = field_value
|
||||
|
|
@ -173,6 +179,8 @@ class CustomComponent(Component):
|
|||
ValueError: If the input data is not of a valid type or if the specified keys are not found in the data.
|
||||
|
||||
"""
|
||||
if not keys:
|
||||
keys = []
|
||||
records = []
|
||||
if not isinstance(data, Sequence):
|
||||
data = [data]
|
||||
|
|
|
|||
|
|
@ -429,7 +429,7 @@ def build_custom_components(components_paths: List[str]):
|
|||
return {}
|
||||
|
||||
logger.info(f"Building custom components from {components_paths}")
|
||||
custom_components_from_file = {}
|
||||
custom_components_from_file: dict = {}
|
||||
processed_paths = set()
|
||||
for path in components_paths:
|
||||
path_str = str(path)
|
||||
|
|
@ -485,7 +485,9 @@ def update_field_dict(
|
|||
def sanitize_field_config(field_config: Union[Dict, TemplateField]):
|
||||
# If any of the already existing keys are in field_config, remove them
|
||||
if isinstance(field_config, TemplateField):
|
||||
field_config = field_config.to_dict()
|
||||
field_dict = field_config.to_dict()
|
||||
else:
|
||||
field_dict = field_config
|
||||
for key in [
|
||||
"name",
|
||||
"field_type",
|
||||
|
|
@ -496,8 +498,8 @@ def sanitize_field_config(field_config: Union[Dict, TemplateField]):
|
|||
"advanced",
|
||||
"show",
|
||||
]:
|
||||
field_config.pop(key, None)
|
||||
return field_config
|
||||
field_dict.pop(key, None)
|
||||
return field_dict
|
||||
|
||||
|
||||
def build_component(component):
|
||||
|
|
|
|||
|
|
@ -217,7 +217,7 @@ async def run_graph(
|
|||
graph = Graph.from_payload(graph, flow_id=flow_id)
|
||||
else:
|
||||
graph_data = graph._graph_data
|
||||
if not session_id and session_service is not None:
|
||||
if session_id is None and session_service is not None:
|
||||
session_id = session_service.generate_key(
|
||||
session_id=flow_id, data_graph=graph_data
|
||||
)
|
||||
|
|
@ -226,9 +226,9 @@ async def run_graph(
|
|||
|
||||
outputs = await graph.run(
|
||||
inputs,
|
||||
outputs,
|
||||
outputs or [],
|
||||
stream=stream,
|
||||
session_id=session_id,
|
||||
session_id=session_id or "",
|
||||
)
|
||||
if session_id and session_service:
|
||||
session_service.update_session(session_id, (graph, artifacts))
|
||||
|
|
@ -236,7 +236,7 @@ async def run_graph(
|
|||
|
||||
|
||||
def validate_input(
|
||||
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
|
||||
graph_data: Dict[str, Any], tweaks: Union["Tweaks", Dict[str, Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
|
||||
raise ValueError("graph_data and tweaks should be dictionaries")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from datetime import datetime
|
|||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from emoji import purely_emoji
|
||||
from emoji import purely_emoji # type: ignore
|
||||
from pydantic import field_serializer, field_validator
|
||||
from sqlmodel import JSON, Column, Field, Relationship, SQLModel
|
||||
|
||||
|
|
@ -22,7 +22,9 @@ class FlowBase(SQLModel):
|
|||
icon_bg_color: Optional[str] = Field(default=None, nullable=True)
|
||||
data: Optional[Dict] = Field(default=None, nullable=True)
|
||||
is_component: Optional[bool] = Field(default=False, nullable=True)
|
||||
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, nullable=True)
|
||||
updated_at: Optional[datetime] = Field(
|
||||
default_factory=datetime.utcnow, nullable=True
|
||||
)
|
||||
folder: Optional[str] = Field(default=None, nullable=True)
|
||||
|
||||
@field_validator("icon_bg_color")
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import Any, Callable
|
|||
|
||||
|
||||
class TaskBackend(ABC):
|
||||
name: str
|
||||
|
||||
@abstractmethod
|
||||
def launch_task(self, task_func: Callable[..., Any], *args: Any, **kwargs: Any):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue