diff --git a/src/backend/langflow/alembic/env.py b/src/backend/langflow/alembic/env.py index bc470f05b..55f5afd14 100644 --- a/src/backend/langflow/alembic/env.py +++ b/src/backend/langflow/alembic/env.py @@ -81,7 +81,8 @@ def run_migrations_online() -> None: logger.error(f"Error getting database engine: {e}") url = os.getenv("LANGFLOW_DATABASE_URL") url = url or config.get_main_option("sqlalchemy.url") - config.set_main_option("sqlalchemy.url", url) + if url: + config.set_main_option("sqlalchemy.url", url) connectable = engine_from_config( config.get_section(config.config_ini_section, {}), prefix="sqlalchemy.", diff --git a/src/backend/langflow/api/utils.py b/src/backend/langflow/api/utils.py index fc9a4cb7f..2a38676d5 100644 --- a/src/backend/langflow/api/utils.py +++ b/src/backend/langflow/api/utils.py @@ -1,6 +1,6 @@ import warnings from pathlib import Path -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from fastapi import HTTPException from platformdirs import user_cache_dir @@ -20,7 +20,9 @@ API_WORDS = ["api", "key", "token"] def has_api_terms(word: str): - return "api" in word and ("key" in word or ("token" in word and "tokens" not in word)) + return "api" in word and ( + "key" in word or ("token" in word and "tokens" not in word) + ) def remove_api_keys(flow: dict): @@ -30,7 +32,11 @@ def remove_api_keys(flow: dict): node_data = node.get("data").get("node") template = node_data.get("template") for value in template.values(): - if isinstance(value, dict) and has_api_terms(value["name"]) and value.get("password"): + if ( + isinstance(value, dict) + and has_api_terms(value["name"]) + and value.get("password") + ): value["value"] = None return flow @@ -51,7 +57,9 @@ def build_input_keys_response(langchain_object, artifacts): input_keys_response["input_keys"][key] = value # If the object has memory, that memory will have a memory_variables attribute # memory variables should be removed from the input keys - if hasattr(langchain_object, "memory") and hasattr(langchain_object.memory, "memory_variables"): + if hasattr(langchain_object, "memory") and hasattr( + langchain_object.memory, "memory_variables" + ): # Remove memory variables from input keys input_keys_response["input_keys"] = { key: value @@ -61,7 +69,9 @@ def build_input_keys_response(langchain_object, artifacts): # Add memory variables to memory_keys input_keys_response["memory_keys"] = langchain_object.memory.memory_variables - if hasattr(langchain_object, "prompt") and hasattr(langchain_object.prompt, "template"): + if hasattr(langchain_object, "prompt") and hasattr( + langchain_object.prompt, "template" + ): input_keys_response["template"] = langchain_object.prompt.template return input_keys_response @@ -96,7 +106,11 @@ def raw_frontend_data_is_valid(raw_frontend_data): def is_valid_data(frontend_node, raw_frontend_data): """Check if the data is valid for processing.""" - return frontend_node and "template" in frontend_node and raw_frontend_data_is_valid(raw_frontend_data) + return ( + frontend_node + and "template" in frontend_node + and raw_frontend_data_is_valid(raw_frontend_data) + ) def update_template_values(frontend_template, raw_template): @@ -136,12 +150,14 @@ def get_file_path_value(file_path): # If the path is not in the cache dir, return empty string # This is to prevent access to files outside the cache dir # If the path is not a file, return empty string - if not path.exists() or not str(path).startswith(user_cache_dir("langflow", "langflow")): + if not path.exists() or not str(path).startswith( + user_cache_dir("langflow", "langflow") + ): return "" return file_path -def validate_is_component(flows: List["Flow"]): +def validate_is_component(flows: list["Flow"]): for flow in flows: if not flow.data or flow.is_component is not None: continue @@ -167,7 +183,9 @@ async def check_langflow_version(component: StoreComponentCreate): langflow_version = get_lf_version_from_pypi() if langflow_version is None: - raise HTTPException(status_code=500, detail="Unable to verify the latest version of Langflow") + raise HTTPException( + status_code=500, detail="Unable to verify the latest version of Langflow" + ) elif langflow_version != component.last_tested_version: warnings.warn( f"Your version of Langflow ({component.last_tested_version}) is outdated. " diff --git a/src/backend/langflow/api/v1/base.py b/src/backend/langflow/api/v1/base.py index bad43c437..eed84815f 100644 --- a/src/backend/langflow/api/v1/base.py +++ b/src/backend/langflow/api/v1/base.py @@ -26,7 +26,7 @@ class FrontendNodeRequest(FrontendNode): class ValidatePromptRequest(BaseModel): name: str template: str - # optional for tweak call + custom_fields: Optional[dict] = None frontend_node: Optional[FrontendNodeRequest] = None diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index b2d175f77..88870b49c 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -105,9 +105,11 @@ async def run_flow_with_caching( """ try: if inputs is not None: - input_values_dict: dict[str, Union[str, list[str]]] = inputs.model_dump() + input_values: list[dict[str, Union[str, list[str]]]] = [ + _input.model_dump() for _input in inputs + ] else: - input_values_dict = {} + input_values = [{}] if outputs is None: outputs = [] @@ -124,7 +126,7 @@ async def run_flow_with_caching( graph=graph, flow_id=flow_id, session_id=session_id, - inputs=input_values_dict, + inputs=input_values, outputs=outputs, artifacts=artifacts, session_service=session_service, @@ -150,7 +152,7 @@ async def run_flow_with_caching( graph=graph_data, flow_id=flow_id, session_id=session_id, - inputs=input_values_dict, + inputs=input_values, outputs=outputs, artifacts={}, session_service=session_service, diff --git a/src/backend/langflow/api/v1/flows.py b/src/backend/langflow/api/v1/flows.py index dd60d5fed..558acb727 100644 --- a/src/backend/langflow/api/v1/flows.py +++ b/src/backend/langflow/api/v1/flows.py @@ -57,7 +57,7 @@ def read_flows( try: auth_settings = settings_service.auth_settings if auth_settings.AUTO_LOGIN: - flows = session.exec( + flows: list[Flow] = session.exec( select(Flow).where( (Flow.user_id == None) | (Flow.user_id == current_user.id) # noqa ) diff --git a/src/backend/langflow/components/chains/ConversationChain.py b/src/backend/langflow/components/chains/ConversationChain.py index 774632412..09b10d506 100644 --- a/src/backend/langflow/components/chains/ConversationChain.py +++ b/src/backend/langflow/components/chains/ConversationChain.py @@ -32,11 +32,12 @@ class ConversationChainComponent(CustomComponent): else: chain = ConversationChain(llm=llm, memory=memory) result = chain.invoke({"input": input_value}) - if hasattr(result, "content") and isinstance(result.content, str): - result = result.content + if isinstance(result, dict): + result = result.get(chain.output_key) + elif isinstance(result, str): result = result else: result = result.get("response") self.status = result - return result + return str(result) diff --git a/src/backend/langflow/components/data/URL.py b/src/backend/langflow/components/data/URL.py index 8368e72be..5fe0ec8f6 100644 --- a/src/backend/langflow/components/data/URL.py +++ b/src/backend/langflow/components/data/URL.py @@ -18,7 +18,7 @@ class URLComponent(CustomComponent): async def build( self, urls: list[str], - ) -> Record: + ) -> list[Record]: loader = WebBaseLoader(web_paths=urls) docs = loader.load() records = self.to_records(docs) diff --git a/src/backend/langflow/components/experimental/MergeRecords.py b/src/backend/langflow/components/experimental/MergeRecords.py index 64582a1a0..c0b112f6b 100644 --- a/src/backend/langflow/components/experimental/MergeRecords.py +++ b/src/backend/langflow/components/experimental/MergeRecords.py @@ -13,10 +13,10 @@ class MergeRecordsComponent(CustomComponent): def build(self, records: list[Record]) -> Record: if not records: - return records + return Record() if len(records) == 1: return records[0] - merged_record = None + merged_record = Record() for record in records: if merged_record is None: merged_record = record @@ -24,3 +24,13 @@ class MergeRecordsComponent(CustomComponent): merged_record += record self.status = merged_record return merged_record + + +if __name__ == "__main__": + records = [ + Record(data={"key1": "value1"}), + Record(data={"key2": "value2"}), + ] + component = MergeRecordsComponent() + result = component.build(records) + print(result) diff --git a/src/backend/langflow/components/experimental/RunFlow.py b/src/backend/langflow/components/experimental/RunFlow.py index 8edb62cef..4efb50da5 100644 --- a/src/backend/langflow/components/experimental/RunFlow.py +++ b/src/backend/langflow/components/experimental/RunFlow.py @@ -36,11 +36,18 @@ class RunFlowComponent(CustomComponent): messages = result_data.messages records = [] for message in messages: - record = Record(text=message.get("text", ""), data={"result": result_data}) + message_dict = ( + message if isinstance(message, dict) else message.model_dump() + ) + record = Record( + text=message_dict.get("text", ""), data={"result": result_data} + ) records.append(record) return records - async def build(self, input_value: Text, flow_name: str, tweaks: NestedDict) -> Record: + async def build( + self, input_value: Text, flow_name: str, tweaks: NestedDict + ) -> Record: results: List[Optional[ResultData]] = await self.run_flow( input_value=input_value, flow_name=flow_name, tweaks=tweaks ) diff --git a/src/backend/langflow/components/vectorstores/Chroma.py b/src/backend/langflow/components/vectorstores/Chroma.py index d5afe8b81..e00a0267a 100644 --- a/src/backend/langflow/components/vectorstores/Chroma.py +++ b/src/backend/langflow/components/vectorstores/Chroma.py @@ -84,7 +84,8 @@ class ChromaComponent(CustomComponent): if chroma_server_host is not None: chroma_settings = chromadb.config.Settings( - chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None, + chroma_server_cors_allow_origins=chroma_server_cors_allow_origins + or None, chroma_server_host=chroma_server_host, chroma_server_port=chroma_server_port or None, chroma_server_grpc_port=chroma_server_grpc_port or None, @@ -98,14 +99,16 @@ class ChromaComponent(CustomComponent): index_directory = self.resolve_path(index_directory) documents = [] - for _input in inputs: + for _input in inputs or []: if isinstance(_input, Record): documents.append(_input.to_lc_document()) else: documents.append(_input) if documents is not None and embedding is not None: if len(documents) == 0: - raise ValueError("If documents are provided, there must be at least one document.") + raise ValueError( + "If documents are provided, there must be at least one document." + ) chroma = Chroma.from_documents( documents=documents, # type: ignore persist_directory=index_directory, diff --git a/src/backend/langflow/components/vectorstores/FAISS.py b/src/backend/langflow/components/vectorstores/FAISS.py index 7cdadccdb..dbdcbed9d 100644 --- a/src/backend/langflow/components/vectorstores/FAISS.py +++ b/src/backend/langflow/components/vectorstores/FAISS.py @@ -33,7 +33,7 @@ class FAISSComponent(CustomComponent): index_name: str = "langflow_index", ) -> Union[VectorStore, FAISS, BaseRetriever]: documents = [] - for _input in inputs: + for _input in inputs or []: if isinstance(_input, Record): documents.append(_input.to_lc_document()) else: diff --git a/src/backend/langflow/components/vectorstores/MongoDBAtlasVector.py b/src/backend/langflow/components/vectorstores/MongoDBAtlasVector.py index f45d55584..d94e3ad14 100644 --- a/src/backend/langflow/components/vectorstores/MongoDBAtlasVector.py +++ b/src/backend/langflow/components/vectorstores/MongoDBAtlasVector.py @@ -9,7 +9,9 @@ from langflow.schema.schema import Record class MongoDBAtlasComponent(CustomComponent): display_name = "MongoDB Atlas" - description = "Construct a `MongoDB Atlas Vector Search` vector store from raw documents." + description = ( + "Construct a `MongoDB Atlas Vector Search` vector store from raw documents." + ) icon = "MongoDB" def build_config(self): @@ -26,7 +28,7 @@ class MongoDBAtlasComponent(CustomComponent): def build( self, embedding: Embeddings, - inputs: List[Record], + inputs: Optional[List[Record]] = None, collection_name: str = "", db_name: str = "", index_name: str = "", @@ -37,14 +39,16 @@ class MongoDBAtlasComponent(CustomComponent): try: from pymongo import MongoClient except ImportError: - raise ImportError("Please install pymongo to use MongoDB Atlas Vector Store") + raise ImportError( + "Please install pymongo to use MongoDB Atlas Vector Store" + ) try: mongo_client: MongoClient = MongoClient(mongodb_atlas_cluster_uri) collection = mongo_client[db_name][collection_name] except Exception as e: raise ValueError(f"Failed to connect to MongoDB Atlas: {e}") documents = [] - for _input in inputs: + for _input in inputs or []: if isinstance(_input, Record): documents.append(_input.to_lc_document()) else: diff --git a/src/backend/langflow/components/vectorstores/MongoDBAtlasVectorSearch.py b/src/backend/langflow/components/vectorstores/MongoDBAtlasVectorSearch.py index 0c713950a..2fa588d89 100644 --- a/src/backend/langflow/components/vectorstores/MongoDBAtlasVectorSearch.py +++ b/src/backend/langflow/components/vectorstores/MongoDBAtlasVectorSearch.py @@ -39,7 +39,6 @@ class MongoDBAtlasSearchComponent(MongoDBAtlasComponent, LCVectorStoreComponent) vector_store = super().build( embedding=embedding, collection_name=collection_name, - documents=[], db_name=db_name, index_name=index_name, mongodb_atlas_cluster_uri=mongodb_atlas_cluster_uri, diff --git a/src/backend/langflow/components/vectorstores/Pinecone.py b/src/backend/langflow/components/vectorstores/Pinecone.py index c71048266..b6c022e80 100644 --- a/src/backend/langflow/components/vectorstores/Pinecone.py +++ b/src/backend/langflow/components/vectorstores/Pinecone.py @@ -45,7 +45,7 @@ class PineconeComponent(CustomComponent): self, embedding: Embeddings, pinecone_env: str, - inputs: List[Record], + inputs: Optional[List[Record]] = None, text_key: str = "text", pool_threads: int = 4, index_name: Optional[str] = None, @@ -61,7 +61,7 @@ class PineconeComponent(CustomComponent): if not index_name: raise ValueError("Index Name is required.") documents = [] - for _input in inputs: + for _input in inputs or []: if isinstance(_input, Record): documents.append(_input.to_lc_document()) else: diff --git a/src/backend/langflow/components/vectorstores/PineconeSearch.py b/src/backend/langflow/components/vectorstores/PineconeSearch.py index 95cd28ded..56dd11196 100644 --- a/src/backend/langflow/components/vectorstores/PineconeSearch.py +++ b/src/backend/langflow/components/vectorstores/PineconeSearch.py @@ -55,7 +55,7 @@ class PineconeSearchComponent(PineconeComponent, LCVectorStoreComponent): vector_store = super().build( embedding=embedding, pinecone_env=pinecone_env, - documents=[], + inputs=[], text_key=text_key, pool_threads=pool_threads, index_name=index_name, diff --git a/src/backend/langflow/components/vectorstores/Qdrant.py b/src/backend/langflow/components/vectorstores/Qdrant.py index e1773268b..9670ab984 100644 --- a/src/backend/langflow/components/vectorstores/Qdrant.py +++ b/src/backend/langflow/components/vectorstores/Qdrant.py @@ -64,7 +64,7 @@ class QdrantComponent(CustomComponent): url: Optional[str] = None, ) -> Union[VectorStore, Qdrant, BaseRetriever]: documents = [] - for _input in inputs: + for _input in inputs or []: if isinstance(_input, Record): documents.append(_input.to_lc_document()) else: diff --git a/src/backend/langflow/components/vectorstores/Redis.py b/src/backend/langflow/components/vectorstores/Redis.py index bd1b85e4f..5996ccc66 100644 --- a/src/backend/langflow/components/vectorstores/Redis.py +++ b/src/backend/langflow/components/vectorstores/Redis.py @@ -59,14 +59,16 @@ class RedisComponent(CustomComponent): - VectorStore: The Vector Store object. """ documents = [] - for _input in inputs: + for _input in inputs or []: if isinstance(_input, Record): documents.append(_input.to_lc_document()) else: documents.append(_input) if not documents: if schema is None: - raise ValueError("If no documents are provided, a schema must be provided.") + raise ValueError( + "If no documents are provided, a schema must be provided." + ) redis_vs = Redis.from_existing_index( embedding=embedding, index_name=redis_index_name, diff --git a/src/backend/langflow/components/vectorstores/SupabaseVectorStore.py b/src/backend/langflow/components/vectorstores/SupabaseVectorStore.py index 5d32388d9..bc9d49e28 100644 --- a/src/backend/langflow/components/vectorstores/SupabaseVectorStore.py +++ b/src/backend/langflow/components/vectorstores/SupabaseVectorStore.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Optional, Union from langchain.schema import BaseRetriever from langchain_community.vectorstores import VectorStore @@ -28,16 +28,18 @@ class SupabaseComponent(CustomComponent): def build( self, embedding: Embeddings, - inputs: List[Record], + inputs: Optional[List[Record]] = None, query_name: str = "", search_kwargs: NestedDict = {}, supabase_service_key: str = "", supabase_url: str = "", table_name: str = "", ) -> Union[VectorStore, SupabaseVectorStore, BaseRetriever]: - supabase: Client = create_client(supabase_url, supabase_key=supabase_service_key) + supabase: Client = create_client( + supabase_url, supabase_key=supabase_service_key + ) documents = [] - for _input in inputs: + for _input in inputs or []: if isinstance(_input, Record): documents.append(_input.to_lc_document()) else: diff --git a/src/backend/langflow/components/vectorstores/Vectara.py b/src/backend/langflow/components/vectorstores/Vectara.py index 001658cd3..c427e139c 100644 --- a/src/backend/langflow/components/vectorstores/Vectara.py +++ b/src/backend/langflow/components/vectorstores/Vectara.py @@ -15,7 +15,9 @@ from langflow.schema.schema import Record class VectaraComponent(CustomComponent): display_name: str = "Vectara" description: str = "Implementation of Vector Store using Vectara" - documentation = "https://python.langchain.com/docs/integrations/vectorstores/vectara" + documentation = ( + "https://python.langchain.com/docs/integrations/vectorstores/vectara" + ) icon = "Vectara" field_config = { "vectara_customer_id": { @@ -50,7 +52,7 @@ class VectaraComponent(CustomComponent): source = "Langflow" documents = [] - for _input in inputs: + for _input in inputs or []: if isinstance(_input, Record): documents.append(_input.to_lc_document()) else: diff --git a/src/backend/langflow/components/vectorstores/Weaviate.py b/src/backend/langflow/components/vectorstores/Weaviate.py index de309034c..62f46c586 100644 --- a/src/backend/langflow/components/vectorstores/Weaviate.py +++ b/src/backend/langflow/components/vectorstores/Weaviate.py @@ -12,7 +12,9 @@ from langflow.schema.schema import Record class WeaviateVectorStoreComponent(CustomComponent): display_name: str = "Weaviate" description: str = "Implementation of Vector Store using Weaviate" - documentation = "https://python.langchain.com/docs/integrations/vectorstores/weaviate" + documentation = ( + "https://python.langchain.com/docs/integrations/vectorstores/weaviate" + ) field_config = { "url": {"display_name": "Weaviate URL", "value": "http://localhost:8080"}, "api_key": { @@ -79,7 +81,7 @@ class WeaviateVectorStoreComponent(CustomComponent): index_name = _to_pascal_case(index_name) if index_name else None documents = [] - for _input in inputs: + for _input in inputs or []: if isinstance(_input, Record): documents.append(_input.to_lc_document()) else: diff --git a/src/backend/langflow/components/vectorstores/pgvector.py b/src/backend/langflow/components/vectorstores/pgvector.py index 7ab20b8df..ae3782714 100644 --- a/src/backend/langflow/components/vectorstores/pgvector.py +++ b/src/backend/langflow/components/vectorstores/pgvector.py @@ -16,7 +16,9 @@ class PGVectorComponent(CustomComponent): display_name: str = "PGVector" description: str = "Implementation of Vector Store using PostgreSQL" - documentation = "https://python.langchain.com/docs/integrations/vectorstores/pgvector" + documentation = ( + "https://python.langchain.com/docs/integrations/vectorstores/pgvector" + ) def build_config(self): """ @@ -57,7 +59,7 @@ class PGVectorComponent(CustomComponent): """ documents = [] - for _input in inputs: + for _input in inputs or []: if isinstance(_input, Record): documents.append(_input.to_lc_document()) else: diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index edb053f4d..2c06ebe27 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -48,7 +48,7 @@ class Graph: self._is_state_vertices: List[str] = [] self._has_session_id_vertices: List[str] = [] self._sorted_vertices_layers: List[List[str]] = [] - self._run_id = None + self._run_id = "" self.top_level_vertices = [] for vertex in self._vertices: @@ -130,9 +130,6 @@ class Graph: self.state_manager.subscribe(run_id, vertex.update_graph_state) self._run_id = run_id - def add_state(self, state: str): - self.state_manager.append_state(self._run_id, state) - @property def sorted_vertices_layers(self) -> List[List[str]]: if not self._sorted_vertices_layers: @@ -198,7 +195,7 @@ class Graph: async def run( self, - inputs: Dict[str, Union[str, list[str]]], + inputs: list[Dict[str, Union[str, list[str]]]], outputs: list[str], session_id: str, stream: Optional[bool] = False, @@ -210,13 +207,12 @@ class Graph: # of the vertices that are inputs # if the value is a list, we need to run multiple times vertex_outputs = [] - inputs_values = inputs.get(INPUT_FIELD_NAME, "") if not isinstance(inputs_values, list): inputs_values = [inputs_values] - for input_value in inputs_values: + for input_dict in inputs_values: run_outputs = await self._run( - inputs={INPUT_FIELD_NAME: input_value}, - input_components=inputs.get("components", []), + inputs={INPUT_FIELD_NAME: input_dict.get(INPUT_FIELD_NAME)}, + input_components=input_dict.get("components", []), outputs=outputs, stream=stream, session_id=session_id, diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 08c8d8f99..3fa59ab45 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -97,7 +97,6 @@ class Vertex: self.use_result = False self.build_times: List[float] = [] self.state = VertexStates.ACTIVE - self.graph_state = {} def update_graph_state(self, key, new_state, append: bool): if append: diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index f034656be..63668b48c 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Coroutine, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Optional, Tuple, Union from langchain.agents import AgentExecutor from langchain.chains.base import Chain @@ -16,6 +16,9 @@ from langflow.interface.custom.custom_component import CustomComponent from langflow.interface.run import get_memory_key, update_memory_keys from langflow.services.session.service import SessionService +if TYPE_CHECKING: + from langflow.api.v1.schemas import Tweaks + def fix_memory_inputs(langchain_object): """ @@ -126,7 +129,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]): elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"): result = await runnable.ainvoke(inputs) else: - raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}") + raise ValueError( + f"Runnable {runnable} does not support inputs of type {type(inputs)}" + ) # Check if the result is a list of AIMessages if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result): result = [r.content for r in result] @@ -135,7 +140,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]): return result -async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict): +async def process_inputs_dict( + built_object: Union[Chain, VectorStore, Runnable], inputs: dict +): if isinstance(built_object, Chain): if inputs is None: raise ValueError("Inputs must be provided for a Chain") @@ -170,7 +177,9 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]): return await process_runnable(built_object, inputs) -async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]): +async def generate_result( + built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]] +): if isinstance(inputs, dict): result = await process_inputs_dict(built_object, inputs) elif isinstance(inputs, List) and isinstance(built_object, Runnable): @@ -197,7 +206,7 @@ async def run_graph( flow_id: str, stream: bool, session_id: Optional[str] = None, - inputs: Optional[dict[str, Union[List[str], str]]] = None, + inputs: Optional[list[dict[str, Union[List[str], str]]]] = None, outputs: Optional[List[str]] = None, artifacts: Optional[Dict[str, Any]] = None, session_service: Optional[SessionService] = None, @@ -209,9 +218,11 @@ async def run_graph( else: graph_data = graph._graph_data if not session_id and session_service is not None: - session_id = session_service.generate_key(session_id=flow_id, data_graph=graph_data) + session_id = session_service.generate_key( + session_id=flow_id, data_graph=graph_data + ) if inputs is None: - inputs = {} + inputs = [{}] outputs = await graph.run( inputs, @@ -224,14 +235,18 @@ async def run_graph( return outputs, session_id -def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]: +def validate_input( + graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] +) -> List[Dict[str, Any]]: if not isinstance(graph_data, dict) or not isinstance(tweaks, dict): raise ValueError("graph_data and tweaks should be dictionaries") nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes") if not isinstance(nodes, list): - raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key") + raise ValueError( + "graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key" + ) return nodes @@ -240,7 +255,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None: template_data = node.get("data", {}).get("node", {}).get("template") if not isinstance(template_data, dict): - logger.warning(f"Template data for node {node.get('id')} should be a dictionary") + logger.warning( + f"Template data for node {node.get('id')} should be a dictionary" + ) return for tweak_name, tweak_value in node_tweaks.items(): @@ -255,7 +272,9 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None: vertex.params[tweak_name] = tweak_value -def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: +def process_tweaks( + graph_data: Dict[str, Any], tweaks: Union["Tweaks", Dict[str, Dict[str, Any]]] +) -> Dict[str, Any]: """ This function is used to tweak the graph data using the node id and the tweaks dict. @@ -291,6 +310,8 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]): if node_tweaks := tweaks.get(node_id): apply_tweaks_on_vertex(vertex, node_tweaks) else: - logger.warning("Each node should be a Vertex with an 'id' attribute of type str") + logger.warning( + "Each node should be a Vertex with an 'id' attribute of type str" + ) return graph