From f86f6e6281796c3bdc1c69efce7b4198ceff6419 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 28 Feb 2024 20:15:23 -0300 Subject: [PATCH] Fix import errors and type annotations --- src/backend/langflow/api/utils.py | 6 +- src/backend/langflow/api/v1/callback.py | 125 +++--------------- src/backend/langflow/api/v1/endpoints.py | 12 +- .../langflow/components/chains/RetrievalQA.py | 2 +- .../documentloaders/GatherRecords.py | 2 +- .../langflow/components/io/base/chat.py | 13 +- .../model_specs/ChatLiteLLMSpecs.py | 1 + .../components/utilities/RunnableExecutor.py | 2 +- src/backend/langflow/graph/graph/base.py | 14 +- src/backend/langflow/graph/schema.py | 14 +- src/backend/langflow/graph/vertex/base.py | 27 +--- .../langflow/interface/custom/attributes.py | 3 +- .../custom_component/custom_component.py | 16 ++- src/backend/langflow/interface/run.py | 7 +- src/backend/langflow/processing/base.py | 21 +-- src/backend/langflow/processing/process.py | 59 +-------- .../langflow/services/monitor/service.py | 5 +- .../langflow/services/monitor/utils.py | 7 +- .../langflow/services/session/service.py | 4 +- .../langflow/services/settings/manager.py | 2 +- .../langflow/services/socket/service.py | 2 +- src/backend/langflow/services/socket/utils.py | 2 +- src/backend/langflow/services/storage/s3.py | 26 +++- src/backend/langflow/utils/schemas.py | 11 ++ 24 files changed, 138 insertions(+), 245 deletions(-) diff --git a/src/backend/langflow/api/utils.py b/src/backend/langflow/api/utils.py index 4db323537..222bb2748 100644 --- a/src/backend/langflow/api/utils.py +++ b/src/backend/langflow/api/utils.py @@ -222,7 +222,7 @@ def build_and_cache_graph( graph: Optional[Graph] = None, ): """Build and cache the graph.""" - flow: Flow = session.get(Flow, flow_id) + flow: Optional[Flow] = session.get(Flow, flow_id) if not flow or not flow.data: raise ValueError("Invalid flow ID") other_graph = Graph.from_payload(flow.data, flow_id) @@ -236,10 +236,12 @@ def build_and_cache_graph( def format_syntax_error_message(exc: SyntaxError) -> str: """Format a SyntaxError message for returning to the frontend.""" + if exc.text is None: + return f"Syntax error in code. Error on line {exc.lineno}" return f"Syntax error in code. Error on line {exc.lineno}: {exc.text.strip()}" -def get_causing_exception(exc: Exception) -> Exception: +def get_causing_exception(exc: BaseException) -> BaseException: """Get the causing exception from an exception.""" if hasattr(exc, "__cause__") and exc.__cause__: return get_causing_exception(exc.__cause__) diff --git a/src/backend/langflow/api/v1/callback.py b/src/backend/langflow/api/v1/callback.py index 0a5e7f743..58e01a57a 100644 --- a/src/backend/langflow/api/v1/callback.py +++ b/src/backend/langflow/api/v1/callback.py @@ -4,117 +4,16 @@ from uuid import UUID from langchain.schema import AgentAction, AgentFinish from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler +from loguru import logger + from langflow.api.v1.schemas import ChatResponse, PromptResponse from langflow.services.deps import get_chat_service from langflow.utils.util import remove_ansi_escape_codes -from loguru import logger if TYPE_CHECKING: from langflow.services.socket.service import SocketIOService -class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): - """Callback handler for streaming LLM responses.""" - - def __init__(self, client_id: str = None): - self.chat_service = get_chat_service() - self.client_id = client_id - self.websocket = self.chat_service.active_connections[self.client_id] - - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - resp = ChatResponse(message=token, type="stream", intermediate_steps="") - await self.websocket.send_json(resp.model_dump()) - - async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any: - """Run when tool starts running.""" - resp = ChatResponse( - message="", - type="stream", - intermediate_steps=f"Tool input: {input_str}", - ) - await self.websocket.send_json(resp.model_dump()) - - async def on_tool_end(self, output: str, **kwargs: Any) -> Any: - """Run when tool ends running.""" - observation_prefix = kwargs.get("observation_prefix", "Tool output: ") - split_output = output.split() - first_word = split_output[0] - rest_of_output = split_output[1:] - # Create a formatted message. - intermediate_steps = f"{observation_prefix}{first_word}" - - # Create a ChatResponse instance. - resp = ChatResponse( - message="", - type="stream", - intermediate_steps=intermediate_steps, - ) - rest_of_resps = [ - ChatResponse( - message="", - type="stream", - intermediate_steps=f"{word}", - ) - for word in rest_of_output - ] - resps = [resp] + rest_of_resps - # Try to send the response, handle potential errors. - - try: - # This is to emulate the stream of tokens - for resp in resps: - await self.websocket.send_json(resp.model_dump()) - except Exception as exc: - logger.error(f"Error sending response: {exc}") - - async def on_tool_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when tool errors.""" - - async def on_text(self, text: str, **kwargs: Any) -> Any: - """Run on arbitrary text.""" - # This runs when first sending the prompt - # to the LLM, adding it will send the final prompt - # to the frontend - if "Prompt after formatting" in text: - text = text.replace("Prompt after formatting:\n", "") - text = remove_ansi_escape_codes(text) - resp = PromptResponse( - prompt=text, - ) - await self.websocket.send_json(resp.model_dump()) - self.chat_service.chat_history.add_message(self.client_id, resp) - - async def on_agent_action(self, action: AgentAction, **kwargs: Any): - log = f"Thought: {action.log}" - # if there are line breaks, split them and send them - # as separate messages - if "\n" in log: - logs = log.split("\n") - for log in logs: - resp = ChatResponse(message="", type="stream", intermediate_steps=log) - await self.websocket.send_json(resp.model_dump()) - else: - resp = ChatResponse(message="", type="stream", intermediate_steps=log) - await self.websocket.send_json(resp.model_dump()) - - async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run on agent end.""" - resp = ChatResponse( - message="", - type="stream", - intermediate_steps=finish.log, - ) - await self.websocket.send_json(resp.model_dump()) - - # https://github.com/hwchase17/chat-langchain/blob/master/callback.py class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler): """Callback handler for streaming LLM responses.""" @@ -130,7 +29,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler): resp = ChatResponse(message=token, type="stream", intermediate_steps="") await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) - async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any: + async def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> Any: """Run when tool starts running.""" resp = ChatResponse( message="", @@ -168,7 +69,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler): try: # This is to emulate the stream of tokens for resp in resps: - await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) + await self.socketio_service.emit_token( + to=self.sid, data=resp.model_dump() + ) except Exception as exc: logger.error(f"Error sending response: {exc}") @@ -194,7 +97,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler): resp = PromptResponse( prompt=text, ) - await self.socketio_service.emit_message(to=self.sid, data=resp.model_dump()) + await self.socketio_service.emit_message( + to=self.sid, data=resp.model_dump() + ) self.chat_service.chat_history.add_message(self.client_id, resp) async def on_agent_action(self, action: AgentAction, **kwargs: Any): @@ -205,7 +110,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler): logs = log.split("\n") for log in logs: resp = ChatResponse(message="", type="stream", intermediate_steps=log) - await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) + await self.socketio_service.emit_token( + to=self.sid, data=resp.model_dump() + ) else: resp = ChatResponse(message="", type="stream", intermediate_steps=log) await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) @@ -232,5 +139,7 @@ class StreamingLLMCallbackHandler(BaseCallbackHandler): resp = ChatResponse(message=token, type="stream", intermediate_steps="") loop = asyncio.get_event_loop() - coroutine = self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) + coroutine = self.socketio_service.emit_token( + to=self.sid, data=resp.model_dump() + ) asyncio.run_coroutine_threadsafe(coroutine, loop) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index d972c1b27..7030639eb 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -9,6 +9,7 @@ from sqlmodel import Session, select from langflow.api.utils import update_frontend_node_with_template_values from langflow.api.v1.schemas import ( CustomComponentCode, + InputValueRequest, ProcessResponse, RunResponse, TaskStatusResponse, @@ -54,7 +55,7 @@ def get_all( async def run_flow_with_caching( session: Annotated[Session, Depends(get_session)], flow_id: str, - inputs: Optional[Union[List[dict], dict]] = None, + inputs: Optional[InputValueRequest] = None, tweaks: Optional[dict] = None, stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821 session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 @@ -62,6 +63,11 @@ async def run_flow_with_caching( session_service: SessionService = Depends(get_session_service), ): try: + if inputs is not None: + input_values_dict: dict[str, Union[str, list[str]]] = inputs.model_dump() + else: + input_values_dict = {} + if session_id: session_data = await session_service.load_session( session_id, flow_id=flow_id @@ -74,7 +80,7 @@ async def run_flow_with_caching( graph=graph, flow_id=flow_id, session_id=session_id, - inputs=inputs, + inputs=input_values_dict, artifacts=artifacts, session_service=session_service, stream=stream, @@ -99,7 +105,7 @@ async def run_flow_with_caching( graph=graph_data, flow_id=flow_id, session_id=session_id, - inputs=inputs, + inputs=input_values_dict, artifacts={}, session_service=session_service, stream=stream, diff --git a/src/backend/langflow/components/chains/RetrievalQA.py b/src/backend/langflow/components/chains/RetrievalQA.py index 567c62e93..30647641d 100644 --- a/src/backend/langflow/components/chains/RetrievalQA.py +++ b/src/backend/langflow/components/chains/RetrievalQA.py @@ -59,4 +59,4 @@ class RetrievalQAComponent(CustomComponent): final_result = "\n".join([str(result_str), references_str]) self.status = final_result - return final_result + return final_result # OK diff --git a/src/backend/langflow/components/documentloaders/GatherRecords.py b/src/backend/langflow/components/documentloaders/GatherRecords.py index dd7f86596..5cd6af317 100644 --- a/src/backend/langflow/components/documentloaders/GatherRecords.py +++ b/src/backend/langflow/components/documentloaders/GatherRecords.py @@ -102,7 +102,7 @@ class GatherRecordsComponent(CustomComponent): silent_errors: bool, max_concurrency: int, use_multithreading: bool, - ) -> List[Record]: + ) -> List[Optional[Record]]: if use_multithreading: records = self.parallel_load_records( file_paths, silent_errors, max_concurrency diff --git a/src/backend/langflow/components/io/base/chat.py b/src/backend/langflow/components/io/base/chat.py index 4d60f6bac..e5ff38249 100644 --- a/src/backend/langflow/components/io/base/chat.py +++ b/src/backend/langflow/components/io/base/chat.py @@ -79,6 +79,7 @@ class ChatComponent(CustomComponent): session_id: Optional[str] = None, return_record: Optional[bool] = False, ) -> Union[Text, Record]: + input_value_record: Optional[Record] = None if return_record: if isinstance(input_value, Record): # Update the data of the record @@ -86,7 +87,7 @@ class ChatComponent(CustomComponent): input_value.data["sender_name"] = sender_name input_value.data["session_id"] = session_id else: - input_value = Record( + input_value_record = Record( text=input_value, data={ "sender": sender, @@ -96,7 +97,11 @@ class ChatComponent(CustomComponent): ) if not input_value: input_value = "" - self.status = input_value + if return_record and input_value_record: + result = input_value_record + else: + result = input_value + self.status = result if session_id: - self.store_message(input_value, session_id, sender, sender_name) - return input_value + self.store_message(result, session_id, sender, sender_name) + return result diff --git a/src/backend/langflow/components/model_specs/ChatLiteLLMSpecs.py b/src/backend/langflow/components/model_specs/ChatLiteLLMSpecs.py index 320b2b10b..2aea38475 100644 --- a/src/backend/langflow/components/model_specs/ChatLiteLLMSpecs.py +++ b/src/backend/langflow/components/model_specs/ChatLiteLLMSpecs.py @@ -150,6 +150,7 @@ class ChatLiteLLMComponent(CustomComponent): LLM = ChatLiteLLM( model=model, + client=None, streaming=streaming, temperature=temperature, model_kwargs=model_kwargs if model_kwargs is not None else {}, diff --git a/src/backend/langflow/components/utilities/RunnableExecutor.py b/src/backend/langflow/components/utilities/RunnableExecutor.py index 502e1eec6..d5c54c69b 100644 --- a/src/backend/langflow/components/utilities/RunnableExecutor.py +++ b/src/backend/langflow/components/utilities/RunnableExecutor.py @@ -36,7 +36,7 @@ class RunnableExecComponent(CustomComponent): runnable: Runnable, output_key: str = "output", ) -> Text: - result = runnable.invoke({input_key: inputs}) + result = runnable.invoke({input_key: input_value}) result = result.get(output_key) self.status = result return result diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 0f715690a..d78b57f01 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -52,7 +52,9 @@ class Graph: self._vertices = self._graph_data["nodes"] self._edges = self._graph_data["edges"] - self.inactive_vertices = set() + self.inactive_vertices: set = set() + self.edges: List[ContractEdge] = [] + self.vertices: List[Vertex] = [] self._build_graph() self.build_graph_maps() self.define_vertices_lists() @@ -100,7 +102,7 @@ class Graph: async def run( self, inputs: Dict[str, Union[str, list[str]]], stream: bool - ) -> List["ResultData"]: + ) -> List[Optional["ResultData"]]: """Runs the graph with the given inputs.""" # inputs is {"message": "Hello, world!"} @@ -108,7 +110,7 @@ class Graph: # of the vertices that are inputs # if the value is a list, we need to run multiple times outputs = [] - inputs_values = inputs.get(INPUT_FIELD_NAME) + inputs_values = inputs.get(INPUT_FIELD_NAME, "") if not isinstance(inputs_values, list): inputs_values = [inputs_values] for input_value in inputs_values: @@ -245,7 +247,7 @@ class Graph: return False return True - def update(self, other: "Graph") -> None: + def update(self, other: "Graph") -> "Graph": # Existing vertices in self graph existing_vertex_ids = set(vertex.id for vertex in self.vertices) # Vertex IDs in the other graph @@ -274,7 +276,7 @@ class Graph: if not self_vertex.pinned: self_vertex._built = False self_vertex.result = None - self_vertex.artifacts = None + self_vertex.artifacts = {} self_vertex.set_top_level(self.top_level_vertices) self.reset_all_edges_of_vertex(self_vertex) @@ -623,7 +625,7 @@ class Graph: queue = deque( vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0 ) - layers = [] + layers: List[List[str]] = [] current_layer = 0 while queue: diff --git a/src/backend/langflow/graph/schema.py b/src/backend/langflow/graph/schema.py index 028b8db9f..f53a0833f 100644 --- a/src/backend/langflow/graph/schema.py +++ b/src/backend/langflow/graph/schema.py @@ -1,9 +1,11 @@ from enum import Enum from typing import Any, Optional -from langflow.graph.utils import serialize_field from pydantic import BaseModel, Field, field_serializer +from langflow.graph.utils import serialize_field +from langflow.utils.schemas import ContainsEnumMeta + class ResultData(BaseModel): results: Optional[Any] = Field(default_factory=dict) @@ -18,7 +20,7 @@ class ResultData(BaseModel): return serialize_field(value) -class InterfaceComponentTypes(str, Enum): +class InterfaceComponentTypes(str, Enum, metaclass=ContainsEnumMeta): # ChatInput and ChatOutput are the only ones that are # power components ChatInput = "ChatInput" @@ -26,6 +28,14 @@ class InterfaceComponentTypes(str, Enum): TextInput = "TextInput" TextOutput = "TextOutput" + def __contains__(cls, item): + try: + cls(item) + except ValueError: + return False + else: + return True + INPUT_COMPONENTS = [ InterfaceComponentTypes.ChatInput, diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index fe2b29a02..3e9ffd64e 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -77,7 +77,7 @@ class Vertex: self.should_run = True self.result: Optional[ResultData] = None try: - self.is_interface_component = InterfaceComponentTypes(self.vertex_type) + self.is_interface_component = self.vertex_type in InterfaceComponentTypes except ValueError: self.is_interface_component = False @@ -107,29 +107,6 @@ class Vertex: def add_build_time(self, time): self.build_times.append(time) - # Build a result dict for each edge - # like so: {edge.target.id: {edge.target_param: self._built_object}} - async def get_result_dict(self, force: bool = False) -> Dict[str, Dict[str, Any]]: - """ - Returns a dictionary with the result of the build process. - """ - edge_results = {} - for edge in self.edges: - target = self.graph.get_vertex(edge.target_id) - if edge.is_fulfilled and isinstance( - await edge.get_result( - source=self, - target=target, - ), - str, - ): - if edge.target_id not in edge_results: - edge_results[edge.target_id] = {} - edge_results[edge.target_id][edge.target_param] = await edge.get_result( - source=self, target=target - ) - return edge_results - def set_result(self, result: ResultData) -> None: self.result = result @@ -626,7 +603,7 @@ class Vertex: return self.get_requester_result(requester) self._reset() - if self.is_input: + if self.is_input and inputs is not None: self.update_raw_params(inputs) # Run steps diff --git a/src/backend/langflow/interface/custom/attributes.py b/src/backend/langflow/interface/custom/attributes.py index d3119cd3d..9b91af43c 100644 --- a/src/backend/langflow/interface/custom/attributes.py +++ b/src/backend/langflow/interface/custom/attributes.py @@ -1,4 +1,5 @@ import warnings +from typing import Callable import emoji @@ -30,7 +31,7 @@ def getattr_return_bool(value): return value -ATTR_FUNC_MAPPING = { +ATTR_FUNC_MAPPING: dict[str, Callable] = { "display_name": getattr_return_str, "description": getattr_return_str, "beta": getattr_return_bool, diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index a8c81f041..1c3cabb42 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -35,6 +35,7 @@ from langflow.utils import validate if TYPE_CHECKING: from langflow.graph.edge.base import ContractEdge + from langflow.graph.graph.base import Graph from langflow.graph.vertex.base import Vertex @@ -292,8 +293,9 @@ class CustomComponent(Component): def get_function(self): return validate.create_function(self.code, self.function_entrypoint_name) - async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any: - from langflow.processing.process import build_sorted_vertices, process_tweaks + async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph": + from langflow.graph.graph.base import Graph + from langflow.processing.process import process_tweaks db_service = get_db_service() with session_getter(db_service) as session: @@ -302,7 +304,15 @@ class CustomComponent(Component): raise ValueError(f"Flow {flow_id} not found") if tweaks: graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks) - return await build_sorted_vertices(graph_data, self.user_id) + graph = Graph(**graph_data) + return graph + + async def run_flow( + self, input_value: str, flow_id: str, tweaks: Optional[dict] = None + ) -> Any: + graph = await self.load_flow(flow_id, tweaks) + input_value_dict = {"input_value": input_value} + return await graph.run(input_value_dict) def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]: if not self._user_id: diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index b078fda04..aec602ac7 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -1,14 +1,11 @@ -from typing import Dict, Optional, Tuple, Union -from uuid import UUID +from typing import Dict, Tuple from loguru import logger from langflow.graph import Graph -async def build_sorted_vertices( - data_graph, flow_id: Optional[Union[str, UUID]] = None -) -> Tuple[Graph, Dict]: +async def build_sorted_vertices(data_graph, flow_id: str) -> Tuple[Graph, Dict]: """ Build langchain object from data_graph. """ diff --git a/src/backend/langflow/processing/base.py b/src/backend/langflow/processing/base.py index 26e0f396e..f61c22743 100644 --- a/src/backend/langflow/processing/base.py +++ b/src/backend/langflow/processing/base.py @@ -4,7 +4,7 @@ from langchain.agents.agent import AgentExecutor from langchain.callbacks.base import BaseCallbackHandler from loguru import logger -from langflow.api.v1.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler +from langflow.api.v1.callback import StreamingLLMCallbackHandler from langflow.processing.process import fix_memory_inputs, format_actions from langflow.services.deps import get_plugins_service @@ -15,10 +15,7 @@ if TYPE_CHECKING: def setup_callbacks(sync, trace_id, **kwargs): """Setup callbacks for langchain object""" callbacks = [] - if sync: - callbacks.append(StreamingLLMCallbackHandler(**kwargs)) - else: - callbacks.append(AsyncStreamingLLMCallbackHandler(**kwargs)) + callbacks.append(StreamingLLMCallbackHandler(**kwargs)) plugin_service = get_plugins_service() plugin_callbacks = plugin_service.get_callbacks(_id=trace_id) @@ -42,7 +39,9 @@ def get_langfuse_callback(trace_id): return None -def flush_langfuse_callback_if_present(callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]]): +def flush_langfuse_callback_if_present( + callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]] +): """ If langfuse callback is present, run callback.langfuse.flush() """ @@ -83,9 +82,15 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa # if langfuse callback is present, run callback.langfuse.flush() flush_langfuse_callback_if_present(callbacks) - intermediate_steps = output.get("intermediate_steps", []) if isinstance(output, dict) else [] + intermediate_steps = ( + output.get("intermediate_steps", []) if isinstance(output, dict) else [] + ) - result = output.get(langchain_object.output_keys[0]) if isinstance(output, dict) else output + result = ( + output.get(langchain_object.output_keys[0]) + if isinstance(output, dict) + else output + ) try: thought = format_actions(intermediate_steps) if intermediate_steps else "" except Exception as exc: diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index fbb3986ab..1b7bc7c81 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -13,12 +13,7 @@ from pydantic import BaseModel from langflow.graph.graph.base import Graph from langflow.graph.vertex.base import Vertex from langflow.interface.custom.custom_component import CustomComponent -from langflow.interface.run import ( - build_sorted_vertices, - get_memory_key, - update_memory_keys, -) -from langflow.services.deps import get_session_service +from langflow.interface.run import get_memory_key, update_memory_keys from langflow.services.session.service import SessionService @@ -203,62 +198,12 @@ class Result(BaseModel): session_id: str -async def process_graph_cached( - data_graph: Dict[str, Any], - inputs: Optional[Union[dict, List[dict]]] = None, - clear_cache=False, - session_id=None, -) -> Result: - session_service = get_session_service() - if clear_cache: - session_service.clear_session(session_id) - if session_id is None: - session_id = session_service.generate_key( - session_id=session_id, data_graph=data_graph - ) - # Load the graph using SessionService - session = await session_service.load_session( - session_id, data_graph, flow_id=flow_id - ) - graph, artifacts = session if session else (None, None) - if not graph: - raise ValueError("Graph not found in the session") - - result = await build_graph_and_generate_result( - graph=graph, - session_id=session_id, - inputs=inputs, - artifacts=artifacts, - session_service=session_service, - ) - - return result - - -async def build_graph_and_generate_result( - graph: "Graph", - session_id: str, - inputs: Optional[Union[dict, List[dict]]] = None, - artifacts: Optional[Dict[str, Any]] = None, - session_service: Optional[SessionService] = None, -): - """Build the graph and generate the result""" - built_object = await graph.build() - processed_inputs = process_inputs(inputs, artifacts or {}) - result = await generate_result(built_object, processed_inputs) - # langchain_object is now updated with the new memory - # we need to update the cache with the updated langchain_object - if session_id and session_service: - session_service.update_session(session_id, (graph, artifacts)) - return Result(result=result, session_id=session_id) - - async def run_graph( graph: Union["Graph", dict], flow_id: str, stream: bool, session_id: Optional[str] = None, - inputs: Optional[Union[dict, List[dict]]] = None, + inputs: Optional[dict[str, Union[List[str], str]]] = None, artifacts: Optional[Dict[str, Any]] = None, session_service: Optional[SessionService] = None, ): diff --git a/src/backend/langflow/services/monitor/service.py b/src/backend/langflow/services/monitor/service.py index 25ff65fd8..c8580158d 100644 --- a/src/backend/langflow/services/monitor/service.py +++ b/src/backend/langflow/services/monitor/service.py @@ -1,11 +1,10 @@ from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Optional, Type, Union +from typing import TYPE_CHECKING, Optional, Union import duckdb from loguru import logger from platformdirs import user_cache_dir -from pydantic import BaseModel from langflow.services.base import Service from langflow.services.monitor.schema import ( @@ -56,7 +55,7 @@ class MonitorService(Service): ): # Make sure the model passed matches the table - model: Type[BaseModel] = self.table_map.get(table_name) + model = self.table_map.get(table_name) if model is None: raise ValueError(f"Unknown table name: {table_name}") diff --git a/src/backend/langflow/services/monitor/utils.py b/src/backend/langflow/services/monitor/utils.py index d308e7653..1ff1db24a 100644 --- a/src/backend/langflow/services/monitor/utils.py +++ b/src/backend/langflow/services/monitor/utils.py @@ -23,7 +23,10 @@ def get_table_schema_as_dict(conn: duckdb.DuckDBPyConnection, table_name: str) - def model_to_sql_column_definitions(model: Type[BaseModel]) -> dict: columns = {} for field_name, field_type in model.model_fields.items(): - if hasattr(field_type.annotation, "__args__"): + if ( + hasattr(field_type.annotation, "__args__") + and field_type.annotation is not None + ): field_args = field_type.annotation.__args__ else: field_args = [] @@ -82,7 +85,7 @@ def drop_and_create_table_if_schema_mismatch( def add_row_to_table( conn: duckdb.DuckDBPyConnection, table_name: str, - model: Type[BaseModel], + model: Type, monitor_data: Union[Dict[str, Any], BaseModel], ): # Validate the data with the Pydantic model diff --git a/src/backend/langflow/services/session/service.py b/src/backend/langflow/services/session/service.py index 914ca7a3a..68fec0430 100644 --- a/src/backend/langflow/services/session/service.py +++ b/src/backend/langflow/services/session/service.py @@ -14,9 +14,7 @@ class SessionService(Service): def __init__(self, cache_service): self.cache_service: "BaseCacheService" = cache_service - async def load_session( - self, key, data_graph: Optional[dict] = None, flow_id: Optional[str] = None - ): + async def load_session(self, key, flow_id: str, data_graph: Optional[dict] = None): # Check if the data is cached if key in self.cache_service: return self.cache_service.get(key) diff --git a/src/backend/langflow/services/settings/manager.py b/src/backend/langflow/services/settings/manager.py index b4812058f..804328860 100644 --- a/src/backend/langflow/services/settings/manager.py +++ b/src/backend/langflow/services/settings/manager.py @@ -30,7 +30,7 @@ class SettingsService(Service): settings_dict = {k.upper(): v for k, v in settings_dict.items()} for key in settings_dict: - if key not in Settings.model_fields().keys(): + if key not in Settings.model_fields.keys(): raise KeyError(f"Key {key} not found in settings") logger.debug( f"Loading {len(settings_dict[key])} {key} from {file_path}" diff --git a/src/backend/langflow/services/socket/service.py b/src/backend/langflow/services/socket/service.py index b3ae2b08a..fd2c236d5 100644 --- a/src/backend/langflow/services/socket/service.py +++ b/src/backend/langflow/services/socket/service.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any -import socketio +import socketio # type: ignore from loguru import logger from langflow.services.base import Service diff --git a/src/backend/langflow/services/socket/utils.py b/src/backend/langflow/services/socket/utils.py index 48208403a..a819fb15d 100644 --- a/src/backend/langflow/services/socket/utils.py +++ b/src/backend/langflow/services/socket/utils.py @@ -1,7 +1,7 @@ import time from typing import Callable -import socketio +import socketio # type: ignore from sqlmodel import select from langflow.api.utils import format_elapsed_time diff --git a/src/backend/langflow/services/storage/s3.py b/src/backend/langflow/services/storage/s3.py index afd39760a..f509052f3 100644 --- a/src/backend/langflow/services/storage/s3.py +++ b/src/backend/langflow/services/storage/s3.py @@ -1,5 +1,5 @@ -import boto3 -from botocore.exceptions import ClientError, NoCredentialsError +import boto3 # type: ignore +from botocore.exceptions import ClientError, NoCredentialsError # type: ignore from loguru import logger from .service import StorageService @@ -25,7 +25,9 @@ class S3StorageService(StorageService): :raises Exception: If an error occurs during file saving. """ try: - self.s3_client.put_object(Bucket=self.bucket, Key=f"{folder}/{file_name}", Body=data) + self.s3_client.put_object( + Bucket=self.bucket, Key=f"{folder}/{file_name}", Body=data + ) logger.info(f"File {file_name} saved successfully in folder {folder}.") except NoCredentialsError: logger.error("Credentials not available for AWS S3.") @@ -44,8 +46,12 @@ class S3StorageService(StorageService): :raises Exception: If an error occurs during file retrieval. """ try: - response = self.s3_client.get_object(Bucket=self.bucket, Key=f"{folder}/{file_name}") - logger.info(f"File {file_name} retrieved successfully from folder {folder}.") + response = self.s3_client.get_object( + Bucket=self.bucket, Key=f"{folder}/{file_name}" + ) + logger.info( + f"File {file_name} retrieved successfully from folder {folder}." + ) return response["Body"].read() except ClientError as e: logger.error(f"Error retrieving file {file_name} from folder {folder}: {e}") @@ -61,7 +67,11 @@ class S3StorageService(StorageService): """ try: response = self.s3_client.list_objects_v2(Bucket=self.bucket, Prefix=folder) - files = [item["Key"] for item in response.get("Contents", []) if "/" not in item["Key"][len(folder) :]] + files = [ + item["Key"] + for item in response.get("Contents", []) + if "/" not in item["Key"][len(folder) :] + ] logger.info(f"{len(files)} files listed in folder {folder}.") return files except ClientError as e: @@ -77,7 +87,9 @@ class S3StorageService(StorageService): :raises Exception: If an error occurs during file deletion. """ try: - self.s3_client.delete_object(Bucket=self.bucket, Key=f"{folder}/{file_name}") + self.s3_client.delete_object( + Bucket=self.bucket, Key=f"{folder}/{file_name}" + ) logger.info(f"File {file_name} deleted successfully from folder {folder}.") except ClientError as e: logger.error(f"Error deleting file {file_name} from folder {folder}: {e}") diff --git a/src/backend/langflow/utils/schemas.py b/src/backend/langflow/utils/schemas.py index 8d0b2db12..354cb5949 100644 --- a/src/backend/langflow/utils/schemas.py +++ b/src/backend/langflow/utils/schemas.py @@ -1,3 +1,4 @@ +import enum from typing import Dict, List, Optional, Union from langchain_core.messages import BaseMessage @@ -40,3 +41,13 @@ class ChatOutputResponse(BaseModel): message = self.message.replace("\n\n", "\n") self.message = message.replace("\n", "\n\n") return self + + +class ContainsEnumMeta(enum.EnumMeta): + def __contains__(cls, item): + try: + cls(item) + except ValueError: + return False + else: + return True