Merge remote-tracking branch 'origin/dev' into streaming
This commit is contained in:
commit
fc007e4349
25 changed files with 384 additions and 937 deletions
|
|
@ -5,6 +5,12 @@ from fastapi import APIRouter, HTTPException
|
|||
|
||||
from langflow.interface.run import process_graph_cached
|
||||
from langflow.interface.types import build_langchain_types_dict
|
||||
from langflow.api.schemas import (
|
||||
ExportedFlow,
|
||||
GraphData,
|
||||
PredictRequest,
|
||||
PredictResponse,
|
||||
)
|
||||
|
||||
# build router
|
||||
router = APIRouter()
|
||||
|
|
@ -16,10 +22,14 @@ def get_all():
|
|||
return build_langchain_types_dict()
|
||||
|
||||
|
||||
@router.post("/predict")
|
||||
def get_load(data: Dict[str, Any]):
|
||||
@router.post("/predict", response_model=PredictResponse)
|
||||
async def get_load(predict_request: PredictRequest):
|
||||
try:
|
||||
return process_graph_cached(data)
|
||||
exported_flow: ExportedFlow = predict_request.exported_flow
|
||||
graph_data: GraphData = exported_flow.data
|
||||
data = graph_data.dict()
|
||||
response = process_graph_cached(data, predict_request.message)
|
||||
return PredictResponse(result=response.get("result", ""))
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,37 @@
|
|||
from typing import Any, Union
|
||||
from typing import Any, Union, Dict, List
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class GraphData(BaseModel):
|
||||
"""Data inside the exported flow."""
|
||||
|
||||
nodes: List[Dict[str, Any]]
|
||||
edges: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class ExportedFlow(BaseModel):
|
||||
"""Exported flow from LangFlow."""
|
||||
|
||||
description: str
|
||||
name: str
|
||||
id: str
|
||||
data: GraphData
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
"""Predict request schema."""
|
||||
|
||||
message: str
|
||||
exported_flow: ExportedFlow
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
"""Predict response schema."""
|
||||
|
||||
result: str
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""Chat message schema."""
|
||||
|
||||
|
|
|
|||
|
|
@ -49,5 +49,5 @@ def post_validate_node(node_id: str, data: dict):
|
|||
return str(node.params)
|
||||
raise Exception(f"Node {node_id} not found")
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
|
|
|||
1
src/backend/langflow/cache/base.py
vendored
1
src/backend/langflow/cache/base.py
vendored
|
|
@ -48,6 +48,7 @@ def memoize_dict(maxsize=128):
|
|||
cache.clear()
|
||||
|
||||
wrapper.clear_cache = clear_cache # type: ignore
|
||||
wrapper.cache = cache # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ tools:
|
|||
- BingSearchRun
|
||||
- GoogleSearchRun
|
||||
- GoogleSearchResults
|
||||
- GoogleSerperRun
|
||||
- JsonListKeysTool
|
||||
- JsonGetValueTool
|
||||
- PythonREPLTool
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from langchain.agents.agent_toolkits.vectorstore.prompt import (
|
|||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS as SQL_FORMAT_INSTRUCTIONS
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
|
|
@ -63,7 +63,7 @@ class JsonAgent(AgentExecutor):
|
|||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) # type: ignore
|
||||
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
|
|
@ -110,7 +110,7 @@ class CSVAgent(AgentExecutor):
|
|||
prompt=partial_prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools}
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
|
||||
|
||||
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
|
|
@ -134,7 +134,7 @@ class VectorStoreAgent(AgentExecutor):
|
|||
|
||||
@classmethod
|
||||
def from_toolkit_and_llm(
|
||||
cls, llm: BaseLLM, vectorstoreinfo: VectorStoreInfo, **kwargs: Any
|
||||
cls, llm: BaseLanguageModel, vectorstoreinfo: VectorStoreInfo, **kwargs: Any
|
||||
):
|
||||
"""Construct a vectorstore agent from an LLM and tools."""
|
||||
|
||||
|
|
@ -147,7 +147,7 @@ class VectorStoreAgent(AgentExecutor):
|
|||
prompt=prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools}
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, verbose=True
|
||||
)
|
||||
|
|
@ -171,7 +171,9 @@ class SQLAgent(AgentExecutor):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_toolkit_and_llm(cls, llm: BaseLLM, database_uri: str, **kwargs: Any):
|
||||
def from_toolkit_and_llm(
|
||||
cls, llm: BaseLanguageModel, database_uri: str, **kwargs: Any
|
||||
):
|
||||
"""Construct a sql agent from an LLM and tools."""
|
||||
db = SQLDatabase.from_uri(database_uri)
|
||||
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
||||
|
|
@ -213,7 +215,7 @@ class SQLAgent(AgentExecutor):
|
|||
prompt=prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools} # type: ignore
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools, # type: ignore
|
||||
|
|
@ -256,7 +258,7 @@ class VectorStoreRouterAgent(AgentExecutor):
|
|||
prompt=prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools}
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, verbose=True
|
||||
)
|
||||
|
|
@ -275,7 +277,7 @@ class InitializeAgent(AgentExecutor):
|
|||
@classmethod
|
||||
def initialize(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
tools: List[Tool],
|
||||
agent: str,
|
||||
memory: Optional[BaseChatMemory] = None,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class MalfoyAgent(AgentExecutor):
|
|||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) # type: ignore
|
||||
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -7,11 +7,9 @@ from langchain import PromptTemplate
|
|||
from langchain.agents import Agent
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from langflow.interface.tools.base import tool_creator
|
||||
|
||||
|
||||
def import_module(module_path: str) -> Any:
|
||||
"""Import module from module path"""
|
||||
|
|
@ -100,15 +98,19 @@ def import_agent(agent: str) -> Agent:
|
|||
return import_class(f"langchain.agents.{agent}")
|
||||
|
||||
|
||||
def import_llm(llm: str) -> BaseLLM:
|
||||
def import_llm(llm: str) -> BaseLanguageModel:
|
||||
"""Import llm from llm name"""
|
||||
return import_class(f"langchain.llms.{llm}")
|
||||
|
||||
|
||||
def import_tool(tool: str) -> BaseTool:
|
||||
"""Import tool from tool name"""
|
||||
from langflow.interface.tools.base import tool_creator
|
||||
|
||||
return tool_creator.type_to_loader_dict[tool]["fcn"]
|
||||
if tool in tool_creator.type_to_loader_dict:
|
||||
return tool_creator.type_to_loader_dict[tool]["fcn"]
|
||||
|
||||
return import_class(f"langchain.tools.{tool}")
|
||||
|
||||
|
||||
def import_chain(chain: str) -> Type[Chain]:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from langchain.agents.loading import load_agent_from_config
|
|||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.loading import load_chain_from_config
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.llms.loading import load_llm_from_config
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -74,12 +74,10 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
|||
return loaded_toolkit
|
||||
elif base_type == "embeddings":
|
||||
# ? Why remove model from params?
|
||||
|
||||
try:
|
||||
params.pop("model")
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# remove all params that are not in class_object.__fields__
|
||||
try:
|
||||
return class_object(**params)
|
||||
|
|
@ -188,7 +186,7 @@ def load_langchain_type_from_config(config: Dict[str, Any]):
|
|||
|
||||
def load_agent_executor_from_config(
|
||||
config: dict,
|
||||
llm: Optional[BaseLLM] = None,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
tools: Optional[list[Tool]] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
|
|
|
|||
|
|
@ -101,13 +101,12 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
||||
def process_graph_cached(data_graph: Dict[str, Any]):
|
||||
def process_graph_cached(data_graph: Dict[str, Any], message: str):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
# Load langchain object
|
||||
message = data_graph.pop("message", "")
|
||||
is_first_message = len(data_graph.get("chatHistory", [])) == 0
|
||||
langchain_object = load_or_build_langchain_object(data_graph, is_first_message)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
|
@ -120,7 +119,7 @@ def process_graph_cached(data_graph: Dict[str, Any]):
|
|||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_steps(langchain_object, message)
|
||||
result, thought = get_result_and_thought(langchain_object, message)
|
||||
logger.debug("Generated result and thought")
|
||||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
|
@ -247,7 +246,7 @@ async def get_result_and_steps(langchain_object, message: str, **kwargs):
|
|||
return result, thought
|
||||
|
||||
|
||||
def async_get_result_and_steps(langchain_object, message: str):
|
||||
def get_result_and_thought(langchain_object, message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
|
|
@ -302,34 +301,6 @@ def async_get_result_and_steps(langchain_object, message: str):
|
|||
return result, thought
|
||||
|
||||
|
||||
def get_result_and_thought(extracted_json: Dict[str, Any], message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
langchain_object = loading.load_langchain_type_from_config(
|
||||
config=extracted_json
|
||||
)
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
output = langchain_object(message)
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
|
||||
if intermediate_steps:
|
||||
thought = format_intermediate_steps(intermediate_steps)
|
||||
else:
|
||||
thought = output_buffer.getvalue()
|
||||
|
||||
except Exception as e:
|
||||
result = f"Error: {str(e)}"
|
||||
thought = ""
|
||||
return result, thought
|
||||
|
||||
|
||||
def format_intermediate_steps(intermediate_steps):
|
||||
formatted_chain = "> Entering new AgentExecutor chain...\n"
|
||||
for step in intermediate_steps:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ TOOL_INPUTS = {
|
|||
placeholder="",
|
||||
value="",
|
||||
),
|
||||
"llm": TemplateField(field_type="BaseLLM", required=True, is_list=False, show=True),
|
||||
"llm": TemplateField(
|
||||
field_type="BaseLanguageModel", required=True, is_list=False, show=True
|
||||
),
|
||||
"func": TemplateField(
|
||||
field_type="function",
|
||||
required=True,
|
||||
|
|
@ -65,6 +67,7 @@ class ToolCreator(LangChainTypeCreator):
|
|||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.tools_dict is None:
|
||||
all_tools = {}
|
||||
|
||||
for tool, tool_fcn in ALL_TOOLS_NAMES.items():
|
||||
tool_params = get_tool_params(tool_fcn)
|
||||
tool_name = tool_params.get("name", tool)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from langchain import tools
|
||||
from langchain.agents import Tool
|
||||
from langchain.agents.load_tools import (
|
||||
_BASE_TOOLS,
|
||||
|
|
@ -5,50 +6,16 @@ from langchain.agents.load_tools import (
|
|||
_EXTRA_OPTIONAL_TOOLS,
|
||||
_LLM_TOOLS,
|
||||
)
|
||||
from langchain.tools.bing_search.tool import BingSearchRun
|
||||
from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun
|
||||
from langchain.tools.json.tool import JsonGetValueTool, JsonListKeysTool, JsonSpec
|
||||
from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
|
||||
from langchain.tools.requests.tool import (
|
||||
RequestsDeleteTool,
|
||||
RequestsGetTool,
|
||||
RequestsPatchTool,
|
||||
RequestsPostTool,
|
||||
RequestsPutTool,
|
||||
)
|
||||
from langchain.tools.sql_database.tool import (
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
QueryCheckerTool,
|
||||
QuerySQLDataBaseTool,
|
||||
)
|
||||
from langchain.tools.wikipedia.tool import WikipediaQueryRun
|
||||
from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun
|
||||
from langchain.tools.json.tool import JsonSpec
|
||||
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.interface.tools.custom import PythonFunction
|
||||
|
||||
FILE_TOOLS = {"JsonSpec": JsonSpec}
|
||||
CUSTOM_TOOLS = {"Tool": Tool, "PythonFunction": PythonFunction}
|
||||
OTHER_TOOLS = {
|
||||
"QuerySQLDataBaseTool": QuerySQLDataBaseTool,
|
||||
"InfoSQLDatabaseTool": InfoSQLDatabaseTool,
|
||||
"ListSQLDatabaseTool": ListSQLDatabaseTool,
|
||||
"QueryCheckerTool": QueryCheckerTool,
|
||||
"BingSearchRun": BingSearchRun,
|
||||
"GoogleSearchRun": GoogleSearchRun,
|
||||
"GoogleSearchResults": GoogleSearchResults,
|
||||
"JsonListKeysTool": JsonListKeysTool,
|
||||
"JsonGetValueTool": JsonGetValueTool,
|
||||
"PythonREPLTool": PythonREPLTool,
|
||||
"PythonAstREPLTool": PythonAstREPLTool,
|
||||
"RequestsGetTool": RequestsGetTool,
|
||||
"RequestsPostTool": RequestsPostTool,
|
||||
"RequestsPatchTool": RequestsPatchTool,
|
||||
"RequestsPutTool": RequestsPutTool,
|
||||
"RequestsDeleteTool": RequestsDeleteTool,
|
||||
"WikipediaQueryRun": WikipediaQueryRun,
|
||||
"WolframAlphaQueryRun": WolframAlphaQueryRun,
|
||||
}
|
||||
|
||||
OTHER_TOOLS = {tool: import_class(f"langchain.tools.{tool}") for tool in tools.__all__}
|
||||
|
||||
ALL_TOOLS_NAMES = {
|
||||
**_BASE_TOOLS,
|
||||
**_LLM_TOOLS, # type: ignore
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue