From 08d60f18ea71d441b982c6bacb7cc0aa0ec54a48 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Mon, 3 Apr 2023 17:06:15 -0300 Subject: [PATCH] feat: frontend_node_class property and other prompts --- src/backend/langflow/config.yaml | 4 +++ src/backend/langflow/custom/customs.py | 1 + src/backend/langflow/interface/base.py | 8 ++++- .../langflow/interface/prompts/base.py | 33 ++++++++++++------- src/backend/langflow/template/base.py | 2 ++ src/backend/langflow/template/nodes.py | 22 ++++++++++++- 6 files changed, 57 insertions(+), 13 deletions(-) diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index 08beaae08..af369f193 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -13,6 +13,10 @@ agents: prompts: - PromptTemplate - FewShotPromptTemplate + - ChatPromptTemplate + - SystemMessagePromptTemplate + - AIMessagePromptTemplate + - HumanMessagePromptTemplate llms: - OpenAI diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index 112b8db26..22e833362 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -1,5 +1,6 @@ from langflow.template import nodes +# These should always be instantiated CUSTOM_NODES = { "prompts": {"ZeroShotPrompt": nodes.ZeroShotPromptNode()}, "tools": {"PythonFunction": nodes.PythonFunctionNode(), "Tool": nodes.ToolNode()}, diff --git a/src/backend/langflow/interface/base.py b/src/backend/langflow/interface/base.py index ad8ccfc6a..87f716ac2 100644 --- a/src/backend/langflow/interface/base.py +++ b/src/backend/langflow/interface/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import abc from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel @@ -12,6 +13,11 @@ class LangChainTypeCreator(BaseModel, ABC): type_name: str type_dict: Optional[Dict] = None + @property + def frontend_node_class(self) -> str: + """The class type of the FrontendNode created in frontend_node.""" + return FrontendNode + @property @abstractmethod def type_to_loader_dict(self) -> Dict: @@ -62,7 +68,7 @@ class LangChainTypeCreator(BaseModel, ABC): if key != "_type" ] template = Template(type_name=name, fields=fields) - return FrontendNode( + return self.frontend_node_class( template=template, description=signature.get("description", ""), base_classes=signature["base_classes"], diff --git a/src/backend/langflow/interface/prompts/base.py b/src/backend/langflow/interface/prompts/base.py index f730481a9..b24522d5c 100644 --- a/src/backend/langflow/interface/prompts/base.py +++ b/src/backend/langflow/interface/prompts/base.py @@ -1,39 +1,50 @@ from typing import Dict, List from langchain.prompts import loading - +from langchain import prompts from langflow.custom.customs import get_custom_nodes from langflow.interface.base import LangChainTypeCreator +from langflow.interface.importing.utils import import_class from langflow.settings import settings -from langflow.utils.util import build_template_from_function +from langflow.template.nodes import PromptFrontendNode +from langflow.utils.util import build_template_from_class class PromptCreator(LangChainTypeCreator): type_name: str = "prompts" + @property + def frontend_node_class(self) -> str: + return PromptFrontendNode + @property def type_to_loader_dict(self) -> Dict: if self.type_dict is None: - self.type_dict = loading.type_to_loader_dict + self.type_dict = { + prompt_name: import_class(f"langchain.prompts.{prompt_name}") + # if prompt_name is not lower case it is a class + for prompt_name in prompts.__all__ + if not prompt_name.islower() and prompt_name in settings.prompts + } return self.type_dict def get_signature(self, name: str) -> Dict | None: try: if name in get_custom_nodes(self.type_name).keys(): return get_custom_nodes(self.type_name)[name] - return build_template_from_function(name, self.type_to_loader_dict) + return build_template_from_class(name, self.type_to_loader_dict) except ValueError as exc: raise ValueError("Prompt not found") from exc def to_list(self) -> List[str]: custom_prompts = get_custom_nodes("prompts") - library_prompts = [ - prompt.__annotations__["return"].__name__ - for prompt in self.type_to_loader_dict.values() - if prompt.__annotations__["return"].__name__ in settings.prompts - or settings.dev - ] - return library_prompts + list(custom_prompts.keys()) + # library_prompts = [ + # prompt.__annotations__["return"].__name__ + # for prompt in self.type_to_loader_dict.values() + # if prompt.__annotations__["return"].__name__ in settings.prompts + # or settings.dev + # ] + return list(self.type_to_loader_dict.keys()) + list(custom_prompts.keys()) prompt_creator = PromptCreator() diff --git a/src/backend/langflow/template/base.py b/src/backend/langflow/template/base.py index 887ab187f..da2c8312d 100644 --- a/src/backend/langflow/template/base.py +++ b/src/backend/langflow/template/base.py @@ -219,3 +219,5 @@ class FrontendNode(BaseModel): elif name == "ChatOpenAI" and key == "model_name": field.options = constants.CHAT_OPENAI_MODELS field.is_list = True + + diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index 6bd23d59a..ad408a6c8 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -6,7 +6,17 @@ from langflow.utils.constants import DEFAULT_PYTHON_FUNCTION from langchain.agents import loading -class ZeroShotPromptNode(FrontendNode): +class BasePromptFrontendNode(FrontendNode): + name: str + template: Template + description: str + base_classes: list[str] + + def to_dict(self): + return super().to_dict() + + +class ZeroShotPromptNode(BasePromptFrontendNode): name: str = "ZeroShotPrompt" template: Template = Template( type_name="zero_shot", @@ -227,3 +237,13 @@ class CSVAgentNode(FrontendNode): def to_dict(self): return super().to_dict() + + +class PromptFrontendNode(FrontendNode): + @staticmethod + def format_field(field: TemplateField, name: Optional[str] = None) -> None: + # if field.field_type == "StringPromptTemplate" + # change it to str + if field.field_type == "StringPromptTemplate": + field.field_type = "str" + field.multiline = True