feat: frontend_node_class property and other prompts

This commit is contained in:
Gabriel Almeida 2023-04-03 17:06:15 -03:00
commit 08d60f18ea
6 changed files with 57 additions and 13 deletions

View file

@ -13,6 +13,10 @@ agents:
prompts:
- PromptTemplate
- FewShotPromptTemplate
- ChatPromptTemplate
- SystemMessagePromptTemplate
- AIMessagePromptTemplate
- HumanMessagePromptTemplate
llms:
- OpenAI

View file

@ -1,5 +1,6 @@
from langflow.template import nodes
# These should always be instantiated
CUSTOM_NODES = {
"prompts": {"ZeroShotPrompt": nodes.ZeroShotPromptNode()},
"tools": {"PythonFunction": nodes.PythonFunctionNode(), "Tool": nodes.ToolNode()},

View file

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
import abc
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
@ -12,6 +13,11 @@ class LangChainTypeCreator(BaseModel, ABC):
type_name: str
type_dict: Optional[Dict] = None
@property
def frontend_node_class(self) -> str:
"""The class type of the FrontendNode created in frontend_node."""
return FrontendNode
@property
@abstractmethod
def type_to_loader_dict(self) -> Dict:
@ -62,7 +68,7 @@ class LangChainTypeCreator(BaseModel, ABC):
if key != "_type"
]
template = Template(type_name=name, fields=fields)
return FrontendNode(
return self.frontend_node_class(
template=template,
description=signature.get("description", ""),
base_classes=signature["base_classes"],

View file

@ -1,39 +1,50 @@
from typing import Dict, List
from langchain.prompts import loading
from langchain import prompts
from langflow.custom.customs import get_custom_nodes
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.importing.utils import import_class
from langflow.settings import settings
from langflow.utils.util import build_template_from_function
from langflow.template.nodes import PromptFrontendNode
from langflow.utils.util import build_template_from_class
class PromptCreator(LangChainTypeCreator):
type_name: str = "prompts"
@property
def frontend_node_class(self) -> str:
return PromptFrontendNode
@property
def type_to_loader_dict(self) -> Dict:
if self.type_dict is None:
self.type_dict = loading.type_to_loader_dict
self.type_dict = {
prompt_name: import_class(f"langchain.prompts.{prompt_name}")
# if prompt_name is not lower case it is a class
for prompt_name in prompts.__all__
if not prompt_name.islower() and prompt_name in settings.prompts
}
return self.type_dict
def get_signature(self, name: str) -> Dict | None:
try:
if name in get_custom_nodes(self.type_name).keys():
return get_custom_nodes(self.type_name)[name]
return build_template_from_function(name, self.type_to_loader_dict)
return build_template_from_class(name, self.type_to_loader_dict)
except ValueError as exc:
raise ValueError("Prompt not found") from exc
def to_list(self) -> List[str]:
custom_prompts = get_custom_nodes("prompts")
library_prompts = [
prompt.__annotations__["return"].__name__
for prompt in self.type_to_loader_dict.values()
if prompt.__annotations__["return"].__name__ in settings.prompts
or settings.dev
]
return library_prompts + list(custom_prompts.keys())
# library_prompts = [
# prompt.__annotations__["return"].__name__
# for prompt in self.type_to_loader_dict.values()
# if prompt.__annotations__["return"].__name__ in settings.prompts
# or settings.dev
# ]
return list(self.type_to_loader_dict.keys()) + list(custom_prompts.keys())
prompt_creator = PromptCreator()

View file

@ -219,3 +219,5 @@ class FrontendNode(BaseModel):
elif name == "ChatOpenAI" and key == "model_name":
field.options = constants.CHAT_OPENAI_MODELS
field.is_list = True

View file

@ -6,7 +6,17 @@ from langflow.utils.constants import DEFAULT_PYTHON_FUNCTION
from langchain.agents import loading
class ZeroShotPromptNode(FrontendNode):
class BasePromptFrontendNode(FrontendNode):
name: str
template: Template
description: str
base_classes: list[str]
def to_dict(self):
return super().to_dict()
class ZeroShotPromptNode(BasePromptFrontendNode):
name: str = "ZeroShotPrompt"
template: Template = Template(
type_name="zero_shot",
@ -227,3 +237,13 @@ class CSVAgentNode(FrontendNode):
def to_dict(self):
return super().to_dict()
class PromptFrontendNode(FrontendNode):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
# if field.field_type == "StringPromptTemplate"
# change it to str
if field.field_type == "StringPromptTemplate":
field.field_type = "str"
field.multiline = True