Merge branch 'logspace-ai:dev' into dev

This commit is contained in:
Deepankar Mahapatro 2023-05-15 17:48:52 +05:30 committed by GitHub
commit 9bbd013dcb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
93 changed files with 6717 additions and 16968 deletions

View file

@ -1,12 +1,13 @@
import asyncio
from typing import Any
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langflow.api.schemas import ChatResponse
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""
def __init__(self, websocket):
@ -15,3 +16,17 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
await self.websocket.send_json(resp.dict())
class StreamingLLMCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming LLM responses."""
def __init__(self, websocket):
self.websocket = websocket
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)

View file

@ -1,4 +1,10 @@
from fastapi import APIRouter, WebSocket
from fastapi import (
APIRouter,
WebSocket,
WebSocketDisconnect,
WebSocketException,
status,
)
from langflow.api.chat_manager import ChatManager
from langflow.utils.logger import logger
@ -12,7 +18,9 @@ async def websocket_endpoint(client_id: str, websocket: WebSocket):
"""Websocket endpoint for chat."""
try:
await chat_manager.handle_websocket(client_id, websocket)
except Exception as e:
# Log stack trace
logger.exception(e)
raise e
except WebSocketException as exc:
logger.error(exc)
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
except WebSocketDisconnect as exc:
logger.error(exc)
await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc))

View file

@ -3,7 +3,7 @@ import json
from collections import defaultdict
from typing import Dict, List
from fastapi import WebSocket
from fastapi import WebSocket, status
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.cache import cache_manager
@ -47,7 +47,6 @@ class ChatManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.chat_history = ChatHistory()
self.chat_history.attach(self.on_chat_history_update)
self.cache_manager = cache_manager
self.cache_manager.attach(self.update)
@ -91,7 +90,7 @@ class ChatManager:
self.active_connections[client_id] = websocket
def disconnect(self, client_id: str):
del self.active_connections[client_id]
self.active_connections.pop(client_id, None)
async def send_message(self, client_id: str, message: str):
websocket = self.active_connections[client_id]
@ -109,7 +108,7 @@ class ChatManager:
graph_data = payload
start_resp = ChatResponse(message=None, type="start", intermediate_steps="")
self.chat_history.add_message(client_id, start_resp)
await self.send_json(client_id, start_resp)
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
# Generate result and thought
@ -143,11 +142,12 @@ class ChatManager:
break
response = ChatResponse(
message=result or "",
message=result,
intermediate_steps=intermediate_steps.strip(),
type="end",
files=file_responses,
)
await self.send_json(client_id, response)
self.chat_history.add_message(client_id, response)
async def handle_websocket(self, client_id: str, websocket: WebSocket):
@ -171,17 +171,24 @@ class ChatManager:
with self.cache_manager.set_client_id(client_id):
await self.process_message(client_id, payload)
except Exception as e:
# Handle any exceptions that might occur
logger.exception(e)
# send a message to the client
await self.send_message(client_id, str(e))
raise e
finally:
await self.active_connections[client_id].close(
code=1000, reason="Client disconnected"
code=status.WS_1011_INTERNAL_ERROR, reason=str(e)[:120]
)
self.disconnect(client_id)
finally:
try:
connection = self.active_connections.get(client_id)
if connection:
await connection.close(code=1000, reason="Client disconnected")
self.disconnect(client_id)
except Exception as e:
logger.exception(e)
self.disconnect(client_id)
async def process_graph(
@ -203,8 +210,8 @@ async def process_graph(
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = get_result_and_steps(
langchain_object, chat_message.message or ""
result, intermediate_steps = await get_result_and_steps(
langchain_object, chat_message.message or "", websocket=websocket
)
logger.debug("Generated result and intermediate_steps")
return result, intermediate_steps

View file

@ -1,16 +1,15 @@
import logging
from typing import Any, Dict
from fastapi import APIRouter, HTTPException
from langflow.interface.run import process_graph_cached
from langflow.interface.types import build_langchain_types_dict
from langflow.api.schemas import (
ExportedFlow,
GraphData,
PredictRequest,
PredictResponse,
)
from langflow.interface.run import process_graph_cached
from langflow.interface.types import build_langchain_types_dict
# build router
router = APIRouter()

View file

@ -1,4 +1,4 @@
from typing import Any, Union, Dict, List
from typing import Any, Dict, List, Union
from pydantic import BaseModel, validator

View file

@ -1,3 +1,5 @@
import json
from fastapi import APIRouter, HTTPException
from langflow.api.base import (
@ -7,6 +9,7 @@ from langflow.api.base import (
PromptValidationResponse,
validate_prompt,
)
from langflow.graph.nodes import VectorStoreNode
from langflow.interface.run import build_graph
from langflow.utils.logger import logger
from langflow.utils.validate import validate_code
@ -44,10 +47,11 @@ def post_validate_node(node_id: str, data: dict):
graph = build_graph(data)
# validate node
node = graph.get_node(node_id)
if node is not None:
_ = node.build()
return str(node.params)
raise Exception(f"Node {node_id} not found")
if node is None:
raise ValueError(f"Node {node_id} not found")
if not isinstance(node, VectorStoreNode):
node.build()
return json.dumps({"valid": True, "params": str(node._built_object_repr())})
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e)) from e
logger.exception(e)
return json.dumps({"valid": False})

View file

@ -232,6 +232,9 @@ class Node:
def __hash__(self) -> int:
return id(self)
def _built_object_repr(self):
return repr(self._built_object)
class Edge:
def __init__(self, source: "Node", target: "Node"):

View file

@ -139,6 +139,13 @@ class DocumentLoaderNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="documentloaders")
def _built_object_repr(self):
# This built_object is a list of documents. Maybe we should
# show how many documents are in the list?
if self._built_object:
return f"""{self.node_type}({len(self._built_object)} documents)\nDocuments: {self._built_object[:3]}..."""
return f"{self.node_type}()"
class EmbeddingNode(Node):
def __init__(self, data: Dict):
@ -149,6 +156,9 @@ class VectorStoreNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="vectorstores")
def _built_object_repr(self):
return "Vector stores can take time to build. It will build on the first query."
class MemoryNode(Node):
def __init__(self, data: Dict):
@ -158,3 +168,10 @@ class MemoryNode(Node):
class TextSplitterNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="textsplitters")
def _built_object_repr(self):
# This built_object is a list of documents. Maybe we should
# show how many documents are in the list?
if self._built_object:
return f"""{self.node_type}({len(self._built_object)} documents)\nDocuments: {self._built_object[:3]}..."""
return f"{self.node_type}()"

View file

@ -1,3 +1,4 @@
from abc import ABC
from typing import Any, List, Optional
from langchain import LLMChain
@ -27,14 +28,31 @@ from langchain.agents.agent_toolkits.vectorstore.prompt import (
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS as SQL_FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.memory.chat_memory import BaseChatMemory
from langchain.sql_database import SQLDatabase
from langchain.tools.python.tool import PythonAstREPLTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER
class JsonAgent(AgentExecutor):
class CustomAgentExecutor(AgentExecutor, ABC):
"""Custom agent executor"""
@staticmethod
def function_name():
return "CustomAgentExecutor"
@classmethod
def initialize(cls, *args, **kwargs):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
class JsonAgent(CustomAgentExecutor):
"""Json agent"""
@staticmethod
@ -70,7 +88,7 @@ class JsonAgent(AgentExecutor):
return super().run(*args, **kwargs)
class CSVAgent(AgentExecutor):
class CSVAgent(CustomAgentExecutor):
"""CSV agent"""
@staticmethod
@ -118,7 +136,7 @@ class CSVAgent(AgentExecutor):
return super().run(*args, **kwargs)
class VectorStoreAgent(AgentExecutor):
class VectorStoreAgent(CustomAgentExecutor):
"""Vector Store agent"""
@staticmethod
@ -156,7 +174,7 @@ class VectorStoreAgent(AgentExecutor):
return super().run(*args, **kwargs)
class SQLAgent(AgentExecutor):
class SQLAgent(CustomAgentExecutor):
"""SQL agent"""
@staticmethod
@ -228,7 +246,7 @@ class SQLAgent(AgentExecutor):
return super().run(*args, **kwargs)
class VectorStoreRouterAgent(AgentExecutor):
class VectorStoreRouterAgent(CustomAgentExecutor):
"""Vector Store Router Agent"""
@staticmethod
@ -267,7 +285,7 @@ class VectorStoreRouterAgent(AgentExecutor):
return super().run(*args, **kwargs)
class InitializeAgent(AgentExecutor):
class InitializeAgent(CustomAgentExecutor):
"""Implementation of initialize_agent function"""
@staticmethod

View file

@ -5,9 +5,9 @@ from typing import Any, Type
from langchain import PromptTemplate
from langchain.agents import Agent
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool

View file

@ -13,9 +13,9 @@ from langchain.agents.load_tools import (
)
from langchain.agents.loading import load_agent_from_config
from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.loading import load_chain_from_config
from langchain.base_language import BaseLanguageModel
from langchain.llms.loading import load_llm_from_config
from pydantic import ValidationError
@ -30,88 +30,19 @@ from langflow.utils import util, validate
def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
"""Instantiate class from module type and key, and params"""
params = convert_params_to_sets(params)
if node_type in CUSTOM_AGENTS:
if custom_agent := CUSTOM_AGENTS.get(node_type):
return custom_agent.initialize(**params) # type: ignore
params = process_params(params)
custom_agent = CUSTOM_AGENTS.get(node_type)
if custom_agent:
return custom_agent.initialize(**params)
class_object = import_by_type(_type=base_type, name=node_type)
# check if it is a class before using issubclass
# if isinstance(class_object, type) and issubclass(class_object, BaseModel):
# # validate params
# fields = class_object.__fields__
# params = {key: value for key, value in params.items() if key in fields}
if base_type == "agents":
# We need to initialize it differently
return load_agent_executor(class_object, params)
elif base_type == "prompts":
if node_type == "ZeroShotPrompt":
if "tools" not in params:
params["tools"] = []
return ZeroShotAgent.create_prompt(**params)
elif base_type == "tools":
if node_type == "JsonSpec":
params["dict_"] = load_file_into_dict(params.pop("path"))
return class_object(**params)
elif node_type == "PythonFunction":
# If the node_type is "PythonFunction"
# we need to get the function from the params
# which will be a str containing a python function
# and then we need to compile it and return the function
# as the instance
function_string = params["code"]
if isinstance(function_string, str):
return validate.eval_function(function_string)
raise ValueError("Function should be a string")
elif node_type.lower() == "tool":
return class_object(**params)
elif base_type == "toolkits":
loaded_toolkit = class_object(**params)
# Check if node_type has a loader
if toolkits_creator.has_create_function(node_type):
return load_toolkits_executor(node_type, loaded_toolkit, params)
return loaded_toolkit
elif base_type == "embeddings":
# ? Why remove model from params?
try:
params.pop("model")
except KeyError:
pass
# remove all params that are not in class_object.__fields__
try:
return class_object(**params)
except ValidationError:
params = {
key: value
for key, value in params.items()
if key in class_object.__fields__
}
return class_object(**params)
elif base_type == "vectorstores":
if len(params.get("documents", [])) == 0:
# Error when the pdf or other source was not correctly
# loaded.
raise ValueError(
"The source you provided did not load correctly or was empty."
"This may cause an error in the vectorstore."
)
return class_object.from_documents(**params)
elif base_type == "documentloaders":
return class_object(**params).load()
elif base_type == "textsplitters":
documents = params.pop("documents")
text_splitter = class_object(**params)
return text_splitter.split_documents(documents)
elif base_type == "utilities":
if node_type == "SQLDatabase":
return class_object.from_uri(params.pop("uri"))
return class_object(**params)
return instantiate_based_on_type(class_object, base_type, node_type, params)
def process_params(params):
"""Process params"""
def convert_params_to_sets(params):
"""Convert certain params to sets"""
if "allowed_special" in params:
params["allowed_special"] = set(params["allowed_special"])
if "disallowed_special" in params:
@ -119,6 +50,100 @@ def process_params(params):
return params
def instantiate_based_on_type(class_object, base_type, node_type, params):
if base_type == "agents":
return instantiate_agent(class_object, params)
elif base_type == "prompts":
return instantiate_prompt(class_object, node_type, params)
elif base_type == "tools":
return instantiate_tool(node_type, class_object, params)
elif base_type == "toolkits":
return instantiate_toolkit(node_type, class_object, params)
elif base_type == "embeddings":
return instantiate_embedding(class_object, params)
elif base_type == "vectorstores":
return instantiate_vectorstore(class_object, params)
elif base_type == "documentloaders":
return instantiate_documentloader(class_object, params)
elif base_type == "textsplitters":
return instantiate_textsplitter(class_object, params)
elif base_type == "utilities":
return instantiate_utility(node_type, class_object, params)
else:
return class_object(**params)
def instantiate_agent(class_object, params):
return load_agent_executor(class_object, params)
def instantiate_prompt(class_object, node_type, params):
if node_type == "ZeroShotPrompt":
if "tools" not in params:
params["tools"] = []
return ZeroShotAgent.create_prompt(**params)
return class_object(**params)
def instantiate_tool(node_type, class_object, params):
if node_type == "JsonSpec":
params["dict_"] = load_file_into_dict(params.pop("path"))
return class_object(**params)
elif node_type == "PythonFunction":
function_string = params["code"]
if isinstance(function_string, str):
return validate.eval_function(function_string)
raise ValueError("Function should be a string")
elif node_type.lower() == "tool":
return class_object(**params)
return None # Or some other default action
def instantiate_toolkit(node_type, class_object, params):
loaded_toolkit = class_object(**params)
if toolkits_creator.has_create_function(node_type):
return load_toolkits_executor(node_type, loaded_toolkit, params)
return loaded_toolkit
def instantiate_embedding(class_object, params):
params.pop("model", None)
try:
return class_object(**params)
except ValidationError:
params = {
key: value
for key, value in params.items()
if key in class_object.__fields__
}
return class_object(**params)
def instantiate_vectorstore(class_object, params):
if len(params.get("documents", [])) == 0:
raise ValueError(
"The source you provided did not load correctly or was empty."
"This may cause an error in the vectorstore."
)
return class_object.from_documents(**params)
def instantiate_documentloader(class_object, params):
return class_object(**params).load()
def instantiate_textsplitter(class_object, params):
documents = params.pop("documents")
text_splitter = class_object(**params)
return text_splitter.split_documents(documents)
def instantiate_utility(node_type, class_object, params):
if node_type == "SQLDatabase":
return class_object.from_uri(params.pop("uri"))
return class_object(**params)
def load_flow_from_json(path: str, build=True):
# This is done to avoid circular imports
from langflow.graph import Graph

View file

@ -1,13 +1,14 @@
import contextlib
import io
from typing import Any, Dict
from typing import Any, Dict, List, Tuple
from chromadb.errors import NotEnoughElementsException # type: ignore
from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler # type: ignore
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
from langflow.graph.graph import Graph
from langflow.interface import loading
from langflow.utils.logger import logger
from langchain.schema import AgentAction
def load_langchain_object(data_graph, is_first_message=False):
@ -66,40 +67,6 @@ def build_langchain_object(data_graph):
return graph.build()
def process_graph(data_graph: Dict[str, Any]):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
with PromptTemplate,then run the graph and return the result and thought.
"""
# Load langchain object
logger.debug("Loading langchain object")
message = data_graph.pop("message", "")
is_first_message = len(data_graph.get("chatHistory", [])) == 0
computed_hash, langchain_object = load_langchain_object(
data_graph, is_first_message
)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
logger.debug("Generating result and thought")
result, thought = get_result_and_steps(langchain_object, message)
logger.debug("Generated result and thought")
# Save langchain_object to cache
# We have to save it here because if the
# memory is updated we need to keep the new values
logger.debug("Saving langchain object to cache")
# save_cache(computed_hash, langchain_object, is_first_message)
logger.debug("Saved langchain object to cache")
return {"result": str(result), "thought": thought.strip()}
def process_graph_cached(data_graph: Dict[str, Any], message: str):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
@ -184,8 +151,9 @@ def fix_memory_inputs(langchain_object):
update_memory_keys(langchain_object, possible_new_mem_key)
def get_result_and_steps(langchain_object, message: str):
async def get_result_and_steps(langchain_object, message: str, **kwargs):
"""Get result and thought from extracted json"""
try:
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
@ -205,32 +173,28 @@ def get_result_and_steps(langchain_object, message: str):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
langchain_object.return_intermediate_steps = True
fix_memory_inputs(langchain_object)
try:
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
output = await langchain_object.acall(chat_input, callbacks=async_callbacks)
except Exception as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)]
output = langchain_object(chat_input, callbacks=sync_callbacks)
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
try:
output = langchain_object(chat_input)
except ValueError as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
output = langchain_object.run(chat_input)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
if intermediate_steps:
thought = format_intermediate_steps(intermediate_steps)
else:
thought = output_buffer.getvalue()
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
thought = format_actions(intermediate_steps) if intermediate_steps else ""
except NotEnoughElementsException as exc:
raise ValueError(
"Error: Not enough documents for ChromaDB to index. Try reducing chunk size in TextSplitter."
@ -286,7 +250,7 @@ def get_result_and_thought(langchain_object, message: str):
else output
)
if intermediate_steps:
thought = format_intermediate_steps(intermediate_steps)
thought = format_actions(intermediate_steps)
else:
thought = output_buffer.getvalue()
@ -295,19 +259,17 @@ def get_result_and_thought(langchain_object, message: str):
return result, thought
def format_intermediate_steps(intermediate_steps):
formatted_chain = "> Entering new AgentExecutor chain...\n"
for step in intermediate_steps:
action = step[0]
observation = step[1]
formatted_chain += (
f" {action.log}\nAction: {action.tool}\nAction Input: {action.tool_input}\n"
)
formatted_chain += f"Observation: {observation}\n"
final_answer = f"Final Answer: {observation}\n"
formatted_chain += f"Thought: I now know the final answer\n{final_answer}\n"
formatted_chain += "> Finished chain.\n"
return formatted_chain
def format_actions(actions: List[Tuple[AgentAction, str]]) -> str:
"""Format a list of (AgentAction, answer) tuples into a string."""
output = []
for action, answer in actions:
log = action.log
tool = action.tool
tool_input = action.tool_input
output.append(f"Log: {log}")
if "Action" not in log and "Action Input" not in log:
output.append(f"Tool: {tool}")
output.append(f"Tool Input: {tool_input}")
output.append(f"Answer: {answer}")
output.append("") # Add a blank line
return "\n".join(output)

View file

@ -4,13 +4,9 @@ import os
from io import BytesIO
import yaml
from langchain.callbacks.manager import AsyncCallbackManager
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.llms import AzureOpenAI, OpenAI
from langchain.base_language import BaseLanguageModel
from PIL.Image import Image
from langflow.api.callback import StreamingLLMCallbackHandler
def load_file_into_dict(file_path: str) -> dict:
if not os.path.exists(file_path):
@ -48,10 +44,7 @@ def try_setting_streaming_options(langchain_object, websocket):
langchain_object.llm_chain, "llm"
):
llm = langchain_object.llm_chain.llm
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
if isinstance(llm, BaseLanguageModel):
llm.streaming = bool(hasattr(llm, "streaming"))
stream_handler = StreamingLLMCallbackHandler(websocket)
stream_manager = AsyncCallbackManager([stream_handler])
llm.callback_manager = stream_manager
return langchain_object

View file

@ -123,6 +123,13 @@ class MidJourneyPromptChainNode(FrontendNode):
multiline=False,
name="llm",
),
TemplateField(
field_type="BaseChatMemory",
required=False,
show=True,
name="memory",
advanced=False,
),
],
)
description: str = "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."