Merge branch 'logspace-ai:dev' into dev
This commit is contained in:
commit
9bbd013dcb
93 changed files with 6717 additions and 16968 deletions
|
|
@ -1,12 +1,13 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
|
||||
from langflow.api.schemas import ChatResponse
|
||||
|
||||
|
||||
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
|
||||
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
||||
def __init__(self, websocket):
|
||||
|
|
@ -15,3 +16,17 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
await self.websocket.send_json(resp.dict())
|
||||
|
||||
|
||||
class StreamingLLMCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
||||
def __init__(self, websocket):
|
||||
self.websocket = websocket
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
coroutine = self.websocket.send_json(resp.dict())
|
||||
asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
from fastapi import APIRouter, WebSocket
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
WebSocketException,
|
||||
status,
|
||||
)
|
||||
|
||||
from langflow.api.chat_manager import ChatManager
|
||||
from langflow.utils.logger import logger
|
||||
|
|
@ -12,7 +18,9 @@ async def websocket_endpoint(client_id: str, websocket: WebSocket):
|
|||
"""Websocket endpoint for chat."""
|
||||
try:
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
except WebSocketException as exc:
|
||||
logger.error(exc)
|
||||
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
|
||||
except WebSocketDisconnect as exc:
|
||||
logger.error(exc)
|
||||
await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc))
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocket, status
|
||||
|
||||
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.cache import cache_manager
|
||||
|
|
@ -47,7 +47,6 @@ class ChatManager:
|
|||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.chat_history = ChatHistory()
|
||||
self.chat_history.attach(self.on_chat_history_update)
|
||||
self.cache_manager = cache_manager
|
||||
self.cache_manager.attach(self.update)
|
||||
|
||||
|
|
@ -91,7 +90,7 @@ class ChatManager:
|
|||
self.active_connections[client_id] = websocket
|
||||
|
||||
def disconnect(self, client_id: str):
|
||||
del self.active_connections[client_id]
|
||||
self.active_connections.pop(client_id, None)
|
||||
|
||||
async def send_message(self, client_id: str, message: str):
|
||||
websocket = self.active_connections[client_id]
|
||||
|
|
@ -109,7 +108,7 @@ class ChatManager:
|
|||
|
||||
graph_data = payload
|
||||
start_resp = ChatResponse(message=None, type="start", intermediate_steps="")
|
||||
self.chat_history.add_message(client_id, start_resp)
|
||||
await self.send_json(client_id, start_resp)
|
||||
|
||||
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
|
||||
# Generate result and thought
|
||||
|
|
@ -143,11 +142,12 @@ class ChatManager:
|
|||
break
|
||||
|
||||
response = ChatResponse(
|
||||
message=result or "",
|
||||
message=result,
|
||||
intermediate_steps=intermediate_steps.strip(),
|
||||
type="end",
|
||||
files=file_responses,
|
||||
)
|
||||
await self.send_json(client_id, response)
|
||||
self.chat_history.add_message(client_id, response)
|
||||
|
||||
async def handle_websocket(self, client_id: str, websocket: WebSocket):
|
||||
|
|
@ -171,17 +171,24 @@ class ChatManager:
|
|||
|
||||
with self.cache_manager.set_client_id(client_id):
|
||||
await self.process_message(client_id, payload)
|
||||
|
||||
except Exception as e:
|
||||
# Handle any exceptions that might occur
|
||||
logger.exception(e)
|
||||
# send a message to the client
|
||||
await self.send_message(client_id, str(e))
|
||||
raise e
|
||||
finally:
|
||||
await self.active_connections[client_id].close(
|
||||
code=1000, reason="Client disconnected"
|
||||
code=status.WS_1011_INTERNAL_ERROR, reason=str(e)[:120]
|
||||
)
|
||||
self.disconnect(client_id)
|
||||
finally:
|
||||
try:
|
||||
connection = self.active_connections.get(client_id)
|
||||
if connection:
|
||||
await connection.close(code=1000, reason="Client disconnected")
|
||||
self.disconnect(client_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
self.disconnect(client_id)
|
||||
|
||||
|
||||
async def process_graph(
|
||||
|
|
@ -203,8 +210,8 @@ async def process_graph(
|
|||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = get_result_and_steps(
|
||||
langchain_object, chat_message.message or ""
|
||||
result, intermediate_steps = await get_result_and_steps(
|
||||
langchain_object, chat_message.message or "", websocket=websocket
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
return result, intermediate_steps
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from langflow.interface.run import process_graph_cached
|
||||
from langflow.interface.types import build_langchain_types_dict
|
||||
from langflow.api.schemas import (
|
||||
ExportedFlow,
|
||||
GraphData,
|
||||
PredictRequest,
|
||||
PredictResponse,
|
||||
)
|
||||
from langflow.interface.run import process_graph_cached
|
||||
from langflow.interface.types import build_langchain_types_dict
|
||||
|
||||
# build router
|
||||
router = APIRouter()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Union, Dict, List
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import json
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from langflow.api.base import (
|
||||
|
|
@ -7,6 +9,7 @@ from langflow.api.base import (
|
|||
PromptValidationResponse,
|
||||
validate_prompt,
|
||||
)
|
||||
from langflow.graph.nodes import VectorStoreNode
|
||||
from langflow.interface.run import build_graph
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.validate import validate_code
|
||||
|
|
@ -44,10 +47,11 @@ def post_validate_node(node_id: str, data: dict):
|
|||
graph = build_graph(data)
|
||||
# validate node
|
||||
node = graph.get_node(node_id)
|
||||
if node is not None:
|
||||
_ = node.build()
|
||||
return str(node.params)
|
||||
raise Exception(f"Node {node_id} not found")
|
||||
if node is None:
|
||||
raise ValueError(f"Node {node_id} not found")
|
||||
if not isinstance(node, VectorStoreNode):
|
||||
node.build()
|
||||
return json.dumps({"valid": True, "params": str(node._built_object_repr())})
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
logger.exception(e)
|
||||
return json.dumps({"valid": False})
|
||||
|
|
|
|||
|
|
@ -232,6 +232,9 @@ class Node:
|
|||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
def _built_object_repr(self):
|
||||
return repr(self._built_object)
|
||||
|
||||
|
||||
class Edge:
|
||||
def __init__(self, source: "Node", target: "Node"):
|
||||
|
|
|
|||
|
|
@ -139,6 +139,13 @@ class DocumentLoaderNode(Node):
|
|||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="documentloaders")
|
||||
|
||||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
# show how many documents are in the list?
|
||||
if self._built_object:
|
||||
return f"""{self.node_type}({len(self._built_object)} documents)\nDocuments: {self._built_object[:3]}..."""
|
||||
return f"{self.node_type}()"
|
||||
|
||||
|
||||
class EmbeddingNode(Node):
|
||||
def __init__(self, data: Dict):
|
||||
|
|
@ -149,6 +156,9 @@ class VectorStoreNode(Node):
|
|||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="vectorstores")
|
||||
|
||||
def _built_object_repr(self):
|
||||
return "Vector stores can take time to build. It will build on the first query."
|
||||
|
||||
|
||||
class MemoryNode(Node):
|
||||
def __init__(self, data: Dict):
|
||||
|
|
@ -158,3 +168,10 @@ class MemoryNode(Node):
|
|||
class TextSplitterNode(Node):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="textsplitters")
|
||||
|
||||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
# show how many documents are in the list?
|
||||
if self._built_object:
|
||||
return f"""{self.node_type}({len(self._built_object)} documents)\nDocuments: {self._built_object[:3]}..."""
|
||||
return f"{self.node_type}()"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from abc import ABC
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
|
|
@ -27,14 +28,31 @@ from langchain.agents.agent_toolkits.vectorstore.prompt import (
|
|||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS as SQL_FORMAT_INSTRUCTIONS
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
|
||||
class JsonAgent(AgentExecutor):
|
||||
class CustomAgentExecutor(AgentExecutor, ABC):
|
||||
"""Custom agent executor"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "CustomAgentExecutor"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class JsonAgent(CustomAgentExecutor):
|
||||
"""Json agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -70,7 +88,7 @@ class JsonAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class CSVAgent(AgentExecutor):
|
||||
class CSVAgent(CustomAgentExecutor):
|
||||
"""CSV agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -118,7 +136,7 @@ class CSVAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class VectorStoreAgent(AgentExecutor):
|
||||
class VectorStoreAgent(CustomAgentExecutor):
|
||||
"""Vector Store agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -156,7 +174,7 @@ class VectorStoreAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class SQLAgent(AgentExecutor):
|
||||
class SQLAgent(CustomAgentExecutor):
|
||||
"""SQL agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -228,7 +246,7 @@ class SQLAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class VectorStoreRouterAgent(AgentExecutor):
|
||||
class VectorStoreRouterAgent(CustomAgentExecutor):
|
||||
"""Vector Store Router Agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -267,7 +285,7 @@ class VectorStoreRouterAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class InitializeAgent(AgentExecutor):
|
||||
class InitializeAgent(CustomAgentExecutor):
|
||||
"""Implementation of initialize_agent function"""
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ from typing import Any, Type
|
|||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.agents import Agent
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ from langchain.agents.load_tools import (
|
|||
)
|
||||
from langchain.agents.loading import load_agent_from_config
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.loading import load_chain_from_config
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.llms.loading import load_llm_from_config
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -30,88 +30,19 @@ from langflow.utils import util, validate
|
|||
|
||||
def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
||||
"""Instantiate class from module type and key, and params"""
|
||||
params = convert_params_to_sets(params)
|
||||
|
||||
if node_type in CUSTOM_AGENTS:
|
||||
if custom_agent := CUSTOM_AGENTS.get(node_type):
|
||||
return custom_agent.initialize(**params) # type: ignore
|
||||
params = process_params(params)
|
||||
custom_agent = CUSTOM_AGENTS.get(node_type)
|
||||
if custom_agent:
|
||||
return custom_agent.initialize(**params)
|
||||
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
# check if it is a class before using issubclass
|
||||
|
||||
# if isinstance(class_object, type) and issubclass(class_object, BaseModel):
|
||||
# # validate params
|
||||
# fields = class_object.__fields__
|
||||
# params = {key: value for key, value in params.items() if key in fields}
|
||||
|
||||
if base_type == "agents":
|
||||
# We need to initialize it differently
|
||||
return load_agent_executor(class_object, params)
|
||||
elif base_type == "prompts":
|
||||
if node_type == "ZeroShotPrompt":
|
||||
if "tools" not in params:
|
||||
params["tools"] = []
|
||||
return ZeroShotAgent.create_prompt(**params)
|
||||
elif base_type == "tools":
|
||||
if node_type == "JsonSpec":
|
||||
params["dict_"] = load_file_into_dict(params.pop("path"))
|
||||
return class_object(**params)
|
||||
elif node_type == "PythonFunction":
|
||||
# If the node_type is "PythonFunction"
|
||||
# we need to get the function from the params
|
||||
# which will be a str containing a python function
|
||||
# and then we need to compile it and return the function
|
||||
# as the instance
|
||||
function_string = params["code"]
|
||||
if isinstance(function_string, str):
|
||||
return validate.eval_function(function_string)
|
||||
raise ValueError("Function should be a string")
|
||||
elif node_type.lower() == "tool":
|
||||
return class_object(**params)
|
||||
elif base_type == "toolkits":
|
||||
loaded_toolkit = class_object(**params)
|
||||
# Check if node_type has a loader
|
||||
if toolkits_creator.has_create_function(node_type):
|
||||
return load_toolkits_executor(node_type, loaded_toolkit, params)
|
||||
return loaded_toolkit
|
||||
elif base_type == "embeddings":
|
||||
# ? Why remove model from params?
|
||||
try:
|
||||
params.pop("model")
|
||||
except KeyError:
|
||||
pass
|
||||
# remove all params that are not in class_object.__fields__
|
||||
try:
|
||||
return class_object(**params)
|
||||
except ValidationError:
|
||||
params = {
|
||||
key: value
|
||||
for key, value in params.items()
|
||||
if key in class_object.__fields__
|
||||
}
|
||||
return class_object(**params)
|
||||
elif base_type == "vectorstores":
|
||||
if len(params.get("documents", [])) == 0:
|
||||
# Error when the pdf or other source was not correctly
|
||||
# loaded.
|
||||
raise ValueError(
|
||||
"The source you provided did not load correctly or was empty."
|
||||
"This may cause an error in the vectorstore."
|
||||
)
|
||||
return class_object.from_documents(**params)
|
||||
elif base_type == "documentloaders":
|
||||
return class_object(**params).load()
|
||||
elif base_type == "textsplitters":
|
||||
documents = params.pop("documents")
|
||||
text_splitter = class_object(**params)
|
||||
return text_splitter.split_documents(documents)
|
||||
elif base_type == "utilities":
|
||||
if node_type == "SQLDatabase":
|
||||
return class_object.from_uri(params.pop("uri"))
|
||||
|
||||
return class_object(**params)
|
||||
return instantiate_based_on_type(class_object, base_type, node_type, params)
|
||||
|
||||
|
||||
def process_params(params):
|
||||
"""Process params"""
|
||||
def convert_params_to_sets(params):
|
||||
"""Convert certain params to sets"""
|
||||
if "allowed_special" in params:
|
||||
params["allowed_special"] = set(params["allowed_special"])
|
||||
if "disallowed_special" in params:
|
||||
|
|
@ -119,6 +50,100 @@ def process_params(params):
|
|||
return params
|
||||
|
||||
|
||||
def instantiate_based_on_type(class_object, base_type, node_type, params):
|
||||
if base_type == "agents":
|
||||
return instantiate_agent(class_object, params)
|
||||
elif base_type == "prompts":
|
||||
return instantiate_prompt(class_object, node_type, params)
|
||||
elif base_type == "tools":
|
||||
return instantiate_tool(node_type, class_object, params)
|
||||
elif base_type == "toolkits":
|
||||
return instantiate_toolkit(node_type, class_object, params)
|
||||
elif base_type == "embeddings":
|
||||
return instantiate_embedding(class_object, params)
|
||||
elif base_type == "vectorstores":
|
||||
return instantiate_vectorstore(class_object, params)
|
||||
elif base_type == "documentloaders":
|
||||
return instantiate_documentloader(class_object, params)
|
||||
elif base_type == "textsplitters":
|
||||
return instantiate_textsplitter(class_object, params)
|
||||
elif base_type == "utilities":
|
||||
return instantiate_utility(node_type, class_object, params)
|
||||
else:
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_agent(class_object, params):
|
||||
return load_agent_executor(class_object, params)
|
||||
|
||||
|
||||
def instantiate_prompt(class_object, node_type, params):
|
||||
if node_type == "ZeroShotPrompt":
|
||||
if "tools" not in params:
|
||||
params["tools"] = []
|
||||
return ZeroShotAgent.create_prompt(**params)
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_tool(node_type, class_object, params):
|
||||
if node_type == "JsonSpec":
|
||||
params["dict_"] = load_file_into_dict(params.pop("path"))
|
||||
return class_object(**params)
|
||||
elif node_type == "PythonFunction":
|
||||
function_string = params["code"]
|
||||
if isinstance(function_string, str):
|
||||
return validate.eval_function(function_string)
|
||||
raise ValueError("Function should be a string")
|
||||
elif node_type.lower() == "tool":
|
||||
return class_object(**params)
|
||||
return None # Or some other default action
|
||||
|
||||
|
||||
def instantiate_toolkit(node_type, class_object, params):
|
||||
loaded_toolkit = class_object(**params)
|
||||
if toolkits_creator.has_create_function(node_type):
|
||||
return load_toolkits_executor(node_type, loaded_toolkit, params)
|
||||
return loaded_toolkit
|
||||
|
||||
|
||||
def instantiate_embedding(class_object, params):
|
||||
params.pop("model", None)
|
||||
try:
|
||||
return class_object(**params)
|
||||
except ValidationError:
|
||||
params = {
|
||||
key: value
|
||||
for key, value in params.items()
|
||||
if key in class_object.__fields__
|
||||
}
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_vectorstore(class_object, params):
|
||||
if len(params.get("documents", [])) == 0:
|
||||
raise ValueError(
|
||||
"The source you provided did not load correctly or was empty."
|
||||
"This may cause an error in the vectorstore."
|
||||
)
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
def instantiate_documentloader(class_object, params):
|
||||
return class_object(**params).load()
|
||||
|
||||
|
||||
def instantiate_textsplitter(class_object, params):
|
||||
documents = params.pop("documents")
|
||||
text_splitter = class_object(**params)
|
||||
return text_splitter.split_documents(documents)
|
||||
|
||||
|
||||
def instantiate_utility(node_type, class_object, params):
|
||||
if node_type == "SQLDatabase":
|
||||
return class_object.from_uri(params.pop("uri"))
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def load_flow_from_json(path: str, build=True):
|
||||
# This is done to avoid circular imports
|
||||
from langflow.graph import Graph
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
import contextlib
|
||||
import io
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from chromadb.errors import NotEnoughElementsException # type: ignore
|
||||
|
||||
from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler # type: ignore
|
||||
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
|
||||
from langflow.graph.graph import Graph
|
||||
from langflow.interface import loading
|
||||
from langflow.utils.logger import logger
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
|
||||
def load_langchain_object(data_graph, is_first_message=False):
|
||||
|
|
@ -66,40 +67,6 @@ def build_langchain_object(data_graph):
|
|||
return graph.build()
|
||||
|
||||
|
||||
def process_graph(data_graph: Dict[str, Any]):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
# Load langchain object
|
||||
logger.debug("Loading langchain object")
|
||||
message = data_graph.pop("message", "")
|
||||
is_first_message = len(data_graph.get("chatHistory", [])) == 0
|
||||
computed_hash, langchain_object = load_langchain_object(
|
||||
data_graph, is_first_message
|
||||
)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_steps(langchain_object, message)
|
||||
logger.debug("Generated result and thought")
|
||||
|
||||
# Save langchain_object to cache
|
||||
# We have to save it here because if the
|
||||
# memory is updated we need to keep the new values
|
||||
logger.debug("Saving langchain object to cache")
|
||||
# save_cache(computed_hash, langchain_object, is_first_message)
|
||||
logger.debug("Saved langchain object to cache")
|
||||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
||||
def process_graph_cached(data_graph: Dict[str, Any], message: str):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
|
|
@ -184,8 +151,9 @@ def fix_memory_inputs(langchain_object):
|
|||
update_memory_keys(langchain_object, possible_new_mem_key)
|
||||
|
||||
|
||||
def get_result_and_steps(langchain_object, message: str):
|
||||
async def get_result_and_steps(langchain_object, message: str, **kwargs):
|
||||
"""Get result and thought from extracted json"""
|
||||
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
|
|
@ -205,32 +173,28 @@ def get_result_and_steps(langchain_object, message: str):
|
|||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
langchain_object.return_intermediate_steps = True
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
try:
|
||||
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
|
||||
output = await langchain_object.acall(chat_input, callbacks=async_callbacks)
|
||||
except Exception as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)]
|
||||
output = langchain_object(chat_input, callbacks=sync_callbacks)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
try:
|
||||
output = langchain_object(chat_input)
|
||||
except ValueError as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
output = langchain_object.run(chat_input)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
if intermediate_steps:
|
||||
thought = format_intermediate_steps(intermediate_steps)
|
||||
else:
|
||||
thought = output_buffer.getvalue()
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
thought = format_actions(intermediate_steps) if intermediate_steps else ""
|
||||
except NotEnoughElementsException as exc:
|
||||
raise ValueError(
|
||||
"Error: Not enough documents for ChromaDB to index. Try reducing chunk size in TextSplitter."
|
||||
|
|
@ -286,7 +250,7 @@ def get_result_and_thought(langchain_object, message: str):
|
|||
else output
|
||||
)
|
||||
if intermediate_steps:
|
||||
thought = format_intermediate_steps(intermediate_steps)
|
||||
thought = format_actions(intermediate_steps)
|
||||
else:
|
||||
thought = output_buffer.getvalue()
|
||||
|
||||
|
|
@ -295,19 +259,17 @@ def get_result_and_thought(langchain_object, message: str):
|
|||
return result, thought
|
||||
|
||||
|
||||
def format_intermediate_steps(intermediate_steps):
|
||||
formatted_chain = "> Entering new AgentExecutor chain...\n"
|
||||
for step in intermediate_steps:
|
||||
action = step[0]
|
||||
observation = step[1]
|
||||
|
||||
formatted_chain += (
|
||||
f" {action.log}\nAction: {action.tool}\nAction Input: {action.tool_input}\n"
|
||||
)
|
||||
formatted_chain += f"Observation: {observation}\n"
|
||||
|
||||
final_answer = f"Final Answer: {observation}\n"
|
||||
formatted_chain += f"Thought: I now know the final answer\n{final_answer}\n"
|
||||
formatted_chain += "> Finished chain.\n"
|
||||
|
||||
return formatted_chain
|
||||
def format_actions(actions: List[Tuple[AgentAction, str]]) -> str:
|
||||
"""Format a list of (AgentAction, answer) tuples into a string."""
|
||||
output = []
|
||||
for action, answer in actions:
|
||||
log = action.log
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
output.append(f"Log: {log}")
|
||||
if "Action" not in log and "Action Input" not in log:
|
||||
output.append(f"Tool: {tool}")
|
||||
output.append(f"Tool Input: {tool_input}")
|
||||
output.append(f"Answer: {answer}")
|
||||
output.append("") # Add a blank line
|
||||
return "\n".join(output)
|
||||
|
|
|
|||
|
|
@ -4,13 +4,9 @@ import os
|
|||
from io import BytesIO
|
||||
|
||||
import yaml
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||
from langchain.llms import AzureOpenAI, OpenAI
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from PIL.Image import Image
|
||||
|
||||
from langflow.api.callback import StreamingLLMCallbackHandler
|
||||
|
||||
|
||||
def load_file_into_dict(file_path: str) -> dict:
|
||||
if not os.path.exists(file_path):
|
||||
|
|
@ -48,10 +44,7 @@ def try_setting_streaming_options(langchain_object, websocket):
|
|||
langchain_object.llm_chain, "llm"
|
||||
):
|
||||
llm = langchain_object.llm_chain.llm
|
||||
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
|
||||
if isinstance(llm, BaseLanguageModel):
|
||||
llm.streaming = bool(hasattr(llm, "streaming"))
|
||||
stream_handler = StreamingLLMCallbackHandler(websocket)
|
||||
stream_manager = AsyncCallbackManager([stream_handler])
|
||||
llm.callback_manager = stream_manager
|
||||
|
||||
return langchain_object
|
||||
|
|
|
|||
|
|
@ -123,6 +123,13 @@ class MidJourneyPromptChainNode(FrontendNode):
|
|||
multiline=False,
|
||||
name="llm",
|
||||
),
|
||||
TemplateField(
|
||||
field_type="BaseChatMemory",
|
||||
required=False,
|
||||
show=True,
|
||||
name="memory",
|
||||
advanced=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
description: str = "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue