diff --git a/poetry.lock b/poetry.lock index a7ed6a921..b3dc0efeb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3681,6 +3681,24 @@ babel = ["Babel"] lingua = ["lingua"] testing = ["pytest"] +[[package]] +name = "markdown" +version = "3.5.2" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Markdown-3.5.2-py3-none-any.whl", hash = "sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd"}, + {file = "Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8"}, +] + +[package.dependencies] +importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -8212,6 +8230,7 @@ emoji = "*" filetype = "*" langdetect = "*" lxml = "*" +markdown = {version = "*", optional = true, markers = "extra == \"md\""} nltk = "*" numpy = "*" python-iso639 = "*" @@ -9026,4 +9045,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "1462954b3befc2989ae226f2214111be786eb05bade578c9c80b4ed80d5b59ff" +content-hash = "b35a356770d3425f524b0c46a449696db1fa7c13fae77324188cb6ffa4a4c5a7" diff --git a/pyproject.toml b/pyproject.toml index 339bdff5e..09d70518a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,7 @@ pytube = "^15.0.0" python-socketio = "^5.11.0" llama-index = "0.9.48" langchain-openai = "^0.0.6" -unstructured = "^0.12.4" +unstructured = {extras = ["md"], version = "^0.12.4"} [tool.poetry.group.dev.dependencies] pytest-asyncio = "^0.23.1" diff --git a/src/backend/langflow/components/chains/ConversationChain.py b/src/backend/langflow/components/chains/ConversationChain.py index 43f71e67b..3183954a3 100644 --- a/src/backend/langflow/components/chains/ConversationChain.py +++ b/src/backend/langflow/components/chains/ConversationChain.py @@ -1,9 +1,9 @@ -from typing import Callable, Optional, Union +from typing import Optional from langchain.chains import ConversationChain from langflow import CustomComponent -from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain, Text +from langflow.field_typing import BaseLanguageModel, BaseMemory, Text class ConversationChainComponent(CustomComponent): @@ -26,7 +26,7 @@ class ConversationChainComponent(CustomComponent): inputs: str, llm: BaseLanguageModel, memory: Optional[BaseMemory] = None, - ) -> Union[Chain, Callable, Text]: + ) -> Text: if memory is None: chain = ConversationChain(llm=llm) else: diff --git a/src/backend/langflow/components/chains/LLMCheckerChain.py b/src/backend/langflow/components/chains/LLMCheckerChain.py index 527cafbb7..bfee0b5a9 100644 --- a/src/backend/langflow/components/chains/LLMCheckerChain.py +++ b/src/backend/langflow/components/chains/LLMCheckerChain.py @@ -1,14 +1,15 @@ -from typing import Callable, Union - from langchain.chains import LLMCheckerChain + from langflow import CustomComponent -from langflow.field_typing import BaseLanguageModel, Chain +from langflow.field_typing import BaseLanguageModel, Text class LLMCheckerChainComponent(CustomComponent): display_name = "LLMCheckerChain" description = "" - documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_checker" + documentation = ( + "https://python.langchain.com/docs/modules/chains/additional/llm_checker" + ) def build_config(self): return { @@ -17,6 +18,12 @@ class LLMCheckerChainComponent(CustomComponent): def build( self, + inputs: str, llm: BaseLanguageModel, - ) -> Union[Chain, Callable]: - return LLMCheckerChain.from_llm(llm=llm) + ) -> Text: + + chain = LLMCheckerChain.from_llm(llm=llm) + response = chain.invoke({chain.input_key: inputs}) + result = response.get(chain.output_key) + self.status = result + return result diff --git a/src/backend/langflow/components/chains/LLMMathChain.py b/src/backend/langflow/components/chains/LLMMathChain.py index 28f430e6d..919de34e6 100644 --- a/src/backend/langflow/components/chains/LLMMathChain.py +++ b/src/backend/langflow/components/chains/LLMMathChain.py @@ -1,15 +1,17 @@ -from typing import Callable, Optional, Union +from typing import Optional from langchain.chains import LLMChain, LLMMathChain from langflow import CustomComponent -from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain +from langflow.field_typing import BaseLanguageModel, BaseMemory, Text class LLMMathChainComponent(CustomComponent): display_name = "LLMMathChain" description = "Chain that interprets a prompt and executes python code to do math." - documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_math" + documentation = ( + "https://python.langchain.com/docs/modules/chains/additional/llm_math" + ) def build_config(self): return { @@ -22,10 +24,21 @@ class LLMMathChainComponent(CustomComponent): def build( self, + inputs: Text, llm: BaseLanguageModel, llm_chain: LLMChain, input_key: str = "question", output_key: str = "answer", memory: Optional[BaseMemory] = None, - ) -> Union[LLMMathChain, Callable, Chain]: - return LLMMathChain(llm=llm, llm_chain=llm_chain, input_key=input_key, output_key=output_key, memory=memory) + ) -> Text: + chain = LLMMathChain( + llm=llm, + llm_chain=llm_chain, + input_key=input_key, + output_key=output_key, + memory=memory, + ) + response = chain.invoke({input_key: inputs}) + result = response.get(output_key) + self.status = result + return result diff --git a/src/backend/langflow/components/chains/SQLGenerator.py b/src/backend/langflow/components/chains/SQLGenerator.py index 5efb0f738..ea22a6de0 100644 --- a/src/backend/langflow/components/chains/SQLGenerator.py +++ b/src/backend/langflow/components/chains/SQLGenerator.py @@ -32,21 +32,39 @@ class SQLGeneratorComponent(CustomComponent): db: SQLDatabase, llm: BaseLanguageModel, top_k: int = 5, - prompt: Optional[PromptTemplate] = None, + prompt: Optional[Text] = None, ) -> Text: + if prompt: + prompt_template = PromptTemplate.from_template(template=prompt) + else: + prompt_template = None + if top_k > 0: kwargs = { "k": top_k, } - if not prompt: + if not prompt_template: sql_query_chain = create_sql_query_chain(llm=llm, db=db, **kwargs) else: - template = prompt.template if hasattr(prompt, "template") else prompt + template = ( + prompt_template.template + if hasattr(prompt, "template") + else prompt_template + ) # Check if {question} is in the prompt - if "{question}" not in template or "question" not in template.input_variables: - raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.") - sql_query_chain = create_sql_query_chain(llm=llm, db=db, prompt=prompt, **kwargs) - query_writer = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()} + if ( + "{question}" not in template + or "question" not in template.input_variables + ): + raise ValueError( + "Prompt must contain `{question}` to be used with Natural Language to SQL." + ) + sql_query_chain = create_sql_query_chain( + llm=llm, db=db, prompt=prompt_template, **kwargs + ) + query_writer = sql_query_chain | { + "query": lambda x: x.replace("SQLQuery:", "").strip() + } response = query_writer.invoke({"question": inputs}) query = response.get("query") self.status = query diff --git a/src/backend/langflow/components/io/ChatInput.py b/src/backend/langflow/components/io/ChatInput.py index 42f2715e6..4d8fc509c 100644 --- a/src/backend/langflow/components/io/ChatInput.py +++ b/src/backend/langflow/components/io/ChatInput.py @@ -1,7 +1,6 @@ -from typing import Optional, Union +from typing import Optional from langflow import CustomComponent -from langflow.field_typing import Text from langflow.schema import Record @@ -25,9 +24,9 @@ class ChatInput(CustomComponent): "display_name": "Session ID", "info": "Session ID of the chat history.", }, - "as_record": { - "display_name": "As Record", - "info": "If true, the message will be returned as a Record.", + "return_record": { + "display_name": "Return Record", + "info": "Return the message as a record containing the sender, sender_name, and session_id.", }, } @@ -36,25 +35,24 @@ class ChatInput(CustomComponent): sender: Optional[str] = "User", sender_name: Optional[str] = "User", message: Optional[str] = None, - as_record: Optional[bool] = False, session_id: Optional[str] = None, - ) -> Union[Text, Record]: - self.status = message - if as_record: + return_record: Optional[bool] = False, + ) -> Record: + if return_record: if isinstance(message, Record): # Update the data of the record message.data["sender"] = sender message.data["sender_name"] = sender_name message.data["session_id"] = session_id - return message - return Record( - text=message, - data={ - "sender": sender, - "sender_name": sender_name, - "session_id": session_id, - }, - ) + else: + message = Record( + text=message, + data={ + "sender": sender, + "sender_name": sender_name, + "session_id": session_id, + }, + ) if not message: message = "" self.status = message diff --git a/src/backend/langflow/components/io/ChatOutput.py b/src/backend/langflow/components/io/ChatOutput.py index 5896adb6b..05639cdb2 100644 --- a/src/backend/langflow/components/io/ChatOutput.py +++ b/src/backend/langflow/components/io/ChatOutput.py @@ -28,9 +28,9 @@ class ChatOutput(CustomComponent): "info": "Session ID of the chat history.", "input_types": ["Text"], }, - "as_record": { - "display_name": "As Record", - "info": "If true, the message will be returned as a Record.", + "return_record": { + "display_name": "Return Record", + "info": "Return the message as a record containing the sender, sender_name, and session_id.", }, } @@ -40,25 +40,23 @@ class ChatOutput(CustomComponent): sender_name: Optional[str] = "AI", session_id: Optional[str] = None, message: Optional[str] = None, - as_record: Optional[bool] = False, + return_record: Optional[bool] = False, ) -> Union[Text, Record]: - self.status = message - if as_record: + if return_record: if isinstance(message, Record): # Update the data of the record message.data["sender"] = sender message.data["sender_name"] = sender_name message.data["session_id"] = session_id - - return message - return Record( - text=message, - data={ - "sender": sender, - "sender_name": sender_name, - "session_id": session_id, - }, - ) + else: + message = Record( + text=message, + data={ + "sender": sender, + "sender_name": sender_name, + "session_id": session_id, + }, + ) if not message: message = "" self.status = message diff --git a/src/backend/langflow/components/utilities/RecordsAsText.py b/src/backend/langflow/components/utilities/RecordsAsText.py index 8c4be331a..b47cf1b04 100644 --- a/src/backend/langflow/components/utilities/RecordsAsText.py +++ b/src/backend/langflow/components/utilities/RecordsAsText.py @@ -27,7 +27,10 @@ class RecordsAsTextComponent(CustomComponent): if isinstance(records, Record): records = [records] - formated_records = [template.format(text=record.text, **record.data) for record in records] + formated_records = [ + template.format(text=record.text, data=record.data, **record.data) + for record in records + ] result_string = "\n".join(formated_records) self.status = result_string return result_string diff --git a/src/backend/langflow/components/vectorstores/Chroma.py b/src/backend/langflow/components/vectorstores/Chroma.py index 69e978701..d9b617e61 100644 --- a/src/backend/langflow/components/vectorstores/Chroma.py +++ b/src/backend/langflow/components/vectorstores/Chroma.py @@ -84,7 +84,8 @@ class ChromaComponent(CustomComponent): if chroma_server_host is not None: chroma_settings = chromadb.config.Settings( - chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None, + chroma_server_cors_allow_origins=chroma_server_cors_allow_origins + or None, chroma_server_host=chroma_server_host, chroma_server_port=chroma_server_port or None, chroma_server_grpc_port=chroma_server_grpc_port or None, @@ -99,12 +100,14 @@ class ChromaComponent(CustomComponent): if documents is not None and embedding is not None: if len(documents) == 0: - raise ValueError("If documents are provided, there must be at least one document.") + raise ValueError( + "If documents are provided, there must be at least one document." + ) chroma = Chroma.from_documents( documents=documents, # type: ignore persist_directory=index_directory, collection_name=collection_name, - embedding_function=embedding, + embedding=embedding, client_settings=chroma_settings, ) else: diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index 85ac1785d..df3b83434 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -11,31 +11,6 @@ agents: documentation: "" SQLAgent: documentation: "" -chains: - # LLMChain: - # documentation: "https://python.langchain.com/docs/modules/chains/foundational/llm_chain" - LLMMathChain: - documentation: "https://python.langchain.com/docs/modules/chains/additional/llm_math" - LLMCheckerChain: - documentation: "https://python.langchain.com/docs/modules/chains/additional/llm_checker" - # ConversationChain: - # documentation: "" - SeriesCharacterChain: - documentation: "" - MidJourneyPromptChain: - documentation: "" - TimeTravelGuideChain: - documentation: "" - SQLDatabaseChain: - documentation: "" - RetrievalQA: - documentation: "https://python.langchain.com/docs/modules/chains/popular/vector_db_qa" - RetrievalQAWithSourcesChain: - documentation: "" - ConversationalRetrievalChain: - documentation: "https://python.langchain.com/docs/modules/chains/popular/chat_vector_db" - CombineDocsChain: - documentation: "" documentloaders: AirbyteJSONLoader: documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/airbyte_json" diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 59a8966b5..1ea34fd51 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -9,13 +9,8 @@ from langflow.graph.graph.constants import lazy_load_vertex_dict from langflow.graph.graph.utils import process_flow from langflow.graph.schema import InterfaceComponentTypes from langflow.graph.vertex.base import Vertex -from langflow.graph.vertex.types import ( - ChatVertex, - FileToolVertex, - LLMVertex, - RoutingVertex, - ToolkitVertex, -) +from langflow.graph.vertex.types import (ChatVertex, FileToolVertex, LLMVertex, + RoutingVertex, ToolkitVertex) from langflow.interface.tools.constants import FILE_TOOLS from langflow.utils import payload @@ -85,7 +80,9 @@ class Graph: def build_parent_child_map(self): parent_child_map = defaultdict(list) for vertex in self.vertices: - parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)] + parent_child_map[vertex.id] = [ + child.id for child in self.get_successors(vertex) + ] return parent_child_map def increment_run_count(self): @@ -149,6 +146,16 @@ class Graph: # both graphs have the same vertices and edges # but the data of the vertices might be different + def update_edges_from_vertex(self, vertex: Vertex, other_vertex: Vertex) -> None: + """Updates the edges of a vertex in the Graph.""" + new_edges = [] + for edge in self.edges: + if edge.source_id == other_vertex.id or edge.target_id == other_vertex.id: + continue + new_edges.append(edge) + new_edges += other_vertex.edges + self.edges = new_edges + def vertex_data_is_identical(self, vertex: Vertex, other_vertex: Vertex) -> bool: return vertex.__repr__() == other_vertex.__repr__() @@ -173,10 +180,6 @@ class Graph: # Find vertices that are in self but not in other (removed vertices) removed_vertex_ids = existing_vertex_ids - other_vertex_ids - # Create a set for new edges - edges_to_add = set() - edges_to_remove = set() - # Update existing vertices that have changed for vertex_id in existing_vertex_ids.intersection(other_vertex_ids): self_vertex = self.get_vertex(vertex_id) @@ -184,6 +187,8 @@ class Graph: if not self.vertex_data_is_identical(self_vertex, other_vertex): self_vertex._data = other_vertex._data self_vertex._parse_data() + # Now we update the edges of the vertex + self.update_edges_from_vertex(self_vertex, other_vertex) self_vertex.params = {} self_vertex._build_params() self_vertex.graph = self @@ -195,25 +200,6 @@ class Graph: self_vertex.artifacts = None self_vertex.set_top_level(self.top_level_vertices) self.reset_all_edges_of_vertex(self_vertex) - if not self.vertex_edges_are_identical(self_vertex, other_vertex): - # New edges are the edges of the other vertex and not the self vertex - # If there are more edges in the other vertex than in the self vertex - # then we need to add the new edges to the self vertex - # if there are less edges in the other vertex than in the self vertex - # then we need to remove the edges that are not in the other vertex - - if len(self_vertex.edges) < len(other_vertex.edges): - edges_to_add.update(edge for edge in other_vertex.edges if edge not in self_vertex.edges) - elif len(self_vertex.edges) > len(other_vertex.edges): - edges_to_remove.update(edge for edge in self_vertex.edges if edge not in other_vertex.edges) - - # Add new edges - # to self.edges if they are not already in self.edges - for edge in edges_to_add: - if edge not in self.edges: - self.edges.append(edge) - for edge in edges_to_remove: - self.edges.remove(edge) # Remove vertices for vertex_id in removed_vertex_ids: @@ -290,7 +276,11 @@ class Graph: return self.vertices.remove(vertex) self.vertex_map.pop(vertex_id) - self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id] + self.edges = [ + edge + for edge in self.edges + if edge.source_id != vertex_id and edge.target_id != vertex_id + ] def _build_vertex_params(self) -> None: """Identifies and handles the LLM vertex within the graph.""" @@ -311,7 +301,9 @@ class Graph: return for vertex in self.vertices: if not self._validate_vertex(vertex): - raise ValueError(f"{vertex.vertex_type} is not connected to any other components") + raise ValueError( + f"{vertex.display_name} is not connected to any other components" + ) def _validate_vertex(self, vertex: Vertex) -> bool: """Validates a vertex.""" @@ -327,7 +319,11 @@ class Graph: def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]: """Returns a list of edges for a given vertex.""" - return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id] + return [ + edge + for edge in self.edges + if edge.source_id == vertex_id or edge.target_id == vertex_id + ] def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]: """Returns the vertices connected to a vertex.""" @@ -365,7 +361,9 @@ class Graph: def dfs(vertex): if state[vertex] == 1: # We have a cycle - raise ValueError("Graph contains a cycle, cannot perform topological sort") + raise ValueError( + "Graph contains a cycle, cannot perform topological sort" + ) if state[vertex] == 0: state[vertex] = 1 for edge in vertex.edges: @@ -389,11 +387,17 @@ class Graph: def get_predecessors(self, vertex): """Returns the predecessors of a vertex.""" - return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])] + return [ + self.get_vertex(source_id) + for source_id in self.predecessor_map.get(vertex.id, []) + ] def get_successors(self, vertex): """Returns the successors of a vertex.""" - return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])] + return [ + self.get_vertex(target_id) + for target_id in self.successor_map.get(vertex.id, []) + ] def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]: """Returns the neighbors of a vertex.""" @@ -432,7 +436,9 @@ class Graph: edges.append(ContractEdge(source, target, edge)) return edges - def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]: + def _get_vertex_class( + self, node_type: str, node_base_type: str, node_id: str + ) -> Type[Vertex]: """Returns the node class based on the node type.""" # First we check for the node_base_type node_name = node_id.split("-")[0] @@ -463,14 +469,18 @@ class Graph: vertex_type: str = vertex_data["type"] # type: ignore vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) + VertexClass = self._get_vertex_class( + vertex_type, vertex_base_type, vertex_data["id"] + ) vertex_instance = VertexClass(vertex, graph=self) vertex_instance.set_top_level(self.top_level_vertices) vertices.append(vertex_instance) return vertices - def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]: + def get_children_by_vertex_type( + self, vertex: Vertex, vertex_type: str + ) -> List[Vertex]: """Returns the children of a vertex based on the vertex type.""" children = [] vertex_types = [vertex.data["type"]] @@ -482,7 +492,9 @@ class Graph: def __repr__(self): vertex_ids = [vertex.id for vertex in self.vertices] - edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]) + edges_repr = "\n".join( + [f"{edge.source_id} --> {edge.target_id}" for edge in self.edges] + ) return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}" def sort_up_to_vertex(self, vertex_id: str) -> "Graph": @@ -513,7 +525,9 @@ class Graph: """Performs a layered topological sort of the vertices in the graph.""" # Queue for vertices with no incoming edges - queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0) + queue = deque( + vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0 + ) layers = [] current_layer = 0 @@ -569,7 +583,9 @@ class Graph: return refined_layers - def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_chat_inputs_first( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: chat_inputs_first = [] for layer in vertices_layers: for vertex_id in layer: @@ -597,11 +613,15 @@ class Graph: self.increment_run_count() return vertices_layers - def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_interface_components_first( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: """Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first.""" def contains_interface_component(vertex): - return any(component.value in vertex for component in InterfaceComponentTypes) + return any( + component.value in vertex for component in InterfaceComponentTypes + ) # Sort each inner list so that vertices containing ChatInput or ChatOutput come first sorted_vertices = [ @@ -620,9 +640,13 @@ class Graph: """Sorts the vertices in the graph so that vertices with the lowest average build time come first.""" if len(vertices_ids) == 1: return vertices_ids - vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time) + vertices_ids.sort( + key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time + ) return vertices_ids - sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers] + sorted_vertices = [ + sort_layer_by_avg_build_time(layer) for layer in vertices_layers + ] return sorted_vertices diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 25b5f1b30..d31edfbce 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -2,7 +2,8 @@ import ast import inspect import types from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional +from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, + Optional) from loguru import logger @@ -72,11 +73,17 @@ class Vertex: def set_state(self, state: str): self.state = VertexStates[state] - if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] < 2: + if ( + self.state == VertexStates.INACTIVE + and self.graph.in_degree_map[self.id] < 2 + ): # If the vertex is inactive and has only one in degree # it means that it is not a merge point in the graph self.graph.inactive_vertices.add(self.id) - elif self.state == VertexStates.ACTIVE and self.id in self.graph.inactive_vertices: + elif ( + self.state == VertexStates.ACTIVE + and self.id in self.graph.inactive_vertices + ): self.graph.inactive_vertices.remove(self.id) @property @@ -104,7 +111,9 @@ class Vertex: ): if edge.target_id not in edge_results: edge_results[edge.target_id] = {} - edge_results[edge.target_id][edge.target_param] = await edge.get_result(source=self, target=target) + edge_results[edge.target_id][edge.target_param] = await edge.get_result( + source=self, target=target + ) return edge_results def set_result(self, result: "ResultData") -> None: @@ -114,7 +123,9 @@ class Vertex: # If the Vertex.type is a power component # then we need to return the built object # instead of the result dict - if self.is_interface_component and not isinstance(self._built_object, UnbuiltObject): + if self.is_interface_component and not isinstance( + self._built_object, UnbuiltObject + ): result = self._built_object # if it is not a dict or a string and hasattr model_dump then # return the model_dump @@ -124,7 +135,11 @@ class Vertex: if isinstance(self._built_result, UnbuiltResult): return {} - return self._built_result if isinstance(self._built_result, dict) else {"result": self._built_result} + return ( + self._built_result + if isinstance(self._built_result, dict) + else {"result": self._built_result} + ) def set_artifacts(self) -> None: pass @@ -187,17 +202,29 @@ class Vertex: self.output = self.data["node"]["base_classes"] self.display_name = self.data["node"]["display_name"] self.pinned = self.data["node"].get("pinned", False) - template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} + template_dicts = { + key: value + for key, value in self.data["node"]["template"].items() + if isinstance(value, dict) + } self.required_inputs = [ - template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"] + template_dicts[key]["type"] + for key, value in template_dicts.items() + if value["required"] ] self.optional_inputs = [ - template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"] + template_dicts[key]["type"] + for key, value in template_dicts.items() + if not value["required"] ] # Add the template_dicts[key]["input_types"] to the optional_inputs self.optional_inputs.extend( - [input_type for value in template_dicts.values() for input_type in value.get("input_types", [])] + [ + input_type + for value in template_dicts.values() + for input_type in value.get("input_types", []) + ] ) template_dict = self.data["node"]["template"] @@ -240,7 +267,11 @@ class Vertex: if self.graph is None: raise ValueError("Graph not found") - template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} + template_dict = { + key: value + for key, value in self.data["node"]["template"].items() + if isinstance(value, dict) + } params = {} for edge in self.edges: @@ -278,7 +309,7 @@ class Vertex: full_path = storage_service.build_full_path(flow_id, file_name) params[key] = full_path else: - raise ValueError(f"File path not found for {self.vertex_type}") + raise ValueError(f"File path not found for {self.display_name}") elif value.get("type") in DIRECT_TYPES and params.get(key) is None: val = value.get("value") if value.get("type") == "code": @@ -291,7 +322,11 @@ class Vertex: # list of dicts, so we need to convert it to a dict # before passing it to the build method if isinstance(val, list): - params[key] = {k: v for item in value.get("value", []) for k, v in item.items()} + params[key] = { + k: v + for item in value.get("value", []) + for k, v in item.items() + } elif isinstance(val, dict): params[key] = val elif value.get("type") == "int" and val is not None: @@ -327,7 +362,7 @@ class Vertex: """ Initiate the build process. """ - logger.debug(f"Building {self.vertex_type}") + logger.debug(f"Building {self.display_name}") await self._build_each_node_in_params_dict(user_id) await self._get_and_instantiate_class(user_id) self._validate_built_object() @@ -354,7 +389,9 @@ class Vertex: if isinstance(self._built_object, str): self._built_result = self._built_object - result = await generate_result(self._built_object, inputs, self.has_external_output, session_id) + result = await generate_result( + self._built_object, inputs, self.has_external_output, session_id + ) self._built_result = result async def _build_each_node_in_params_dict(self, user_id=None): @@ -382,7 +419,9 @@ class Vertex: """ return all(self._is_node(node) for node in value) - async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any: + async def get_result( + self, requester: Optional["Vertex"] = None, user_id=None, timeout=None + ) -> Any: # PLEASE REVIEW THIS IF STATEMENT # Check if the Vertex was built already if self._built: @@ -416,7 +455,9 @@ class Vertex: self._extend_params_list_with_result(key, result) self.params[key] = result - async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None): + async def _build_list_of_nodes_and_update_params( + self, key, nodes: List["Vertex"], user_id=None + ): """ Iterates over a list of nodes, builds each and updates the params dictionary. """ @@ -457,7 +498,7 @@ class Vertex: Gets the class from a dictionary and instantiates it with the params. """ if self.base_type is None: - raise ValueError(f"Base type for node {self.vertex_type} not found") + raise ValueError(f"Base type for node {self.display_name} not found") try: result = await loading.instantiate_class( node_type=self.vertex_type, @@ -468,7 +509,9 @@ class Vertex: self._update_built_object_and_artifacts(result) except Exception as exc: logger.exception(exc) - raise ValueError(f"Error building node {self.display_name}: {str(exc)}") from exc + raise ValueError( + f"Error building node {self.display_name}: {str(exc)}" + ) from exc def _update_built_object_and_artifacts(self, result): """ @@ -484,9 +527,9 @@ class Vertex: Checks if the built object is None and raises a ValueError if so. """ if isinstance(self._built_object, UnbuiltObject): - raise ValueError(f"{self.vertex_type}: {self._built_object_repr()}") + raise ValueError(f"{self.display_name}: {self._built_object_repr()}") elif self._built_object is None: - message = f"{self.vertex_type} returned None." + message = f"{self.display_name} returned None." if self.base_type == "custom_components": message += " Make sure your build method returns a component." @@ -498,6 +541,7 @@ class Vertex: self._built_result = UnbuiltResult() self.artifacts = {} self.steps_ran = [] + self._build_params() def build_inactive(self): # Just set the results to None @@ -538,16 +582,24 @@ class Vertex: return self._built_object # Get the requester edge - requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None) + requester_edge = next( + (edge for edge in self.edges if edge.target_id == requester.id), None + ) # Return the result of the requester edge - return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester) + return ( + None + if requester_edge is None + else await requester_edge.get_result(source=self, target=requester) + ) def add_edge(self, edge: "ContractEdge") -> None: if edge not in self.edges: self.edges.append(edge) def __repr__(self) -> str: - return f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})" + return ( + f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})" + ) def __eq__(self, __o: object) -> bool: try: @@ -560,7 +612,11 @@ class Vertex: def _built_object_repr(self): # Add a message with an emoji, stars for sucess, - return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵💫" + return ( + "Built sucessfully ✨" + if self._built_object is not None + else "Failed to build 😵💫" + ) class StatefulVertex(Vertex): diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index cec1f6caf..b7746b6a1 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -122,10 +122,12 @@ class DocumentLoaderVertex(StatefulVertex): # show how many documents are in the list? if not isinstance(self._built_object, UnbuiltObject): - avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len( - self._built_object - ) - return f"""{self.vertex_type}({len(self._built_object)} documents) + avg_length = sum( + len(doc.page_content) + for doc in self._built_object + if hasattr(doc, "page_content") + ) / len(self._built_object) + return f"""{self.display_name}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} Documents: {self._built_object[:3]}...""" return f"{self.vertex_type}()" @@ -197,7 +199,9 @@ class TextSplitterVertex(StatefulVertex): # show how many documents are in the list? if not isinstance(self._built_object, UnbuiltObject): - avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object) + avg_length = sum(len(doc.page_content) for doc in self._built_object) / len( + self._built_object + ) return f"""{self.vertex_type}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} \nDocuments: {self._built_object[:3]}...""" @@ -244,18 +248,27 @@ class PromptVertex(StatelessVertex): user_id = kwargs.get("user_id", None) tools = kwargs.get("tools", []) if not self._built or force: - if "input_variables" not in self.params or self.params["input_variables"] is None: + if ( + "input_variables" not in self.params + or self.params["input_variables"] is None + ): self.params["input_variables"] = [] # Check if it is a ZeroShotPrompt and needs a tool if "ShotPrompt" in self.vertex_type: - tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else [] + tools = ( + [tool_node.build(user_id=user_id) for tool_node in tools] + if tools is not None + else [] + ) # flatten the list of tools if it is a list of lists # first check if it is a list if tools and isinstance(tools, list) and isinstance(tools[0], list): tools = flatten_list(tools) self.params["tools"] = tools prompt_params = [ - key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions" + key + for key, value in self.params.items() + if isinstance(value, str) and key != "format_instructions" ] else: prompt_params = ["template"] @@ -265,14 +278,20 @@ class PromptVertex(StatelessVertex): prompt_text = self.params[param] variables = extract_input_variables_from_prompt(prompt_text) self.params["input_variables"].extend(variables) - self.params["input_variables"] = list(set(self.params["input_variables"])) + self.params["input_variables"] = list( + set(self.params["input_variables"]) + ) elif isinstance(self.params, dict): self.params.pop("input_variables", None) await self._build(user_id=user_id) def _built_object_repr(self): - if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"): + if ( + not self.artifacts + or self._built_object is None + or not hasattr(self._built_object, "format") + ): return super()._built_object_repr() elif isinstance(self._built_object, UnbuiltObject): return super()._built_object_repr() @@ -284,7 +303,9 @@ class PromptVertex(StatelessVertex): # so the prompt format doesn't break artifacts.pop("handle_keys", None) try: - if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"): + if not hasattr(self._built_object, "template") and hasattr( + self._built_object, "prompt" + ): template = self._built_object.prompt.template else: template = self._built_object.template @@ -292,7 +313,11 @@ class PromptVertex(StatelessVertex): if value: replace_key = "{" + key + "}" template = template.replace(replace_key, value) - return template if isinstance(template, str) else f"{self.vertex_type}({template})" + return ( + template + if isinstance(template, str) + else f"{self.vertex_type}({template})" + ) except KeyError: return str(self._built_object) diff --git a/src/frontend/src/CustomNodes/GenericNode/index.tsx b/src/frontend/src/CustomNodes/GenericNode/index.tsx index c42a6474d..8d59ac002 100644 --- a/src/frontend/src/CustomNodes/GenericNode/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/index.tsx @@ -477,11 +477,11 @@ export default function GenericNode({ ) : (