Fix import errors and type annotations
This commit is contained in:
parent
3c9f2a0e8c
commit
f86f6e6281
24 changed files with 138 additions and 245 deletions
|
|
@ -222,7 +222,7 @@ def build_and_cache_graph(
|
|||
graph: Optional[Graph] = None,
|
||||
):
|
||||
"""Build and cache the graph."""
|
||||
flow: Flow = session.get(Flow, flow_id)
|
||||
flow: Optional[Flow] = session.get(Flow, flow_id)
|
||||
if not flow or not flow.data:
|
||||
raise ValueError("Invalid flow ID")
|
||||
other_graph = Graph.from_payload(flow.data, flow_id)
|
||||
|
|
@ -236,10 +236,12 @@ def build_and_cache_graph(
|
|||
|
||||
def format_syntax_error_message(exc: SyntaxError) -> str:
|
||||
"""Format a SyntaxError message for returning to the frontend."""
|
||||
if exc.text is None:
|
||||
return f"Syntax error in code. Error on line {exc.lineno}"
|
||||
return f"Syntax error in code. Error on line {exc.lineno}: {exc.text.strip()}"
|
||||
|
||||
|
||||
def get_causing_exception(exc: Exception) -> Exception:
|
||||
def get_causing_exception(exc: BaseException) -> BaseException:
|
||||
"""Get the causing exception from an exception."""
|
||||
if hasattr(exc, "__cause__") and exc.__cause__:
|
||||
return get_causing_exception(exc.__cause__)
|
||||
|
|
|
|||
|
|
@ -4,117 +4,16 @@ from uuid import UUID
|
|||
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from loguru import logger
|
||||
|
||||
from langflow.api.v1.schemas import ChatResponse, PromptResponse
|
||||
from langflow.services.deps import get_chat_service
|
||||
from langflow.utils.util import remove_ansi_escape_codes
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.socket.service import SocketIOService
|
||||
|
||||
|
||||
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
||||
def __init__(self, client_id: str = None):
|
||||
self.chat_service = get_chat_service()
|
||||
self.client_id = client_id
|
||||
self.websocket = self.chat_service.active_connections[self.client_id]
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
|
||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
resp = ChatResponse(
|
||||
message="",
|
||||
type="stream",
|
||||
intermediate_steps=f"Tool input: {input_str}",
|
||||
)
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
observation_prefix = kwargs.get("observation_prefix", "Tool output: ")
|
||||
split_output = output.split()
|
||||
first_word = split_output[0]
|
||||
rest_of_output = split_output[1:]
|
||||
# Create a formatted message.
|
||||
intermediate_steps = f"{observation_prefix}{first_word}"
|
||||
|
||||
# Create a ChatResponse instance.
|
||||
resp = ChatResponse(
|
||||
message="",
|
||||
type="stream",
|
||||
intermediate_steps=intermediate_steps,
|
||||
)
|
||||
rest_of_resps = [
|
||||
ChatResponse(
|
||||
message="",
|
||||
type="stream",
|
||||
intermediate_steps=f"{word}",
|
||||
)
|
||||
for word in rest_of_output
|
||||
]
|
||||
resps = [resp] + rest_of_resps
|
||||
# Try to send the response, handle potential errors.
|
||||
|
||||
try:
|
||||
# This is to emulate the stream of tokens
|
||||
for resp in resps:
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
except Exception as exc:
|
||||
logger.error(f"Error sending response: {exc}")
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
async def on_text(self, text: str, **kwargs: Any) -> Any:
|
||||
"""Run on arbitrary text."""
|
||||
# This runs when first sending the prompt
|
||||
# to the LLM, adding it will send the final prompt
|
||||
# to the frontend
|
||||
if "Prompt after formatting" in text:
|
||||
text = text.replace("Prompt after formatting:\n", "")
|
||||
text = remove_ansi_escape_codes(text)
|
||||
resp = PromptResponse(
|
||||
prompt=text,
|
||||
)
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
self.chat_service.chat_history.add_message(self.client_id, resp)
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
|
||||
log = f"Thought: {action.log}"
|
||||
# if there are line breaks, split them and send them
|
||||
# as separate messages
|
||||
if "\n" in log:
|
||||
logs = log.split("\n")
|
||||
for log in logs:
|
||||
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
else:
|
||||
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
resp = ChatResponse(
|
||||
message="",
|
||||
type="stream",
|
||||
intermediate_steps=finish.log,
|
||||
)
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
|
||||
|
||||
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
|
||||
class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
|
@ -130,7 +29,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
|
|||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
|
||||
|
||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
resp = ChatResponse(
|
||||
message="",
|
||||
|
|
@ -168,7 +69,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
|
|||
try:
|
||||
# This is to emulate the stream of tokens
|
||||
for resp in resps:
|
||||
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
|
||||
await self.socketio_service.emit_token(
|
||||
to=self.sid, data=resp.model_dump()
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error sending response: {exc}")
|
||||
|
||||
|
|
@ -194,7 +97,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
|
|||
resp = PromptResponse(
|
||||
prompt=text,
|
||||
)
|
||||
await self.socketio_service.emit_message(to=self.sid, data=resp.model_dump())
|
||||
await self.socketio_service.emit_message(
|
||||
to=self.sid, data=resp.model_dump()
|
||||
)
|
||||
self.chat_service.chat_history.add_message(self.client_id, resp)
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
|
||||
|
|
@ -205,7 +110,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
|
|||
logs = log.split("\n")
|
||||
for log in logs:
|
||||
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
|
||||
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
|
||||
await self.socketio_service.emit_token(
|
||||
to=self.sid, data=resp.model_dump()
|
||||
)
|
||||
else:
|
||||
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
|
||||
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
|
||||
|
|
@ -232,5 +139,7 @@ class StreamingLLMCallbackHandler(BaseCallbackHandler):
|
|||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
coroutine = self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
|
||||
coroutine = self.socketio_service.emit_token(
|
||||
to=self.sid, data=resp.model_dump()
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from sqlmodel import Session, select
|
|||
from langflow.api.utils import update_frontend_node_with_template_values
|
||||
from langflow.api.v1.schemas import (
|
||||
CustomComponentCode,
|
||||
InputValueRequest,
|
||||
ProcessResponse,
|
||||
RunResponse,
|
||||
TaskStatusResponse,
|
||||
|
|
@ -54,7 +55,7 @@ def get_all(
|
|||
async def run_flow_with_caching(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
flow_id: str,
|
||||
inputs: Optional[Union[List[dict], dict]] = None,
|
||||
inputs: Optional[InputValueRequest] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
|
|
@ -62,6 +63,11 @@ async def run_flow_with_caching(
|
|||
session_service: SessionService = Depends(get_session_service),
|
||||
):
|
||||
try:
|
||||
if inputs is not None:
|
||||
input_values_dict: dict[str, Union[str, list[str]]] = inputs.model_dump()
|
||||
else:
|
||||
input_values_dict = {}
|
||||
|
||||
if session_id:
|
||||
session_data = await session_service.load_session(
|
||||
session_id, flow_id=flow_id
|
||||
|
|
@ -74,7 +80,7 @@ async def run_flow_with_caching(
|
|||
graph=graph,
|
||||
flow_id=flow_id,
|
||||
session_id=session_id,
|
||||
inputs=inputs,
|
||||
inputs=input_values_dict,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
|
|
@ -99,7 +105,7 @@ async def run_flow_with_caching(
|
|||
graph=graph_data,
|
||||
flow_id=flow_id,
|
||||
session_id=session_id,
|
||||
inputs=inputs,
|
||||
inputs=input_values_dict,
|
||||
artifacts={},
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
|
|
|
|||
|
|
@ -59,4 +59,4 @@ class RetrievalQAComponent(CustomComponent):
|
|||
|
||||
final_result = "\n".join([str(result_str), references_str])
|
||||
self.status = final_result
|
||||
return final_result
|
||||
return final_result # OK
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ class GatherRecordsComponent(CustomComponent):
|
|||
silent_errors: bool,
|
||||
max_concurrency: int,
|
||||
use_multithreading: bool,
|
||||
) -> List[Record]:
|
||||
) -> List[Optional[Record]]:
|
||||
if use_multithreading:
|
||||
records = self.parallel_load_records(
|
||||
file_paths, silent_errors, max_concurrency
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ class ChatComponent(CustomComponent):
|
|||
session_id: Optional[str] = None,
|
||||
return_record: Optional[bool] = False,
|
||||
) -> Union[Text, Record]:
|
||||
input_value_record: Optional[Record] = None
|
||||
if return_record:
|
||||
if isinstance(input_value, Record):
|
||||
# Update the data of the record
|
||||
|
|
@ -86,7 +87,7 @@ class ChatComponent(CustomComponent):
|
|||
input_value.data["sender_name"] = sender_name
|
||||
input_value.data["session_id"] = session_id
|
||||
else:
|
||||
input_value = Record(
|
||||
input_value_record = Record(
|
||||
text=input_value,
|
||||
data={
|
||||
"sender": sender,
|
||||
|
|
@ -96,7 +97,11 @@ class ChatComponent(CustomComponent):
|
|||
)
|
||||
if not input_value:
|
||||
input_value = ""
|
||||
self.status = input_value
|
||||
if return_record and input_value_record:
|
||||
result = input_value_record
|
||||
else:
|
||||
result = input_value
|
||||
self.status = result
|
||||
if session_id:
|
||||
self.store_message(input_value, session_id, sender, sender_name)
|
||||
return input_value
|
||||
self.store_message(result, session_id, sender, sender_name)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -150,6 +150,7 @@ class ChatLiteLLMComponent(CustomComponent):
|
|||
|
||||
LLM = ChatLiteLLM(
|
||||
model=model,
|
||||
client=None,
|
||||
streaming=streaming,
|
||||
temperature=temperature,
|
||||
model_kwargs=model_kwargs if model_kwargs is not None else {},
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class RunnableExecComponent(CustomComponent):
|
|||
runnable: Runnable,
|
||||
output_key: str = "output",
|
||||
) -> Text:
|
||||
result = runnable.invoke({input_key: inputs})
|
||||
result = runnable.invoke({input_key: input_value})
|
||||
result = result.get(output_key)
|
||||
self.status = result
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -52,7 +52,9 @@ class Graph:
|
|||
|
||||
self._vertices = self._graph_data["nodes"]
|
||||
self._edges = self._graph_data["edges"]
|
||||
self.inactive_vertices = set()
|
||||
self.inactive_vertices: set = set()
|
||||
self.edges: List[ContractEdge] = []
|
||||
self.vertices: List[Vertex] = []
|
||||
self._build_graph()
|
||||
self.build_graph_maps()
|
||||
self.define_vertices_lists()
|
||||
|
|
@ -100,7 +102,7 @@ class Graph:
|
|||
|
||||
async def run(
|
||||
self, inputs: Dict[str, Union[str, list[str]]], stream: bool
|
||||
) -> List["ResultData"]:
|
||||
) -> List[Optional["ResultData"]]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
|
||||
# inputs is {"message": "Hello, world!"}
|
||||
|
|
@ -108,7 +110,7 @@ class Graph:
|
|||
# of the vertices that are inputs
|
||||
# if the value is a list, we need to run multiple times
|
||||
outputs = []
|
||||
inputs_values = inputs.get(INPUT_FIELD_NAME)
|
||||
inputs_values = inputs.get(INPUT_FIELD_NAME, "")
|
||||
if not isinstance(inputs_values, list):
|
||||
inputs_values = [inputs_values]
|
||||
for input_value in inputs_values:
|
||||
|
|
@ -245,7 +247,7 @@ class Graph:
|
|||
return False
|
||||
return True
|
||||
|
||||
def update(self, other: "Graph") -> None:
|
||||
def update(self, other: "Graph") -> "Graph":
|
||||
# Existing vertices in self graph
|
||||
existing_vertex_ids = set(vertex.id for vertex in self.vertices)
|
||||
# Vertex IDs in the other graph
|
||||
|
|
@ -274,7 +276,7 @@ class Graph:
|
|||
if not self_vertex.pinned:
|
||||
self_vertex._built = False
|
||||
self_vertex.result = None
|
||||
self_vertex.artifacts = None
|
||||
self_vertex.artifacts = {}
|
||||
self_vertex.set_top_level(self.top_level_vertices)
|
||||
self.reset_all_edges_of_vertex(self_vertex)
|
||||
|
||||
|
|
@ -623,7 +625,7 @@ class Graph:
|
|||
queue = deque(
|
||||
vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0
|
||||
)
|
||||
layers = []
|
||||
layers: List[List[str]] = []
|
||||
|
||||
current_layer = 0
|
||||
while queue:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from langflow.graph.utils import serialize_field
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
|
||||
from langflow.graph.utils import serialize_field
|
||||
from langflow.utils.schemas import ContainsEnumMeta
|
||||
|
||||
|
||||
class ResultData(BaseModel):
|
||||
results: Optional[Any] = Field(default_factory=dict)
|
||||
|
|
@ -18,7 +20,7 @@ class ResultData(BaseModel):
|
|||
return serialize_field(value)
|
||||
|
||||
|
||||
class InterfaceComponentTypes(str, Enum):
|
||||
class InterfaceComponentTypes(str, Enum, metaclass=ContainsEnumMeta):
|
||||
# ChatInput and ChatOutput are the only ones that are
|
||||
# power components
|
||||
ChatInput = "ChatInput"
|
||||
|
|
@ -26,6 +28,14 @@ class InterfaceComponentTypes(str, Enum):
|
|||
TextInput = "TextInput"
|
||||
TextOutput = "TextOutput"
|
||||
|
||||
def __contains__(cls, item):
|
||||
try:
|
||||
cls(item)
|
||||
except ValueError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
INPUT_COMPONENTS = [
|
||||
InterfaceComponentTypes.ChatInput,
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class Vertex:
|
|||
self.should_run = True
|
||||
self.result: Optional[ResultData] = None
|
||||
try:
|
||||
self.is_interface_component = InterfaceComponentTypes(self.vertex_type)
|
||||
self.is_interface_component = self.vertex_type in InterfaceComponentTypes
|
||||
except ValueError:
|
||||
self.is_interface_component = False
|
||||
|
||||
|
|
@ -107,29 +107,6 @@ class Vertex:
|
|||
def add_build_time(self, time):
|
||||
self.build_times.append(time)
|
||||
|
||||
# Build a result dict for each edge
|
||||
# like so: {edge.target.id: {edge.target_param: self._built_object}}
|
||||
async def get_result_dict(self, force: bool = False) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Returns a dictionary with the result of the build process.
|
||||
"""
|
||||
edge_results = {}
|
||||
for edge in self.edges:
|
||||
target = self.graph.get_vertex(edge.target_id)
|
||||
if edge.is_fulfilled and isinstance(
|
||||
await edge.get_result(
|
||||
source=self,
|
||||
target=target,
|
||||
),
|
||||
str,
|
||||
):
|
||||
if edge.target_id not in edge_results:
|
||||
edge_results[edge.target_id] = {}
|
||||
edge_results[edge.target_id][edge.target_param] = await edge.get_result(
|
||||
source=self, target=target
|
||||
)
|
||||
return edge_results
|
||||
|
||||
def set_result(self, result: ResultData) -> None:
|
||||
self.result = result
|
||||
|
||||
|
|
@ -626,7 +603,7 @@ class Vertex:
|
|||
return self.get_requester_result(requester)
|
||||
self._reset()
|
||||
|
||||
if self.is_input:
|
||||
if self.is_input and inputs is not None:
|
||||
self.update_raw_params(inputs)
|
||||
|
||||
# Run steps
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
import emoji
|
||||
|
||||
|
|
@ -30,7 +31,7 @@ def getattr_return_bool(value):
|
|||
return value
|
||||
|
||||
|
||||
ATTR_FUNC_MAPPING = {
|
||||
ATTR_FUNC_MAPPING: dict[str, Callable] = {
|
||||
"display_name": getattr_return_str,
|
||||
"description": getattr_return_str,
|
||||
"beta": getattr_return_bool,
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from langflow.utils import validate
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.edge.base import ContractEdge
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
|
||||
|
|
@ -292,8 +293,9 @@ class CustomComponent(Component):
|
|||
def get_function(self):
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
|
||||
from langflow.processing.process import build_sorted_vertices, process_tweaks
|
||||
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph":
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.processing.process import process_tweaks
|
||||
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
|
|
@ -302,7 +304,15 @@ class CustomComponent(Component):
|
|||
raise ValueError(f"Flow {flow_id} not found")
|
||||
if tweaks:
|
||||
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
|
||||
return await build_sorted_vertices(graph_data, self.user_id)
|
||||
graph = Graph(**graph_data)
|
||||
return graph
|
||||
|
||||
async def run_flow(
|
||||
self, input_value: str, flow_id: str, tweaks: Optional[dict] = None
|
||||
) -> Any:
|
||||
graph = await self.load_flow(flow_id, tweaks)
|
||||
input_value_dict = {"input_value": input_value}
|
||||
return await graph.run(input_value_dict)
|
||||
|
||||
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
|
||||
if not self._user_id:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
from typing import Dict, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph import Graph
|
||||
|
||||
|
||||
async def build_sorted_vertices(
|
||||
data_graph, flow_id: Optional[Union[str, UUID]] = None
|
||||
) -> Tuple[Graph, Dict]:
|
||||
async def build_sorted_vertices(data_graph, flow_id: str) -> Tuple[Graph, Dict]:
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from langchain.agents.agent import AgentExecutor
|
|||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from loguru import logger
|
||||
|
||||
from langflow.api.v1.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler
|
||||
from langflow.api.v1.callback import StreamingLLMCallbackHandler
|
||||
from langflow.processing.process import fix_memory_inputs, format_actions
|
||||
from langflow.services.deps import get_plugins_service
|
||||
|
||||
|
|
@ -15,10 +15,7 @@ if TYPE_CHECKING:
|
|||
def setup_callbacks(sync, trace_id, **kwargs):
|
||||
"""Setup callbacks for langchain object"""
|
||||
callbacks = []
|
||||
if sync:
|
||||
callbacks.append(StreamingLLMCallbackHandler(**kwargs))
|
||||
else:
|
||||
callbacks.append(AsyncStreamingLLMCallbackHandler(**kwargs))
|
||||
callbacks.append(StreamingLLMCallbackHandler(**kwargs))
|
||||
|
||||
plugin_service = get_plugins_service()
|
||||
plugin_callbacks = plugin_service.get_callbacks(_id=trace_id)
|
||||
|
|
@ -42,7 +39,9 @@ def get_langfuse_callback(trace_id):
|
|||
return None
|
||||
|
||||
|
||||
def flush_langfuse_callback_if_present(callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]]):
|
||||
def flush_langfuse_callback_if_present(
|
||||
callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]]
|
||||
):
|
||||
"""
|
||||
If langfuse callback is present, run callback.langfuse.flush()
|
||||
"""
|
||||
|
|
@ -83,9 +82,15 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa
|
|||
# if langfuse callback is present, run callback.langfuse.flush()
|
||||
flush_langfuse_callback_if_present(callbacks)
|
||||
|
||||
intermediate_steps = output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = output.get(langchain_object.output_keys[0]) if isinstance(output, dict) else output
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
try:
|
||||
thought = format_actions(intermediate_steps) if intermediate_steps else ""
|
||||
except Exception as exc:
|
||||
|
|
|
|||
|
|
@ -13,12 +13,7 @@ from pydantic import BaseModel
|
|||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.interface.run import (
|
||||
build_sorted_vertices,
|
||||
get_memory_key,
|
||||
update_memory_keys,
|
||||
)
|
||||
from langflow.services.deps import get_session_service
|
||||
from langflow.interface.run import get_memory_key, update_memory_keys
|
||||
from langflow.services.session.service import SessionService
|
||||
|
||||
|
||||
|
|
@ -203,62 +198,12 @@ class Result(BaseModel):
|
|||
session_id: str
|
||||
|
||||
|
||||
async def process_graph_cached(
|
||||
data_graph: Dict[str, Any],
|
||||
inputs: Optional[Union[dict, List[dict]]] = None,
|
||||
clear_cache=False,
|
||||
session_id=None,
|
||||
) -> Result:
|
||||
session_service = get_session_service()
|
||||
if clear_cache:
|
||||
session_service.clear_session(session_id)
|
||||
if session_id is None:
|
||||
session_id = session_service.generate_key(
|
||||
session_id=session_id, data_graph=data_graph
|
||||
)
|
||||
# Load the graph using SessionService
|
||||
session = await session_service.load_session(
|
||||
session_id, data_graph, flow_id=flow_id
|
||||
)
|
||||
graph, artifacts = session if session else (None, None)
|
||||
if not graph:
|
||||
raise ValueError("Graph not found in the session")
|
||||
|
||||
result = await build_graph_and_generate_result(
|
||||
graph=graph,
|
||||
session_id=session_id,
|
||||
inputs=inputs,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def build_graph_and_generate_result(
|
||||
graph: "Graph",
|
||||
session_id: str,
|
||||
inputs: Optional[Union[dict, List[dict]]] = None,
|
||||
artifacts: Optional[Dict[str, Any]] = None,
|
||||
session_service: Optional[SessionService] = None,
|
||||
):
|
||||
"""Build the graph and generate the result"""
|
||||
built_object = await graph.build()
|
||||
processed_inputs = process_inputs(inputs, artifacts or {})
|
||||
result = await generate_result(built_object, processed_inputs)
|
||||
# langchain_object is now updated with the new memory
|
||||
# we need to update the cache with the updated langchain_object
|
||||
if session_id and session_service:
|
||||
session_service.update_session(session_id, (graph, artifacts))
|
||||
return Result(result=result, session_id=session_id)
|
||||
|
||||
|
||||
async def run_graph(
|
||||
graph: Union["Graph", dict],
|
||||
flow_id: str,
|
||||
stream: bool,
|
||||
session_id: Optional[str] = None,
|
||||
inputs: Optional[Union[dict, List[dict]]] = None,
|
||||
inputs: Optional[dict[str, Union[List[str], str]]] = None,
|
||||
artifacts: Optional[Dict[str, Any]] = None,
|
||||
session_service: Optional[SessionService] = None,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import duckdb
|
||||
from loguru import logger
|
||||
from platformdirs import user_cache_dir
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.monitor.schema import (
|
||||
|
|
@ -56,7 +55,7 @@ class MonitorService(Service):
|
|||
):
|
||||
# Make sure the model passed matches the table
|
||||
|
||||
model: Type[BaseModel] = self.table_map.get(table_name)
|
||||
model = self.table_map.get(table_name)
|
||||
if model is None:
|
||||
raise ValueError(f"Unknown table name: {table_name}")
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,10 @@ def get_table_schema_as_dict(conn: duckdb.DuckDBPyConnection, table_name: str) -
|
|||
def model_to_sql_column_definitions(model: Type[BaseModel]) -> dict:
|
||||
columns = {}
|
||||
for field_name, field_type in model.model_fields.items():
|
||||
if hasattr(field_type.annotation, "__args__"):
|
||||
if (
|
||||
hasattr(field_type.annotation, "__args__")
|
||||
and field_type.annotation is not None
|
||||
):
|
||||
field_args = field_type.annotation.__args__
|
||||
else:
|
||||
field_args = []
|
||||
|
|
@ -82,7 +85,7 @@ def drop_and_create_table_if_schema_mismatch(
|
|||
def add_row_to_table(
|
||||
conn: duckdb.DuckDBPyConnection,
|
||||
table_name: str,
|
||||
model: Type[BaseModel],
|
||||
model: Type,
|
||||
monitor_data: Union[Dict[str, Any], BaseModel],
|
||||
):
|
||||
# Validate the data with the Pydantic model
|
||||
|
|
|
|||
|
|
@ -14,9 +14,7 @@ class SessionService(Service):
|
|||
def __init__(self, cache_service):
|
||||
self.cache_service: "BaseCacheService" = cache_service
|
||||
|
||||
async def load_session(
|
||||
self, key, data_graph: Optional[dict] = None, flow_id: Optional[str] = None
|
||||
):
|
||||
async def load_session(self, key, flow_id: str, data_graph: Optional[dict] = None):
|
||||
# Check if the data is cached
|
||||
if key in self.cache_service:
|
||||
return self.cache_service.get(key)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class SettingsService(Service):
|
|||
settings_dict = {k.upper(): v for k, v in settings_dict.items()}
|
||||
|
||||
for key in settings_dict:
|
||||
if key not in Settings.model_fields().keys():
|
||||
if key not in Settings.model_fields.keys():
|
||||
raise KeyError(f"Key {key} not found in settings")
|
||||
logger.debug(
|
||||
f"Loading {len(settings_dict[key])} {key} from {file_path}"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import socketio
|
||||
import socketio # type: ignore
|
||||
from loguru import logger
|
||||
|
||||
from langflow.services.base import Service
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import time
|
||||
from typing import Callable
|
||||
|
||||
import socketio
|
||||
import socketio # type: ignore
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.utils import format_elapsed_time
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import boto3
|
||||
from botocore.exceptions import ClientError, NoCredentialsError
|
||||
import boto3 # type: ignore
|
||||
from botocore.exceptions import ClientError, NoCredentialsError # type: ignore
|
||||
from loguru import logger
|
||||
|
||||
from .service import StorageService
|
||||
|
|
@ -25,7 +25,9 @@ class S3StorageService(StorageService):
|
|||
:raises Exception: If an error occurs during file saving.
|
||||
"""
|
||||
try:
|
||||
self.s3_client.put_object(Bucket=self.bucket, Key=f"{folder}/{file_name}", Body=data)
|
||||
self.s3_client.put_object(
|
||||
Bucket=self.bucket, Key=f"{folder}/{file_name}", Body=data
|
||||
)
|
||||
logger.info(f"File {file_name} saved successfully in folder {folder}.")
|
||||
except NoCredentialsError:
|
||||
logger.error("Credentials not available for AWS S3.")
|
||||
|
|
@ -44,8 +46,12 @@ class S3StorageService(StorageService):
|
|||
:raises Exception: If an error occurs during file retrieval.
|
||||
"""
|
||||
try:
|
||||
response = self.s3_client.get_object(Bucket=self.bucket, Key=f"{folder}/{file_name}")
|
||||
logger.info(f"File {file_name} retrieved successfully from folder {folder}.")
|
||||
response = self.s3_client.get_object(
|
||||
Bucket=self.bucket, Key=f"{folder}/{file_name}"
|
||||
)
|
||||
logger.info(
|
||||
f"File {file_name} retrieved successfully from folder {folder}."
|
||||
)
|
||||
return response["Body"].read()
|
||||
except ClientError as e:
|
||||
logger.error(f"Error retrieving file {file_name} from folder {folder}: {e}")
|
||||
|
|
@ -61,7 +67,11 @@ class S3StorageService(StorageService):
|
|||
"""
|
||||
try:
|
||||
response = self.s3_client.list_objects_v2(Bucket=self.bucket, Prefix=folder)
|
||||
files = [item["Key"] for item in response.get("Contents", []) if "/" not in item["Key"][len(folder) :]]
|
||||
files = [
|
||||
item["Key"]
|
||||
for item in response.get("Contents", [])
|
||||
if "/" not in item["Key"][len(folder) :]
|
||||
]
|
||||
logger.info(f"{len(files)} files listed in folder {folder}.")
|
||||
return files
|
||||
except ClientError as e:
|
||||
|
|
@ -77,7 +87,9 @@ class S3StorageService(StorageService):
|
|||
:raises Exception: If an error occurs during file deletion.
|
||||
"""
|
||||
try:
|
||||
self.s3_client.delete_object(Bucket=self.bucket, Key=f"{folder}/{file_name}")
|
||||
self.s3_client.delete_object(
|
||||
Bucket=self.bucket, Key=f"{folder}/{file_name}"
|
||||
)
|
||||
logger.info(f"File {file_name} deleted successfully from folder {folder}.")
|
||||
except ClientError as e:
|
||||
logger.error(f"Error deleting file {file_name} from folder {folder}: {e}")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
|
@ -40,3 +41,13 @@ class ChatOutputResponse(BaseModel):
|
|||
message = self.message.replace("\n\n", "\n")
|
||||
self.message = message.replace("\n", "\n\n")
|
||||
return self
|
||||
|
||||
|
||||
class ContainsEnumMeta(enum.EnumMeta):
|
||||
def __contains__(cls, item):
|
||||
try:
|
||||
cls(item)
|
||||
except ValueError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue