Merge branch 'zustand/io/migration' of github.com:logspace-ai/langflow into zustand/io/migration

This commit is contained in:
igorrCarvalho 2024-02-26 21:08:49 -03:00
commit ab60f59578
26 changed files with 466 additions and 287 deletions

21
poetry.lock generated
View file

@ -3681,6 +3681,24 @@ babel = ["Babel"]
lingua = ["lingua"]
testing = ["pytest"]
[[package]]
name = "markdown"
version = "3.5.2"
description = "Python implementation of John Gruber's Markdown."
optional = false
python-versions = ">=3.8"
files = [
{file = "Markdown-3.5.2-py3-none-any.whl", hash = "sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd"},
{file = "Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8"},
]
[package.dependencies]
importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""}
[package.extras]
docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"]
testing = ["coverage", "pyyaml"]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
@ -8212,6 +8230,7 @@ emoji = "*"
filetype = "*"
langdetect = "*"
lxml = "*"
markdown = {version = "*", optional = true, markers = "extra == \"md\""}
nltk = "*"
numpy = "*"
python-iso639 = "*"
@ -9026,4 +9045,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.12"
content-hash = "1462954b3befc2989ae226f2214111be786eb05bade578c9c80b4ed80d5b59ff"
content-hash = "b35a356770d3425f524b0c46a449696db1fa7c13fae77324188cb6ffa4a4c5a7"

View file

@ -105,7 +105,7 @@ pytube = "^15.0.0"
python-socketio = "^5.11.0"
llama-index = "0.9.48"
langchain-openai = "^0.0.6"
unstructured = "^0.12.4"
unstructured = {extras = ["md"], version = "^0.12.4"}
[tool.poetry.group.dev.dependencies]
pytest-asyncio = "^0.23.1"

View file

@ -1,9 +1,9 @@
from typing import Callable, Optional, Union
from typing import Optional
from langchain.chains import ConversationChain
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain, Text
from langflow.field_typing import BaseLanguageModel, BaseMemory, Text
class ConversationChainComponent(CustomComponent):
@ -26,7 +26,7 @@ class ConversationChainComponent(CustomComponent):
inputs: str,
llm: BaseLanguageModel,
memory: Optional[BaseMemory] = None,
) -> Union[Chain, Callable, Text]:
) -> Text:
if memory is None:
chain = ConversationChain(llm=llm)
else:

View file

@ -1,14 +1,15 @@
from typing import Callable, Union
from langchain.chains import LLMCheckerChain
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel, Chain
from langflow.field_typing import BaseLanguageModel, Text
class LLMCheckerChainComponent(CustomComponent):
display_name = "LLMCheckerChain"
description = ""
documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_checker"
documentation = (
"https://python.langchain.com/docs/modules/chains/additional/llm_checker"
)
def build_config(self):
return {
@ -17,6 +18,12 @@ class LLMCheckerChainComponent(CustomComponent):
def build(
self,
inputs: str,
llm: BaseLanguageModel,
) -> Union[Chain, Callable]:
return LLMCheckerChain.from_llm(llm=llm)
) -> Text:
chain = LLMCheckerChain.from_llm(llm=llm)
response = chain.invoke({chain.input_key: inputs})
result = response.get(chain.output_key)
self.status = result
return result

View file

@ -1,15 +1,17 @@
from typing import Callable, Optional, Union
from typing import Optional
from langchain.chains import LLMChain, LLMMathChain
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain
from langflow.field_typing import BaseLanguageModel, BaseMemory, Text
class LLMMathChainComponent(CustomComponent):
display_name = "LLMMathChain"
description = "Chain that interprets a prompt and executes python code to do math."
documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_math"
documentation = (
"https://python.langchain.com/docs/modules/chains/additional/llm_math"
)
def build_config(self):
return {
@ -22,10 +24,21 @@ class LLMMathChainComponent(CustomComponent):
def build(
self,
inputs: Text,
llm: BaseLanguageModel,
llm_chain: LLMChain,
input_key: str = "question",
output_key: str = "answer",
memory: Optional[BaseMemory] = None,
) -> Union[LLMMathChain, Callable, Chain]:
return LLMMathChain(llm=llm, llm_chain=llm_chain, input_key=input_key, output_key=output_key, memory=memory)
) -> Text:
chain = LLMMathChain(
llm=llm,
llm_chain=llm_chain,
input_key=input_key,
output_key=output_key,
memory=memory,
)
response = chain.invoke({input_key: inputs})
result = response.get(output_key)
self.status = result
return result

View file

@ -32,21 +32,39 @@ class SQLGeneratorComponent(CustomComponent):
db: SQLDatabase,
llm: BaseLanguageModel,
top_k: int = 5,
prompt: Optional[PromptTemplate] = None,
prompt: Optional[Text] = None,
) -> Text:
if prompt:
prompt_template = PromptTemplate.from_template(template=prompt)
else:
prompt_template = None
if top_k > 0:
kwargs = {
"k": top_k,
}
if not prompt:
if not prompt_template:
sql_query_chain = create_sql_query_chain(llm=llm, db=db, **kwargs)
else:
template = prompt.template if hasattr(prompt, "template") else prompt
template = (
prompt_template.template
if hasattr(prompt, "template")
else prompt_template
)
# Check if {question} is in the prompt
if "{question}" not in template or "question" not in template.input_variables:
raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.")
sql_query_chain = create_sql_query_chain(llm=llm, db=db, prompt=prompt, **kwargs)
query_writer = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()}
if (
"{question}" not in template
or "question" not in template.input_variables
):
raise ValueError(
"Prompt must contain `{question}` to be used with Natural Language to SQL."
)
sql_query_chain = create_sql_query_chain(
llm=llm, db=db, prompt=prompt_template, **kwargs
)
query_writer = sql_query_chain | {
"query": lambda x: x.replace("SQLQuery:", "").strip()
}
response = query_writer.invoke({"question": inputs})
query = response.get("query")
self.status = query

View file

@ -1,7 +1,6 @@
from typing import Optional, Union
from typing import Optional
from langflow import CustomComponent
from langflow.field_typing import Text
from langflow.schema import Record
@ -25,9 +24,9 @@ class ChatInput(CustomComponent):
"display_name": "Session ID",
"info": "Session ID of the chat history.",
},
"as_record": {
"display_name": "As Record",
"info": "If true, the message will be returned as a Record.",
"return_record": {
"display_name": "Return Record",
"info": "Return the message as a record containing the sender, sender_name, and session_id.",
},
}
@ -36,25 +35,24 @@ class ChatInput(CustomComponent):
sender: Optional[str] = "User",
sender_name: Optional[str] = "User",
message: Optional[str] = None,
as_record: Optional[bool] = False,
session_id: Optional[str] = None,
) -> Union[Text, Record]:
self.status = message
if as_record:
return_record: Optional[bool] = False,
) -> Record:
if return_record:
if isinstance(message, Record):
# Update the data of the record
message.data["sender"] = sender
message.data["sender_name"] = sender_name
message.data["session_id"] = session_id
return message
return Record(
text=message,
data={
"sender": sender,
"sender_name": sender_name,
"session_id": session_id,
},
)
else:
message = Record(
text=message,
data={
"sender": sender,
"sender_name": sender_name,
"session_id": session_id,
},
)
if not message:
message = ""
self.status = message

View file

@ -28,9 +28,9 @@ class ChatOutput(CustomComponent):
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"as_record": {
"display_name": "As Record",
"info": "If true, the message will be returned as a Record.",
"return_record": {
"display_name": "Return Record",
"info": "Return the message as a record containing the sender, sender_name, and session_id.",
},
}
@ -40,25 +40,23 @@ class ChatOutput(CustomComponent):
sender_name: Optional[str] = "AI",
session_id: Optional[str] = None,
message: Optional[str] = None,
as_record: Optional[bool] = False,
return_record: Optional[bool] = False,
) -> Union[Text, Record]:
self.status = message
if as_record:
if return_record:
if isinstance(message, Record):
# Update the data of the record
message.data["sender"] = sender
message.data["sender_name"] = sender_name
message.data["session_id"] = session_id
return message
return Record(
text=message,
data={
"sender": sender,
"sender_name": sender_name,
"session_id": session_id,
},
)
else:
message = Record(
text=message,
data={
"sender": sender,
"sender_name": sender_name,
"session_id": session_id,
},
)
if not message:
message = ""
self.status = message

View file

@ -27,7 +27,10 @@ class RecordsAsTextComponent(CustomComponent):
if isinstance(records, Record):
records = [records]
formated_records = [template.format(text=record.text, **record.data) for record in records]
formated_records = [
template.format(text=record.text, data=record.data, **record.data)
for record in records
]
result_string = "\n".join(formated_records)
self.status = result_string
return result_string

View file

@ -84,7 +84,8 @@ class ChromaComponent(CustomComponent):
if chroma_server_host is not None:
chroma_settings = chromadb.config.Settings(
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
or None,
chroma_server_host=chroma_server_host,
chroma_server_port=chroma_server_port or None,
chroma_server_grpc_port=chroma_server_grpc_port or None,
@ -99,12 +100,14 @@ class ChromaComponent(CustomComponent):
if documents is not None and embedding is not None:
if len(documents) == 0:
raise ValueError("If documents are provided, there must be at least one document.")
raise ValueError(
"If documents are provided, there must be at least one document."
)
chroma = Chroma.from_documents(
documents=documents, # type: ignore
persist_directory=index_directory,
collection_name=collection_name,
embedding_function=embedding,
embedding=embedding,
client_settings=chroma_settings,
)
else:

View file

@ -11,31 +11,6 @@ agents:
documentation: ""
SQLAgent:
documentation: ""
chains:
# LLMChain:
# documentation: "https://python.langchain.com/docs/modules/chains/foundational/llm_chain"
LLMMathChain:
documentation: "https://python.langchain.com/docs/modules/chains/additional/llm_math"
LLMCheckerChain:
documentation: "https://python.langchain.com/docs/modules/chains/additional/llm_checker"
# ConversationChain:
# documentation: ""
SeriesCharacterChain:
documentation: ""
MidJourneyPromptChain:
documentation: ""
TimeTravelGuideChain:
documentation: ""
SQLDatabaseChain:
documentation: ""
RetrievalQA:
documentation: "https://python.langchain.com/docs/modules/chains/popular/vector_db_qa"
RetrievalQAWithSourcesChain:
documentation: ""
ConversationalRetrievalChain:
documentation: "https://python.langchain.com/docs/modules/chains/popular/chat_vector_db"
CombineDocsChain:
documentation: ""
documentloaders:
AirbyteJSONLoader:
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/airbyte_json"

View file

@ -9,13 +9,8 @@ from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.utils import process_flow
from langflow.graph.schema import InterfaceComponentTypes
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import (
ChatVertex,
FileToolVertex,
LLMVertex,
RoutingVertex,
ToolkitVertex,
)
from langflow.graph.vertex.types import (ChatVertex, FileToolVertex, LLMVertex,
RoutingVertex, ToolkitVertex)
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.utils import payload
@ -85,7 +80,9 @@ class Graph:
def build_parent_child_map(self):
parent_child_map = defaultdict(list)
for vertex in self.vertices:
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
parent_child_map[vertex.id] = [
child.id for child in self.get_successors(vertex)
]
return parent_child_map
def increment_run_count(self):
@ -149,6 +146,16 @@ class Graph:
# both graphs have the same vertices and edges
# but the data of the vertices might be different
def update_edges_from_vertex(self, vertex: Vertex, other_vertex: Vertex) -> None:
"""Updates the edges of a vertex in the Graph."""
new_edges = []
for edge in self.edges:
if edge.source_id == other_vertex.id or edge.target_id == other_vertex.id:
continue
new_edges.append(edge)
new_edges += other_vertex.edges
self.edges = new_edges
def vertex_data_is_identical(self, vertex: Vertex, other_vertex: Vertex) -> bool:
return vertex.__repr__() == other_vertex.__repr__()
@ -173,10 +180,6 @@ class Graph:
# Find vertices that are in self but not in other (removed vertices)
removed_vertex_ids = existing_vertex_ids - other_vertex_ids
# Create a set for new edges
edges_to_add = set()
edges_to_remove = set()
# Update existing vertices that have changed
for vertex_id in existing_vertex_ids.intersection(other_vertex_ids):
self_vertex = self.get_vertex(vertex_id)
@ -184,6 +187,8 @@ class Graph:
if not self.vertex_data_is_identical(self_vertex, other_vertex):
self_vertex._data = other_vertex._data
self_vertex._parse_data()
# Now we update the edges of the vertex
self.update_edges_from_vertex(self_vertex, other_vertex)
self_vertex.params = {}
self_vertex._build_params()
self_vertex.graph = self
@ -195,25 +200,6 @@ class Graph:
self_vertex.artifacts = None
self_vertex.set_top_level(self.top_level_vertices)
self.reset_all_edges_of_vertex(self_vertex)
if not self.vertex_edges_are_identical(self_vertex, other_vertex):
# New edges are the edges of the other vertex and not the self vertex
# If there are more edges in the other vertex than in the self vertex
# then we need to add the new edges to the self vertex
# if there are less edges in the other vertex than in the self vertex
# then we need to remove the edges that are not in the other vertex
if len(self_vertex.edges) < len(other_vertex.edges):
edges_to_add.update(edge for edge in other_vertex.edges if edge not in self_vertex.edges)
elif len(self_vertex.edges) > len(other_vertex.edges):
edges_to_remove.update(edge for edge in self_vertex.edges if edge not in other_vertex.edges)
# Add new edges
# to self.edges if they are not already in self.edges
for edge in edges_to_add:
if edge not in self.edges:
self.edges.append(edge)
for edge in edges_to_remove:
self.edges.remove(edge)
# Remove vertices
for vertex_id in removed_vertex_ids:
@ -290,7 +276,11 @@ class Graph:
return
self.vertices.remove(vertex)
self.vertex_map.pop(vertex_id)
self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id]
self.edges = [
edge
for edge in self.edges
if edge.source_id != vertex_id and edge.target_id != vertex_id
]
def _build_vertex_params(self) -> None:
"""Identifies and handles the LLM vertex within the graph."""
@ -311,7 +301,9 @@ class Graph:
return
for vertex in self.vertices:
if not self._validate_vertex(vertex):
raise ValueError(f"{vertex.vertex_type} is not connected to any other components")
raise ValueError(
f"{vertex.display_name} is not connected to any other components"
)
def _validate_vertex(self, vertex: Vertex) -> bool:
"""Validates a vertex."""
@ -327,7 +319,11 @@ class Graph:
def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]:
"""Returns a list of edges for a given vertex."""
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
return [
edge
for edge in self.edges
if edge.source_id == vertex_id or edge.target_id == vertex_id
]
def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]:
"""Returns the vertices connected to a vertex."""
@ -365,7 +361,9 @@ class Graph:
def dfs(vertex):
if state[vertex] == 1:
# We have a cycle
raise ValueError("Graph contains a cycle, cannot perform topological sort")
raise ValueError(
"Graph contains a cycle, cannot perform topological sort"
)
if state[vertex] == 0:
state[vertex] = 1
for edge in vertex.edges:
@ -389,11 +387,17 @@ class Graph:
def get_predecessors(self, vertex):
"""Returns the predecessors of a vertex."""
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
return [
self.get_vertex(source_id)
for source_id in self.predecessor_map.get(vertex.id, [])
]
def get_successors(self, vertex):
"""Returns the successors of a vertex."""
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
return [
self.get_vertex(target_id)
for target_id in self.successor_map.get(vertex.id, [])
]
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
"""Returns the neighbors of a vertex."""
@ -432,7 +436,9 @@ class Graph:
edges.append(ContractEdge(source, target, edge))
return edges
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
def _get_vertex_class(
self, node_type: str, node_base_type: str, node_id: str
) -> Type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
@ -463,14 +469,18 @@ class Graph:
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
VertexClass = self._get_vertex_class(
vertex_type, vertex_base_type, vertex_data["id"]
)
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)
return vertices
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
def get_children_by_vertex_type(
self, vertex: Vertex, vertex_type: str
) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
vertex_types = [vertex.data["type"]]
@ -482,7 +492,9 @@ class Graph:
def __repr__(self):
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
edges_repr = "\n".join(
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
)
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def sort_up_to_vertex(self, vertex_id: str) -> "Graph":
@ -513,7 +525,9 @@ class Graph:
"""Performs a layered topological sort of the vertices in the graph."""
# Queue for vertices with no incoming edges
queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0)
queue = deque(
vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0
)
layers = []
current_layer = 0
@ -569,7 +583,9 @@ class Graph:
return refined_layers
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_chat_inputs_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
chat_inputs_first = []
for layer in vertices_layers:
for vertex_id in layer:
@ -597,11 +613,15 @@ class Graph:
self.increment_run_count()
return vertices_layers
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_interface_components_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
def contains_interface_component(vertex):
return any(component.value in vertex for component in InterfaceComponentTypes)
return any(
component.value in vertex for component in InterfaceComponentTypes
)
# Sort each inner list so that vertices containing ChatInput or ChatOutput come first
sorted_vertices = [
@ -620,9 +640,13 @@ class Graph:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
if len(vertices_ids) == 1:
return vertices_ids
vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time)
vertices_ids.sort(
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
)
return vertices_ids
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
sorted_vertices = [
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
]
return sorted_vertices

View file

@ -2,7 +2,8 @@ import ast
import inspect
import types
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional
from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Dict, List,
Optional)
from loguru import logger
@ -72,11 +73,17 @@ class Vertex:
def set_state(self, state: str):
self.state = VertexStates[state]
if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] < 2:
if (
self.state == VertexStates.INACTIVE
and self.graph.in_degree_map[self.id] < 2
):
# If the vertex is inactive and has only one in degree
# it means that it is not a merge point in the graph
self.graph.inactive_vertices.add(self.id)
elif self.state == VertexStates.ACTIVE and self.id in self.graph.inactive_vertices:
elif (
self.state == VertexStates.ACTIVE
and self.id in self.graph.inactive_vertices
):
self.graph.inactive_vertices.remove(self.id)
@property
@ -104,7 +111,9 @@ class Vertex:
):
if edge.target_id not in edge_results:
edge_results[edge.target_id] = {}
edge_results[edge.target_id][edge.target_param] = await edge.get_result(source=self, target=target)
edge_results[edge.target_id][edge.target_param] = await edge.get_result(
source=self, target=target
)
return edge_results
def set_result(self, result: "ResultData") -> None:
@ -114,7 +123,9 @@ class Vertex:
# If the Vertex.type is a power component
# then we need to return the built object
# instead of the result dict
if self.is_interface_component and not isinstance(self._built_object, UnbuiltObject):
if self.is_interface_component and not isinstance(
self._built_object, UnbuiltObject
):
result = self._built_object
# if it is not a dict or a string and hasattr model_dump then
# return the model_dump
@ -124,7 +135,11 @@ class Vertex:
if isinstance(self._built_result, UnbuiltResult):
return {}
return self._built_result if isinstance(self._built_result, dict) else {"result": self._built_result}
return (
self._built_result
if isinstance(self._built_result, dict)
else {"result": self._built_result}
)
def set_artifacts(self) -> None:
pass
@ -187,17 +202,29 @@ class Vertex:
self.output = self.data["node"]["base_classes"]
self.display_name = self.data["node"]["display_name"]
self.pinned = self.data["node"].get("pinned", False)
template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
template_dicts = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
self.required_inputs = [
template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"]
template_dicts[key]["type"]
for key, value in template_dicts.items()
if value["required"]
]
self.optional_inputs = [
template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"]
template_dicts[key]["type"]
for key, value in template_dicts.items()
if not value["required"]
]
# Add the template_dicts[key]["input_types"] to the optional_inputs
self.optional_inputs.extend(
[input_type for value in template_dicts.values() for input_type in value.get("input_types", [])]
[
input_type
for value in template_dicts.values()
for input_type in value.get("input_types", [])
]
)
template_dict = self.data["node"]["template"]
@ -240,7 +267,11 @@ class Vertex:
if self.graph is None:
raise ValueError("Graph not found")
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
template_dict = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
params = {}
for edge in self.edges:
@ -278,7 +309,7 @@ class Vertex:
full_path = storage_service.build_full_path(flow_id, file_name)
params[key] = full_path
else:
raise ValueError(f"File path not found for {self.vertex_type}")
raise ValueError(f"File path not found for {self.display_name}")
elif value.get("type") in DIRECT_TYPES and params.get(key) is None:
val = value.get("value")
if value.get("type") == "code":
@ -291,7 +322,11 @@ class Vertex:
# list of dicts, so we need to convert it to a dict
# before passing it to the build method
if isinstance(val, list):
params[key] = {k: v for item in value.get("value", []) for k, v in item.items()}
params[key] = {
k: v
for item in value.get("value", [])
for k, v in item.items()
}
elif isinstance(val, dict):
params[key] = val
elif value.get("type") == "int" and val is not None:
@ -327,7 +362,7 @@ class Vertex:
"""
Initiate the build process.
"""
logger.debug(f"Building {self.vertex_type}")
logger.debug(f"Building {self.display_name}")
await self._build_each_node_in_params_dict(user_id)
await self._get_and_instantiate_class(user_id)
self._validate_built_object()
@ -354,7 +389,9 @@ class Vertex:
if isinstance(self._built_object, str):
self._built_result = self._built_object
result = await generate_result(self._built_object, inputs, self.has_external_output, session_id)
result = await generate_result(
self._built_object, inputs, self.has_external_output, session_id
)
self._built_result = result
async def _build_each_node_in_params_dict(self, user_id=None):
@ -382,7 +419,9 @@ class Vertex:
"""
return all(self._is_node(node) for node in value)
async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any:
async def get_result(
self, requester: Optional["Vertex"] = None, user_id=None, timeout=None
) -> Any:
# PLEASE REVIEW THIS IF STATEMENT
# Check if the Vertex was built already
if self._built:
@ -416,7 +455,9 @@ class Vertex:
self._extend_params_list_with_result(key, result)
self.params[key] = result
async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
async def _build_list_of_nodes_and_update_params(
self, key, nodes: List["Vertex"], user_id=None
):
"""
Iterates over a list of nodes, builds each and updates the params dictionary.
"""
@ -457,7 +498,7 @@ class Vertex:
Gets the class from a dictionary and instantiates it with the params.
"""
if self.base_type is None:
raise ValueError(f"Base type for node {self.vertex_type} not found")
raise ValueError(f"Base type for node {self.display_name} not found")
try:
result = await loading.instantiate_class(
node_type=self.vertex_type,
@ -468,7 +509,9 @@ class Vertex:
self._update_built_object_and_artifacts(result)
except Exception as exc:
logger.exception(exc)
raise ValueError(f"Error building node {self.display_name}: {str(exc)}") from exc
raise ValueError(
f"Error building node {self.display_name}: {str(exc)}"
) from exc
def _update_built_object_and_artifacts(self, result):
"""
@ -484,9 +527,9 @@ class Vertex:
Checks if the built object is None and raises a ValueError if so.
"""
if isinstance(self._built_object, UnbuiltObject):
raise ValueError(f"{self.vertex_type}: {self._built_object_repr()}")
raise ValueError(f"{self.display_name}: {self._built_object_repr()}")
elif self._built_object is None:
message = f"{self.vertex_type} returned None."
message = f"{self.display_name} returned None."
if self.base_type == "custom_components":
message += " Make sure your build method returns a component."
@ -498,6 +541,7 @@ class Vertex:
self._built_result = UnbuiltResult()
self.artifacts = {}
self.steps_ran = []
self._build_params()
def build_inactive(self):
# Just set the results to None
@ -538,16 +582,24 @@ class Vertex:
return self._built_object
# Get the requester edge
requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None)
requester_edge = next(
(edge for edge in self.edges if edge.target_id == requester.id), None
)
# Return the result of the requester edge
return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester)
return (
None
if requester_edge is None
else await requester_edge.get_result(source=self, target=requester)
)
def add_edge(self, edge: "ContractEdge") -> None:
if edge not in self.edges:
self.edges.append(edge)
def __repr__(self) -> str:
return f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})"
return (
f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})"
)
def __eq__(self, __o: object) -> bool:
try:
@ -560,7 +612,11 @@ class Vertex:
def _built_object_repr(self):
# Add a message with an emoji, stars for sucess,
return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵‍💫"
return (
"Built sucessfully ✨"
if self._built_object is not None
else "Failed to build 😵‍💫"
)
class StatefulVertex(Vertex):

View file

@ -122,10 +122,12 @@ class DocumentLoaderVertex(StatefulVertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len(
self._built_object
)
return f"""{self.vertex_type}({len(self._built_object)} documents)
avg_length = sum(
len(doc.page_content)
for doc in self._built_object
if hasattr(doc, "page_content")
) / len(self._built_object)
return f"""{self.display_name}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
Documents: {self._built_object[:3]}..."""
return f"{self.vertex_type}()"
@ -197,7 +199,9 @@ class TextSplitterVertex(StatefulVertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
self._built_object
)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
\nDocuments: {self._built_object[:3]}..."""
@ -244,18 +248,27 @@ class PromptVertex(StatelessVertex):
user_id = kwargs.get("user_id", None)
tools = kwargs.get("tools", [])
if not self._built or force:
if "input_variables" not in self.params or self.params["input_variables"] is None:
if (
"input_variables" not in self.params
or self.params["input_variables"] is None
):
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
tools = (
[tool_node.build(user_id=user_id) for tool_node in tools]
if tools is not None
else []
)
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
tools = flatten_list(tools)
self.params["tools"] = tools
prompt_params = [
key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions"
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]
@ -265,14 +278,20 @@ class PromptVertex(StatelessVertex):
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(set(self.params["input_variables"]))
self.params["input_variables"] = list(
set(self.params["input_variables"])
)
elif isinstance(self.params, dict):
self.params.pop("input_variables", None)
await self._build(user_id=user_id)
def _built_object_repr(self):
if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"):
if (
not self.artifacts
or self._built_object is None
or not hasattr(self._built_object, "format")
):
return super()._built_object_repr()
elif isinstance(self._built_object, UnbuiltObject):
return super()._built_object_repr()
@ -284,7 +303,9 @@ class PromptVertex(StatelessVertex):
# so the prompt format doesn't break
artifacts.pop("handle_keys", None)
try:
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
if not hasattr(self._built_object, "template") and hasattr(
self._built_object, "prompt"
):
template = self._built_object.prompt.template
else:
template = self._built_object.template
@ -292,7 +313,11 @@ class PromptVertex(StatelessVertex):
if value:
replace_key = "{" + key + "}"
template = template.replace(replace_key, value)
return template if isinstance(template, str) else f"{self.vertex_type}({template})"
return (
template
if isinstance(template, str)
else f"{self.vertex_type}({template})"
)
except KeyError:
return str(self._built_object)

View file

@ -477,11 +477,11 @@ export default function GenericNode({
) : (
<div className="max-h-96 overflow-auto">
{typeof validationStatus.params === "string"
? `${durationString}\n${validationStatus.params}`
? (`${durationString}\n${validationStatus.params}`
.split("\n")
.map((line, index) => (
<div key={index}>{line}</div>
))
)))
: durationString}
</div>
)

