refac: removed all logic from api endpoints
This commit is contained in:
parent
5b1d3d7394
commit
9bc9d9a7eb
9 changed files with 411 additions and 292 deletions
|
|
@ -1,186 +1,21 @@
|
|||
from fastapi import APIRouter
|
||||
from langflow_backend.api import signature
|
||||
from langflow_backend.api import list_endpoints
|
||||
from langflow_backend.utils import payload
|
||||
from langchain.agents.loading import load_agent_executor_from_config
|
||||
from langchain.chains.loading import load_chain_from_config
|
||||
from langchain.llms.loading import load_llm_from_config
|
||||
|
||||
from langchain.prompts.loading import load_prompt_from_config
|
||||
from typing import Any
|
||||
import io
|
||||
import contextlib
|
||||
import re
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from langflow_backend.interface.types import build_langchain_types_dict, get_type_list
|
||||
from langflow_backend.interface.loading import process_data_graph
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
# build router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_type_list():
|
||||
all_types = get_all()
|
||||
|
||||
all_types.pop("tools")
|
||||
|
||||
for key, value in all_types.items():
|
||||
all_types[key] = [item["template"]["_type"] for item in value.values()]
|
||||
|
||||
return all_types
|
||||
|
||||
|
||||
@router.get("/all")
|
||||
def get_all():
|
||||
return {
|
||||
"chains": {
|
||||
chain: signature.get_chain(chain) for chain in list_endpoints.list_chains()
|
||||
},
|
||||
"agents": {
|
||||
agent: signature.get_agent(agent) for agent in list_endpoints.list_agents()
|
||||
},
|
||||
"prompts": {
|
||||
prompt: signature.get_prompt(prompt)
|
||||
for prompt in list_endpoints.list_prompts()
|
||||
},
|
||||
"llms": {llm: signature.get_llm(llm) for llm in list_endpoints.list_llms()},
|
||||
"tools": {
|
||||
tool: signature.get_tool(tool) for tool in list_endpoints.list_tools()
|
||||
},
|
||||
}
|
||||
return build_langchain_types_dict()
|
||||
|
||||
|
||||
@router.post("/predict")
|
||||
def get_load(data: dict[str, Any]):
|
||||
# Get type list
|
||||
type_list = get_type_list()
|
||||
|
||||
# Substitute ZeroShotPromt with PromptTemplate
|
||||
for node in data["nodes"]:
|
||||
if node["data"]["type"] == "ZeroShotPrompt":
|
||||
# Build Prompt Template
|
||||
tools = [
|
||||
tool
|
||||
for tool in data["nodes"]
|
||||
if tool["type"] != "chatOutputNode"
|
||||
and "Tool" in tool["data"]["node"]["base_classes"]
|
||||
]
|
||||
node["data"] = build_prompt_template(prompt=node["data"], tools=tools)
|
||||
break
|
||||
|
||||
# Add input variables
|
||||
data = payload.extract_input_variables(data)
|
||||
|
||||
# Nodes, edges and root node
|
||||
message = data["message"]
|
||||
nodes = data["nodes"]
|
||||
edges = data["edges"]
|
||||
root = payload.get_root_node(data)
|
||||
|
||||
extracted_json = payload.build_json(root, nodes, edges)
|
||||
|
||||
# Build json
|
||||
if extracted_json["_type"] in type_list["agents"]:
|
||||
loaded = load_agent_executor_from_config(extracted_json)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
result = loaded.run(message)
|
||||
thought = output_buffer.getvalue()
|
||||
|
||||
elif extracted_json["_type"] in type_list["chains"]:
|
||||
loaded = load_chain_from_config(extracted_json)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
result = loaded.run(message)
|
||||
thought = output_buffer.getvalue()
|
||||
|
||||
elif extracted_json["_type"] in type_list["llms"]:
|
||||
loaded = load_llm_from_config(extracted_json)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
result = loaded(message)
|
||||
thought = output_buffer.getvalue()
|
||||
else:
|
||||
result = "Error: Type should be either agent, chain or llm"
|
||||
thought = ""
|
||||
|
||||
# Remove unnecessary data from response
|
||||
begin = thought.rfind(message)
|
||||
thought = thought[(begin + len(message)) :]
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"thought": re.sub(
|
||||
r"\x1b\[([0-9,A-Z]{1,2}(;[0-9,A-Z]{1,2})?)?[m|K]", "", thought
|
||||
).strip(),
|
||||
}
|
||||
|
||||
|
||||
def build_prompt_template(prompt, tools):
|
||||
prefix = prompt["node"]["template"]["prefix"]["value"]
|
||||
suffix = prompt["node"]["template"]["suffix"]["value"]
|
||||
format_instructions = prompt["node"]["template"]["format_instructions"]["value"]
|
||||
|
||||
tool_strings = "\n".join(
|
||||
[
|
||||
f"{tool['data']['node']['name']}: {tool['data']['node']['description']}"
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
tool_names = ", ".join([tool["data"]["node"]["name"] for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
value = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
|
||||
prompt["type"] = "PromptTemplate"
|
||||
|
||||
prompt["node"] = {
|
||||
"template": {
|
||||
"_type": "prompt",
|
||||
"input_variables": {
|
||||
"type": "str",
|
||||
"required": True,
|
||||
"placeholder": "",
|
||||
"list": True,
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
},
|
||||
"output_parser": {
|
||||
"type": "BaseOutputParser",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": None,
|
||||
},
|
||||
"template": {
|
||||
"type": "str",
|
||||
"required": True,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": True,
|
||||
"multiline": True,
|
||||
"value": value,
|
||||
},
|
||||
"template_format": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": "f-string",
|
||||
},
|
||||
"validate_template": {
|
||||
"type": "bool",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": True,
|
||||
},
|
||||
},
|
||||
"description": "Schema to represent a prompt for an LLM.",
|
||||
"base_classes": ["BasePromptTemplate"],
|
||||
}
|
||||
|
||||
return prompt
|
||||
def get_load(data: Dict[str, Any]):
|
||||
try:
|
||||
return process_data_graph(data)
|
||||
except Exception as e:
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
|||
|
|
@ -1,14 +1,6 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from langchain import chains
|
||||
from langchain import agents
|
||||
from langchain import prompts
|
||||
from langchain import llms
|
||||
from langchain.chains.conversation import memory as memories
|
||||
from langchain.agents.load_tools import get_all_tool_names
|
||||
from langflow_backend.utils import util, allowed_components
|
||||
from langflow_backend.custom import customs
|
||||
|
||||
from langflow_backend.interface.listing import list_type
|
||||
|
||||
# build router
|
||||
router = APIRouter(
|
||||
|
|
@ -32,61 +24,35 @@ def read_items():
|
|||
@router.get("/chains")
|
||||
def list_chains():
|
||||
"""List all chain types"""
|
||||
return [
|
||||
chain.__annotations__["return"].__name__
|
||||
for chain in chains.loading.type_to_loader_dict.values()
|
||||
if chain.__annotations__["return"].__name__ in allowed_components.CHAINS
|
||||
]
|
||||
return list_type("chains")
|
||||
|
||||
|
||||
@router.get("/agents")
|
||||
def list_agents():
|
||||
"""List all agent types"""
|
||||
# return list(agents.loading.AGENT_TO_CLASS.keys())
|
||||
return [
|
||||
agent.__name__
|
||||
for agent in agents.loading.AGENT_TO_CLASS.values()
|
||||
if agent.__name__ in allowed_components.AGENTS
|
||||
]
|
||||
return list_type("agents")
|
||||
|
||||
|
||||
@router.get("/prompts")
|
||||
def list_prompts():
|
||||
"""List all prompt types"""
|
||||
custom_prompts = customs.get_custom_prompts()
|
||||
library_prompts = [
|
||||
prompt.__annotations__["return"].__name__
|
||||
for prompt in prompts.loading.type_to_loader_dict.values()
|
||||
if prompt.__annotations__["return"].__name__ in allowed_components.PROMPTS
|
||||
]
|
||||
return library_prompts + list(custom_prompts.keys())
|
||||
return list_type("prompts")
|
||||
|
||||
|
||||
@router.get("/llms")
|
||||
def list_llms():
|
||||
"""List all llm types"""
|
||||
return [
|
||||
llm.__name__
|
||||
for llm in llms.type_to_cls_dict.values()
|
||||
if llm.__name__ in allowed_components.LLMS
|
||||
]
|
||||
return list_type("llms")
|
||||
|
||||
|
||||
@router.get("/memories")
|
||||
def list_memories():
|
||||
"""List all memory types"""
|
||||
return [memory.__name__ for memory in memories.type_to_cls_dict.values()]
|
||||
return list_type("memories")
|
||||
|
||||
|
||||
@router.get("/tools")
|
||||
def list_tools():
|
||||
"""List all load tools"""
|
||||
|
||||
tools = []
|
||||
|
||||
for tool in get_all_tool_names():
|
||||
tool_params = util.get_tool_params(util.get_tools_dict(tool))
|
||||
if tool_params and tool_params["name"] in allowed_components.TOOLS:
|
||||
tools.append(tool_params["name"])
|
||||
|
||||
return tools
|
||||
return list_type("tools")
|
||||
|
|
|
|||
|
|
@ -1,17 +1,6 @@
|
|||
from typing import Any, Dict
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from langchain import agents, chains, llms, prompts
|
||||
from langchain.agents.load_tools import (
|
||||
_BASE_TOOLS,
|
||||
_EXTRA_LLM_TOOLS,
|
||||
_EXTRA_OPTIONAL_TOOLS,
|
||||
_LLM_TOOLS,
|
||||
get_all_tool_names,
|
||||
)
|
||||
from langchain.chains.conversation import memory as memories
|
||||
|
||||
from langflow_backend.utils import util
|
||||
from langflow_backend.custom import customs
|
||||
from langflow_backend.interface.signature import get_signature
|
||||
|
||||
# build router
|
||||
router = APIRouter(
|
||||
|
|
@ -24,9 +13,7 @@ router = APIRouter(
|
|||
def get_chain(name: str):
|
||||
"""Get the signature of a chain."""
|
||||
try:
|
||||
return util.build_template_from_function(
|
||||
name, chains.loading.type_to_loader_dict
|
||||
)
|
||||
return get_signature(name, "chains")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail="Chain not found") from exc
|
||||
|
||||
|
|
@ -35,7 +22,7 @@ def get_chain(name: str):
|
|||
def get_agent(name: str):
|
||||
"""Get the signature of an agent."""
|
||||
try:
|
||||
return util.build_template_from_class(name, agents.loading.AGENT_TO_CLASS)
|
||||
return get_signature(name, "agents")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail="Agent not found") from exc
|
||||
|
||||
|
|
@ -44,11 +31,7 @@ def get_agent(name: str):
|
|||
def get_prompt(name: str):
|
||||
"""Get the signature of a prompt."""
|
||||
try:
|
||||
if name in customs.get_custom_prompts().keys():
|
||||
return customs.get_custom_prompts()[name]
|
||||
return util.build_template_from_function(
|
||||
name, prompts.loading.type_to_loader_dict
|
||||
)
|
||||
return get_signature(name, "prompts")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail="Prompt not found") from exc
|
||||
|
||||
|
|
@ -57,7 +40,7 @@ def get_prompt(name: str):
|
|||
def get_llm(name: str):
|
||||
"""Get the signature of an llm."""
|
||||
try:
|
||||
return util.build_template_from_class(name, llms.type_to_cls_dict)
|
||||
return get_signature(name, "llms")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail="LLM not found") from exc
|
||||
|
||||
|
|
@ -66,7 +49,7 @@ def get_llm(name: str):
|
|||
def get_memory(name: str):
|
||||
"""Get the signature of a memory."""
|
||||
try:
|
||||
return util.build_template_from_class(name, memories.type_to_cls_dict)
|
||||
return get_signature(name, "memories")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail="Memory not found") from exc
|
||||
|
||||
|
|
@ -74,51 +57,7 @@ def get_memory(name: str):
|
|||
@router.get("/tool")
|
||||
def get_tool(name: str):
|
||||
"""Get the signature of a tool."""
|
||||
|
||||
all_tools = {}
|
||||
for tool in get_all_tool_names():
|
||||
if tool_params := util.get_tool_params(util.get_tools_dict(tool)):
|
||||
all_tools[tool_params["name"]] = tool
|
||||
|
||||
# Raise error if name is not in tools
|
||||
if name not in all_tools.keys():
|
||||
raise HTTPException(status_code=404, detail=f"Tool {name} not found.")
|
||||
|
||||
type_dict = {
|
||||
"str": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": True,
|
||||
"placeholder": "",
|
||||
"value": "",
|
||||
},
|
||||
"llm": {"type": "BaseLLM", "required": True, "list": False, "show": True},
|
||||
}
|
||||
|
||||
tool_type = all_tools[name]
|
||||
|
||||
if tool_type in _BASE_TOOLS:
|
||||
params = []
|
||||
elif tool_type in _LLM_TOOLS:
|
||||
params = ["llm"]
|
||||
elif tool_type in _EXTRA_LLM_TOOLS:
|
||||
_, extra_keys = _EXTRA_LLM_TOOLS[tool_type]
|
||||
params = ["llm"] + extra_keys
|
||||
elif tool_type in _EXTRA_OPTIONAL_TOOLS:
|
||||
_, extra_keys = _EXTRA_OPTIONAL_TOOLS[tool_type]
|
||||
params = extra_keys
|
||||
else:
|
||||
params = []
|
||||
|
||||
template = {
|
||||
param: (type_dict[param] if param == "llm" else type_dict["str"])
|
||||
for param in params
|
||||
} # type: Dict[str, Any]
|
||||
template["_type"] = tool_type
|
||||
|
||||
return {
|
||||
"template": template,
|
||||
**util.get_tool_params(util.get_tools_dict(tool_type)),
|
||||
"base_classes": ["Tool"],
|
||||
}
|
||||
try:
|
||||
return get_signature(name, "tools")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail="Tool not found") from exc
|
||||
|
|
|
|||
0
langflow/backend/langflow_backend/interface/__init__.py
Normal file
0
langflow/backend/langflow_backend/interface/__init__.py
Normal file
74
langflow/backend/langflow_backend/interface/listing.py
Normal file
74
langflow/backend/langflow_backend/interface/listing.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
from langchain import chains, agents, prompts, llms
|
||||
from langflow_backend.custom import customs
|
||||
from langflow_backend.utils import util, allowed_components
|
||||
from langchain.agents.load_tools import get_all_tool_names
|
||||
from langchain.chains.conversation import memory as memories
|
||||
|
||||
|
||||
def list_type(object_type: str):
|
||||
"""List all components"""
|
||||
return {
|
||||
"chains": list_chain_types,
|
||||
"agents": list_agents,
|
||||
"prompts": list_prompts,
|
||||
"llms": list_llms,
|
||||
"tools": list_tools,
|
||||
"memories": list_memories,
|
||||
}.get(object_type, lambda: "Invalid type")()
|
||||
|
||||
|
||||
def list_agents():
|
||||
"""List all agent types"""
|
||||
# return list(agents.loading.AGENT_TO_CLASS.keys())
|
||||
return [
|
||||
agent.__name__
|
||||
for agent in agents.loading.AGENT_TO_CLASS.values()
|
||||
if agent.__name__ in allowed_components.AGENTS
|
||||
]
|
||||
|
||||
|
||||
def list_prompts():
|
||||
"""List all prompt types"""
|
||||
custom_prompts = customs.get_custom_prompts()
|
||||
library_prompts = [
|
||||
prompt.__annotations__["return"].__name__
|
||||
for prompt in prompts.loading.type_to_loader_dict.values()
|
||||
if prompt.__annotations__["return"].__name__ in allowed_components.PROMPTS
|
||||
]
|
||||
return library_prompts + list(custom_prompts.keys())
|
||||
|
||||
|
||||
def list_tools():
|
||||
"""List all load tools"""
|
||||
|
||||
tools = []
|
||||
|
||||
for tool in get_all_tool_names():
|
||||
tool_params = util.get_tool_params(util.get_tools_dict(tool))
|
||||
if tool_params and tool_params["name"] in allowed_components.TOOLS:
|
||||
tools.append(tool_params["name"])
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def list_llms():
|
||||
"""List all llm types"""
|
||||
return [
|
||||
llm.__name__
|
||||
for llm in llms.type_to_cls_dict.values()
|
||||
if llm.__name__ in allowed_components.LLMS
|
||||
]
|
||||
|
||||
|
||||
def list_chain_types():
|
||||
"""List all chain types"""
|
||||
return [
|
||||
chain.__annotations__["return"].__name__
|
||||
for chain in chains.loading.type_to_loader_dict.values()
|
||||
if chain.__annotations__["return"].__name__ in allowed_components.CHAINS
|
||||
]
|
||||
|
||||
|
||||
def list_memories():
|
||||
"""List all memory types"""
|
||||
return [memory.__name__ for memory in memories.type_to_cls_dict.values()]
|
||||
151
langflow/backend/langflow_backend/interface/loading.py
Normal file
151
langflow/backend/langflow_backend/interface/loading.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
import contextlib
|
||||
import re
|
||||
import io
|
||||
from typing import Any, Dict
|
||||
from langflow_backend.interface.types import get_type_list
|
||||
from langchain.agents.loading import load_agent_executor_from_config
|
||||
from langchain.chains.loading import load_chain_from_config
|
||||
from langchain.llms.loading import load_llm_from_config
|
||||
from langflow_backend.utils import payload
|
||||
|
||||
|
||||
def replace_zero_shot_prompt_with_prompt_template(nodes):
|
||||
for node in nodes:
|
||||
if node["data"]["type"] == "ZeroShotPrompt":
|
||||
# Build Prompt Template
|
||||
tools = [
|
||||
tool
|
||||
for tool in nodes
|
||||
if tool["type"] != "chatOutputNode"
|
||||
and "Tool" in tool["data"]["node"]["base_classes"]
|
||||
]
|
||||
node["data"] = build_prompt_template(prompt=node["data"], tools=tools)
|
||||
break
|
||||
|
||||
|
||||
def process_data_graph(data_graph: Dict[str, Any]):
|
||||
nodes = data_graph["nodes"]
|
||||
# Substitute ZeroShotPrompt with PromptTemplate
|
||||
nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
|
||||
# Add input variables
|
||||
data_graph = payload.extract_input_variables(data_graph)
|
||||
# Nodes, edges and root node
|
||||
message = data_graph["message"]
|
||||
edges = data_graph["edges"]
|
||||
root = payload.get_root_node(data_graph)
|
||||
extracted_json = payload.build_json(root, nodes, edges)
|
||||
|
||||
# Process json
|
||||
result, thought = get_result_and_thought(extracted_json, message)
|
||||
|
||||
# Remove unnecessary data from response
|
||||
begin = thought.rfind(message)
|
||||
thought = thought[(begin + len(message)) :]
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"thought": re.sub(
|
||||
r"\x1b\[([0-9,A-Z]{1,2}(;[0-9,A-Z]{1,2})?)?[m|K]", "", thought
|
||||
).strip(),
|
||||
}
|
||||
|
||||
|
||||
def get_result_and_thought(extracted_json: Dict[str, Any], message: str):
|
||||
# Get type list
|
||||
type_list = get_type_list()
|
||||
if extracted_json["_type"] in type_list["agents"]:
|
||||
loaded = load_agent_executor_from_config(extracted_json)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
result = loaded.run(message)
|
||||
thought = output_buffer.getvalue()
|
||||
|
||||
elif extracted_json["_type"] in type_list["chains"]:
|
||||
loaded = load_chain_from_config(extracted_json)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
result = loaded.run(message)
|
||||
thought = output_buffer.getvalue()
|
||||
|
||||
elif extracted_json["_type"] in type_list["llms"]:
|
||||
loaded = load_llm_from_config(extracted_json)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
result = loaded(message)
|
||||
thought = output_buffer.getvalue()
|
||||
else:
|
||||
result = "Error: Type should be either agent, chain or llm"
|
||||
thought = ""
|
||||
return result, thought
|
||||
|
||||
|
||||
def build_prompt_template(prompt, tools):
|
||||
prefix = prompt["node"]["template"]["prefix"]["value"]
|
||||
suffix = prompt["node"]["template"]["suffix"]["value"]
|
||||
format_instructions = prompt["node"]["template"]["format_instructions"]["value"]
|
||||
|
||||
tool_strings = "\n".join(
|
||||
[
|
||||
f"{tool['data']['node']['name']}: {tool['data']['node']['description']}"
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
tool_names = ", ".join([tool["data"]["node"]["name"] for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
value = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
|
||||
prompt["type"] = "PromptTemplate"
|
||||
|
||||
prompt["node"] = {
|
||||
"template": {
|
||||
"_type": "prompt",
|
||||
"input_variables": {
|
||||
"type": "str",
|
||||
"required": True,
|
||||
"placeholder": "",
|
||||
"list": True,
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
},
|
||||
"output_parser": {
|
||||
"type": "BaseOutputParser",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": None,
|
||||
},
|
||||
"template": {
|
||||
"type": "str",
|
||||
"required": True,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": True,
|
||||
"multiline": True,
|
||||
"value": value,
|
||||
},
|
||||
"template_format": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": "f-string",
|
||||
},
|
||||
"validate_template": {
|
||||
"type": "bool",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": True,
|
||||
},
|
||||
},
|
||||
"description": "Schema to represent a prompt for an LLM.",
|
||||
"base_classes": ["BasePromptTemplate"],
|
||||
}
|
||||
|
||||
return prompt
|
||||
123
langflow/backend/langflow_backend/interface/signature.py
Normal file
123
langflow/backend/langflow_backend/interface/signature.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
from typing import Dict, Any # noqa: F401
|
||||
from langchain import agents, chains, llms, prompts
|
||||
from langchain.agents.load_tools import (
|
||||
_BASE_TOOLS,
|
||||
_EXTRA_LLM_TOOLS,
|
||||
_EXTRA_OPTIONAL_TOOLS,
|
||||
_LLM_TOOLS,
|
||||
get_all_tool_names,
|
||||
)
|
||||
from langchain.chains.conversation import memory as memories
|
||||
|
||||
from langflow_backend.utils import util
|
||||
from langflow_backend.custom import customs
|
||||
|
||||
|
||||
def get_signature(name: str, object_type: str):
|
||||
"""Get the signature of an object."""
|
||||
return {
|
||||
"chains": get_chain_signature,
|
||||
"agents": get_agent_signature,
|
||||
"prompts": get_prompt_signature,
|
||||
"llms": get_llm_signature,
|
||||
"tools": get_tool_signature,
|
||||
"memories": get_memory_signature,
|
||||
}.get(object_type, lambda name: f"Invalid type: {name}")(name)
|
||||
|
||||
|
||||
def get_chain_signature(name: str):
|
||||
"""Get the chain type by signature."""
|
||||
try:
|
||||
return util.build_template_from_function(
|
||||
name, chains.loading.type_to_loader_dict
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Chain not found") from exc
|
||||
|
||||
|
||||
def get_agent_signature(name: str):
|
||||
"""Get the signature of an agent."""
|
||||
try:
|
||||
return util.build_template_from_class(name, agents.loading.AGENT_TO_CLASS)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Agent not found") from exc
|
||||
|
||||
|
||||
def get_prompt_signature(name: str):
|
||||
"""Get the signature of a prompt."""
|
||||
try:
|
||||
if name in customs.get_custom_prompts().keys():
|
||||
return customs.get_custom_prompts()[name]
|
||||
return util.build_template_from_function(
|
||||
name, prompts.loading.type_to_loader_dict
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Prompt not found") from exc
|
||||
|
||||
|
||||
def get_llm_signature(name: str):
|
||||
"""Get the signature of an llm."""
|
||||
try:
|
||||
return util.build_template_from_class(name, llms.type_to_cls_dict)
|
||||
except ValueError as exc:
|
||||
raise ValueError("LLM not found") from exc
|
||||
|
||||
|
||||
def get_memory_signature(name: str):
|
||||
"""Get the signature of a memory."""
|
||||
try:
|
||||
return util.build_template_from_class(name, memories.type_to_cls_dict)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Memory not found") from exc
|
||||
|
||||
|
||||
def get_tool_signature(name: str):
|
||||
"""Get the signature of a tool."""
|
||||
|
||||
all_tools = {}
|
||||
for tool in get_all_tool_names():
|
||||
if tool_params := util.get_tool_params(util.get_tools_dict(tool)):
|
||||
all_tools[tool_params["name"]] = tool
|
||||
|
||||
# Raise error if name is not in tools
|
||||
if name not in all_tools.keys():
|
||||
raise ValueError("Tool not found")
|
||||
|
||||
type_dict = {
|
||||
"str": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": True,
|
||||
"placeholder": "",
|
||||
"value": "",
|
||||
},
|
||||
"llm": {"type": "BaseLLM", "required": True, "list": False, "show": True},
|
||||
}
|
||||
|
||||
tool_type = all_tools[name]
|
||||
|
||||
if tool_type in _BASE_TOOLS:
|
||||
params = []
|
||||
elif tool_type in _LLM_TOOLS:
|
||||
params = ["llm"]
|
||||
elif tool_type in _EXTRA_LLM_TOOLS:
|
||||
_, extra_keys = _EXTRA_LLM_TOOLS[tool_type]
|
||||
params = ["llm"] + extra_keys
|
||||
elif tool_type in _EXTRA_OPTIONAL_TOOLS:
|
||||
_, extra_keys = _EXTRA_OPTIONAL_TOOLS[tool_type]
|
||||
params = extra_keys
|
||||
else:
|
||||
params = []
|
||||
|
||||
template = {
|
||||
param: (type_dict[param] if param == "llm" else type_dict["str"])
|
||||
for param in params
|
||||
} # type: Dict[str, Any]
|
||||
template["_type"] = tool_type
|
||||
|
||||
return {
|
||||
"template": template,
|
||||
**util.get_tool_params(util.get_tools_dict(tool_type)),
|
||||
"base_classes": ["Tool"],
|
||||
}
|
||||
31
langflow/backend/langflow_backend/interface/types.py
Normal file
31
langflow/backend/langflow_backend/interface/types.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from langflow_backend.interface.listing import list_type
|
||||
from langflow_backend.interface.signature import get_signature
|
||||
|
||||
|
||||
def get_type_list():
|
||||
"""Get a list of all langchain types"""
|
||||
all_types = build_langchain_types_dict()
|
||||
|
||||
all_types.pop("tools")
|
||||
|
||||
for key, value in all_types.items():
|
||||
all_types[key] = [item["template"]["_type"] for item in value.values()]
|
||||
|
||||
return all_types
|
||||
|
||||
|
||||
def build_langchain_types_dict():
|
||||
"""Build a dictionary of all langchain types"""
|
||||
return {
|
||||
"chains": {
|
||||
chain: get_signature(chain, "chains") for chain in list_type("chains")
|
||||
},
|
||||
"agents": {
|
||||
agent: get_signature(agent, "agents") for agent in list_type("agents")
|
||||
},
|
||||
"prompts": {
|
||||
prompt: get_signature(prompt, "prompts") for prompt in list_type("prompts")
|
||||
},
|
||||
"llms": {llm: get_signature(llm, "llms") for llm in list_type("llms")},
|
||||
"tools": {tool: get_signature(tool, "tools") for tool in list_type("tools")},
|
||||
}
|
||||
|
|
@ -10,10 +10,10 @@ from langchain.agents.load_tools import (
|
|||
_EXTRA_LLM_TOOLS,
|
||||
_EXTRA_OPTIONAL_TOOLS,
|
||||
)
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
def build_template_from_function(name: str, type_to_loader_dict: dict):
|
||||
def build_template_from_function(name: str, type_to_loader_dict: Dict):
|
||||
classes = [
|
||||
item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()
|
||||
]
|
||||
|
|
@ -59,7 +59,7 @@ def build_template_from_function(name: str, type_to_loader_dict: dict):
|
|||
}
|
||||
|
||||
|
||||
def build_template_from_class(name: str, type_to_cls_dict: dict):
|
||||
def build_template_from_class(name: str, type_to_cls_dict: Dict):
|
||||
classes = [item.__name__ for item in type_to_cls_dict.values()]
|
||||
|
||||
# Raise error if name is not in chains
|
||||
|
|
@ -130,8 +130,8 @@ def get_tools_dict(name: Optional[str] = None):
|
|||
"""Get the tools dictionary."""
|
||||
tools = {
|
||||
**_BASE_TOOLS,
|
||||
**_LLM_TOOLS,
|
||||
**{k: v[0] for k, v in _EXTRA_LLM_TOOLS.items()},
|
||||
**_LLM_TOOLS, # type: ignore
|
||||
**{k: v[0] for k, v in _EXTRA_LLM_TOOLS.items()}, # type: ignore
|
||||
**{k: v[0] for k, v in _EXTRA_OPTIONAL_TOOLS.items()},
|
||||
}
|
||||
return tools[name] if name else tools
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue