Refactor code and update dependencies
This commit is contained in:
parent
3363729515
commit
69145f35ca
24 changed files with 142 additions and 71 deletions
|
|
@ -81,7 +81,8 @@ def run_migrations_online() -> None:
|
|||
logger.error(f"Error getting database engine: {e}")
|
||||
url = os.getenv("LANGFLOW_DATABASE_URL")
|
||||
url = url or config.get_main_option("sqlalchemy.url")
|
||||
config.set_main_option("sqlalchemy.url", url)
|
||||
if url:
|
||||
config.set_main_option("sqlalchemy.url", url)
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from platformdirs import user_cache_dir
|
||||
|
|
@ -20,7 +20,9 @@ API_WORDS = ["api", "key", "token"]
|
|||
|
||||
|
||||
def has_api_terms(word: str):
|
||||
return "api" in word and ("key" in word or ("token" in word and "tokens" not in word))
|
||||
return "api" in word and (
|
||||
"key" in word or ("token" in word and "tokens" not in word)
|
||||
)
|
||||
|
||||
|
||||
def remove_api_keys(flow: dict):
|
||||
|
|
@ -30,7 +32,11 @@ def remove_api_keys(flow: dict):
|
|||
node_data = node.get("data").get("node")
|
||||
template = node_data.get("template")
|
||||
for value in template.values():
|
||||
if isinstance(value, dict) and has_api_terms(value["name"]) and value.get("password"):
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and has_api_terms(value["name"])
|
||||
and value.get("password")
|
||||
):
|
||||
value["value"] = None
|
||||
|
||||
return flow
|
||||
|
|
@ -51,7 +57,9 @@ def build_input_keys_response(langchain_object, artifacts):
|
|||
input_keys_response["input_keys"][key] = value
|
||||
# If the object has memory, that memory will have a memory_variables attribute
|
||||
# memory variables should be removed from the input keys
|
||||
if hasattr(langchain_object, "memory") and hasattr(langchain_object.memory, "memory_variables"):
|
||||
if hasattr(langchain_object, "memory") and hasattr(
|
||||
langchain_object.memory, "memory_variables"
|
||||
):
|
||||
# Remove memory variables from input keys
|
||||
input_keys_response["input_keys"] = {
|
||||
key: value
|
||||
|
|
@ -61,7 +69,9 @@ def build_input_keys_response(langchain_object, artifacts):
|
|||
# Add memory variables to memory_keys
|
||||
input_keys_response["memory_keys"] = langchain_object.memory.memory_variables
|
||||
|
||||
if hasattr(langchain_object, "prompt") and hasattr(langchain_object.prompt, "template"):
|
||||
if hasattr(langchain_object, "prompt") and hasattr(
|
||||
langchain_object.prompt, "template"
|
||||
):
|
||||
input_keys_response["template"] = langchain_object.prompt.template
|
||||
|
||||
return input_keys_response
|
||||
|
|
@ -96,7 +106,11 @@ def raw_frontend_data_is_valid(raw_frontend_data):
|
|||
def is_valid_data(frontend_node, raw_frontend_data):
|
||||
"""Check if the data is valid for processing."""
|
||||
|
||||
return frontend_node and "template" in frontend_node and raw_frontend_data_is_valid(raw_frontend_data)
|
||||
return (
|
||||
frontend_node
|
||||
and "template" in frontend_node
|
||||
and raw_frontend_data_is_valid(raw_frontend_data)
|
||||
)
|
||||
|
||||
|
||||
def update_template_values(frontend_template, raw_template):
|
||||
|
|
@ -136,12 +150,14 @@ def get_file_path_value(file_path):
|
|||
# If the path is not in the cache dir, return empty string
|
||||
# This is to prevent access to files outside the cache dir
|
||||
# If the path is not a file, return empty string
|
||||
if not path.exists() or not str(path).startswith(user_cache_dir("langflow", "langflow")):
|
||||
if not path.exists() or not str(path).startswith(
|
||||
user_cache_dir("langflow", "langflow")
|
||||
):
|
||||
return ""
|
||||
return file_path
|
||||
|
||||
|
||||
def validate_is_component(flows: List["Flow"]):
|
||||
def validate_is_component(flows: list["Flow"]):
|
||||
for flow in flows:
|
||||
if not flow.data or flow.is_component is not None:
|
||||
continue
|
||||
|
|
@ -167,7 +183,9 @@ async def check_langflow_version(component: StoreComponentCreate):
|
|||
|
||||
langflow_version = get_lf_version_from_pypi()
|
||||
if langflow_version is None:
|
||||
raise HTTPException(status_code=500, detail="Unable to verify the latest version of Langflow")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Unable to verify the latest version of Langflow"
|
||||
)
|
||||
elif langflow_version != component.last_tested_version:
|
||||
warnings.warn(
|
||||
f"Your version of Langflow ({component.last_tested_version}) is outdated. "
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class FrontendNodeRequest(FrontendNode):
|
|||
class ValidatePromptRequest(BaseModel):
|
||||
name: str
|
||||
template: str
|
||||
# optional for tweak call
|
||||
custom_fields: Optional[dict] = None
|
||||
frontend_node: Optional[FrontendNodeRequest] = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -105,9 +105,11 @@ async def run_flow_with_caching(
|
|||
"""
|
||||
try:
|
||||
if inputs is not None:
|
||||
input_values_dict: dict[str, Union[str, list[str]]] = inputs.model_dump()
|
||||
input_values: list[dict[str, Union[str, list[str]]]] = [
|
||||
_input.model_dump() for _input in inputs
|
||||
]
|
||||
else:
|
||||
input_values_dict = {}
|
||||
input_values = [{}]
|
||||
|
||||
if outputs is None:
|
||||
outputs = []
|
||||
|
|
@ -124,7 +126,7 @@ async def run_flow_with_caching(
|
|||
graph=graph,
|
||||
flow_id=flow_id,
|
||||
session_id=session_id,
|
||||
inputs=input_values_dict,
|
||||
inputs=input_values,
|
||||
outputs=outputs,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
|
|
@ -150,7 +152,7 @@ async def run_flow_with_caching(
|
|||
graph=graph_data,
|
||||
flow_id=flow_id,
|
||||
session_id=session_id,
|
||||
inputs=input_values_dict,
|
||||
inputs=input_values,
|
||||
outputs=outputs,
|
||||
artifacts={},
|
||||
session_service=session_service,
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ def read_flows(
|
|||
try:
|
||||
auth_settings = settings_service.auth_settings
|
||||
if auth_settings.AUTO_LOGIN:
|
||||
flows = session.exec(
|
||||
flows: list[Flow] = session.exec(
|
||||
select(Flow).where(
|
||||
(Flow.user_id == None) | (Flow.user_id == current_user.id) # noqa
|
||||
)
|
||||
|
|
|
|||
|
|
@ -32,11 +32,12 @@ class ConversationChainComponent(CustomComponent):
|
|||
else:
|
||||
chain = ConversationChain(llm=llm, memory=memory)
|
||||
result = chain.invoke({"input": input_value})
|
||||
if hasattr(result, "content") and isinstance(result.content, str):
|
||||
result = result.content
|
||||
if isinstance(result, dict):
|
||||
result = result.get(chain.output_key)
|
||||
|
||||
elif isinstance(result, str):
|
||||
result = result
|
||||
else:
|
||||
result = result.get("response")
|
||||
self.status = result
|
||||
return result
|
||||
return str(result)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class URLComponent(CustomComponent):
|
|||
async def build(
|
||||
self,
|
||||
urls: list[str],
|
||||
) -> Record:
|
||||
) -> list[Record]:
|
||||
loader = WebBaseLoader(web_paths=urls)
|
||||
docs = loader.load()
|
||||
records = self.to_records(docs)
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@ class MergeRecordsComponent(CustomComponent):
|
|||
|
||||
def build(self, records: list[Record]) -> Record:
|
||||
if not records:
|
||||
return records
|
||||
return Record()
|
||||
if len(records) == 1:
|
||||
return records[0]
|
||||
merged_record = None
|
||||
merged_record = Record()
|
||||
for record in records:
|
||||
if merged_record is None:
|
||||
merged_record = record
|
||||
|
|
@ -24,3 +24,13 @@ class MergeRecordsComponent(CustomComponent):
|
|||
merged_record += record
|
||||
self.status = merged_record
|
||||
return merged_record
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
records = [
|
||||
Record(data={"key1": "value1"}),
|
||||
Record(data={"key2": "value2"}),
|
||||
]
|
||||
component = MergeRecordsComponent()
|
||||
result = component.build(records)
|
||||
print(result)
|
||||
|
|
|
|||
|
|
@ -36,11 +36,18 @@ class RunFlowComponent(CustomComponent):
|
|||
messages = result_data.messages
|
||||
records = []
|
||||
for message in messages:
|
||||
record = Record(text=message.get("text", ""), data={"result": result_data})
|
||||
message_dict = (
|
||||
message if isinstance(message, dict) else message.model_dump()
|
||||
)
|
||||
record = Record(
|
||||
text=message_dict.get("text", ""), data={"result": result_data}
|
||||
)
|
||||
records.append(record)
|
||||
return records
|
||||
|
||||
async def build(self, input_value: Text, flow_name: str, tweaks: NestedDict) -> Record:
|
||||
async def build(
|
||||
self, input_value: Text, flow_name: str, tweaks: NestedDict
|
||||
) -> Record:
|
||||
results: List[Optional[ResultData]] = await self.run_flow(
|
||||
input_value=input_value, flow_name=flow_name, tweaks=tweaks
|
||||
)
|
||||
|
|
|
|||
|
|
@ -84,7 +84,8 @@ class ChromaComponent(CustomComponent):
|
|||
|
||||
if chroma_server_host is not None:
|
||||
chroma_settings = chromadb.config.Settings(
|
||||
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
|
||||
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
|
||||
or None,
|
||||
chroma_server_host=chroma_server_host,
|
||||
chroma_server_port=chroma_server_port or None,
|
||||
chroma_server_grpc_port=chroma_server_grpc_port or None,
|
||||
|
|
@ -98,14 +99,16 @@ class ChromaComponent(CustomComponent):
|
|||
index_directory = self.resolve_path(index_directory)
|
||||
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
for _input in inputs or []:
|
||||
if isinstance(_input, Record):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
documents.append(_input)
|
||||
if documents is not None and embedding is not None:
|
||||
if len(documents) == 0:
|
||||
raise ValueError("If documents are provided, there must be at least one document.")
|
||||
raise ValueError(
|
||||
"If documents are provided, there must be at least one document."
|
||||
)
|
||||
chroma = Chroma.from_documents(
|
||||
documents=documents, # type: ignore
|
||||
persist_directory=index_directory,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class FAISSComponent(CustomComponent):
|
|||
index_name: str = "langflow_index",
|
||||
) -> Union[VectorStore, FAISS, BaseRetriever]:
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
for _input in inputs or []:
|
||||
if isinstance(_input, Record):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ from langflow.schema.schema import Record
|
|||
|
||||
class MongoDBAtlasComponent(CustomComponent):
|
||||
display_name = "MongoDB Atlas"
|
||||
description = "Construct a `MongoDB Atlas Vector Search` vector store from raw documents."
|
||||
description = (
|
||||
"Construct a `MongoDB Atlas Vector Search` vector store from raw documents."
|
||||
)
|
||||
icon = "MongoDB"
|
||||
|
||||
def build_config(self):
|
||||
|
|
@ -26,7 +28,7 @@ class MongoDBAtlasComponent(CustomComponent):
|
|||
def build(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
inputs: List[Record],
|
||||
inputs: Optional[List[Record]] = None,
|
||||
collection_name: str = "",
|
||||
db_name: str = "",
|
||||
index_name: str = "",
|
||||
|
|
@ -37,14 +39,16 @@ class MongoDBAtlasComponent(CustomComponent):
|
|||
try:
|
||||
from pymongo import MongoClient
|
||||
except ImportError:
|
||||
raise ImportError("Please install pymongo to use MongoDB Atlas Vector Store")
|
||||
raise ImportError(
|
||||
"Please install pymongo to use MongoDB Atlas Vector Store"
|
||||
)
|
||||
try:
|
||||
mongo_client: MongoClient = MongoClient(mongodb_atlas_cluster_uri)
|
||||
collection = mongo_client[db_name][collection_name]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to connect to MongoDB Atlas: {e}")
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
for _input in inputs or []:
|
||||
if isinstance(_input, Record):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -39,7 +39,6 @@ class MongoDBAtlasSearchComponent(MongoDBAtlasComponent, LCVectorStoreComponent)
|
|||
vector_store = super().build(
|
||||
embedding=embedding,
|
||||
collection_name=collection_name,
|
||||
documents=[],
|
||||
db_name=db_name,
|
||||
index_name=index_name,
|
||||
mongodb_atlas_cluster_uri=mongodb_atlas_cluster_uri,
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class PineconeComponent(CustomComponent):
|
|||
self,
|
||||
embedding: Embeddings,
|
||||
pinecone_env: str,
|
||||
inputs: List[Record],
|
||||
inputs: Optional[List[Record]] = None,
|
||||
text_key: str = "text",
|
||||
pool_threads: int = 4,
|
||||
index_name: Optional[str] = None,
|
||||
|
|
@ -61,7 +61,7 @@ class PineconeComponent(CustomComponent):
|
|||
if not index_name:
|
||||
raise ValueError("Index Name is required.")
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
for _input in inputs or []:
|
||||
if isinstance(_input, Record):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class PineconeSearchComponent(PineconeComponent, LCVectorStoreComponent):
|
|||
vector_store = super().build(
|
||||
embedding=embedding,
|
||||
pinecone_env=pinecone_env,
|
||||
documents=[],
|
||||
inputs=[],
|
||||
text_key=text_key,
|
||||
pool_threads=pool_threads,
|
||||
index_name=index_name,
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class QdrantComponent(CustomComponent):
|
|||
url: Optional[str] = None,
|
||||
) -> Union[VectorStore, Qdrant, BaseRetriever]:
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
for _input in inputs or []:
|
||||
if isinstance(_input, Record):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -59,14 +59,16 @@ class RedisComponent(CustomComponent):
|
|||
- VectorStore: The Vector Store object.
|
||||
"""
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
for _input in inputs or []:
|
||||
if isinstance(_input, Record):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
documents.append(_input)
|
||||
if not documents:
|
||||
if schema is None:
|
||||
raise ValueError("If no documents are provided, a schema must be provided.")
|
||||
raise ValueError(
|
||||
"If no documents are provided, a schema must be provided."
|
||||
)
|
||||
redis_vs = Redis.from_existing_index(
|
||||
embedding=embedding,
|
||||
index_name=redis_index_name,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain_community.vectorstores import VectorStore
|
||||
|
|
@ -28,16 +28,18 @@ class SupabaseComponent(CustomComponent):
|
|||
def build(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
inputs: List[Record],
|
||||
inputs: Optional[List[Record]] = None,
|
||||
query_name: str = "",
|
||||
search_kwargs: NestedDict = {},
|
||||
supabase_service_key: str = "",
|
||||
supabase_url: str = "",
|
||||
table_name: str = "",
|
||||
) -> Union[VectorStore, SupabaseVectorStore, BaseRetriever]:
|
||||
supabase: Client = create_client(supabase_url, supabase_key=supabase_service_key)
|
||||
supabase: Client = create_client(
|
||||
supabase_url, supabase_key=supabase_service_key
|
||||
)
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
for _input in inputs or []:
|
||||
if isinstance(_input, Record):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ from langflow.schema.schema import Record
|
|||
class VectaraComponent(CustomComponent):
|
||||
display_name: str = "Vectara"
|
||||
description: str = "Implementation of Vector Store using Vectara"
|
||||
documentation = "https://python.langchain.com/docs/integrations/vectorstores/vectara"
|
||||
documentation = (
|
||||
"https://python.langchain.com/docs/integrations/vectorstores/vectara"
|
||||
)
|
||||
icon = "Vectara"
|
||||
field_config = {
|
||||
"vectara_customer_id": {
|
||||
|
|
@ -50,7 +52,7 @@ class VectaraComponent(CustomComponent):
|
|||
source = "Langflow"
|
||||
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
for _input in inputs or []:
|
||||
if isinstance(_input, Record):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ from langflow.schema.schema import Record
|
|||
class WeaviateVectorStoreComponent(CustomComponent):
|
||||
display_name: str = "Weaviate"
|
||||
description: str = "Implementation of Vector Store using Weaviate"
|
||||
documentation = "https://python.langchain.com/docs/integrations/vectorstores/weaviate"
|
||||
documentation = (
|
||||
"https://python.langchain.com/docs/integrations/vectorstores/weaviate"
|
||||
)
|
||||
field_config = {
|
||||
"url": {"display_name": "Weaviate URL", "value": "http://localhost:8080"},
|
||||
"api_key": {
|
||||
|
|
@ -79,7 +81,7 @@ class WeaviateVectorStoreComponent(CustomComponent):
|
|||
|
||||
index_name = _to_pascal_case(index_name) if index_name else None
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
for _input in inputs or []:
|
||||
if isinstance(_input, Record):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ class PGVectorComponent(CustomComponent):
|
|||
|
||||
display_name: str = "PGVector"
|
||||
description: str = "Implementation of Vector Store using PostgreSQL"
|
||||
documentation = "https://python.langchain.com/docs/integrations/vectorstores/pgvector"
|
||||
documentation = (
|
||||
"https://python.langchain.com/docs/integrations/vectorstores/pgvector"
|
||||
)
|
||||
|
||||
def build_config(self):
|
||||
"""
|
||||
|
|
@ -57,7 +59,7 @@ class PGVectorComponent(CustomComponent):
|
|||
"""
|
||||
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
for _input in inputs or []:
|
||||
if isinstance(_input, Record):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class Graph:
|
|||
self._is_state_vertices: List[str] = []
|
||||
self._has_session_id_vertices: List[str] = []
|
||||
self._sorted_vertices_layers: List[List[str]] = []
|
||||
self._run_id = None
|
||||
self._run_id = ""
|
||||
|
||||
self.top_level_vertices = []
|
||||
for vertex in self._vertices:
|
||||
|
|
@ -130,9 +130,6 @@ class Graph:
|
|||
self.state_manager.subscribe(run_id, vertex.update_graph_state)
|
||||
self._run_id = run_id
|
||||
|
||||
def add_state(self, state: str):
|
||||
self.state_manager.append_state(self._run_id, state)
|
||||
|
||||
@property
|
||||
def sorted_vertices_layers(self) -> List[List[str]]:
|
||||
if not self._sorted_vertices_layers:
|
||||
|
|
@ -198,7 +195,7 @@ class Graph:
|
|||
|
||||
async def run(
|
||||
self,
|
||||
inputs: Dict[str, Union[str, list[str]]],
|
||||
inputs: list[Dict[str, Union[str, list[str]]]],
|
||||
outputs: list[str],
|
||||
session_id: str,
|
||||
stream: Optional[bool] = False,
|
||||
|
|
@ -210,13 +207,12 @@ class Graph:
|
|||
# of the vertices that are inputs
|
||||
# if the value is a list, we need to run multiple times
|
||||
vertex_outputs = []
|
||||
inputs_values = inputs.get(INPUT_FIELD_NAME, "")
|
||||
if not isinstance(inputs_values, list):
|
||||
inputs_values = [inputs_values]
|
||||
for input_value in inputs_values:
|
||||
for input_dict in inputs_values:
|
||||
run_outputs = await self._run(
|
||||
inputs={INPUT_FIELD_NAME: input_value},
|
||||
input_components=inputs.get("components", []),
|
||||
inputs={INPUT_FIELD_NAME: input_dict.get(INPUT_FIELD_NAME)},
|
||||
input_components=input_dict.get("components", []),
|
||||
outputs=outputs,
|
||||
stream=stream,
|
||||
session_id=session_id,
|
||||
|
|
|
|||
|
|
@ -97,7 +97,6 @@ class Vertex:
|
|||
self.use_result = False
|
||||
self.build_times: List[float] = []
|
||||
self.state = VertexStates.ACTIVE
|
||||
self.graph_state = {}
|
||||
|
||||
def update_graph_state(self, key, new_state, append: bool):
|
||||
if append:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Any, Coroutine, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.chains.base import Chain
|
||||
|
|
@ -16,6 +16,9 @@ from langflow.interface.custom.custom_component import CustomComponent
|
|||
from langflow.interface.run import get_memory_key, update_memory_keys
|
||||
from langflow.services.session.service import SessionService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.api.v1.schemas import Tweaks
|
||||
|
||||
|
||||
def fix_memory_inputs(langchain_object):
|
||||
"""
|
||||
|
|
@ -126,7 +129,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
|
|||
elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"):
|
||||
result = await runnable.ainvoke(inputs)
|
||||
else:
|
||||
raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}")
|
||||
raise ValueError(
|
||||
f"Runnable {runnable} does not support inputs of type {type(inputs)}"
|
||||
)
|
||||
# Check if the result is a list of AIMessages
|
||||
if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result):
|
||||
result = [r.content for r in result]
|
||||
|
|
@ -135,7 +140,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
|
|||
return result
|
||||
|
||||
|
||||
async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict):
|
||||
async def process_inputs_dict(
|
||||
built_object: Union[Chain, VectorStore, Runnable], inputs: dict
|
||||
):
|
||||
if isinstance(built_object, Chain):
|
||||
if inputs is None:
|
||||
raise ValueError("Inputs must be provided for a Chain")
|
||||
|
|
@ -170,7 +177,9 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]):
|
|||
return await process_runnable(built_object, inputs)
|
||||
|
||||
|
||||
async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
|
||||
async def generate_result(
|
||||
built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]
|
||||
):
|
||||
if isinstance(inputs, dict):
|
||||
result = await process_inputs_dict(built_object, inputs)
|
||||
elif isinstance(inputs, List) and isinstance(built_object, Runnable):
|
||||
|
|
@ -197,7 +206,7 @@ async def run_graph(
|
|||
flow_id: str,
|
||||
stream: bool,
|
||||
session_id: Optional[str] = None,
|
||||
inputs: Optional[dict[str, Union[List[str], str]]] = None,
|
||||
inputs: Optional[list[dict[str, Union[List[str], str]]]] = None,
|
||||
outputs: Optional[List[str]] = None,
|
||||
artifacts: Optional[Dict[str, Any]] = None,
|
||||
session_service: Optional[SessionService] = None,
|
||||
|
|
@ -209,9 +218,11 @@ async def run_graph(
|
|||
else:
|
||||
graph_data = graph._graph_data
|
||||
if not session_id and session_service is not None:
|
||||
session_id = session_service.generate_key(session_id=flow_id, data_graph=graph_data)
|
||||
session_id = session_service.generate_key(
|
||||
session_id=flow_id, data_graph=graph_data
|
||||
)
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
inputs = [{}]
|
||||
|
||||
outputs = await graph.run(
|
||||
inputs,
|
||||
|
|
@ -224,14 +235,18 @@ async def run_graph(
|
|||
return outputs, session_id
|
||||
|
||||
|
||||
def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def validate_input(
|
||||
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
|
||||
raise ValueError("graph_data and tweaks should be dictionaries")
|
||||
|
||||
nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes")
|
||||
|
||||
if not isinstance(nodes, list):
|
||||
raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key")
|
||||
raise ValueError(
|
||||
"graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key"
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
|
|
@ -240,7 +255,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
|
|||
template_data = node.get("data", {}).get("node", {}).get("template")
|
||||
|
||||
if not isinstance(template_data, dict):
|
||||
logger.warning(f"Template data for node {node.get('id')} should be a dictionary")
|
||||
logger.warning(
|
||||
f"Template data for node {node.get('id')} should be a dictionary"
|
||||
)
|
||||
return
|
||||
|
||||
for tweak_name, tweak_value in node_tweaks.items():
|
||||
|
|
@ -255,7 +272,9 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None:
|
|||
vertex.params[tweak_name] = tweak_value
|
||||
|
||||
|
||||
def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
|
||||
def process_tweaks(
|
||||
graph_data: Dict[str, Any], tweaks: Union["Tweaks", Dict[str, Dict[str, Any]]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
This function is used to tweak the graph data using the node id and the tweaks dict.
|
||||
|
||||
|
|
@ -291,6 +310,8 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]):
|
|||
if node_tweaks := tweaks.get(node_id):
|
||||
apply_tweaks_on_vertex(vertex, node_tweaks)
|
||||
else:
|
||||
logger.warning("Each node should be a Vertex with an 'id' attribute of type str")
|
||||
logger.warning(
|
||||
"Each node should be a Vertex with an 'id' attribute of type str"
|
||||
)
|
||||
|
||||
return graph
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue