Refactor code and update dependencies

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-07 19:54:14 -03:00
commit 69145f35ca
24 changed files with 142 additions and 71 deletions

View file

@ -81,7 +81,8 @@ def run_migrations_online() -> None:
logger.error(f"Error getting database engine: {e}")
url = os.getenv("LANGFLOW_DATABASE_URL")
url = url or config.get_main_option("sqlalchemy.url")
config.set_main_option("sqlalchemy.url", url)
if url:
config.set_main_option("sqlalchemy.url", url)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",

View file

@ -1,6 +1,6 @@
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from fastapi import HTTPException
from platformdirs import user_cache_dir
@ -20,7 +20,9 @@ API_WORDS = ["api", "key", "token"]
def has_api_terms(word: str):
return "api" in word and ("key" in word or ("token" in word and "tokens" not in word))
return "api" in word and (
"key" in word or ("token" in word and "tokens" not in word)
)
def remove_api_keys(flow: dict):
@ -30,7 +32,11 @@ def remove_api_keys(flow: dict):
node_data = node.get("data").get("node")
template = node_data.get("template")
for value in template.values():
if isinstance(value, dict) and has_api_terms(value["name"]) and value.get("password"):
if (
isinstance(value, dict)
and has_api_terms(value["name"])
and value.get("password")
):
value["value"] = None
return flow
@ -51,7 +57,9 @@ def build_input_keys_response(langchain_object, artifacts):
input_keys_response["input_keys"][key] = value
# If the object has memory, that memory will have a memory_variables attribute
# memory variables should be removed from the input keys
if hasattr(langchain_object, "memory") and hasattr(langchain_object.memory, "memory_variables"):
if hasattr(langchain_object, "memory") and hasattr(
langchain_object.memory, "memory_variables"
):
# Remove memory variables from input keys
input_keys_response["input_keys"] = {
key: value
@ -61,7 +69,9 @@ def build_input_keys_response(langchain_object, artifacts):
# Add memory variables to memory_keys
input_keys_response["memory_keys"] = langchain_object.memory.memory_variables
if hasattr(langchain_object, "prompt") and hasattr(langchain_object.prompt, "template"):
if hasattr(langchain_object, "prompt") and hasattr(
langchain_object.prompt, "template"
):
input_keys_response["template"] = langchain_object.prompt.template
return input_keys_response
@ -96,7 +106,11 @@ def raw_frontend_data_is_valid(raw_frontend_data):
def is_valid_data(frontend_node, raw_frontend_data):
"""Check if the data is valid for processing."""
return frontend_node and "template" in frontend_node and raw_frontend_data_is_valid(raw_frontend_data)
return (
frontend_node
and "template" in frontend_node
and raw_frontend_data_is_valid(raw_frontend_data)
)
def update_template_values(frontend_template, raw_template):
@ -136,12 +150,14 @@ def get_file_path_value(file_path):
# If the path is not in the cache dir, return empty string
# This is to prevent access to files outside the cache dir
# If the path is not a file, return empty string
if not path.exists() or not str(path).startswith(user_cache_dir("langflow", "langflow")):
if not path.exists() or not str(path).startswith(
user_cache_dir("langflow", "langflow")
):
return ""
return file_path
def validate_is_component(flows: List["Flow"]):
def validate_is_component(flows: list["Flow"]):
for flow in flows:
if not flow.data or flow.is_component is not None:
continue
@ -167,7 +183,9 @@ async def check_langflow_version(component: StoreComponentCreate):
langflow_version = get_lf_version_from_pypi()
if langflow_version is None:
raise HTTPException(status_code=500, detail="Unable to verify the latest version of Langflow")
raise HTTPException(
status_code=500, detail="Unable to verify the latest version of Langflow"
)
elif langflow_version != component.last_tested_version:
warnings.warn(
f"Your version of Langflow ({component.last_tested_version}) is outdated. "

View file

@ -26,7 +26,7 @@ class FrontendNodeRequest(FrontendNode):
class ValidatePromptRequest(BaseModel):
name: str
template: str
# optional for tweak call
custom_fields: Optional[dict] = None
frontend_node: Optional[FrontendNodeRequest] = None

View file

@ -105,9 +105,11 @@ async def run_flow_with_caching(
"""
try:
if inputs is not None:
input_values_dict: dict[str, Union[str, list[str]]] = inputs.model_dump()
input_values: list[dict[str, Union[str, list[str]]]] = [
_input.model_dump() for _input in inputs
]
else:
input_values_dict = {}
input_values = [{}]
if outputs is None:
outputs = []
@ -124,7 +126,7 @@ async def run_flow_with_caching(
graph=graph,
flow_id=flow_id,
session_id=session_id,
inputs=input_values_dict,
inputs=input_values,
outputs=outputs,
artifacts=artifacts,
session_service=session_service,
@ -150,7 +152,7 @@ async def run_flow_with_caching(
graph=graph_data,
flow_id=flow_id,
session_id=session_id,
inputs=input_values_dict,
inputs=input_values,
outputs=outputs,
artifacts={},
session_service=session_service,

View file

@ -57,7 +57,7 @@ def read_flows(
try:
auth_settings = settings_service.auth_settings
if auth_settings.AUTO_LOGIN:
flows = session.exec(
flows: list[Flow] = session.exec(
select(Flow).where(
(Flow.user_id == None) | (Flow.user_id == current_user.id) # noqa
)

View file

@ -32,11 +32,12 @@ class ConversationChainComponent(CustomComponent):
else:
chain = ConversationChain(llm=llm, memory=memory)
result = chain.invoke({"input": input_value})
if hasattr(result, "content") and isinstance(result.content, str):
result = result.content
if isinstance(result, dict):
result = result.get(chain.output_key)
elif isinstance(result, str):
result = result
else:
result = result.get("response")
self.status = result
return result
return str(result)

View file

@ -18,7 +18,7 @@ class URLComponent(CustomComponent):
async def build(
self,
urls: list[str],
) -> Record:
) -> list[Record]:
loader = WebBaseLoader(web_paths=urls)
docs = loader.load()
records = self.to_records(docs)

View file

@ -13,10 +13,10 @@ class MergeRecordsComponent(CustomComponent):
def build(self, records: list[Record]) -> Record:
if not records:
return records
return Record()
if len(records) == 1:
return records[0]
merged_record = None
merged_record = Record()
for record in records:
if merged_record is None:
merged_record = record
@ -24,3 +24,13 @@ class MergeRecordsComponent(CustomComponent):
merged_record += record
self.status = merged_record
return merged_record
if __name__ == "__main__":
records = [
Record(data={"key1": "value1"}),
Record(data={"key2": "value2"}),
]
component = MergeRecordsComponent()
result = component.build(records)
print(result)

View file

@ -36,11 +36,18 @@ class RunFlowComponent(CustomComponent):
messages = result_data.messages
records = []
for message in messages:
record = Record(text=message.get("text", ""), data={"result": result_data})
message_dict = (
message if isinstance(message, dict) else message.model_dump()
)
record = Record(
text=message_dict.get("text", ""), data={"result": result_data}
)
records.append(record)
return records
async def build(self, input_value: Text, flow_name: str, tweaks: NestedDict) -> Record:
async def build(
self, input_value: Text, flow_name: str, tweaks: NestedDict
) -> Record:
results: List[Optional[ResultData]] = await self.run_flow(
input_value=input_value, flow_name=flow_name, tweaks=tweaks
)

View file

@ -84,7 +84,8 @@ class ChromaComponent(CustomComponent):
if chroma_server_host is not None:
chroma_settings = chromadb.config.Settings(
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
or None,
chroma_server_host=chroma_server_host,
chroma_server_port=chroma_server_port or None,
chroma_server_grpc_port=chroma_server_grpc_port or None,
@ -98,14 +99,16 @@ class ChromaComponent(CustomComponent):
index_directory = self.resolve_path(index_directory)
documents = []
for _input in inputs:
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
if documents is not None and embedding is not None:
if len(documents) == 0:
raise ValueError("If documents are provided, there must be at least one document.")
raise ValueError(
"If documents are provided, there must be at least one document."
)
chroma = Chroma.from_documents(
documents=documents, # type: ignore
persist_directory=index_directory,

View file

@ -33,7 +33,7 @@ class FAISSComponent(CustomComponent):
index_name: str = "langflow_index",
) -> Union[VectorStore, FAISS, BaseRetriever]:
documents = []
for _input in inputs:
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
else:

View file

@ -9,7 +9,9 @@ from langflow.schema.schema import Record
class MongoDBAtlasComponent(CustomComponent):
display_name = "MongoDB Atlas"
description = "Construct a `MongoDB Atlas Vector Search` vector store from raw documents."
description = (
"Construct a `MongoDB Atlas Vector Search` vector store from raw documents."
)
icon = "MongoDB"
def build_config(self):
@ -26,7 +28,7 @@ class MongoDBAtlasComponent(CustomComponent):
def build(
self,
embedding: Embeddings,
inputs: List[Record],
inputs: Optional[List[Record]] = None,
collection_name: str = "",
db_name: str = "",
index_name: str = "",
@ -37,14 +39,16 @@ class MongoDBAtlasComponent(CustomComponent):
try:
from pymongo import MongoClient
except ImportError:
raise ImportError("Please install pymongo to use MongoDB Atlas Vector Store")
raise ImportError(
"Please install pymongo to use MongoDB Atlas Vector Store"
)
try:
mongo_client: MongoClient = MongoClient(mongodb_atlas_cluster_uri)
collection = mongo_client[db_name][collection_name]
except Exception as e:
raise ValueError(f"Failed to connect to MongoDB Atlas: {e}")
documents = []
for _input in inputs:
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
else:

View file

@ -39,7 +39,6 @@ class MongoDBAtlasSearchComponent(MongoDBAtlasComponent, LCVectorStoreComponent)
vector_store = super().build(
embedding=embedding,
collection_name=collection_name,
documents=[],
db_name=db_name,
index_name=index_name,
mongodb_atlas_cluster_uri=mongodb_atlas_cluster_uri,

View file

@ -45,7 +45,7 @@ class PineconeComponent(CustomComponent):
self,
embedding: Embeddings,
pinecone_env: str,
inputs: List[Record],
inputs: Optional[List[Record]] = None,
text_key: str = "text",
pool_threads: int = 4,
index_name: Optional[str] = None,
@ -61,7 +61,7 @@ class PineconeComponent(CustomComponent):
if not index_name:
raise ValueError("Index Name is required.")
documents = []
for _input in inputs:
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
else:

View file

@ -55,7 +55,7 @@ class PineconeSearchComponent(PineconeComponent, LCVectorStoreComponent):
vector_store = super().build(
embedding=embedding,
pinecone_env=pinecone_env,
documents=[],
inputs=[],
text_key=text_key,
pool_threads=pool_threads,
index_name=index_name,

View file

@ -64,7 +64,7 @@ class QdrantComponent(CustomComponent):
url: Optional[str] = None,
) -> Union[VectorStore, Qdrant, BaseRetriever]:
documents = []
for _input in inputs:
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
else:

View file

@ -59,14 +59,16 @@ class RedisComponent(CustomComponent):
- VectorStore: The Vector Store object.
"""
documents = []
for _input in inputs:
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
if not documents:
if schema is None:
raise ValueError("If no documents are provided, a schema must be provided.")
raise ValueError(
"If no documents are provided, a schema must be provided."
)
redis_vs = Redis.from_existing_index(
embedding=embedding,
index_name=redis_index_name,

View file

@ -1,4 +1,4 @@
from typing import List, Union
from typing import List, Optional, Union
from langchain.schema import BaseRetriever
from langchain_community.vectorstores import VectorStore
@ -28,16 +28,18 @@ class SupabaseComponent(CustomComponent):
def build(
self,
embedding: Embeddings,
inputs: List[Record],
inputs: Optional[List[Record]] = None,
query_name: str = "",
search_kwargs: NestedDict = {},
supabase_service_key: str = "",
supabase_url: str = "",
table_name: str = "",
) -> Union[VectorStore, SupabaseVectorStore, BaseRetriever]:
supabase: Client = create_client(supabase_url, supabase_key=supabase_service_key)
supabase: Client = create_client(
supabase_url, supabase_key=supabase_service_key
)
documents = []
for _input in inputs:
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
else:

View file

@ -15,7 +15,9 @@ from langflow.schema.schema import Record
class VectaraComponent(CustomComponent):
display_name: str = "Vectara"
description: str = "Implementation of Vector Store using Vectara"
documentation = "https://python.langchain.com/docs/integrations/vectorstores/vectara"
documentation = (
"https://python.langchain.com/docs/integrations/vectorstores/vectara"
)
icon = "Vectara"
field_config = {
"vectara_customer_id": {
@ -50,7 +52,7 @@ class VectaraComponent(CustomComponent):
source = "Langflow"
documents = []
for _input in inputs:
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
else:

View file

@ -12,7 +12,9 @@ from langflow.schema.schema import Record
class WeaviateVectorStoreComponent(CustomComponent):
display_name: str = "Weaviate"
description: str = "Implementation of Vector Store using Weaviate"
documentation = "https://python.langchain.com/docs/integrations/vectorstores/weaviate"
documentation = (
"https://python.langchain.com/docs/integrations/vectorstores/weaviate"
)
field_config = {
"url": {"display_name": "Weaviate URL", "value": "http://localhost:8080"},
"api_key": {
@ -79,7 +81,7 @@ class WeaviateVectorStoreComponent(CustomComponent):
index_name = _to_pascal_case(index_name) if index_name else None
documents = []
for _input in inputs:
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
else:

View file

@ -16,7 +16,9 @@ class PGVectorComponent(CustomComponent):
display_name: str = "PGVector"
description: str = "Implementation of Vector Store using PostgreSQL"
documentation = "https://python.langchain.com/docs/integrations/vectorstores/pgvector"
documentation = (
"https://python.langchain.com/docs/integrations/vectorstores/pgvector"
)
def build_config(self):
"""
@ -57,7 +59,7 @@ class PGVectorComponent(CustomComponent):
"""
documents = []
for _input in inputs:
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
else:

View file

@ -48,7 +48,7 @@ class Graph:
self._is_state_vertices: List[str] = []
self._has_session_id_vertices: List[str] = []
self._sorted_vertices_layers: List[List[str]] = []
self._run_id = None
self._run_id = ""
self.top_level_vertices = []
for vertex in self._vertices:
@ -130,9 +130,6 @@ class Graph:
self.state_manager.subscribe(run_id, vertex.update_graph_state)
self._run_id = run_id
def add_state(self, state: str):
self.state_manager.append_state(self._run_id, state)
@property
def sorted_vertices_layers(self) -> List[List[str]]:
if not self._sorted_vertices_layers:
@ -198,7 +195,7 @@ class Graph:
async def run(
self,
inputs: Dict[str, Union[str, list[str]]],
inputs: list[Dict[str, Union[str, list[str]]]],
outputs: list[str],
session_id: str,
stream: Optional[bool] = False,
@ -210,13 +207,12 @@ class Graph:
# of the vertices that are inputs
# if the value is a list, we need to run multiple times
vertex_outputs = []
inputs_values = inputs.get(INPUT_FIELD_NAME, "")
if not isinstance(inputs_values, list):
inputs_values = [inputs_values]
for input_value in inputs_values:
for input_dict in inputs_values:
run_outputs = await self._run(
inputs={INPUT_FIELD_NAME: input_value},
input_components=inputs.get("components", []),
inputs={INPUT_FIELD_NAME: input_dict.get(INPUT_FIELD_NAME)},
input_components=input_dict.get("components", []),
outputs=outputs,
stream=stream,
session_id=session_id,

View file

@ -97,7 +97,6 @@ class Vertex:
self.use_result = False
self.build_times: List[float] = []
self.state = VertexStates.ACTIVE
self.graph_state = {}
def update_graph_state(self, key, new_state, append: bool):
if append:

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Any, Coroutine, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Optional, Tuple, Union
from langchain.agents import AgentExecutor
from langchain.chains.base import Chain
@ -16,6 +16,9 @@ from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.run import get_memory_key, update_memory_keys
from langflow.services.session.service import SessionService
if TYPE_CHECKING:
from langflow.api.v1.schemas import Tweaks
def fix_memory_inputs(langchain_object):
"""
@ -126,7 +129,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"):
result = await runnable.ainvoke(inputs)
else:
raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}")
raise ValueError(
f"Runnable {runnable} does not support inputs of type {type(inputs)}"
)
# Check if the result is a list of AIMessages
if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result):
result = [r.content for r in result]
@ -135,7 +140,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
return result
async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict):
async def process_inputs_dict(
built_object: Union[Chain, VectorStore, Runnable], inputs: dict
):
if isinstance(built_object, Chain):
if inputs is None:
raise ValueError("Inputs must be provided for a Chain")
@ -170,7 +177,9 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]):
return await process_runnable(built_object, inputs)
async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
async def generate_result(
built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]
):
if isinstance(inputs, dict):
result = await process_inputs_dict(built_object, inputs)
elif isinstance(inputs, List) and isinstance(built_object, Runnable):
@ -197,7 +206,7 @@ async def run_graph(
flow_id: str,
stream: bool,
session_id: Optional[str] = None,
inputs: Optional[dict[str, Union[List[str], str]]] = None,
inputs: Optional[list[dict[str, Union[List[str], str]]]] = None,
outputs: Optional[List[str]] = None,
artifacts: Optional[Dict[str, Any]] = None,
session_service: Optional[SessionService] = None,
@ -209,9 +218,11 @@ async def run_graph(
else:
graph_data = graph._graph_data
if not session_id and session_service is not None:
session_id = session_service.generate_key(session_id=flow_id, data_graph=graph_data)
session_id = session_service.generate_key(
session_id=flow_id, data_graph=graph_data
)
if inputs is None:
inputs = {}
inputs = [{}]
outputs = await graph.run(
inputs,
@ -224,14 +235,18 @@ async def run_graph(
return outputs, session_id
def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
def validate_input(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
raise ValueError("graph_data and tweaks should be dictionaries")
nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes")
if not isinstance(nodes, list):
raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key")
raise ValueError(
"graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key"
)
return nodes
@ -240,7 +255,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
template_data = node.get("data", {}).get("node", {}).get("template")
if not isinstance(template_data, dict):
logger.warning(f"Template data for node {node.get('id')} should be a dictionary")
logger.warning(
f"Template data for node {node.get('id')} should be a dictionary"
)
return
for tweak_name, tweak_value in node_tweaks.items():
@ -255,7 +272,9 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None:
vertex.params[tweak_name] = tweak_value
def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
def process_tweaks(
graph_data: Dict[str, Any], tweaks: Union["Tweaks", Dict[str, Dict[str, Any]]]
) -> Dict[str, Any]:
"""
This function is used to tweak the graph data using the node id and the tweaks dict.
@ -291,6 +310,8 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]):
if node_tweaks := tweaks.get(node_id):
apply_tweaks_on_vertex(vertex, node_tweaks)
else:
logger.warning("Each node should be a Vertex with an 'id' attribute of type str")
logger.warning(
"Each node should be a Vertex with an 'id' attribute of type str"
)
return graph