From 95e99899cc9d109c71fb88d8612031f907522e7d Mon Sep 17 00:00:00 2001 From: Ibis Prevedello Date: Sat, 11 Feb 2023 20:33:00 -0300 Subject: [PATCH] feat: add template --- src/main.py | 2 ++ src/templates.py | 87 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 src/templates.py diff --git a/src/main.py b/src/main.py index 9586d6f65..40b0cf19c 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,7 @@ from fastapi import FastAPI from endpoints import router as endpoints_router from list import router as list_router +from templates import router as templates_router def create_app(): @@ -8,6 +9,7 @@ def create_app(): app = FastAPI() app.include_router(endpoints_router) app.include_router(list_router) + app.include_router(templates_router) return app diff --git a/src/templates.py b/src/templates.py new file mode 100644 index 000000000..1f32bdc16 --- /dev/null +++ b/src/templates.py @@ -0,0 +1,87 @@ +from fastapi import APIRouter + +from langchain import chains +from langchain import agents +from langchain import prompts +from langchain import llms +from langchain import utilities +from langchain.chains.conversation import memory + + +# build router +router = APIRouter( + prefix="/templates", + tags=["templates"], +) + + +@router.get("/chain") +def chain(name: str): + # Raise error if name is not in chains + if name not in chains.__all__: + raise Exception(f"Prompt {name} not found.") + _class = getattr(chains, name) + return { + name: {name: value for (name, value) in value.__repr_args__() if name != "name"} + for name, value in _class.__dict__["__fields__"].items() + } + + +@router.get("/agent") +def agent(name: str): + # Raise error if name is not in agents + if name not in agents.__all__: + raise Exception(f"Prompt {name} not found.") + _class = getattr(agents, name) + return { + name: {name: value for (name, value) in value.__repr_args__() if name != "name"} + for name, value in _class.__dict__["__fields__"].items() + } + + +@router.get("/prompt") +def prompt(name: str): + # Raise error if name is not in prompts + if name not in prompts.__all__: + raise Exception(f"Prompt {name} not found.") + _class = getattr(prompts, name) + return { + name: {name: value for (name, value) in value.__repr_args__() if name != "name"} + for name, value in _class.__dict__["__fields__"].items() + } + + +@router.get("/llm") +def llm(name: str): + # Raise error if name is not in llms + if name not in llms.__all__: + raise Exception(f"Prompt {name} not found.") + _class = getattr(llms, name) + return { + name: {name: value for (name, value) in value.__repr_args__() if name != "name"} + for name, value in _class.__dict__["__fields__"].items() + } + + +@router.get("/utility") +def utility(name: str): + # Raise error if name is not in utilities + if name not in utilities.__all__: + raise Exception(f"Prompt {name} not found.") + _class = getattr(utilities, name) + return { + name: {name: value for (name, value) in value.__repr_args__() if name != "name"} + for name, value in _class.__dict__["__fields__"].items() + } + + +@router.get("/memory") +def memory(name: str): + # Raise error if name is not in memory + if name not in memory.__all__: + raise Exception(f"Prompt {name} not found.") + _class = getattr(memory, name) + return { + name: {name: value for (name, value) in value.__repr_args__() if name != "name"} + for name, value in _class.__dict__["__fields__"].items() + }