Implement Memories, validation and other fixes
This commit is contained in:
commit
54add63b40
29 changed files with 525 additions and 121 deletions
15
Makefile
15
Makefile
|
|
@ -41,14 +41,13 @@ build:
|
|||
|
||||
dev:
|
||||
make install_frontend
|
||||
ifeq ($(build),1)
|
||||
@echo 'Running docker compose up with build'
|
||||
docker compose up --build
|
||||
else
|
||||
@echo 'Running docker compose up without build'
|
||||
docker compose up
|
||||
endif
|
||||
|
||||
ifeq ($(build),1)
|
||||
@echo 'Running docker compose up with build'
|
||||
docker compose up $(if $(debug),-f docker-compose.debug.yml) --build
|
||||
else
|
||||
@echo 'Running docker compose up without build'
|
||||
docker compose up $(if $(debug),-f docker-compose.debug.yml)
|
||||
endif
|
||||
|
||||
publish:
|
||||
make build
|
||||
|
|
|
|||
28
docker-compose.debug.yml
Normal file
28
docker-compose.debug.yml
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
version: '3.4'
|
||||
|
||||
services:
|
||||
backend:
|
||||
volumes:
|
||||
- ./:/app
|
||||
build:
|
||||
context: ./
|
||||
dockerfile: ./dev.Dockerfile
|
||||
command: ["sh", "-c", "pip install debugpy -t /tmp && python /tmp/debugpy --wait-for-client --listen 0.0.0.0:5678 -m uvicorn langflow.main:app --host 0.0.0.0 --port 7860 --reload"]
|
||||
ports:
|
||||
- 7860:7860
|
||||
- 5678:5678
|
||||
restart: on-failure
|
||||
|
||||
frontend:
|
||||
build:
|
||||
context: ./src/frontend
|
||||
dockerfile: ./dev.Dockerfile
|
||||
args:
|
||||
- BACKEND_URL=http://backend:7860
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- ./src/frontend/public:/home/node/app/public
|
||||
- ./src/frontend/src:/home/node/app/src
|
||||
- ./src/frontend/package.json:/home/node/app/package.json
|
||||
restart: on-failure
|
||||
19
poetry.lock
generated
19
poetry.lock
generated
|
|
@ -1199,14 +1199,14 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"]
|
|||
|
||||
[[package]]
|
||||
name = "langchain"
|
||||
version = "0.0.127"
|
||||
version = "0.0.131"
|
||||
description = "Building applications with LLMs through composability"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
files = [
|
||||
{file = "langchain-0.0.127-py3-none-any.whl", hash = "sha256:04ba053881e6098e80e0f4afc8922f3fe78923b160fd12d856aebce49c261918"},
|
||||
{file = "langchain-0.0.127.tar.gz", hash = "sha256:e8a3b67fd86a6f79c4334f0a7588c9476fcb57b27a8fb0e617f47c01eaab8be8"},
|
||||
{file = "langchain-0.0.131-py3-none-any.whl", hash = "sha256:3564a759e85095c9d71a78817da9cec1e2a8a0cda1bdd94ef8ac7008e432717a"},
|
||||
{file = "langchain-0.0.131.tar.gz", hash = "sha256:61baf67fbec561ce38d187915a46e1c41139270826453600951760fde1a5d98a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -1220,8 +1220,11 @@ SQLAlchemy = ">=1,<2"
|
|||
tenacity = ">=8.1.0,<9.0.0"
|
||||
|
||||
[package.extras]
|
||||
all = ["aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.2.4,<0.3.0)", "beautifulsoup4 (>=4,<5)", "boto3 (>=1.26.96,<2.0.0)", "cohere (>=3,<4)", "deeplake (>=3.2.9,<4.0.0)", "elasticsearch (>=8,<9)", "faiss-cpu (>=1,<2)", "google-api-python-client (==2.70.0)", "google-search-results (>=2,<3)", "huggingface_hub (>=0,<1)", "jina (>=3.14,<4.0)", "jinja2 (>=3,<4)", "manifest-ml (>=0.0.1,<0.0.2)", "networkx (>=2.6.3,<3.0.0)", "nlpcloud (>=1,<2)", "nltk (>=3,<4)", "nomic (>=1.0.43,<2.0.0)", "openai (>=0,<1)", "opensearch-py (>=2.0.0,<3.0.0)", "pgvector (>=0.1.6,<0.2.0)", "pinecone-client (>=2,<3)", "psycopg2-binary (>=2.9.5,<3.0.0)", "pyowm (>=3.3.0,<4.0.0)", "pypdf (>=3.4.0,<4.0.0)", "qdrant-client (>=1.0.4,<2.0.0)", "redis (>=4,<5)", "sentence-transformers (>=2,<3)", "spacy (>=3,<4)", "tensorflow-text (>=2.11.0,<3.0.0)", "tiktoken (>=0.3.2,<0.4.0)", "torch (>=1,<2)", "transformers (>=4,<5)", "weaviate-client (>=3,<4)", "wikipedia (>=1,<2)", "wolframalpha (==5.0.0)"]
|
||||
all = ["aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.2.4,<0.3.0)", "beautifulsoup4 (>=4,<5)", "boto3 (>=1.26.96,<2.0.0)", "cohere (>=3,<4)", "deeplake (>=3.2.9,<4.0.0)", "elasticsearch (>=8,<9)", "faiss-cpu (>=1,<2)", "google-api-python-client (==2.70.0)", "google-search-results (>=2,<3)", "huggingface_hub (>=0,<1)", "jina (>=3.14,<4.0)", "jinja2 (>=3,<4)", "manifest-ml (>=0.0.1,<0.0.2)", "networkx (>=2.6.3,<3.0.0)", "nlpcloud (>=1,<2)", "nltk (>=3,<4)", "nomic (>=1.0.43,<2.0.0)", "openai (>=0,<1)", "opensearch-py (>=2.0.0,<3.0.0)", "pgvector (>=0.1.6,<0.2.0)", "pinecone-client (>=2,<3)", "psycopg2-binary (>=2.9.5,<3.0.0)", "pyowm (>=3.3.0,<4.0.0)", "pypdf (>=3.4.0,<4.0.0)", "qdrant-client (>=1.1.1,<2.0.0)", "redis (>=4,<5)", "sentence-transformers (>=2,<3)", "spacy (>=3,<4)", "tensorflow-text (>=2.11.0,<3.0.0)", "tiktoken (>=0.3.2,<0.4.0)", "torch (>=1,<2)", "transformers (>=4,<5)", "weaviate-client (>=3,<4)", "wikipedia (>=1,<2)", "wolframalpha (==5.0.0)"]
|
||||
cohere = ["cohere (>=3,<4)"]
|
||||
llms = ["anthropic (>=0.2.4,<0.3.0)", "cohere (>=3,<4)", "huggingface_hub (>=0,<1)", "manifest-ml (>=0.0.1,<0.0.2)", "nlpcloud (>=1,<2)", "openai (>=0,<1)", "torch (>=1,<2)", "transformers (>=4,<5)"]
|
||||
openai = ["openai (>=0,<1)"]
|
||||
qdrant = ["qdrant-client (>=1.1.1,<2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
|
|
@ -1506,14 +1509,14 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "0.27.2"
|
||||
version = "0.27.4"
|
||||
description = "Python client library for the OpenAI API"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7.1"
|
||||
files = [
|
||||
{file = "openai-0.27.2-py3-none-any.whl", hash = "sha256:6df674cf257e9e0504f1fd191c333d3f6a2442b13218d0eccf06230eb24d320e"},
|
||||
{file = "openai-0.27.2.tar.gz", hash = "sha256:5869fdfa34b0ec66c39afa22f4a0fb83a135dff81f6505f52834c6ab3113f762"},
|
||||
{file = "openai-0.27.4-py3-none-any.whl", hash = "sha256:3b82c867d531e1fd2003d9de2131e1c4bfd4c70b1a3149e0543a555b30807b70"},
|
||||
{file = "openai-0.27.4.tar.gz", hash = "sha256:9f9d27d26e62c6068f516c0729449954b5ef6994be1a6cbfe7dbefbc84423a04"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -2712,4 +2715,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "2b523f3d737ef8f7082e8156f096bce6f4f84a8bee9d07bd4ed23a29d3dcfab1"
|
||||
content-hash = "91c68c5a5673f7b2bd0833af35da1262afd21d631cc62ec6ff9c65f69a96af0a"
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ google-search-results = "^2.4.1"
|
|||
google-api-python-client = "^2.79.0"
|
||||
typer = "^0.7.0"
|
||||
gunicorn = "^20.1.0"
|
||||
langchain = "^0.0.127"
|
||||
langchain = "^0.0.131"
|
||||
openai = "^0.27.2"
|
||||
types-pyyaml = "^6.0.12.8"
|
||||
dill = "^0.3.6"
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ from pydantic import BaseModel, validator
|
|||
class Code(BaseModel):
|
||||
code: str
|
||||
|
||||
@validator("code")
|
||||
def validate_code(cls, v):
|
||||
return v
|
||||
|
||||
class Prompt(BaseModel):
|
||||
template: str
|
||||
|
||||
|
||||
# Build ValidationResponse class for {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
class ValidationResponse(BaseModel):
|
||||
class CodeValidationResponse(BaseModel):
|
||||
imports: dict
|
||||
function: dict
|
||||
|
||||
|
|
@ -21,3 +21,7 @@ class ValidationResponse(BaseModel):
|
|||
@validator("function")
|
||||
def validate_function(cls, v):
|
||||
return v or {"errors": []}
|
||||
|
||||
|
||||
class PromptValidationResponse(BaseModel):
|
||||
input_variables: list
|
||||
|
|
|
|||
|
|
@ -3,10 +3,8 @@ from typing import Any, Dict
|
|||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from langflow.api.base import Code, ValidationResponse
|
||||
from langflow.interface.run import process_graph
|
||||
from langflow.interface.types import build_langchain_types_dict
|
||||
from langflow.utils.validate import validate_code
|
||||
|
||||
# build router
|
||||
router = APIRouter()
|
||||
|
|
@ -26,15 +24,3 @@ def get_load(data: Dict[str, Any]):
|
|||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.post("/validate", status_code=200, response_model=ValidationResponse)
|
||||
def post_validate_code(code: Code):
|
||||
try:
|
||||
errors = validate_code(code.code)
|
||||
return ValidationResponse(
|
||||
imports=errors.get("imports", {}),
|
||||
function=errors.get("function", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
|||
36
src/backend/langflow/api/validate.py
Normal file
36
src/backend/langflow/api/validate.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from langflow.api.base import (
|
||||
Code,
|
||||
CodeValidationResponse,
|
||||
Prompt,
|
||||
PromptValidationResponse,
|
||||
)
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.validate import validate_code
|
||||
|
||||
# build router
|
||||
router = APIRouter(prefix="/validate", tags=["validate"])
|
||||
|
||||
|
||||
@router.post("/code", status_code=200, response_model=CodeValidationResponse)
|
||||
def post_validate_code(code: Code):
|
||||
try:
|
||||
errors = validate_code(code.code)
|
||||
return CodeValidationResponse(
|
||||
imports=errors.get("imports", {}),
|
||||
function=errors.get("function", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/prompt", status_code=200, response_model=PromptValidationResponse)
|
||||
def post_validate_prompt(prompt: Prompt):
|
||||
try:
|
||||
input_variables = extract_input_variables_from_prompt(prompt.template)
|
||||
return PromptValidationResponse(input_variables=input_variables)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -3,6 +3,9 @@ chains:
|
|||
- LLMMathChain
|
||||
- LLMCheckerChain
|
||||
- ConversationChain
|
||||
- SeriesCharacterChain
|
||||
- MidJourneyPromptChain
|
||||
- TimeTravelGuideChain
|
||||
|
||||
agents:
|
||||
- ZeroShotAgent
|
||||
|
|
@ -13,9 +16,15 @@ agents:
|
|||
prompts:
|
||||
- PromptTemplate
|
||||
- FewShotPromptTemplate
|
||||
- ZeroShotPrompt
|
||||
# Wait more tests
|
||||
# - ChatPromptTemplate
|
||||
# - SystemMessagePromptTemplate
|
||||
# - HumanMessagePromptTemplate
|
||||
|
||||
llms:
|
||||
- OpenAI
|
||||
- AzureOpenAI
|
||||
- ChatOpenAI
|
||||
|
||||
tools:
|
||||
|
|
@ -36,6 +45,7 @@ toolkits:
|
|||
|
||||
memories:
|
||||
- ConversationBufferMemory
|
||||
- ConversationSummaryMemory
|
||||
|
||||
embeddings: []
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from langflow.template import nodes
|
||||
|
||||
# These should always be instantiated
|
||||
CUSTOM_NODES = {
|
||||
"prompts": {"ZeroShotPrompt": nodes.ZeroShotPromptNode()},
|
||||
"tools": {"PythonFunction": nodes.PythonFunctionNode(), "Tool": nodes.ToolNode()},
|
||||
|
|
|
|||
|
|
@ -75,7 +75,9 @@ class PromptNode(Node):
|
|||
for param in prompt_params:
|
||||
prompt_text = self.params[param]
|
||||
variables = extract_input_variables_from_prompt(prompt_text)
|
||||
|
||||
self.params["input_variables"].extend(variables)
|
||||
self.params["input_variables"] = list(set(self.params["input_variables"]))
|
||||
|
||||
self._build()
|
||||
return deepcopy(self._built_object)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,8 @@ class AgentCreator(LangChainTypeCreator):
|
|||
self.type_dict = loading.AGENT_TO_CLASS
|
||||
# Add JsonAgent to the list of agents
|
||||
for name, agent in CUSTOM_AGENTS.items():
|
||||
self.type_dict[name] = agent
|
||||
# TODO: validate AgentType
|
||||
self.type_dict[name] = agent # type: ignore
|
||||
return self.type_dict
|
||||
|
||||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
|
|
|
|||
|
|
@ -111,7 +111,8 @@ class InitializeAgent(AgentExecutor):
|
|||
return initialize_agent(
|
||||
tools=tools,
|
||||
llm=llm,
|
||||
agent=agent,
|
||||
# LangChain now uses Enum for agent, but we still support string
|
||||
agent=agent, # type: ignore
|
||||
memory=memory,
|
||||
return_intermediate_steps=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -12,6 +12,11 @@ class LangChainTypeCreator(BaseModel, ABC):
|
|||
type_name: str
|
||||
type_dict: Optional[Dict] = None
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> Type[FrontendNode]:
|
||||
"""The class type of the FrontendNode created in frontend_node."""
|
||||
return FrontendNode
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
|
|
@ -62,7 +67,7 @@ class LangChainTypeCreator(BaseModel, ABC):
|
|||
if key != "_type"
|
||||
]
|
||||
template = Template(type_name=name, fields=fields)
|
||||
return FrontendNode(
|
||||
return self.frontend_node_class(
|
||||
template=template,
|
||||
description=signature.get("description", ""),
|
||||
base_classes=signature["base_classes"],
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import chain_type_to_cls_dict
|
||||
from langflow.settings import settings
|
||||
|
|
@ -15,19 +16,27 @@ class ChainCreator(LangChainTypeCreator):
|
|||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict = chain_type_to_cls_dict
|
||||
from langflow.interface.chains.custom import CUSTOM_CHAINS
|
||||
|
||||
self.type_dict.update(CUSTOM_CHAINS)
|
||||
return self.type_dict
|
||||
|
||||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
try:
|
||||
return build_template_from_class(name, chain_type_to_cls_dict)
|
||||
if name in get_custom_nodes(self.type_name).keys():
|
||||
return get_custom_nodes(self.type_name)[name]
|
||||
return build_template_from_class(name, self.type_to_loader_dict)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Memory not found") from exc
|
||||
raise ValueError("Chain not found") from exc
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
custom_chains = list(get_custom_nodes("chains").keys())
|
||||
default_chains = list(self.type_to_loader_dict.keys())
|
||||
# Check if the chain is in the settings
|
||||
return [
|
||||
chain.__name__
|
||||
for chain in self.type_to_loader_dict.values()
|
||||
if chain.__name__ in settings.chains or settings.dev
|
||||
chain
|
||||
for chain in default_chains + custom_chains
|
||||
if chain in settings.chains or settings.dev
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
101
src/backend/langflow/interface/chains/custom.py
Normal file
101
src/backend/langflow/interface/chains/custom.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
from typing import Dict, Optional, Type
|
||||
|
||||
from langchain.chains import ConversationChain
|
||||
from langchain.memory.buffer import ConversationBufferMemory
|
||||
from langchain.schema import BaseMemory
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
|
||||
DEFAULT_SUFFIX = """"
|
||||
Current conversation:
|
||||
{history}
|
||||
Human: {input}
|
||||
{ai_prefix}"""
|
||||
|
||||
|
||||
class BaseCustomChain(ConversationChain):
|
||||
"""BaseCustomChain is a chain you can use to have a conversation with a custom character."""
|
||||
|
||||
template: Optional[str]
|
||||
|
||||
ai_prefix_key: Optional[str]
|
||||
"""Field to use as the ai_prefix. It needs to be set and has to be in the template"""
|
||||
|
||||
@root_validator(pre=False)
|
||||
def build_template(cls, values):
|
||||
format_dict = {}
|
||||
input_variables = extract_input_variables_from_prompt(values["template"])
|
||||
|
||||
if values.get("ai_prefix_key", None) is None:
|
||||
values["ai_prefix_key"] = values["memory"].ai_prefix
|
||||
|
||||
for key in input_variables:
|
||||
new_value = values.get(key, f"{{{key}}}")
|
||||
format_dict[key] = new_value
|
||||
if key == values.get("ai_prefix_key", None):
|
||||
values["memory"].ai_prefix = new_value
|
||||
|
||||
values["template"] = values["template"].format(**format_dict)
|
||||
|
||||
values["template"] = values["template"]
|
||||
values["input_variables"] = extract_input_variables_from_prompt(
|
||||
values["template"]
|
||||
)
|
||||
values["prompt"].template = values["template"]
|
||||
values["prompt"].input_variables = values["input_variables"]
|
||||
return values
|
||||
|
||||
|
||||
class SeriesCharacterChain(BaseCustomChain):
|
||||
"""SeriesCharacterChain is a chain you can use to have a conversation with a character from a series."""
|
||||
|
||||
character: str
|
||||
series: str
|
||||
template: Optional[
|
||||
str
|
||||
] = """I want you to act like {character} from {series}.
|
||||
I want you to respond and answer like {character}. do not write any explanations. only answer like {character}.
|
||||
You must know all of the knowledge of {character}.
|
||||
Current conversation:
|
||||
{history}
|
||||
Human: {input}
|
||||
{character}:"""
|
||||
memory: BaseMemory = Field(default_factory=ConversationBufferMemory)
|
||||
ai_prefix_key: Optional[str] = "character"
|
||||
"""Default memory store."""
|
||||
|
||||
|
||||
class MidJourneyPromptChain(BaseCustomChain):
|
||||
"""MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."""
|
||||
|
||||
template: Optional[
|
||||
str
|
||||
] = """I want you to act as a prompt generator for Midjourney's artificial intelligence program.
|
||||
Your job is to provide detailed and creative descriptions that will inspire unique and interesting images from the AI.
|
||||
Keep in mind that the AI is capable of understanding a wide range of language and can interpret abstract concepts, so feel free to be as imaginative and descriptive as possible.
|
||||
For example, you could describe a scene from a futuristic city, or a surreal landscape filled with strange creatures.
|
||||
The more detailed and imaginative your description, the more interesting the resulting image will be. Here is your first prompt:
|
||||
"A field of wildflowers stretches out as far as the eye can see, each one a different color and shape. In the distance, a massive tree towers over the landscape, its branches reaching up to the sky like tentacles.\"
|
||||
|
||||
Current conversation:
|
||||
{history}
|
||||
Human: {input}
|
||||
AI:""" # noqa: E501
|
||||
|
||||
|
||||
class TimeTravelGuideChain(BaseCustomChain):
|
||||
template: Optional[
|
||||
str
|
||||
] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.
|
||||
Current conversation:
|
||||
{history}
|
||||
Human: {input}
|
||||
AI:""" # noqa: E501
|
||||
|
||||
|
||||
CUSTOM_CHAINS: Dict[str, Type[ConversationChain]] = {
|
||||
"SeriesCharacterChain": SeriesCharacterChain,
|
||||
"MidJourneyPromptChain": MidJourneyPromptChain,
|
||||
"TimeTravelGuideChain": TimeTravelGuideChain,
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
# This module is used to import any langchain class by name.
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
from typing import Any, Type
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.agents import Agent
|
||||
|
|
@ -65,10 +65,14 @@ def import_class(class_path: str) -> Any:
|
|||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def import_prompt(prompt: str) -> PromptTemplate:
|
||||
def import_prompt(prompt: str) -> Type[PromptTemplate]:
|
||||
from langflow.interface.prompts.custom import CUSTOM_PROMPTS
|
||||
|
||||
"""Import prompt from prompt name"""
|
||||
if prompt == "ZeroShotPrompt":
|
||||
return import_class("langchain.prompts.PromptTemplate")
|
||||
elif prompt in CUSTOM_PROMPTS:
|
||||
return CUSTOM_PROMPTS[prompt]
|
||||
return import_class(f"langchain.prompts.{prompt}")
|
||||
|
||||
|
||||
|
|
@ -100,6 +104,10 @@ def import_tool(tool: str) -> BaseTool:
|
|||
return get_tool_by_name(tool)
|
||||
|
||||
|
||||
def import_chain(chain: str) -> Chain:
|
||||
def import_chain(chain: str) -> Type[Chain]:
|
||||
"""Import chain from chain name"""
|
||||
from langflow.interface.chains.custom import CUSTOM_CHAINS
|
||||
|
||||
if chain in CUSTOM_CHAINS:
|
||||
return CUSTOM_CHAINS[chain]
|
||||
return import_class(f"langchain.chains.{chain}")
|
||||
|
|
|
|||
|
|
@ -1,14 +1,21 @@
|
|||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import memory_type_to_cls_dict
|
||||
from langflow.settings import settings
|
||||
from langflow.template.base import FrontendNode
|
||||
from langflow.template.nodes import MemoryFrontendNode
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
||||
|
||||
class MemoryCreator(LangChainTypeCreator):
|
||||
type_name: str = "memories"
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> Type[FrontendNode]:
|
||||
"""The class type of the FrontendNode created in frontend_node."""
|
||||
return MemoryFrontendNode
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
|
|
|
|||
|
|
@ -1,39 +1,54 @@
|
|||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langchain.prompts import loading
|
||||
from langchain import prompts
|
||||
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.utils.util import build_template_from_function
|
||||
from langflow.template.nodes import PromptFrontendNode
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
||||
|
||||
class PromptCreator(LangChainTypeCreator):
|
||||
type_name: str = "prompts"
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> Type[PromptFrontendNode]:
|
||||
return PromptFrontendNode
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict = loading.type_to_loader_dict
|
||||
self.type_dict = {
|
||||
prompt_name: import_class(f"langchain.prompts.{prompt_name}")
|
||||
# if prompt_name is not lower case it is a class
|
||||
for prompt_name in prompts.__all__
|
||||
if not prompt_name.islower() and prompt_name in settings.prompts
|
||||
}
|
||||
# Merge CUSTOM_PROMPTS into self.type_dict
|
||||
from langflow.interface.prompts.custom import CUSTOM_PROMPTS
|
||||
|
||||
self.type_dict.update(CUSTOM_PROMPTS)
|
||||
return self.type_dict
|
||||
|
||||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
try:
|
||||
if name in get_custom_nodes(self.type_name).keys():
|
||||
return get_custom_nodes(self.type_name)[name]
|
||||
return build_template_from_function(name, self.type_to_loader_dict)
|
||||
return build_template_from_class(name, self.type_to_loader_dict)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Prompt not found") from exc
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
custom_prompts = get_custom_nodes("prompts")
|
||||
library_prompts = [
|
||||
prompt.__annotations__["return"].__name__
|
||||
for prompt in self.type_to_loader_dict.values()
|
||||
if prompt.__annotations__["return"].__name__ in settings.prompts
|
||||
or settings.dev
|
||||
]
|
||||
return library_prompts + list(custom_prompts.keys())
|
||||
# library_prompts = [
|
||||
# prompt.__annotations__["return"].__name__
|
||||
# for prompt in self.type_to_loader_dict.values()
|
||||
# if prompt.__annotations__["return"].__name__ in settings.prompts
|
||||
# or settings.dev
|
||||
# ]
|
||||
return list(self.type_to_loader_dict.keys()) + list(custom_prompts.keys())
|
||||
|
||||
|
||||
prompt_creator = PromptCreator()
|
||||
|
|
|
|||
|
|
@ -1,49 +1,52 @@
|
|||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from pydantic import root_validator
|
||||
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
from langflow.template.base import Template, TemplateField
|
||||
from langflow.template.nodes import PromptTemplateNode
|
||||
|
||||
CHARACTER_PROMPT = """I want you to act like {character} from {series}.
|
||||
I want you to respond and answer like {character}. do not write any explanations. only answer like {character}.
|
||||
You must know all of the knowledge of {character}."""
|
||||
# Steps to create a BaseCustomPrompt:
|
||||
# 1. Create a prompt template that endes with:
|
||||
# Current conversation:
|
||||
# {history}
|
||||
# Human: {input}
|
||||
# {ai_prefix}:
|
||||
# 2. Create a class that inherits from BaseCustomPrompt
|
||||
# 3. Add the following class attributes:
|
||||
# template: str = ""
|
||||
# description: Optional[str]
|
||||
# ai_prefix: Optional[str] = "{ai_prefix}"
|
||||
# 3.1. The ai_prefix should be a value in input_variables
|
||||
# SeriesCharacterPrompt is a working example
|
||||
# If used in a LLMChain, with a Memory module, it will work as expected
|
||||
# We should consider creating ConversationalChains that expose custom parameters
|
||||
# That way it will be easier to create custom prompts
|
||||
|
||||
|
||||
class BaseCustomPrompt(PromptTemplate):
|
||||
template: str = ""
|
||||
description: Optional[str]
|
||||
human_text: str = "\n {input}"
|
||||
ai_prefix: Optional[str]
|
||||
|
||||
@root_validator(pre=False)
|
||||
def build_template(cls, values):
|
||||
format_dict = {}
|
||||
ai_prefix_format_dict = {}
|
||||
for key in values.get("input_variables", []):
|
||||
new_value = values[key]
|
||||
new_value = values.get(key, f"{{{key}}}")
|
||||
format_dict[key] = new_value
|
||||
if key in values["ai_prefix"]:
|
||||
ai_prefix_format_dict[key] = new_value
|
||||
|
||||
values["ai_prefix"] = values["ai_prefix"].format(**ai_prefix_format_dict)
|
||||
values["template"] = values["template"].format(**format_dict)
|
||||
|
||||
values["template"] = values["template"] + values["human_text"]
|
||||
values["template"] = values["template"]
|
||||
values["input_variables"] = extract_input_variables_from_prompt(
|
||||
values["template"]
|
||||
)
|
||||
return values
|
||||
|
||||
def build_frontend_node(self) -> PromptTemplateNode:
|
||||
return PromptTemplateNode(
|
||||
template=Template(
|
||||
type_name="test",
|
||||
fields=[
|
||||
TemplateField(name=field, field_type="str", required=True)
|
||||
for field in self.input_variables
|
||||
],
|
||||
),
|
||||
description=self.description or "",
|
||||
)
|
||||
|
||||
|
||||
class SeriesCharacterPrompt(BaseCustomPrompt):
|
||||
# Add a very descriptive description for the prompt generator
|
||||
|
|
@ -52,14 +55,23 @@ class SeriesCharacterPrompt(BaseCustomPrompt):
|
|||
] = "A prompt that asks the AI to act like a character from a series."
|
||||
character: str
|
||||
series: str
|
||||
human_text: str = "\n {input}"
|
||||
template: str = CHARACTER_PROMPT
|
||||
template: str = """I want you to act like {character} from {series}.
|
||||
I want you to respond and answer like {character}. do not write any explanations. only answer like {character}.
|
||||
You must know all of the knowledge of {character}.
|
||||
|
||||
Current conversation:
|
||||
{history}
|
||||
Human: {input}
|
||||
{character}:"""
|
||||
|
||||
ai_prefix: str = "{character}"
|
||||
input_variables: List[str] = ["character", "series"]
|
||||
|
||||
|
||||
CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = {
|
||||
"SeriesCharacterPrompt": SeriesCharacterPrompt
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt = SeriesCharacterPrompt(character="Walter White", series="Breaking Bad")
|
||||
user_input = "I am the one who knocks"
|
||||
full_prompt = prompt.format(input=user_input)
|
||||
print(full_prompt)
|
||||
prompt = SeriesCharacterPrompt(character="Harry Potter", series="Harry Potter")
|
||||
print(prompt.template)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,12 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_thought_using_graph(langchain_object, message)
|
||||
|
|
@ -66,34 +72,82 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
||||
def get_result_and_thought_using_graph(loaded_langchain, message: str):
|
||||
def fix_memory_inputs(langchain_object):
|
||||
"""
|
||||
Fix memory inputs by replacing the memory key with the input key.
|
||||
"""
|
||||
# Possible memory keys
|
||||
# "chat_history", "history"
|
||||
# if memory_key is "chat_history" and input_keys has "history"
|
||||
# we need to replace "chat_history" with "history"
|
||||
mem_key_dict = {
|
||||
"chat_history": "history",
|
||||
"history": "chat_history",
|
||||
}
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
possible_new_mem_key = mem_key_dict.get(memory_key)
|
||||
if possible_new_mem_key is not None:
|
||||
# get input_key
|
||||
input_key = [
|
||||
key
|
||||
for key in langchain_object.input_keys
|
||||
if key not in [memory_key, possible_new_mem_key]
|
||||
][0]
|
||||
|
||||
# get output_key
|
||||
output_key = [
|
||||
key
|
||||
for key in langchain_object.output_keys
|
||||
if key not in [memory_key, possible_new_mem_key]
|
||||
][0]
|
||||
|
||||
# set input_key and output_key in memory
|
||||
langchain_object.memory.input_key = input_key
|
||||
langchain_object.memory.output_key = output_key
|
||||
for input_key in langchain_object.input_keys:
|
||||
if input_key == possible_new_mem_key:
|
||||
langchain_object.memory.memory_key = possible_new_mem_key
|
||||
|
||||
|
||||
def get_result_and_thought_using_graph(langchain_object, message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
loaded_langchain.verbose = True
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
chat_input = None
|
||||
for key in loaded_langchain.input_keys:
|
||||
if key == "chat_history" and hasattr(loaded_langchain, "memory"):
|
||||
loaded_langchain.memory.memory_key = "chat_history"
|
||||
else:
|
||||
memory_key = ""
|
||||
if (
|
||||
hasattr(langchain_object, "memory")
|
||||
and langchain_object.memory is not None
|
||||
):
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
|
||||
if hasattr(loaded_langchain, "return_intermediate_steps"):
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
loaded_langchain.return_intermediate_steps = False
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
|
||||
try:
|
||||
output = loaded_langchain(chat_input)
|
||||
output = langchain_object(chat_input)
|
||||
except ValueError as exc:
|
||||
logger.debug("Error: %s", str(exc))
|
||||
output = loaded_langchain.run(chat_input)
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
output = langchain_object.run(chat_input)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = (
|
||||
output.get(loaded_langchain.output_keys[0])
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
|
|
@ -110,16 +164,16 @@ def get_result_and_thought_using_graph(loaded_langchain, message: str):
|
|||
def get_result_and_thought(extracted_json: Dict[str, Any], message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
loaded_langchain = loading.load_langchain_type_from_config(
|
||||
langchain_object = loading.load_langchain_type_from_config(
|
||||
config=extracted_json
|
||||
)
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
output = loaded_langchain(message)
|
||||
output = langchain_object(message)
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
result = (
|
||||
output.get(loaded_langchain.output_keys[0])
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class WrapperCreator(LangChainTypeCreator):
|
|||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict = {
|
||||
wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]
|
||||
wrapper.__name__: wrapper for wrapper in [requests.TextRequestsWrapper]
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from fastapi import FastAPI
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from langflow.api.endpoints import router as endpoints_router
|
||||
from langflow.api.validate import router as validate_router
|
||||
|
||||
|
||||
def create_app():
|
||||
|
|
@ -21,6 +22,7 @@ def create_app():
|
|||
)
|
||||
|
||||
app.include_router(endpoints_router)
|
||||
app.include_router(validate_router)
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class TemplateFieldCreator(BaseModel, ABC):
|
|||
password: bool = False
|
||||
options: list[str] = []
|
||||
name: str = ""
|
||||
display_name: Optional[str] = None
|
||||
|
||||
def to_dict(self):
|
||||
result = self.dict()
|
||||
|
|
|
|||
|
|
@ -9,3 +9,24 @@ FORCE_SHOW_FIELDS = [
|
|||
"max_value_length",
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
I want you to act as a naming consultant for new companies.
|
||||
|
||||
Here are some examples of good company names:
|
||||
|
||||
- search engine, Google
|
||||
- social media, Facebook
|
||||
- video sharing, YouTube
|
||||
|
||||
The name should be short, catchy and easy to remember.
|
||||
|
||||
What is a good name for a company that makes {product}?
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that talks casually about life in general.
|
||||
You are a good listener and you can talk about anything.
|
||||
"""
|
||||
|
||||
HUMAN_PROMPT = "{input}"
|
||||
|
|
|
|||
|
|
@ -4,10 +4,27 @@ from langchain.agents import loading
|
|||
from langchain.agents.mrkl import prompt
|
||||
|
||||
from langflow.template.base import FrontendNode, Template, TemplateField
|
||||
from langflow.template.constants import DEFAULT_PROMPT, HUMAN_PROMPT, SYSTEM_PROMPT
|
||||
from langflow.utils.constants import DEFAULT_PYTHON_FUNCTION
|
||||
|
||||
NON_CHAT_AGENTS = {
|
||||
agent_type: agent_class
|
||||
for agent_type, agent_class in loading.AGENT_TO_CLASS.items()
|
||||
if "chat" not in agent_type.value
|
||||
}
|
||||
|
||||
class ZeroShotPromptNode(FrontendNode):
|
||||
|
||||
class BasePromptFrontendNode(FrontendNode):
|
||||
name: str
|
||||
template: Template
|
||||
description: str
|
||||
base_classes: list[str]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
||||
|
||||
class ZeroShotPromptNode(BasePromptFrontendNode):
|
||||
name: str = "ZeroShotPrompt"
|
||||
template: Template = Template(
|
||||
type_name="zero_shot",
|
||||
|
|
@ -165,8 +182,8 @@ class InitializeAgentNode(FrontendNode):
|
|||
is_list=True,
|
||||
show=True,
|
||||
multiline=False,
|
||||
options=list(loading.AGENT_TO_CLASS.keys()),
|
||||
value=list(loading.AGENT_TO_CLASS.keys())[0],
|
||||
options=list(NON_CHAT_AGENTS.keys()),
|
||||
value=list(NON_CHAT_AGENTS.keys())[0],
|
||||
name="agent",
|
||||
),
|
||||
TemplateField(
|
||||
|
|
@ -229,3 +246,37 @@ class CSVAgentNode(FrontendNode):
|
|||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
||||
|
||||
class PromptFrontendNode(FrontendNode):
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
# if field.field_type == "StringPromptTemplate"
|
||||
# change it to str
|
||||
if field.field_type == "StringPromptTemplate" and "Message" in str(name):
|
||||
field.field_type = "str"
|
||||
field.multiline = True
|
||||
field.value = HUMAN_PROMPT if "Human" in field.name else SYSTEM_PROMPT
|
||||
if field.name == "template" and field.value == "":
|
||||
field.value = DEFAULT_PROMPT
|
||||
|
||||
if (
|
||||
"Union" in field.field_type
|
||||
and "BaseMessagePromptTemplate" in field.field_type
|
||||
):
|
||||
field.field_type = "BaseMessagePromptTemplate"
|
||||
|
||||
|
||||
class MemoryFrontendNode(FrontendNode):
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
|
||||
if not isinstance(field.value, str):
|
||||
field.value = None
|
||||
if field.name == "k":
|
||||
field.required = True
|
||||
field.show = True
|
||||
field.field_type = "int"
|
||||
field.value = 10
|
||||
field.display_name = "Memory Size"
|
||||
|
|
|
|||
|
|
@ -60,4 +60,4 @@
|
|||
]
|
||||
},
|
||||
"proxy": "http://backend:7860"
|
||||
}
|
||||
}
|
||||
|
|
@ -12,5 +12,5 @@ export async function sendAll(data:sendAllProps) {
|
|||
|
||||
export async function checkCode(code:string):Promise<AxiosResponse<errorsTypeAPI>>{
|
||||
|
||||
return await axios.post('/validate',{code})
|
||||
return await axios.post('/validate/code',{code})
|
||||
}
|
||||
|
|
@ -74,7 +74,7 @@
|
|||
"multiline": false,
|
||||
"password": false,
|
||||
"name": "requests_wrapper",
|
||||
"type": "RequestsWrapper",
|
||||
"type": "TextRequestsWrapper",
|
||||
"list": false
|
||||
},
|
||||
"_type": "OpenAPIToolkit"
|
||||
|
|
@ -154,7 +154,7 @@
|
|||
"y": 532.9920887988924
|
||||
},
|
||||
"data": {
|
||||
"type": "RequestsWrapper",
|
||||
"type": "TextRequestsWrapper",
|
||||
"node": {
|
||||
"template": {
|
||||
"headers": {
|
||||
|
|
@ -178,11 +178,11 @@
|
|||
"type": "ClientSession",
|
||||
"list": false
|
||||
},
|
||||
"_type": "RequestsWrapper"
|
||||
"_type": "TextRequestsWrapper"
|
||||
},
|
||||
"description": "Lightweight wrapper around requests library.",
|
||||
"base_classes": [
|
||||
"RequestsWrapper"
|
||||
"TextRequestsWrapper"
|
||||
]
|
||||
},
|
||||
"id": "dndnode_34",
|
||||
|
|
@ -405,11 +405,11 @@
|
|||
},
|
||||
{
|
||||
"source": "dndnode_34",
|
||||
"sourceHandle": "RequestsWrapper|dndnode_34|RequestsWrapper",
|
||||
"sourceHandle": "TextRequestsWrapper|dndnode_34|TextRequestsWrapper",
|
||||
"target": "dndnode_32",
|
||||
"targetHandle": "RequestsWrapper|requests_wrapper|dndnode_32",
|
||||
"targetHandle": "TextRequestsWrapper|requests_wrapper|dndnode_32",
|
||||
"className": "animate-pulse",
|
||||
"id": "reactflow__edge-dndnode_34RequestsWrapper|dndnode_34|RequestsWrapper-dndnode_32RequestsWrapper|requests_wrapper|dndnode_32",
|
||||
"id": "reactflow__edge-dndnode_34RequestsWrapper|dndnode_34|TextRequestsWrapper-dndnode_32RequestsWrapper|requests_wrapper|dndnode_32",
|
||||
"selected": false
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from langflow.interface.tools.constants import CUSTOM_TOOLS
|
||||
|
||||
|
|
@ -20,7 +21,7 @@ import math
|
|||
def square(x):
|
||||
return x ** 2
|
||||
"""
|
||||
response1 = client.post("/validate", json={"code": code1})
|
||||
response1 = client.post("/validate/code", json={"code": code1})
|
||||
assert response1.status_code == 200
|
||||
assert response1.json() == {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
|
||||
|
|
@ -31,7 +32,7 @@ import non_existent_module
|
|||
def square(x):
|
||||
return x ** 2
|
||||
"""
|
||||
response2 = client.post("/validate", json={"code": code2})
|
||||
response2 = client.post("/validate/code", json={"code": code2})
|
||||
assert response2.status_code == 200
|
||||
assert response2.json() == {
|
||||
"imports": {"errors": ["No module named 'non_existent_module'"]},
|
||||
|
|
@ -45,7 +46,7 @@ import math
|
|||
def square(x)
|
||||
return x ** 2
|
||||
"""
|
||||
response3 = client.post("/validate", json={"code": code3})
|
||||
response3 = client.post("/validate/code", json={"code": code3})
|
||||
assert response3.status_code == 200
|
||||
assert response3.json() == {
|
||||
"imports": {"errors": []},
|
||||
|
|
@ -53,11 +54,11 @@ def square(x)
|
|||
}
|
||||
|
||||
# Test case with invalid JSON payload
|
||||
response4 = client.post("/validate", json={"invalid_key": code1})
|
||||
response4 = client.post("/validate/code", json={"invalid_key": code1})
|
||||
assert response4.status_code == 422
|
||||
|
||||
# Test case with an empty code string
|
||||
response5 = client.post("/validate", json={"code": ""})
|
||||
response5 = client.post("/validate/code", json={"code": ""})
|
||||
assert response5.status_code == 200
|
||||
assert response5.json() == {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
|
||||
|
|
@ -68,9 +69,55 @@ import math
|
|||
def square(x)
|
||||
return x ** 2
|
||||
"""
|
||||
response6 = client.post("/validate", json={"code": code6})
|
||||
response6 = client.post("/validate/code", json={"code": code6})
|
||||
assert response6.status_code == 200
|
||||
assert response6.json() == {
|
||||
"imports": {"errors": []},
|
||||
"function": {"errors": ["expected ':' (<unknown>, line 4)"]},
|
||||
}
|
||||
|
||||
|
||||
VALID_PROMPT = """
|
||||
I want you to act as a naming consultant for new companies.
|
||||
|
||||
Here are some examples of good company names:
|
||||
|
||||
- search engine, Google
|
||||
- social media, Facebook
|
||||
- video sharing, YouTube
|
||||
|
||||
The name should be short, catchy and easy to remember.
|
||||
|
||||
What is a good name for a company that makes {product}?
|
||||
"""
|
||||
|
||||
INVALID_PROMPT = "This is an invalid prompt without any input variable."
|
||||
|
||||
|
||||
def test_valid_prompt(client: TestClient):
|
||||
response = client.post("/validate/prompt", json={"template": VALID_PROMPT})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"input_variables": ["product"]}
|
||||
|
||||
|
||||
def test_invalid_prompt(client: TestClient):
|
||||
response = client.post("/validate/prompt", json={"template": INVALID_PROMPT})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"input_variables": []}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,expected_input_variables",
|
||||
[
|
||||
("{color} is my favorite color.", ["color"]),
|
||||
("The weather is {weather} today.", ["weather"]),
|
||||
("This prompt has no variables.", []),
|
||||
("{a}, {b}, and {c} are variables.", ["a", "b", "c"]),
|
||||
],
|
||||
)
|
||||
def test_various_prompts(client, prompt, expected_input_variables):
|
||||
response = client.post("/validate/prompt", json={"template": prompt})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"input_variables": expected_input_variables,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue