fix: added options and cls to base_classes

This commit is contained in:
Gabriel Almeida 2023-03-31 23:16:16 -03:00
commit 48d2ab27da
2 changed files with 94 additions and 2 deletions

View file

@ -1,5 +1,6 @@
from abc import ABC
from typing import Any, Union
from typing import Any, Optional, Union, Dict
from langflow.utils import constants
from pydantic import BaseModel
@ -16,6 +17,7 @@ class TemplateFieldCreator(BaseModel, ABC):
file_types: list[str] = []
content: Union[str, None] = None
password: bool = False
options: list[str] = []
# _name will be used to store the name of the field
# in the template
name: str = ""
@ -36,6 +38,88 @@ class TemplateFieldCreator(BaseModel, ABC):
result["content"] = self.content
return result
def process_field(
self, key: str, value: Dict[str, Any], name: Optional[str] = None
) -> None:
_type = value["type"]
# Remove 'Optional' wrapper
if "Optional" in _type:
_type = _type.replace("Optional[", "")[:-1]
# Check for list type
if "List" in _type:
_type = _type.replace("List[", "")[:-1]
self.is_list = True
else:
self.is_list = False
# Replace 'Mapping' with 'dict'
if "Mapping" in _type:
_type = _type.replace("Mapping", "dict")
# Change type from str to Tool
self.field_type = "Tool" if key in ["allowed_tools"] else _type
self.field_type = "int" if key in ["max_value_length"] else self.field_type
# Show or not field
self.show = bool(
(self.required and key not in ["input_variables"])
or key
in [
"allowed_tools",
"memory",
"prefix",
"examples",
"temperature",
"model_name",
"headers",
"max_value_length",
]
or "api_key" in key
)
# Add password field
self.password = any(
text in key.lower() for text in ["password", "token", "api", "key"]
)
# Add multline
self.multiline = key in [
"suffix",
"prefix",
"template",
"examples",
"code",
"headers",
]
# Replace dict type with str
if "dict" in self.field_type.lower():
self.field_type = "code"
if key == "dict_":
self.field_type = "file"
self.suffixes = [".json", ".yaml", ".yml"]
self.file_types = ["json", "yaml", "yml"]
# Replace default value with actual value
if "default" in value:
self.value = value["default"]
if key == "headers":
self.value = """{'Authorization':
'Bearer <token>'}"""
# Add options to openai
if name == "OpenAI" and key == "model_name":
self.options = constants.OPENAI_MODELS
self.is_list = True
elif name == "OpenAIChat" and key == "model_name":
self.options = constants.CHAT_OPENAI_MODELS
self.is_list = True
class TemplateField(TemplateFieldCreator):
pass
@ -45,7 +129,13 @@ class Template(BaseModel):
type_name: str
fields: list[TemplateField]
def process_fields(self, name: Optional[str] = None) -> None:
for field in self.fields:
signature = field.to_dict()
field.process_field(field.name, signature, name)
def to_dict(self):
self.process_fields(self.type_name)
result = {field.name: field.to_dict() for field in self.fields}
result["_type"] = self.type_name # type: ignore
return result

View file

@ -173,7 +173,7 @@ def get_base_classes(cls):
result = [cls.__name__]
if not result:
result = [cls.__name__]
return list(set(result))
return list(set(result + [cls.__name__]))
def get_default_factory(module: str, function: str):
@ -333,8 +333,10 @@ def format_dict(d, name: Optional[str] = None):
# Add options to openai
if name == "OpenAI" and key == "model_name":
value["options"] = constants.OPENAI_MODELS
value["list"] = True
elif name == "OpenAIChat" and key == "model_name":
value["options"] = constants.CHAT_OPENAI_MODELS
value["list"] = True
return d