Merge branch 'logspace-ai:dev' into gabfr-patch-1

This commit is contained in:
Gabriel Ferreira Rosalino 2023-05-27 14:15:16 -03:00 committed by GitHub
commit aac6e3f2e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 709 additions and 381 deletions

View file

@ -108,6 +108,8 @@ utilities:
- SQLDatabase
vectorstores:
- Chroma
- Qdrant
- Weaviate
wrappers:
- RequestsWrapper # Wait more tests
# - ChatPromptTemplate

View file

@ -1,4 +1,3 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from langflow.graph.base import Node

View file

@ -42,7 +42,7 @@ class LangChainTypeCreator(BaseModel, ABC):
# so we should update the result dict
node = self.frontend_node(name)
if node is not None:
node = node.to_dict()
node = node.to_dict() # type: ignore
result[self.type_name].update(node)
return result

View file

@ -10,7 +10,6 @@ from langchain import (
requests,
text_splitter,
utilities,
vectorstores,
)
from langchain.agents import agent_toolkits
from langchain.chat_models import ChatOpenAI

View file

@ -1,5 +1,6 @@
import re
from abc import ABC
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, List, Optional, Union
from pydantic import BaseModel
@ -41,76 +42,6 @@ class TemplateFieldCreator(BaseModel, ABC):
result["content"] = self.content
return result
def process_field(
self, key: str, value: Dict[str, Any], name: Optional[str] = None
) -> None:
_type = value["type"]
# Remove 'Optional' wrapper
if "Optional" in _type:
_type = _type.replace("Optional[", "")[:-1]
# Check for list type
if "List" in _type:
_type = _type.replace("List[", "")[:-1]
self.is_list = True
# Replace 'Mapping' with 'dict'
if "Mapping" in _type:
_type = _type.replace("Mapping", "dict")
# Change type from str to Tool
self.field_type = "Tool" if key in {"allowed_tools"} else self.field_type
self.field_type = "int" if key in {"max_value_length"} else self.field_type
# Show or not field
self.show = bool(
(self.required and key not in ["input_variables"])
or key in FORCE_SHOW_FIELDS
or "api_key" in key
)
# Add password field
self.password = any(
text in key.lower() for text in {"password", "token", "api", "key"}
)
# Add multline
self.multiline = key in {
"suffix",
"prefix",
"template",
"examples",
"code",
"headers",
}
# Replace dict type with str
if "dict" in self.field_type.lower():
self.field_type = "code"
if key == "dict_":
self.field_type = "file"
self.suffixes = [".json", ".yaml", ".yml"]
self.file_types = ["json", "yaml", "yml"]
# Replace default value with actual value
if "default" in value:
self.value = value["default"]
if key == "headers":
self.value = """{'Authorization':
'Bearer <token>'}"""
# Add options to openai
if name == "OpenAI" and key == "model_name":
self.options = constants.OPENAI_MODELS
self.is_list = True
elif name == "ChatOpenAI" and key == "model_name":
self.options = constants.CHAT_OPENAI_MODELS
self.is_list = True
class TemplateField(TemplateFieldCreator):
pass
@ -139,10 +70,10 @@ class Template(BaseModel):
class FrontendNode(BaseModel):
template: Template
description: str
base_classes: list
base_classes: List[str]
name: str = ""
def to_dict(self):
def to_dict(self) -> dict:
return {
self.name: {
"template": self.template.to_dict(self.format_field),
@ -153,53 +84,145 @@ class FrontendNode(BaseModel):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
"""Formats a given field based on its attributes and value."""
SPECIAL_FIELD_HANDLERS = {
"allowed_tools": lambda field: "Tool",
"max_value_length": lambda field: "int",
}
key = field.name
value = field.to_dict()
_type = value["type"]
# Remove 'Optional' wrapper
if "Optional" in _type:
_type = _type.replace("Optional[", "")[:-1]
_type = FrontendNode.remove_optional(_type)
_type, is_list = FrontendNode.check_for_list_type(_type)
field.is_list = is_list or field.is_list
_type = FrontendNode.replace_mapping_with_dict(_type)
_type = FrontendNode.handle_union_type(_type)
# Check for list type
if "List" in _type or "Sequence" in _type:
_type = _type.replace("List[", "")
_type = _type.replace("Sequence[", "")[:-1]
field.is_list = True
field.field_type = FrontendNode.handle_special_field(
field, key, _type, SPECIAL_FIELD_HANDLERS
)
field.field_type = FrontendNode.handle_dict_type(field, _type)
field.show = FrontendNode.should_show_field(key, field.required)
field.password = FrontendNode.should_be_password(key, field.show)
field.multiline = FrontendNode.should_be_multiline(key)
# Replace 'Mapping' with 'dict'
if "Mapping" in _type:
_type = _type.replace("Mapping", "dict")
FrontendNode.replace_default_value(field, value)
FrontendNode.handle_specific_field_values(field, key, name)
FrontendNode.handle_kwargs_field(field)
FrontendNode.handle_api_key_field(field, key)
# {'type': 'Union[float, Tuple[float, float], NoneType]'} != {'type': 'float'}
@staticmethod
def remove_optional(_type: str) -> str:
"""Removes 'Optional' wrapper from the type if present."""
return re.sub(r"Optional\[(.*)\]", r"\1", _type)
@staticmethod
def check_for_list_type(_type: str) -> tuple:
"""Checks for list type and returns the modified type and a boolean indicating if it's a list."""
is_list = "List" in _type or "Sequence" in _type
if is_list:
_type = re.sub(r"(List|Sequence)\[(.*)\]", r"\2", _type)
return _type, is_list
@staticmethod
def replace_mapping_with_dict(_type: str) -> str:
"""Replaces 'Mapping' with 'dict'."""
return _type.replace("Mapping", "dict")
@staticmethod
def handle_union_type(_type: str) -> str:
"""Simplifies the 'Union' type to the first type in the Union."""
if "Union" in _type:
_type = _type.replace("Union[", "")[:-1]
_type = _type.split(",")[0]
_type = _type.replace("]", "").replace("[", "")
return _type
field.field_type = _type
@staticmethod
def handle_special_field(
field, key: str, _type: str, SPECIAL_FIELD_HANDLERS
) -> str:
"""Handles special field by using the respective handler if present."""
handler = SPECIAL_FIELD_HANDLERS.get(key)
return handler(field) if handler else _type
# Change type from str to Tool
field.field_type = "Tool" if key in {"allowed_tools"} else field.field_type
@staticmethod
def handle_dict_type(field: TemplateField, _type: str) -> str:
"""Handles 'dict' type by replacing it with 'code' or 'file' based on the field name."""
if "dict" in _type.lower():
if field.name == "dict_":
field.field_type = "file"
field.suffixes = [".json", ".yaml", ".yml"]
field.file_types = ["json", "yaml", "yml"]
else:
field.field_type = "code"
return _type
field.field_type = "int" if key in {"max_value_length"} else field.field_type
@staticmethod
def replace_default_value(field: TemplateField, value: dict) -> None:
"""Replaces default value with actual value if 'default' is present in value."""
if "default" in value:
field.value = value["default"]
# Show or not field
field.show = bool(
(field.required and key not in ["input_variables"])
@staticmethod
def handle_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
"""Handles specific field values for certain fields."""
if key == "headers":
field.value = """{'Authorization':
'Bearer <token>'}"""
if name == "OpenAI" and key == "model_name":
field.options = constants.OPENAI_MODELS
field.is_list = True
elif name == "ChatOpenAI" and key == "model_name":
field.options = constants.CHAT_OPENAI_MODELS
field.is_list = True
if "api_key" in key and "OpenAI" in str(name):
field.display_name = "OpenAI API Key"
field.required = False
if field.value is None:
field.value = ""
@staticmethod
def handle_kwargs_field(field: TemplateField) -> None:
"""Handles kwargs field by setting certain attributes."""
if "kwargs" in field.name.lower():
field.advanced = True
field.required = False
field.show = False
@staticmethod
def handle_api_key_field(field: TemplateField, key: str) -> None:
"""Handles api key field by setting certain attributes."""
if "api" in key.lower() and "key" in key.lower():
field.required = False
field.advanced = False
@staticmethod
def should_show_field(key: str, required: bool) -> bool:
"""Determines whether the field should be shown."""
return (
(required and key not in ["input_variables"])
or key in FORCE_SHOW_FIELDS
or "api" in key
or ("key" in key and "input" not in key and "output" not in key)
)
# Add password field
field.password = (
@staticmethod
def should_be_password(key: str, show: bool) -> bool:
"""Determines whether the field should be a password field."""
return (
any(text in key.lower() for text in {"password", "token", "api", "key"})
and field.show
and show
)
# Add multline
field.multiline = key in {
@staticmethod
def should_be_multiline(key: str) -> bool:
"""Determines whether the field should be multiline."""
return key in {
"suffix",
"prefix",
"template",
@ -209,43 +232,24 @@ class FrontendNode(BaseModel):
"description",
}
# Replace dict type with str
if "dict" in field.field_type.lower():
field.field_type = "code"
@staticmethod
def replace_dict_with_code_or_file(
field: TemplateField, _type: str, key: str
) -> str:
"""Replaces 'dict' type with 'code' or 'file'."""
if "dict" in _type.lower():
if key == "dict_":
field.field_type = "file"
field.suffixes = [".json", ".yaml", ".yml"]
field.file_types = ["json", "yaml", "yml"]
else:
field.field_type = "code"
return field.field_type
if key == "dict_":
field.field_type = "file"
field.suffixes = [".json", ".yaml", ".yml"]
field.file_types = ["json", "yaml", "yml"]
# Replace default value with actual value
@staticmethod
def set_field_default_value(field: TemplateField, value: dict, key: str) -> None:
"""Sets the field value with the default value if present."""
if "default" in value:
field.value = value["default"]
if key == "headers":
field.value = """{'Authorization':
'Bearer <token>'}"""
# Add options to openai
if name == "OpenAI" and key == "model_name":
field.options = constants.OPENAI_MODELS
field.is_list = True
elif name == "ChatOpenAI":
if key == "model_name":
field.options = constants.CHAT_OPENAI_MODELS
field.is_list = True
if "api_key" in key and "OpenAI" in str(name):
field.display_name = "OpenAI API Key"
field.required = False
if field.value is None:
field.value = ""
if "kwargs" in field.name.lower():
field.advanced = True
field.required = False
field.show = False
# If the field.name contains api or api and key, then it might be an api key
# other conditions are to make sure that it is not an input or output variable
if "api" in key.lower() and "key" in key.lower():
field.required = False
field.advanced = False
field.value = """{'Authorization': 'Bearer <token>'}"""

View file

@ -634,6 +634,26 @@ class VectorStoreFrontendNode(FrontendNode):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
# Define common field attributes
basic_fields = ["work_dir", "collection_name", "api_key", "location"]
advanced_fields = [
"n_dim",
"key",
"prefix",
"distance_func",
"content_payload_key",
"metadata_payload_key",
"timeout",
"host",
"path",
"url",
"port",
"https",
"prefer_grpc",
"grpc_port",
]
# Check and set field attributes
if field.name == "texts":
field.name = "documents"
field.field_type = "TextSplitter"
@ -642,7 +662,7 @@ class VectorStoreFrontendNode(FrontendNode):
field.show = True
field.advanced = False
if "embedding" in field.name:
elif "embedding" in field.name:
# for backwards compatibility
field.name = "embedding"
field.required = True
@ -651,9 +671,21 @@ class VectorStoreFrontendNode(FrontendNode):
field.display_name = "Embedding"
field.field_type = "Embeddings"
elif field.name == "n_dim":
field.show = True
field.advanced = True
elif field.name == "work_dir":
elif field.name in basic_fields:
field.show = True
field.advanced = False
if field.name == "api_key":
field.display_name = "API Key"
field.password = True
elif field.name == "location":
field.value = ":memory:"
field.placeholder = ":memory:"
elif field.name in advanced_fields:
field.show = True
field.advanced = True
if "key" in field.name:
field.password = False
# TODO: Weaviate requires weaviate_url to be passed as it is not part of
# the class or from_texts method. We need the add_extra_fields to fix this