Merge branch 'dev' into vecstores

This commit is contained in:
Ibis Prevedello 2023-04-07 13:06:49 -03:00
commit 616dfd0370
22 changed files with 1890 additions and 128 deletions

View file

@ -1,5 +1,7 @@
from pydantic import BaseModel, validator
from langflow.graph.utils import extract_input_variables_from_prompt
class Code(BaseModel):
code: str
@ -25,3 +27,54 @@ class CodeValidationResponse(BaseModel):
class PromptValidationResponse(BaseModel):
input_variables: list
INVALID_CHARACTERS = {
" ",
",",
".",
":",
";",
"!",
"?",
"/",
"\\",
"(",
")",
"[",
"]",
"{",
"}",
}
def validate_prompt(template: str):
input_variables = extract_input_variables_from_prompt(template)
# Check if there are invalid characters in the input_variables
input_variables = check_input_variables(input_variables)
return PromptValidationResponse(input_variables=input_variables)
def check_input_variables(input_variables: list):
invalid_chars = []
fixed_variables = []
for variable in input_variables:
new_var = variable
for char in INVALID_CHARACTERS:
if char in variable:
invalid_chars.append(char)
new_var = new_var.replace(char, "")
fixed_variables.append(new_var)
if new_var != variable:
input_variables.remove(variable)
input_variables.append(new_var)
# If any of the input_variables is not in the fixed_variables, then it means that
# there are invalid characters in the input_variables
if any(var not in fixed_variables for var in input_variables):
raise ValueError(
f"Invalid input variables: {input_variables}. Please, use something like {fixed_variables} instead."
)
return input_variables

View file

@ -5,8 +5,8 @@ from langflow.api.base import (
CodeValidationResponse,
Prompt,
PromptValidationResponse,
validate_prompt,
)
from langflow.graph.utils import extract_input_variables_from_prompt
from langflow.utils.logger import logger
from langflow.utils.validate import validate_code
@ -29,8 +29,7 @@ def post_validate_code(code: Code):
@router.post("/prompt", status_code=200, response_model=PromptValidationResponse)
def post_validate_prompt(prompt: Prompt):
try:
input_variables = extract_input_variables_from_prompt(prompt.template)
return PromptValidationResponse(input_variables=input_variables)
return validate_prompt(prompt.template)
except Exception as e:
logger.exception(e)
return HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e

View file

@ -26,8 +26,9 @@ prompts:
llms:
- OpenAI
- AzureOpenAI
# - AzureOpenAI
- ChatOpenAI
- HuggingFaceHub
tools:
- Search

View file

