refac: formatting, naming improvements, error improvements

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-02-21 19:19:14 +00:00
commit 293442f303
6 changed files with 69 additions and 92 deletions

View file

@ -1,6 +1,6 @@
from fastapi import FastAPI
from endpoints import router as endpoints_router
from list import router as list_router
from list_endpoints import router as list_router
from signature import router as signatures_router
from fastapi.middleware.cors import CORSMiddleware

View file

@ -1,44 +1,43 @@
from fastapi import APIRouter, FastAPI
from langchain import OpenAI
from langchain.agents import initialize_agent
from interface import (
DictableChain,
DictableConversationalAgent,
DictableMemory,
DictableTool,
)
from fastapi import APIRouter
import signature
import list
import list_endpoints
# build router
router = APIRouter()
def get_type_list():
all = get_all()
all_types = get_all()
all.pop("tools")
all_types.pop("tools")
for key, value in all.items():
all[key] = [item["template"]["_type"] for item in value.values()]
for key, value in all_types.items():
all_types[key] = [item["template"]["_type"] for item in value.values()]
return all
return all_types
@router.get("/")
def get_all():
return {
"chains": {chain: signature.chain(chain) for chain in list.list_chains()},
"agents": {agent: signature.agent(agent) for agent in list.list_agents()},
"prompts": {prompt: signature.prompt(prompt) for prompt in list.list_prompts()},
"llms": {llm: signature.llm(llm) for llm in list.list_llms()},
"chains": {
chain: signature.chain(chain) for chain in list_endpoints.list_chains()
},
"agents": {
agent: signature.agent(agent) for agent in list_endpoints.list_agents()
},
"prompts": {
prompt: signature.prompt(prompt) for prompt in list_endpoints.list_prompts()
},
"llms": {llm: signature.llm(llm) for llm in list_endpoints.list_llms()},
# "utilities": {
# "template": {
# # utility: templates.utility(utility) for utility in list.list_utilities()
# }
# },
"memories": {
memory: signature.memory(memory) for memory in list.list_memories()
memory: signature.memory(memory)
for memory in list_endpoints.list_memories()
},
# "document_loaders": {
# "template": {
@ -52,7 +51,7 @@ def get_all():
# tool: {"template": signature.tool(tool), **values}
# for tool, values in tools.items()
# },
"tools": {tool: signature.tool(tool) for tool in list.list_tools()},
"tools": {tool: signature.tool(tool) for tool in list_endpoints.list_tools()},
}

View file

@ -1,45 +0,0 @@
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.agents import Agent, ConversationalAgent, Tool, initialize_agent
class Dictable(object):
"""A mixin that allows an object to be converted to and from a dict"""
@classmethod
def from_dict(cls, d):
"""Convert a dict to an object"""
return cls(**d)
def to_dict(self):
return self.__dict__
class DictableChain(Dictable, ConversationChain):
"""A ConversationChain that is also Dictable"""
pass
class DictableMemory(Dictable, ConversationBufferMemory):
"""A ConversationBufferMemory that is also Dictable"""
pass
class DictableAgent(Dictable, Agent):
"""An Agent that is also Dictable"""
pass
class DictableConversationalAgent(Dictable, ConversationalAgent):
"""A ConversationalAgent that is also Dictable"""
pass
class DictableTool(Dictable, Tool):
"""A Tool that is also Dictable"""
pass

View file

