refactor settings

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-04 00:04:35 -03:00
commit b81b596b55
23 changed files with 109 additions and 89 deletions

View file

@ -39,10 +39,10 @@ def get_all():
# custom_components is a list of dicts
# need to merge all the keys into one dict
custom_components_from_file = {}
if settings.components_path:
if settings.COMPONENTS_PATH:
custom_component_dicts = [
build_langchain_custom_component_list_from_path(str(path))
for path in settings.components_path
for path in settings.COMPONENTS_PATH
]
for custom_component_dict in custom_component_dicts:
custom_components_from_file = merge_nested_dicts(

View file

@ -61,7 +61,7 @@ def update_flow(
if not db_flow:
raise HTTPException(status_code=404, detail="Flow not found")
flow_data = flow.dict(exclude_unset=True)
if settings.remove_api_keys:
if settings.REMOVE_API_KEYS:
flow_data = remove_api_keys(flow_data)
for key, value in flow_data.items():
setattr(db_flow, key, value)

View file

@ -59,7 +59,7 @@ class AgentCreator(LangChainTypeCreator):
if hasattr(agent, "function_name")
else agent.__name__
)
if agent_name in settings.agents or settings.dev:
if agent_name in settings.AGENTS or settings.DEV:
names.append(agent_name)
return names

View file

@ -28,7 +28,7 @@ class LangChainTypeCreator(BaseModel, ABC):
"""A dict with the name of the component as key and the documentation link as value."""
if self.name_docs_dict is None:
try:
type_settings = getattr(settings, self.type_name)
type_settings = getattr(settings, self.type_name.upper())
self.name_docs_dict = {
name: value_dict["documentation"]
for name, value_dict in type_settings.items()

View file

@ -43,7 +43,7 @@ class ChainCreator(LangChainTypeCreator):
self.type_dict = {
name: chain
for name, chain in self.type_dict.items()
if name in settings.chains or settings.dev
if name in settings.CHAINS or settings.DEV
}
return self.type_dict

View file

@ -33,7 +33,7 @@ class DocumentLoaderCreator(LangChainTypeCreator):
return [
documentloader.__name__
for documentloader in self.type_to_loader_dict.values()
if documentloader.__name__ in settings.documentloaders or settings.dev
if documentloader.__name__ in settings.DOCUMENTLOADERS or settings.DEV
]

View file

@ -35,7 +35,7 @@ class EmbeddingCreator(LangChainTypeCreator):
return [
embedding.__name__
for embedding in self.type_to_loader_dict.values()
if embedding.__name__ in settings.embeddings or settings.dev
if embedding.__name__ in settings.EMBEDDINGS or settings.DEV
]

View file

@ -36,7 +36,7 @@ class LLMCreator(LangChainTypeCreator):
return [
llm.__name__
for llm in self.type_to_loader_dict.values()
if llm.__name__ in settings.llms or settings.dev
if llm.__name__ in settings.LLMS or settings.DEV
]

View file

@ -51,7 +51,7 @@ class MemoryCreator(LangChainTypeCreator):
return [
memory.__name__
for memory in self.type_to_loader_dict.values()
if memory.__name__ in settings.memories or settings.dev
if memory.__name__ in settings.MEMORIES or settings.DEV
]

View file

@ -33,7 +33,7 @@ class OutputParserCreator(LangChainTypeCreator):
self.type_dict = {
name: output_parser
for name, output_parser in self.type_dict.items()
if name in settings.output_parsers or settings.dev
if name in settings.OUTPUT_PARSERS or settings.DEV
}
return self.type_dict

View file

@ -34,7 +34,7 @@ class PromptCreator(LangChainTypeCreator):
self.type_dict = {
name: prompt
for name, prompt in self.type_dict.items()
if name in settings.prompts or settings.dev
if name in settings.PROMPTS or settings.DEV
}
return self.type_dict

View file

@ -51,7 +51,7 @@ class RetrieverCreator(LangChainTypeCreator):
return [
retriever
for retriever in self.type_to_loader_dict.keys()
if retriever in settings.retrievers or settings.dev
if retriever in settings.RETRIEVERS or settings.DEV
]

View file

@ -33,7 +33,7 @@ class TextSplitterCreator(LangChainTypeCreator):
return [
textsplitter.__name__
for textsplitter in self.type_to_loader_dict.values()
if textsplitter.__name__ in settings.textsplitters or settings.dev
if textsplitter.__name__ in settings.TEXTSPLITTERS or settings.DEV
]

View file

@ -35,7 +35,7 @@ class ToolkitCreator(LangChainTypeCreator):
)
# if toolkit_name is not lower case it is a class
for toolkit_name in agent_toolkits.__all__
if not toolkit_name.islower() and toolkit_name in settings.toolkits
if not toolkit_name.islower() and toolkit_name in settings.TOOLKITS
}
return self.type_dict

View file

