From 48d2ab27daeed601ab5f6532f3e055f9a6436ba1 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Fri, 31 Mar 2023 23:16:16 -0300 Subject: [PATCH] fix: added options and cls to base_classes --- src/backend/langflow/template/base.py | 92 ++++++++++++++++++++++++++- src/backend/langflow/utils/util.py | 4 +- 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/src/backend/langflow/template/base.py b/src/backend/langflow/template/base.py index bcd6ed162..e050be55e 100644 --- a/src/backend/langflow/template/base.py +++ b/src/backend/langflow/template/base.py @@ -1,5 +1,6 @@ from abc import ABC -from typing import Any, Union +from typing import Any, Optional, Union, Dict +from langflow.utils import constants from pydantic import BaseModel @@ -16,6 +17,7 @@ class TemplateFieldCreator(BaseModel, ABC): file_types: list[str] = [] content: Union[str, None] = None password: bool = False + options: list[str] = [] # _name will be used to store the name of the field # in the template name: str = "" @@ -36,6 +38,88 @@ class TemplateFieldCreator(BaseModel, ABC): result["content"] = self.content return result + def process_field( + self, key: str, value: Dict[str, Any], name: Optional[str] = None + ) -> None: + _type = value["type"] + + # Remove 'Optional' wrapper + if "Optional" in _type: + _type = _type.replace("Optional[", "")[:-1] + + # Check for list type + if "List" in _type: + _type = _type.replace("List[", "")[:-1] + self.is_list = True + else: + self.is_list = False + + # Replace 'Mapping' with 'dict' + if "Mapping" in _type: + _type = _type.replace("Mapping", "dict") + + # Change type from str to Tool + self.field_type = "Tool" if key in ["allowed_tools"] else _type + + self.field_type = "int" if key in ["max_value_length"] else self.field_type + + # Show or not field + self.show = bool( + (self.required and key not in ["input_variables"]) + or key + in [ + "allowed_tools", + "memory", + "prefix", + "examples", + "temperature", + "model_name", + "headers", + "max_value_length", + ] + or "api_key" in key + ) + + # Add password field + self.password = any( + text in key.lower() for text in ["password", "token", "api", "key"] + ) + + # Add multline + self.multiline = key in [ + "suffix", + "prefix", + "template", + "examples", + "code", + "headers", + ] + + # Replace dict type with str + if "dict" in self.field_type.lower(): + self.field_type = "code" + + if key == "dict_": + self.field_type = "file" + self.suffixes = [".json", ".yaml", ".yml"] + self.file_types = ["json", "yaml", "yml"] + + # Replace default value with actual value + if "default" in value: + self.value = value["default"] + + if key == "headers": + self.value = """{'Authorization': + 'Bearer '}""" + + # Add options to openai + if name == "OpenAI" and key == "model_name": + self.options = constants.OPENAI_MODELS + self.is_list = True + elif name == "OpenAIChat" and key == "model_name": + self.options = constants.CHAT_OPENAI_MODELS + self.is_list = True + class TemplateField(TemplateFieldCreator): pass @@ -45,7 +129,13 @@ class Template(BaseModel): type_name: str fields: list[TemplateField] + def process_fields(self, name: Optional[str] = None) -> None: + for field in self.fields: + signature = field.to_dict() + field.process_field(field.name, signature, name) + def to_dict(self): + self.process_fields(self.type_name) result = {field.name: field.to_dict() for field in self.fields} result["_type"] = self.type_name # type: ignore return result diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index 961058950..2fee19475 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -173,7 +173,7 @@ def get_base_classes(cls): result = [cls.__name__] if not result: result = [cls.__name__] - return list(set(result)) + return list(set(result + [cls.__name__])) def get_default_factory(module: str, function: str): @@ -333,8 +333,10 @@ def format_dict(d, name: Optional[str] = None): # Add options to openai if name == "OpenAI" and key == "model_name": value["options"] = constants.OPENAI_MODELS + value["list"] = True elif name == "OpenAIChat" and key == "model_name": value["options"] = constants.CHAT_OPENAI_MODELS + value["list"] = True return d