From 293442f303f85175609e6804566b39c6fcbbacb4 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 21 Feb 2023 19:19:14 +0000 Subject: [PATCH] refac: formatting, naming improvements, error improvements --- src/app.py | 2 +- src/endpoints.py | 41 +++++++++-------- src/interface.py | 45 ------------------- src/{list.py => list_endpoints.py} | 1 + src/signature.py | 71 +++++++++++++++++++----------- src/util.py | 1 + 6 files changed, 69 insertions(+), 92 deletions(-) delete mode 100644 src/interface.py rename src/{list.py => list_endpoints.py} (98%) diff --git a/src/app.py b/src/app.py index c68243b79..f683f08c8 100644 --- a/src/app.py +++ b/src/app.py @@ -1,6 +1,6 @@ from fastapi import FastAPI from endpoints import router as endpoints_router -from list import router as list_router +from list_endpoints import router as list_router from signature import router as signatures_router from fastapi.middleware.cors import CORSMiddleware diff --git a/src/endpoints.py b/src/endpoints.py index 02e96f12a..efd11ad65 100644 --- a/src/endpoints.py +++ b/src/endpoints.py @@ -1,44 +1,43 @@ -from fastapi import APIRouter, FastAPI -from langchain import OpenAI -from langchain.agents import initialize_agent -from interface import ( - DictableChain, - DictableConversationalAgent, - DictableMemory, - DictableTool, -) +from fastapi import APIRouter import signature -import list +import list_endpoints # build router router = APIRouter() def get_type_list(): - all = get_all() + all_types = get_all() - all.pop("tools") + all_types.pop("tools") - for key, value in all.items(): - all[key] = [item["template"]["_type"] for item in value.values()] + for key, value in all_types.items(): + all_types[key] = [item["template"]["_type"] for item in value.values()] - return all + return all_types @router.get("/") def get_all(): return { - "chains": {chain: signature.chain(chain) for chain in list.list_chains()}, - "agents": {agent: signature.agent(agent) for agent in list.list_agents()}, - "prompts": {prompt: signature.prompt(prompt) for prompt in list.list_prompts()}, - "llms": {llm: signature.llm(llm) for llm in list.list_llms()}, + "chains": { + chain: signature.chain(chain) for chain in list_endpoints.list_chains() + }, + "agents": { + agent: signature.agent(agent) for agent in list_endpoints.list_agents() + }, + "prompts": { + prompt: signature.prompt(prompt) for prompt in list_endpoints.list_prompts() + }, + "llms": {llm: signature.llm(llm) for llm in list_endpoints.list_llms()}, # "utilities": { # "template": { # # utility: templates.utility(utility) for utility in list.list_utilities() # } # }, "memories": { - memory: signature.memory(memory) for memory in list.list_memories() + memory: signature.memory(memory) + for memory in list_endpoints.list_memories() }, # "document_loaders": { # "template": { @@ -52,7 +51,7 @@ def get_all(): # tool: {"template": signature.tool(tool), **values} # for tool, values in tools.items() # }, - "tools": {tool: signature.tool(tool) for tool in list.list_tools()}, + "tools": {tool: signature.tool(tool) for tool in list_endpoints.list_tools()}, } diff --git a/src/interface.py b/src/interface.py deleted file mode 100644 index e1afefe8e..000000000 --- a/src/interface.py +++ /dev/null @@ -1,45 +0,0 @@ -from langchain.chains import ConversationChain -from langchain.chains.conversation.memory import ConversationBufferMemory -from langchain.agents import Agent, ConversationalAgent, Tool, initialize_agent - - -class Dictable(object): - """A mixin that allows an object to be converted to and from a dict""" - - @classmethod - def from_dict(cls, d): - """Convert a dict to an object""" - return cls(**d) - - def to_dict(self): - return self.__dict__ - - -class DictableChain(Dictable, ConversationChain): - """A ConversationChain that is also Dictable""" - - pass - - -class DictableMemory(Dictable, ConversationBufferMemory): - """A ConversationBufferMemory that is also Dictable""" - - pass - - -class DictableAgent(Dictable, Agent): - """An Agent that is also Dictable""" - - pass - - -class DictableConversationalAgent(Dictable, ConversationalAgent): - """A ConversationalAgent that is also Dictable""" - - pass - - -class DictableTool(Dictable, Tool): - """A Tool that is also Dictable""" - - pass diff --git a/src/list.py b/src/list_endpoints.py similarity index 98% rename from src/list.py rename to src/list_endpoints.py index 4ccbdf358..53a2dc82d 100644 --- a/src/list.py +++ b/src/list_endpoints.py @@ -22,6 +22,7 @@ router = APIRouter( @router.get("/") def read_items(): + """List all components""" return [ "chains", "agents", diff --git a/src/signature.py b/src/signature.py index 0a46ba9d6..e72c968a6 100644 --- a/src/signature.py +++ b/src/signature.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from langchain import chains from langchain import agents @@ -15,8 +15,6 @@ from langchain.agents.load_tools import ( _EXTRA_OPTIONAL_TOOLS, ) import util -import list -import inspect # build router router = APIRouter( @@ -25,14 +23,16 @@ router = APIRouter( ) -def build_template_from_function(name: str, dict: dict): - classes = [item.__annotations__["return"].__name__ for item in dict.values()] +def build_template_from_function(name: str, type_to_loader_dict: dict): + classes = [ + item.__annotations__["return"].__name__ for item in type_to_loader_dict.values() + ] # Raise error if name is not in chains if name not in classes: - raise Exception(f"{name} not found.") + raise ValueError(f"{name} not found") - for _type, v in dict.items(): + for _type, v in type_to_loader_dict.items(): if v.__annotations__["return"].__name__ == name: _class = v.__annotations__["return"] @@ -49,7 +49,7 @@ def build_template_from_function(name: str, dict: dict): variables[name]["default"] = util.get_default_factory( module=_class.__base__.__module__, function=value_ ) - except: + except Exception: variables[name]["default"] = None elif name_ not in ["name"]: variables[name][name_] = value_ @@ -65,14 +65,14 @@ def build_template_from_function(name: str, dict: dict): } -def build_template_from_class(name: str, dict: dict): - classes = [item.__name__ for item in dict.values()] +def build_template_from_class(name: str, type_to_cls_dict: dict): + classes = [item.__name__ for item in type_to_cls_dict.values()] # Raise error if name is not in chains if name not in classes: - raise Exception(f"{name} not found.") + raise ValueError(f"{name} not found.") - for _type, v in dict.items(): + for _type, v in type_to_cls_dict.items(): if v.__name__ == name: _class = v @@ -89,7 +89,7 @@ def build_template_from_class(name: str, dict: dict): variables[name]["default"] = util.get_default_factory( module=_class.__base__.__module__, function=value_ ) - except: + except Exception: variables[name]["default"] = None elif name_ not in ["name"]: variables[name][name_] = value_ @@ -106,23 +106,39 @@ def build_template_from_class(name: str, dict: dict): @router.get("/chain") -def chain(name: str): - return build_template_from_function(name, chains.loading.type_to_loader_dict) +def get_chain(name: str): + """Get the signature of a chain.""" + try: + return build_template_from_function(name, chains.loading.type_to_loader_dict) + except ValueError as exc: + raise HTTPException(status_code=404, detail="Chain not found") from exc @router.get("/agent") -def agent(name: str): - return build_template_from_class(name, agents.loading.AGENT_TO_CLASS) +def get_agent(name: str): + """Get the signature of an agent.""" + try: + return build_template_from_class(name, agents.loading.AGENT_TO_CLASS) + except ValueError as exc: + raise HTTPException(status_code=404, detail="Agent not found") from exc @router.get("/prompt") -def prompt(name: str): - return build_template_from_function(name, prompts.loading.type_to_loader_dict) +def get_prompt(name: str): + """Get the signature of a prompt.""" + try: + return build_template_from_function(name, prompts.loading.type_to_loader_dict) + except ValueError as exc: + raise HTTPException(status_code=404, detail="Prompt not found") from exc @router.get("/llm") -def llm(name: str): - return build_template_from_class(name, llms.type_to_cls_dict) +def get_llm(name: str): + """Get the signature of an llm.""" + try: + return build_template_from_class(name, llms.type_to_cls_dict) + except ValueError as exc: + raise HTTPException(status_code=404, detail="LLM not found") from exc # @router.get("/utility") @@ -138,8 +154,12 @@ def llm(name: str): @router.get("/memory") -def memory(name: str): - return build_template_from_class(name, memories.type_to_cls_dict) +def get_memory(name: str): + """Get the signature of a memory.""" + try: + return build_template_from_class(name, memories.type_to_cls_dict) + except ValueError as exc: + raise HTTPException(status_code=404, detail="Memory not found") from exc # @router.get("/document_loader") @@ -155,10 +175,11 @@ def memory(name: str): @router.get("/tool") -def tool(name: str): +def get_tool(name: str): + """Get the signature of a tool.""" # Raise error if name is not in tools if name not in get_all_tool_names(): - raise Exception(f"Tool {name} not found.") + raise HTTPException(status_code=404, detail=f"Tool {name} not found.") type_dict = { "str": { diff --git a/src/util.py b/src/util.py index 8a3f8c0ab..dbee3503c 100644 --- a/src/util.py +++ b/src/util.py @@ -35,6 +35,7 @@ def get_default_factory(module: str, function: str): def get_tools_dict(name: Optional[str] = None): + """Get the tools dictionary.""" tools = { **_BASE_TOOLS, **_LLM_TOOLS,