@ -22,6 +22,7 @@ router = APIRouter(
@router.get("/")
def read_items():
"""List all components"""
return [
"chains",
"agents",

View file

@ -1,4 +1,4 @@
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from langchain import chains
from langchain import agents
@ -15,8 +15,6 @@ from langchain.agents.load_tools import (
_EXTRA_OPTIONAL_TOOLS,
)
import util
import list
import inspect
# build router
router = APIRouter(
@ -25,14 +23,16 @@ router = APIRouter(
)
def build_template_from_function(name: str, dict: dict):
classes = [item.__annotations__["return"].__name__ for item in dict.values()]
def build_template_from_function(name: str, type_to_loader_dict: dict):
classes = [
item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()
]
# Raise error if name is not in chains
if name not in classes:
raise Exception(f"{name} not found.")
raise ValueError(f"{name} not found")
for _type, v in dict.items():
for _type, v in type_to_loader_dict.items():
if v.__annotations__["return"].__name__ == name:
_class = v.__annotations__["return"]
@ -49,7 +49,7 @@ def build_template_from_function(name: str, dict: dict):
variables[name]["default"] = util.get_default_factory(
module=_class.__base__.__module__, function=value_
)
except:
except Exception:
variables[name]["default"] = None
elif name_ not in ["name"]:
variables[name][name_] = value_
@ -65,14 +65,14 @@ def build_template_from_function(name: str, dict: dict):
}
def build_template_from_class(name: str, dict: dict):
classes = [item.__name__ for item in dict.values()]
def build_template_from_class(name: str, type_to_cls_dict: dict):
classes = [item.__name__ for item in type_to_cls_dict.values()]
# Raise error if name is not in chains
if name not in classes:
raise Exception(f"{name} not found.")
raise ValueError(f"{name} not found.")
for _type, v in dict.items():
for _type, v in type_to_cls_dict.items():
if v.__name__ == name:
_class = v
@ -89,7 +89,7 @@ def build_template_from_class(name: str, dict: dict):
variables[name]["default"] = util.get_default_factory(
module=_class.__base__.__module__, function=value_
)
except:
except Exception:
variables[name]["default"] = None
elif name_ not in ["name"]:
variables[name][name_] = value_
@ -106,23 +106,39 @@ def build_template_from_class(name: str, dict: dict):
@router.get("/chain")
def chain(name: str):
return build_template_from_function(name, chains.loading.type_to_loader_dict)
def get_chain(name: str):
"""Get the signature of a chain."""
try:
return build_template_from_function(name, chains.loading.type_to_loader_dict)
except ValueError as exc:
raise HTTPException(status_code=404, detail="Chain not found") from exc
@router.get("/agent")
def agent(name: str):
return build_template_from_class(name, agents.loading.AGENT_TO_CLASS)
def get_agent(name: str):
"""Get the signature of an agent."""
try:
return build_template_from_class(name, agents.loading.AGENT_TO_CLASS)
except ValueError as exc:
raise HTTPException(status_code=404, detail="Agent not found") from exc
@router.get("/prompt")
def prompt(name: str):
return build_template_from_function(name, prompts.loading.type_to_loader_dict)
def get_prompt(name: str):
"""Get the signature of a prompt."""
try:
return build_template_from_function(name, prompts.loading.type_to_loader_dict)
except ValueError as exc:
raise HTTPException(status_code=404, detail="Prompt not found") from exc
@router.get("/llm")
def llm(name: str):
return build_template_from_class(name, llms.type_to_cls_dict)
def get_llm(name: str):
"""Get the signature of an llm."""
try:
return build_template_from_class(name, llms.type_to_cls_dict)
except ValueError as exc:
raise HTTPException(status_code=404, detail="LLM not found") from exc
# @router.get("/utility")
@ -138,8 +154,12 @@ def llm(name: str):
@router.get("/memory")
def memory(name: str):
return build_template_from_class(name, memories.type_to_cls_dict)
def get_memory(name: str):
"""Get the signature of a memory."""
try:
return build_template_from_class(name, memories.type_to_cls_dict)
except ValueError as exc:
raise HTTPException(status_code=404, detail="Memory not found") from exc
# @router.get("/document_loader")
@ -155,10 +175,11 @@ def memory(name: str):
@router.get("/tool")
def tool(name: str):
def get_tool(name: str):
"""Get the signature of a tool."""
# Raise error if name is not in tools
if name not in get_all_tool_names():
raise Exception(f"Tool {name} not found.")
raise HTTPException(status_code=404, detail=f"Tool {name} not found.")
type_dict = {
"str": {

View file

@ -35,6 +35,7 @@ def get_default_factory(module: str, function: str):
def get_tools_dict(name: Optional[str] = None):
"""Get the tools dictionary."""
tools = {
**_BASE_TOOLS,
**_LLM_TOOLS,