@ -1,4 +1,4 @@
from typing import Dict, List, Union
from typing import Dict, List, Type, Union
from langflow.graph.base import Edge, Node
from langflow.graph.nodes import (
@ -25,7 +25,6 @@ from langflow.interface.prompts.base import prompt_creator
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.tools.base import tool_creator
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.interface.tools.util import get_tools_dict
from langflow.interface.vectorStore.base import vectorstore_creator
from langflow.interface.wrappers.base import wrapper_creator
from langflow.utils import payload
@ -114,6 +113,29 @@ class Graph:
edges.append(Edge(source, target))
return edges
def _get_node_class(self, node_type: str, node_lc_type: str) -> Type[Node]:
node_type_map: Dict[str, Type[Node]] = {
**{t: PromptNode for t in prompt_creator.to_list()},
**{t: AgentNode for t in agent_creator.to_list()},
**{t: ChainNode for t in chain_creator.to_list()},
**{t: ToolNode for t in tool_creator.to_list()},
**{t: ToolkitNode for t in toolkits_creator.to_list()},
**{t: WrapperNode for t in wrapper_creator.to_list()},
**{t: LLMNode for t in llm_creator.to_list()},
**{t: MemoryNode for t in memory_creator.to_list()},
**{t: EmbeddingNode for t in embedding_creator.to_list()},
**{t: VectorStoreNode for t in vectorstore_creator.to_list()},
**{t: DocumentLoaderNode for t in documentloader_creator.to_list()},
}
if node_type in FILE_TOOLS:
return FileToolNode
if node_type in node_type_map:
return node_type_map[node_type]
if node_lc_type in node_type_map:
return node_type_map[node_lc_type]
return Node
def _build_nodes(self) -> List[Node]:
nodes: List[Node] = []
for node in self._nodes:
@ -121,44 +143,9 @@ class Graph:
node_type: str = node_data["type"] # type: ignore
node_lc_type: str = node_data["node"]["template"]["_type"] # type: ignore
if node_type in prompt_creator.to_list():
nodes.append(PromptNode(node))
elif (
node_type in agent_creator.to_list()
or node_lc_type in agent_creator.to_list()
):
nodes.append(AgentNode(node))
elif node_type in chain_creator.to_list():
nodes.append(ChainNode(node))
elif (
node_type in tool_creator.to_list()
or node_lc_type in get_tools_dict().keys()
):
if node_type in FILE_TOOLS:
nodes.append(FileToolNode(node))
nodes.append(ToolNode(node))
elif node_type in toolkits_creator.to_list():
nodes.append(ToolkitNode(node))
elif node_type in wrapper_creator.to_list():
nodes.append(WrapperNode(node))
elif (
node_type in llm_creator.to_list()
or node_lc_type in llm_creator.to_list()
):
nodes.append(LLMNode(node))
elif node_type in embedding_creator.to_list():
nodes.append(EmbeddingNode(node))
elif node_type in vectorstore_creator.to_list():
nodes.append(VectorStoreNode(node))
elif node_type in documentloader_creator.to_list():
nodes.append(DocumentLoaderNode(node))
elif (
node_type in memory_creator.to_list()
or node_lc_type in memory_creator.to_list()
):
nodes.append(MemoryNode(node))
else:
nodes.append(Node(node))
NodeClass = self._get_node_class(node_type, node_lc_type)
nodes.append(NodeClass(node))
return nodes
def get_children_by_node_type(self, node: Node, node_type: str) -> List[Node]:

View file

@ -19,7 +19,7 @@ class BaseCustomChain(ConversationChain):
template: Optional[str]
ai_prefix_key: Optional[str]
ai_prefix_value: Optional[str]
"""Field to use as the ai_prefix. It needs to be set and has to be in the template"""
@root_validator(pre=False)
@ -27,13 +27,13 @@ class BaseCustomChain(ConversationChain):
format_dict = {}
input_variables = extract_input_variables_from_prompt(values["template"])
if values.get("ai_prefix_key", None) is None:
values["ai_prefix_key"] = values["memory"].ai_prefix
if values.get("ai_prefix_value", None) is None:
values["ai_prefix_value"] = values["memory"].ai_prefix
for key in input_variables:
new_value = values.get(key, f"{{{key}}}")
format_dict[key] = new_value
if key == values.get("ai_prefix_key", None):
if key == values.get("ai_prefix_value", None):
values["memory"].ai_prefix = new_value
values["template"] = values["template"].format(**format_dict)
@ -62,7 +62,7 @@ Current conversation:
Human: {input}
{character}:"""
memory: BaseMemory = Field(default_factory=ConversationBufferMemory)
ai_prefix_key: Optional[str] = "character"
ai_prefix_value: Optional[str] = "character"
"""Default memory store."""

View file

@ -178,12 +178,14 @@ class FrontendNode(BaseModel):
field.show = bool(
(field.required and key not in ["input_variables"])
or key in FORCE_SHOW_FIELDS
or "api_key" in key
or "api" in key
or ("key" in key and "input" not in key and "output" not in key)
)
# Add password field
field.password = any(
text in key.lower() for text in {"password", "token", "api", "key"}
field.password = (
any(text in key.lower() for text in {"password", "token", "api", "key"})
and field.show
)
# Add multline

View file

@ -309,13 +309,22 @@ class PromptFrontendNode(FrontendNode):
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
# if field.field_type == "StringPromptTemplate"
# change it to str
PROMPT_FIELDS = [
"template",
"suffix",
"prefix",
"examples",
]
if field.field_type == "StringPromptTemplate" and "Message" in str(name):
field.field_type = "str"
field.field_type = "prompt"
field.multiline = True
field.value = HUMAN_PROMPT if "Human" in field.name else SYSTEM_PROMPT
if field.name == "template" and field.value == "":
field.value = DEFAULT_PROMPT
if field.name in PROMPT_FIELDS:
field.field_type = "prompt"
if (
"Union" in field.field_type
and "BaseMessagePromptTemplate" in field.field_type