@ -74,7 +74,7 @@ class ToolCreator(LangChainTypeCreator):
tool_name = tool_params.get("name") or tool
if tool_name in settings.tools or settings.dev:
if tool_name in settings.TOOLS or settings.DEV:
if tool_name == "JsonSpec":
tool_params["path"] = tool_params.pop("dict_") # type: ignore
all_tools[tool_name] = {

View file

@ -35,7 +35,7 @@ class UtilityCreator(LangChainTypeCreator):
self.type_dict = {
name: utility
for name, utility in self.type_dict.items()
if name in settings.utilities or settings.dev
if name in settings.UTILITIES or settings.DEV
}
return self.type_dict

View file

@ -69,7 +69,7 @@ def setup_llm_caching():
try:
set_langchain_cache(settings)
except ImportError:
logger.warning(f"Could not import {settings.cache}. ")
logger.warning(f"Could not import {settings.CACHE}. ")
except Exception as exc:
logger.warning(f"Could not setup LLM caching. Error: {exc}")

View file

@ -46,7 +46,7 @@ class VectorstoreCreator(LangChainTypeCreator):
return [
vectorstore
for vectorstore in self.type_to_loader_dict.keys()
if vectorstore in settings.vectorstores or settings.dev
if vectorstore in settings.VECTORSTORES or settings.DEV
]

View file

@ -3,69 +3,76 @@ from typing import Optional, List
from pathlib import Path
import yaml
from pydantic import BaseSettings, root_validator
from pydantic import BaseSettings, root_validator, validator
from langflow.utils.logger import logger
BASE_COMPONENTS_PATH = Path(__file__).parent / "components"
class Settings(BaseSettings):
chains: dict = {}
agents: dict = {}
prompts: dict = {}
llms: dict = {}
tools: dict = {}
memories: dict = {}
embeddings: dict = {}
vectorstores: dict = {}
documentloaders: dict = {}
wrappers: dict = {}
retrievers: dict = {}
toolkits: dict = {}
textsplitters: dict = {}
utilities: dict = {}
output_parsers: dict = {}
custom_components: dict = {}
CHAINS: dict = {}
AGENTS: dict = {}
PROMPTS: dict = {}
LLMS: dict = {}
TOOLS: dict = {}
MEMORIES: dict = {}
EMBEDDINGS: dict = {}
VECTORSTORES: dict = {}
DOCUMENTLOADERS: dict = {}
WRAPPERS: dict = {}
RETRIEVERS: dict = {}
TOOLKITS: dict = {}
TEXTSPLITTERS: dict = {}
UTILITIES: dict = {}
OUTPUT_PARSERS: dict = {}
CUSTOM_COMPONENTS: dict = {}
dev: bool = False
database_url: Optional[str] = None
cache: str = "InMemoryCache"
remove_api_keys: bool = False
components_path: List[Path]
DEV: bool = False
DATABASE_URL: Optional[str] = None
CACHE: str = "InMemoryCache"
REMOVE_API_KEYS: bool = False
COMPONENTS_PATH: List[Path] = []
@root_validator(pre=True)
def set_env_variables(cls, values):
if "database_url" not in values:
@validator("DATABASE_URL", pre=True)
def set_database_url(cls, value):
if not value:
logger.debug(
"No database_url provided, trying LANGFLOW_DATABASE_URL env variable"
)
if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"):
values["database_url"] = langflow_database_url
value = langflow_database_url
logger.debug("Using LANGFLOW_DATABASE_URL env variable.")
else:
logger.debug("No DATABASE_URL env variable, using sqlite database")
values["database_url"] = "sqlite:///./langflow.db"
value = "sqlite:///./langflow.db"
if not values.get("components_path"):
values["components_path"] = [BASE_COMPONENTS_PATH]
logger.debug("No components_path provided, using default components path")
elif BASE_COMPONENTS_PATH not in values["components_path"]:
values["components_path"].append(BASE_COMPONENTS_PATH)
logger.debug("Adding default components path to components_path")
return value
@validator("COMPONENTS_PATH", pre=True)
def set_components_path(cls, value):
if os.getenv("LANGFLOW_COMPONENTS_PATH"):
logger.debug("Adding LANGFLOW_COMPONENTS_PATH to components_path")
langflow_component_path = Path(os.getenv("LANGFLOW_COMPONENTS_PATH"))
if (
langflow_component_path.exists()
and langflow_component_path not in values["components_path"]
and langflow_component_path not in value
):
values["components_path"].append(langflow_component_path)
value.append(langflow_component_path)
logger.debug(f"Adding {langflow_component_path} to components_path")
return values
if not value:
value = [BASE_COMPONENTS_PATH]
logger.debug("No components_path provided, using default components path")
elif BASE_COMPONENTS_PATH not in value:
value.append(BASE_COMPONENTS_PATH)
logger.debug("Adding default components path to components_path")
return value
class Config:
validate_assignment = True
extra = "ignore"
env_prefix = "LANGFLOW_"
@root_validator(allow_reuse=True)
def validate_lists(cls, values):
@ -76,35 +83,43 @@ class Settings(BaseSettings):
def update_from_yaml(self, file_path: str, dev: bool = False):
new_settings = load_settings_from_yaml(file_path)
self.chains = new_settings.chains or {}
self.agents = new_settings.agents or {}
self.prompts = new_settings.prompts or {}
self.llms = new_settings.llms or {}
self.tools = new_settings.tools or {}
self.memories = new_settings.memories or {}
self.wrappers = new_settings.wrappers or {}
self.toolkits = new_settings.toolkits or {}
self.textsplitters = new_settings.textsplitters or {}
self.utilities = new_settings.utilities or {}
self.embeddings = new_settings.embeddings or {}
self.vectorstores = new_settings.vectorstores or {}
self.documentloaders = new_settings.documentloaders or {}
self.retrievers = new_settings.retrievers or {}
self.output_parsers = new_settings.output_parsers or {}
self.custom_components = new_settings.custom_components or {}
self.components_path = new_settings.components_path or []
self.dev = dev
self.CHAINS = new_settings.CHAINS or {}
self.AGENTS = new_settings.AGENTS or {}
self.PROMPTS = new_settings.PROMPTS or {}
self.LLMS = new_settings.LLMS or {}
self.TOOLS = new_settings.TOOLS or {}
self.MEMORIES = new_settings.MEMORIES or {}
self.WRAPPERS = new_settings.WRAPPERS or {}
self.TOOLKITS = new_settings.TOOLKITS or {}
self.TEXTSPLITTERS = new_settings.TEXTSPLITTERS or {}
self.UTILITIES = new_settings.UTILITIES or {}
self.EMBEDDINGS = new_settings.EMBEDDINGS or {}
self.VECTORSTORES = new_settings.VECTORSTORES or {}
self.DOCUMENTLOADERS = new_settings.DOCUMENTLOADERS or {}
self.RETRIEVERS = new_settings.RETRIEVERS or {}
self.OUTPUT_PARSERS = new_settings.OUTPUT_PARSERS or {}
self.CUSTOM_COMPONENTS = new_settings.CUSTOM_COMPONENTS or {}
self.COMPONENTS_PATH = new_settings.COMPONENTS_PATH or []
self.DEV = dev
def update_settings(self, **kwargs):
logger.debug("Updating settings")
for key, value in kwargs.items():
if hasattr(self, key):
if isinstance(getattr(self, key), list):
if isinstance(value, list):
getattr(self, key).extend(value)
else:
getattr(self, key).append(value)
# value may contain sensitive information, so we don't want to log it
if not hasattr(self, key):
logger.debug(f"Key {key} not found in settings")
continue
logger.debug(f"Updating {key}")
if isinstance(getattr(self, key), list):
if isinstance(value, list):
getattr(self, key).extend(value)
logger.debug(f"Extended {key}")
else:
setattr(self, key, value)
getattr(self, key).append(value)
logger.debug(f"Appended {key}")
else:
setattr(self, key, value)
logger.debug(f"Updated {key}")
def save_settings_to_yaml(settings: Settings, file_path: str):
@ -123,6 +138,12 @@ def load_settings_from_yaml(file_path: str) -> Settings:
with open(file_path, "r") as f:
settings_dict = yaml.safe_load(f)
settings_dict = {k.upper(): v for k, v in settings_dict.items()}
for key in settings_dict:
if key not in Settings.__fields__.keys():
raise KeyError(f"Key {key} not found in settings")
logger.debug(f"Loading {len(settings_dict[key])} {key} from {file_path}")
return Settings(**settings_dict)

View file

@ -12,7 +12,6 @@ from langflow.graph.vertex.types import (
FileToolVertex,
LLMVertex,
ToolkitVertex,
WrapperVertex,
)
from langflow.processing.process import get_result_and_thought
from langflow.utils.payload import get_root_node
@ -292,11 +291,11 @@ def test_file_tool_node_build(openapi_graph):
assert not Path(file_path).exists()
def test_wrapper_node_build(openapi_graph):
wrapper_node = get_node_by_type(openapi_graph, WrapperVertex)
assert wrapper_node is not None
built_object = wrapper_node.build()
assert built_object is not None
# def test_wrapper_node_build(openapi_graph):
# wrapper_node = get_node_by_type(openapi_graph, WrapperVertex)
# assert wrapper_node is not None
# built_object = wrapper_node.build()
# assert built_object is not None
def test_get_result_and_thought(basic_graph):

View file

@ -7,7 +7,7 @@ def test_llms_settings(client: TestClient):
assert response.status_code == 200
json_response = response.json()
llms = json_response["llms"]
assert set(llms.keys()) == set(settings.llms)
assert set(llms.keys()) == set(settings.LLMS)
# def test_hugging_face_hub(client: TestClient):

View file

@ -7,7 +7,7 @@ def test_prompts_settings(client: TestClient):
assert response.status_code == 200
json_response = response.json()
prompts = json_response["prompts"]
assert set(prompts.keys()) == set(settings.prompts)
assert set(prompts.keys()) == set(settings.PROMPTS)
def test_prompt_template(client: TestClient):

View file

@ -9,4 +9,4 @@ def test_vectorstores_settings(client: TestClient):
assert response.status_code == 200
json_response = response.json()
vectorstores = json_response["vectorstores"]
assert set(vectorstores.keys()) == set(settings.vectorstores)
assert set(vectorstores.keys()) == set(settings.VECTORSTORES)