View file

@ -49,7 +49,7 @@ export default function AccordionComponent({
>
{trigger}
</AccordionTrigger>
<AccordionContent className="AccordionContent">
<AccordionContent className="AccordionContent flex flex-col">
{children}
</AccordionContent>
</AccordionItem>

View file

@ -12,12 +12,12 @@ export default function IOInputField({
const setNode = useFlowStore((state) => state.setNode);
const node = nodes.find((node) => node.id === inputId);
function handleInputType() {
if (!node) return "no node found";
if (!node) return <>"No node found!"</>;
switch (inputType) {
case "TextInput":
return (
<Textarea
className="h-full w-full custom-scroll"
className="w-full"
placeholder={"Enter text..."}
value={node.data.node!.template["value"].value}
onChange={(e) => {
@ -47,7 +47,7 @@ export default function IOInputField({
default:
return (
<Textarea
className="h-full w-full custom-scroll"
className="w-full custom-scroll"
placeholder={"Enter text..."}
value={node.data.node!.template["value"]}
onChange={(e) => {
@ -62,10 +62,5 @@ export default function IOInputField({
);
}
}
return (
<div className="font-xl flex h-full w-full flex-col items-start gap-4 p-4 font-semibold">
{inputType}
{handleInputType()}
</div>
);
return handleInputType();
}

View file

@ -12,15 +12,15 @@ export default function IOOutputView({
const flowPool = useFlowStore((state) => state.flowPool);
const node = nodes.find((node) => node.id === outputId);
function handleOutputType() {
if (!node) return "no node found";
if (!node) return <>"No node found!"</>;
switch (outputType) {
case "TextOutput":
return (
<Textarea
className="h-full w-full custom-scroll"
placeholder={"Enter text..."}
className="w-full custom-scroll"
placeholder={"Empty"}
// update to real value on flowPool
value={flowPool[node.id][flowPool[node.id].length - 1].data.results}
value={((flowPool[node.id] ?? [])[(flowPool[node.id]?.length ?? 1) - 1])?.params ?? ""}
readOnly
/>
);
@ -28,7 +28,7 @@ export default function IOOutputView({
default:
return (
<Textarea
className="h-full w-full custom-scroll"
className="w-full custom-scroll"
placeholder={"Enter text..."}
value={node.data.node!.template["value"]}
onChange={(e) => {
@ -43,10 +43,5 @@ export default function IOOutputView({
);
}
}
return (
<div className="font-xl flex h-full w-full flex-col items-start gap-4 p-4 font-semibold">
{outputType}
{handleOutputType()}
</div>
);
return handleOutputType();
}

View file

@ -4,7 +4,6 @@ import { CHAT_FORM_DIALOG_SUBTITLE } from "../../constants/constants";
import BaseModal from "../../modals/baseModal";
import useAlertStore from "../../stores/alertStore";
import useFlowStore from "../../stores/flowStore";
import { validateNodes } from "../../utils/reactflowUtils";
import { cn } from "../../utils/utils";
import AccordionComponent from "../AccordionComponent";
import IOInputField from "../IOInputField";
@ -51,33 +50,22 @@ export default function IOView({ children, open, setOpen }): JSX.Element {
async function sendMessage(count = 1): Promise<void> {
if (isBuilding) return;
const { nodes, edges } = getFlow();
let nodeValidationErrors = validateNodes(nodes, edges);
if (nodeValidationErrors.length === 0) {
setIsBuilding(true);
setLockChat(true);
setChatValue("");
const chatInputNode = nodes.find((node) => node.id === chatInput?.id);
if (chatInputNode) {
let newNode = cloneDeep(chatInputNode);
newNode.data.node!.template["message"].value = chatValue;
setNode(chatInput!.id, newNode);
}
for (let i = 0; i < count; i++) {
await buildFlow().catch((err) => {
console.error(err);
setLockChat(false);
});
}
setLockChat(false);
//set chat message in the flow and run build
//@ts-ignore
} else {
setErrorData({
title: "Oops! Looks like you missed some required information:",
list: nodeValidationErrors,
setIsBuilding(true);
setLockChat(true);
setChatValue("");
const chatInputNode = nodes.find((node) => node.id === chatInput?.id);
if (chatInputNode) {
let newNode = cloneDeep(chatInputNode);
newNode.data.node!.template["message"].value = chatValue;
setNode(chatInput!.id, newNode);
}
for (let i = 0; i < count; i++) {
await buildFlow().catch((err) => {
console.error(err);
setLockChat(false);
});
}
setLockChat(false);
}
useEffect(() => {
@ -260,27 +248,52 @@ export default function IOView({ children, open, setOpen }): JSX.Element {
)}
{haveChat ? (
selectedViewField ? (
inputs.some((input) => input.id === selectedViewField.id) ? (
<IOInputField
inputType={selectedViewField.type!}
inputId={selectedViewField.id!}
<div className="flex h-full w-full">
{selectedViewField && (
<div
className={cn(
"flex h-full w-full flex-col items-start gap-4 p-4",
!selectedViewField ? "hidden" : ""
)}
>
<div className="font-xl flex items-center justify-center gap-3 font-semibold">
<button onClick={() => setSelectedViewField(undefined)}>
<IconComponent
name={"ArrowLeft"}
className="h-6 w-6"
></IconComponent>
</button>
{selectedViewField.type}
</div>
<div className="h-full">
{inputs.some(
(input) => input.id === selectedViewField.id
) ? (
<IOInputField
inputType={selectedViewField.type!}
inputId={selectedViewField.id!}
/>
) : (
<IOOutputView
outputType={selectedViewField.type!}
outputId={selectedViewField.id!}
/>
)}
</div>
</div>
)}
<div
className={cn("flex w-full h-full",selectedViewField ? "hidden" : "")}
>
<NewChatView
sendMessage={sendMessage}
chatValue={chatValue}
setChatValue={setChatValue}
lockChat={lockChat}
setLockChat={setLockChat}
/>
) : (
<IOOutputView
outputType={selectedViewField.type!}
outputId={selectedViewField.id!}
/>
)
) : (
<NewChatView
sendMessage={sendMessage}
chatValue={chatValue}
setChatValue={setChatValue}
lockChat={lockChat}
setLockChat={setLockChat}
/>
)
</div>
</div>
) : (
<div className="absolute bottom-8 right-8"></div>
)}

View file

@ -1,4 +1,5 @@
import { ShadToolTipType } from "../../types/components";
import { cn } from "../../utils/utils";
import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip";
export default function ShadTooltip({
@ -14,7 +15,7 @@ export default function ShadTooltip({
<TooltipTrigger asChild={asChild}>{children}</TooltipTrigger>
<TooltipContent
className={styleClasses}
className={cn(styleClasses, "max-w-96") }
side={side}
avoidCollisions={false}
sticky="always"

View file

@ -20,48 +20,42 @@ export default function TextAreaComponent({
}, [disabled]);
return (
<div
className={
"flex w-full items-center " + (disabled ? "pointer-events-none" : "")
}
>
<GenericModal
type={TypeModal.TEXT}
buttonText="Finishing Editing"
modalTitle="Edit Text"
value={value}
setValue={(value: string) => {
onChange(value);
}}
>
<div className="flex w-full items-center" data-testid={"div-" + id}>
<Input
id={id}
data-testid={id}
<div className={"flex w-full items-center " + (disabled ? "" : "")}>
<div className="flex w-full items-center" data-testid={"div-" + id}>
<Input
id={id}
data-testid={id}
value={value}
disabled={disabled}
className={editNode ? "input-edit-node w-full" : " w-full"}
placeholder={"Type something..."}
onChange={(event) => {
onChange(event.target.value);
}}
/>
<div>
<GenericModal
type={TypeModal.TEXT}
buttonText="Finish Editing"
modalTitle="Edit Text"
value={value}
disabled={disabled}
className={
editNode
? "input-edit-node pointer-events-none "
: " pointer-events-none"
}
placeholder={"Type something..."}
onChange={(event) => {
onChange(event.target.value);
setValue={(value: string) => {
onChange(value);
}}
/>
{!editNode && (
<IconComponent
id={id}
name="ExternalLink"
className={
"icons-parameters-comp" +
(disabled ? " text-ring" : " hover:text-accent-foreground")
}
/>
)}
>
{!editNode && (
<IconComponent
id={id}
name="ExternalLink"
className={
"icons-parameters-comp" +
(disabled ? " text-ring" : " hover:text-accent-foreground")
}
/>
)}
</GenericModal>
</div>
</GenericModal>
</div>
</div>
);
}

View file

@ -681,4 +681,4 @@ export const LANGFLOW_SUPPORTED_TYPES = new Set([
export const priorityFields = new Set(["code", "template"]);
export const INPUT_TYPES = new Set(["ChatInput", "TextInput"]);
export const OUTPUT_TYPES = new Set(["ChatOutput", "PromptTemplate"]);
export const OUTPUT_TYPES = new Set(["ChatOutput", "TextOutput"]);

View file

@ -21,7 +21,7 @@ type TriggerProps = {
};
const Content: React.FC<ContentProps> = ({ children }) => {
return <div className="h-full w-full">{children}</div>;
return <div className="h-full w-full flex flex-col">{children}</div>;
};
const Trigger: React.FC<TriggerProps> = ({ children, asChild, disable }) => {
return (

View file

@ -9,6 +9,7 @@ import {
applyNodeChanges,
} from "reactflow";
import { create } from "zustand";
import { FLOW_BUILD_SUCCESS_ALERT } from "../alerts_constants";
import { BuildStatus } from "../constants/enums";
import { getFlowPool, updateFlowInDatabase } from "../controllers/API";
import { VertexBuildTypeAPI } from "../types/api";
@ -26,12 +27,12 @@ import {
getNodeId,
scapeJSONParse,
scapedJSONStringfy,
validateNodes,
} from "../utils/reactflowUtils";
import { getInputsAndOutputs } from "../utils/storeUtils";
import useAlertStore from "./alertStore";
import { useDarkStore } from "./darkStore";
import useFlowsManagerStore from "./flowsManagerStore";
import { FLOW_BUILD_SUCCESS_ALERT } from "../alerts_constants";
// this is our useStore hook that we can use in our components to get parts of the store and call actions
const useFlowStore = create<FlowStoreType>((set, get) => ({
@ -377,6 +378,20 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
const setSuccessData = useAlertStore.getState().setSuccessData;
const setErrorData = useAlertStore.getState().setErrorData;
const setNoticeData = useAlertStore.getState().setNoticeData;
function validateSubgraph(nodes: string[]) {
const errors = validateNodes(
get().nodes.filter((node) => nodes.includes(node.id)),
get().edges
);
if (errors.length > 0) {
setErrorData({
title: "Oops! Looks like you missed something",
list: errors,
});
get().setIsBuilding(false);
throw new Error("Invalid nodes");
}
}
function handleBuildUpdate(
vertexBuildData: VertexBuildTypeAPI,
status: BuildStatus
@ -397,10 +412,12 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
name: currentFlow!.name,
description: currentFlow!.description,
});
setNoticeData({ title: "Running components" });
await buildVertices({
flowId: currentFlow!.id,
nodeId,
onGetOrderSuccess: () => {
setNoticeData({ title: "Running components" });
},
onBuildComplete: () => {
if (nodeId) {
setSuccessData({
@ -422,6 +439,7 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
onBuildStart: (idList) => {
useFlowStore.getState().updateBuildStatus(idList, BuildStatus.BUILDING);
},
validateNodes: validateSubgraph,
});
get().revertBuiltStatusFromBuilding();
},

View file

@ -1,17 +1,19 @@
import { AxiosError } from "axios";
import { BuildStatus } from "../constants/enums";
import { getVerticesOrder, postBuildVertex } from "../controllers/API";
import useAlertStore from "../stores/alertStore";
import useFlowStore from "../stores/flowStore";
import { VertexBuildTypeAPI } from "../types/api";
type BuildVerticesParams = {
flowId: string; // Assuming FlowType is the type for your flow
nodeId?: string | null; // Assuming nodeId is of type string, and it's optional
onProgressUpdate?: (progress: number) => void; // Replace number with the actual type if it's not a number
onGetOrderSuccess?: () => void;
onBuildUpdate?: (data: VertexBuildTypeAPI, status: BuildStatus) => void; // Replace any with the actual type if it's not any
onBuildComplete?: (allNodesValid: boolean) => void;
onBuildError?: (title, list, idList: string[]) => void;
onBuildStart?: (idList: string[]) => void;
validateNodes?: (nodes: string[]) => void;
};
function getInactiveVertexData(vertexId: string): VertexBuildTypeAPI {
@ -35,17 +37,37 @@ function getInactiveVertexData(vertexId: string): VertexBuildTypeAPI {
export async function buildVertices({
flowId,
nodeId = null,
onProgressUpdate,
onGetOrderSuccess,
onBuildUpdate,
onBuildComplete,
onBuildError,
onBuildStart,
validateNodes,
}: BuildVerticesParams) {
let orderResponse = await getVerticesOrder(flowId, nodeId);
const setErrorData = useAlertStore.getState().setErrorData;
let orderResponse;
try {
orderResponse = await getVerticesOrder(flowId, nodeId);
} catch (error) {
console.log(error);
setErrorData({
title: "Oops! Looks like you missed something",
list: [error.response?.data?.detail ?? "Unknown Error"],
});
useFlowStore.getState().setIsBuilding(false);
throw new Error("Invalid nodes");
}
if (onGetOrderSuccess) onGetOrderSuccess();
let verticesOrder: Array<Array<string>> = orderResponse.data.ids;
let vertices_layers: Array<Array<string>> = [];
let stop = false;
if (validateNodes) {
try {
validateNodes(verticesOrder.flatMap((id) => id));
} catch (e) {
return;
}
}
if (nodeId) {
for (let i = 0; i < verticesOrder.length; i += 1) {
const innerArray = verticesOrder[i];

View file

@ -1,5 +1,6 @@
import {
AlertCircle,
ArrowLeft,
ArrowUpToLine,
Bell,
BookMarked,
@ -340,6 +341,7 @@ export const nodeIconsLucide: iconsType = {
Bell,
ChevronLeft,
ChevronDown,
ArrowLeft,
Shield,
Plus,
Redo,