Merge branch 'zustand/io/migration' of github.com:logspace-ai/langflow into zustand/io/migration
This commit is contained in:
commit
ab60f59578
26 changed files with 466 additions and 287 deletions
21
poetry.lock
generated
21
poetry.lock
generated
|
|
@ -3681,6 +3681,24 @@ babel = ["Babel"]
|
|||
lingua = ["lingua"]
|
||||
testing = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "markdown"
|
||||
version = "3.5.2"
|
||||
description = "Python implementation of John Gruber's Markdown."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "Markdown-3.5.2-py3-none-any.whl", hash = "sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd"},
|
||||
{file = "Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"]
|
||||
testing = ["coverage", "pyyaml"]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "3.0.0"
|
||||
|
|
@ -8212,6 +8230,7 @@ emoji = "*"
|
|||
filetype = "*"
|
||||
langdetect = "*"
|
||||
lxml = "*"
|
||||
markdown = {version = "*", optional = true, markers = "extra == \"md\""}
|
||||
nltk = "*"
|
||||
numpy = "*"
|
||||
python-iso639 = "*"
|
||||
|
|
@ -9026,4 +9045,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.12"
|
||||
content-hash = "1462954b3befc2989ae226f2214111be786eb05bade578c9c80b4ed80d5b59ff"
|
||||
content-hash = "b35a356770d3425f524b0c46a449696db1fa7c13fae77324188cb6ffa4a4c5a7"
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ pytube = "^15.0.0"
|
|||
python-socketio = "^5.11.0"
|
||||
llama-index = "0.9.48"
|
||||
langchain-openai = "^0.0.6"
|
||||
unstructured = "^0.12.4"
|
||||
unstructured = {extras = ["md"], version = "^0.12.4"}
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest-asyncio = "^0.23.1"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Callable, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from langchain.chains import ConversationChain
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain, Text
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, Text
|
||||
|
||||
|
||||
class ConversationChainComponent(CustomComponent):
|
||||
|
|
@ -26,7 +26,7 @@ class ConversationChainComponent(CustomComponent):
|
|||
inputs: str,
|
||||
llm: BaseLanguageModel,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Union[Chain, Callable, Text]:
|
||||
) -> Text:
|
||||
if memory is None:
|
||||
chain = ConversationChain(llm=llm)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
from typing import Callable, Union
|
||||
|
||||
from langchain.chains import LLMCheckerChain
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import BaseLanguageModel, Chain
|
||||
from langflow.field_typing import BaseLanguageModel, Text
|
||||
|
||||
|
||||
class LLMCheckerChainComponent(CustomComponent):
|
||||
display_name = "LLMCheckerChain"
|
||||
description = ""
|
||||
documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_checker"
|
||||
documentation = (
|
||||
"https://python.langchain.com/docs/modules/chains/additional/llm_checker"
|
||||
)
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
|
|
@ -17,6 +18,12 @@ class LLMCheckerChainComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: str,
|
||||
llm: BaseLanguageModel,
|
||||
) -> Union[Chain, Callable]:
|
||||
return LLMCheckerChain.from_llm(llm=llm)
|
||||
) -> Text:
|
||||
|
||||
chain = LLMCheckerChain.from_llm(llm=llm)
|
||||
response = chain.invoke({chain.input_key: inputs})
|
||||
result = response.get(chain.output_key)
|
||||
self.status = result
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
from typing import Callable, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from langchain.chains import LLMChain, LLMMathChain
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, Text
|
||||
|
||||
|
||||
class LLMMathChainComponent(CustomComponent):
|
||||
display_name = "LLMMathChain"
|
||||
description = "Chain that interprets a prompt and executes python code to do math."
|
||||
documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_math"
|
||||
documentation = (
|
||||
"https://python.langchain.com/docs/modules/chains/additional/llm_math"
|
||||
)
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
|
|
@ -22,10 +24,21 @@ class LLMMathChainComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: Text,
|
||||
llm: BaseLanguageModel,
|
||||
llm_chain: LLMChain,
|
||||
input_key: str = "question",
|
||||
output_key: str = "answer",
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Union[LLMMathChain, Callable, Chain]:
|
||||
return LLMMathChain(llm=llm, llm_chain=llm_chain, input_key=input_key, output_key=output_key, memory=memory)
|
||||
) -> Text:
|
||||
chain = LLMMathChain(
|
||||
llm=llm,
|
||||
llm_chain=llm_chain,
|
||||
input_key=input_key,
|
||||
output_key=output_key,
|
||||
memory=memory,
|
||||
)
|
||||
response = chain.invoke({input_key: inputs})
|
||||
result = response.get(output_key)
|
||||
self.status = result
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -32,21 +32,39 @@ class SQLGeneratorComponent(CustomComponent):
|
|||
db: SQLDatabase,
|
||||
llm: BaseLanguageModel,
|
||||
top_k: int = 5,
|
||||
prompt: Optional[PromptTemplate] = None,
|
||||
prompt: Optional[Text] = None,
|
||||
) -> Text:
|
||||
if prompt:
|
||||
prompt_template = PromptTemplate.from_template(template=prompt)
|
||||
else:
|
||||
prompt_template = None
|
||||
|
||||
if top_k > 0:
|
||||
kwargs = {
|
||||
"k": top_k,
|
||||
}
|
||||
if not prompt:
|
||||
if not prompt_template:
|
||||
sql_query_chain = create_sql_query_chain(llm=llm, db=db, **kwargs)
|
||||
else:
|
||||
template = prompt.template if hasattr(prompt, "template") else prompt
|
||||
template = (
|
||||
prompt_template.template
|
||||
if hasattr(prompt, "template")
|
||||
else prompt_template
|
||||
)
|
||||
# Check if {question} is in the prompt
|
||||
if "{question}" not in template or "question" not in template.input_variables:
|
||||
raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.")
|
||||
sql_query_chain = create_sql_query_chain(llm=llm, db=db, prompt=prompt, **kwargs)
|
||||
query_writer = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()}
|
||||
if (
|
||||
"{question}" not in template
|
||||
or "question" not in template.input_variables
|
||||
):
|
||||
raise ValueError(
|
||||
"Prompt must contain `{question}` to be used with Natural Language to SQL."
|
||||
)
|
||||
sql_query_chain = create_sql_query_chain(
|
||||
llm=llm, db=db, prompt=prompt_template, **kwargs
|
||||
)
|
||||
query_writer = sql_query_chain | {
|
||||
"query": lambda x: x.replace("SQLQuery:", "").strip()
|
||||
}
|
||||
response = query_writer.invoke({"question": inputs})
|
||||
query = response.get("query")
|
||||
self.status = query
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import Text
|
||||
from langflow.schema import Record
|
||||
|
||||
|
||||
|
|
@ -25,9 +24,9 @@ class ChatInput(CustomComponent):
|
|||
"display_name": "Session ID",
|
||||
"info": "Session ID of the chat history.",
|
||||
},
|
||||
"as_record": {
|
||||
"display_name": "As Record",
|
||||
"info": "If true, the message will be returned as a Record.",
|
||||
"return_record": {
|
||||
"display_name": "Return Record",
|
||||
"info": "Return the message as a record containing the sender, sender_name, and session_id.",
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -36,25 +35,24 @@ class ChatInput(CustomComponent):
|
|||
sender: Optional[str] = "User",
|
||||
sender_name: Optional[str] = "User",
|
||||
message: Optional[str] = None,
|
||||
as_record: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Union[Text, Record]:
|
||||
self.status = message
|
||||
if as_record:
|
||||
return_record: Optional[bool] = False,
|
||||
) -> Record:
|
||||
if return_record:
|
||||
if isinstance(message, Record):
|
||||
# Update the data of the record
|
||||
message.data["sender"] = sender
|
||||
message.data["sender_name"] = sender_name
|
||||
message.data["session_id"] = session_id
|
||||
return message
|
||||
return Record(
|
||||
text=message,
|
||||
data={
|
||||
"sender": sender,
|
||||
"sender_name": sender_name,
|
||||
"session_id": session_id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
message = Record(
|
||||
text=message,
|
||||
data={
|
||||
"sender": sender,
|
||||
"sender_name": sender_name,
|
||||
"session_id": session_id,
|
||||
},
|
||||
)
|
||||
if not message:
|
||||
message = ""
|
||||
self.status = message
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ class ChatOutput(CustomComponent):
|
|||
"info": "Session ID of the chat history.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"as_record": {
|
||||
"display_name": "As Record",
|
||||
"info": "If true, the message will be returned as a Record.",
|
||||
"return_record": {
|
||||
"display_name": "Return Record",
|
||||
"info": "Return the message as a record containing the sender, sender_name, and session_id.",
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -40,25 +40,23 @@ class ChatOutput(CustomComponent):
|
|||
sender_name: Optional[str] = "AI",
|
||||
session_id: Optional[str] = None,
|
||||
message: Optional[str] = None,
|
||||
as_record: Optional[bool] = False,
|
||||
return_record: Optional[bool] = False,
|
||||
) -> Union[Text, Record]:
|
||||
self.status = message
|
||||
if as_record:
|
||||
if return_record:
|
||||
if isinstance(message, Record):
|
||||
# Update the data of the record
|
||||
message.data["sender"] = sender
|
||||
message.data["sender_name"] = sender_name
|
||||
message.data["session_id"] = session_id
|
||||
|
||||
return message
|
||||
return Record(
|
||||
text=message,
|
||||
data={
|
||||
"sender": sender,
|
||||
"sender_name": sender_name,
|
||||
"session_id": session_id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
message = Record(
|
||||
text=message,
|
||||
data={
|
||||
"sender": sender,
|
||||
"sender_name": sender_name,
|
||||
"session_id": session_id,
|
||||
},
|
||||
)
|
||||
if not message:
|
||||
message = ""
|
||||
self.status = message
|
||||
|
|
|
|||
|
|
@ -27,7 +27,10 @@ class RecordsAsTextComponent(CustomComponent):
|
|||
if isinstance(records, Record):
|
||||
records = [records]
|
||||
|
||||
formated_records = [template.format(text=record.text, **record.data) for record in records]
|
||||
formated_records = [
|
||||
template.format(text=record.text, data=record.data, **record.data)
|
||||
for record in records
|
||||
]
|
||||
result_string = "\n".join(formated_records)
|
||||
self.status = result_string
|
||||
return result_string
|
||||
|
|
|
|||
|
|
@ -84,7 +84,8 @@ class ChromaComponent(CustomComponent):
|
|||
|
||||
if chroma_server_host is not None:
|
||||
chroma_settings = chromadb.config.Settings(
|
||||
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
|
||||
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
|
||||
or None,
|
||||
chroma_server_host=chroma_server_host,
|
||||
chroma_server_port=chroma_server_port or None,
|
||||
chroma_server_grpc_port=chroma_server_grpc_port or None,
|
||||
|
|
@ -99,12 +100,14 @@ class ChromaComponent(CustomComponent):
|
|||
|
||||
if documents is not None and embedding is not None:
|
||||
if len(documents) == 0:
|
||||
raise ValueError("If documents are provided, there must be at least one document.")
|
||||
raise ValueError(
|
||||
"If documents are provided, there must be at least one document."
|
||||
)
|
||||
chroma = Chroma.from_documents(
|
||||
documents=documents, # type: ignore
|
||||
persist_directory=index_directory,
|
||||
collection_name=collection_name,
|
||||
embedding_function=embedding,
|
||||
embedding=embedding,
|
||||
client_settings=chroma_settings,
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -11,31 +11,6 @@ agents:
|
|||
documentation: ""
|
||||
SQLAgent:
|
||||
documentation: ""
|
||||
chains:
|
||||
# LLMChain:
|
||||
# documentation: "https://python.langchain.com/docs/modules/chains/foundational/llm_chain"
|
||||
LLMMathChain:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/additional/llm_math"
|
||||
LLMCheckerChain:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/additional/llm_checker"
|
||||
# ConversationChain:
|
||||
# documentation: ""
|
||||
SeriesCharacterChain:
|
||||
documentation: ""
|
||||
MidJourneyPromptChain:
|
||||
documentation: ""
|
||||
TimeTravelGuideChain:
|
||||
documentation: ""
|
||||
SQLDatabaseChain:
|
||||
documentation: ""
|
||||
RetrievalQA:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/popular/vector_db_qa"
|
||||
RetrievalQAWithSourcesChain:
|
||||
documentation: ""
|
||||
ConversationalRetrievalChain:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/popular/chat_vector_db"
|
||||
CombineDocsChain:
|
||||
documentation: ""
|
||||
documentloaders:
|
||||
AirbyteJSONLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/airbyte_json"
|
||||
|
|
|
|||
|
|
@ -9,13 +9,8 @@ from langflow.graph.graph.constants import lazy_load_vertex_dict
|
|||
from langflow.graph.graph.utils import process_flow
|
||||
from langflow.graph.schema import InterfaceComponentTypes
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex.types import (
|
||||
ChatVertex,
|
||||
FileToolVertex,
|
||||
LLMVertex,
|
||||
RoutingVertex,
|
||||
ToolkitVertex,
|
||||
)
|
||||
from langflow.graph.vertex.types import (ChatVertex, FileToolVertex, LLMVertex,
|
||||
RoutingVertex, ToolkitVertex)
|
||||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.utils import payload
|
||||
|
||||
|
|
@ -85,7 +80,9 @@ class Graph:
|
|||
def build_parent_child_map(self):
|
||||
parent_child_map = defaultdict(list)
|
||||
for vertex in self.vertices:
|
||||
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
|
||||
parent_child_map[vertex.id] = [
|
||||
child.id for child in self.get_successors(vertex)
|
||||
]
|
||||
return parent_child_map
|
||||
|
||||
def increment_run_count(self):
|
||||
|
|
@ -149,6 +146,16 @@ class Graph:
|
|||
# both graphs have the same vertices and edges
|
||||
# but the data of the vertices might be different
|
||||
|
||||
def update_edges_from_vertex(self, vertex: Vertex, other_vertex: Vertex) -> None:
|
||||
"""Updates the edges of a vertex in the Graph."""
|
||||
new_edges = []
|
||||
for edge in self.edges:
|
||||
if edge.source_id == other_vertex.id or edge.target_id == other_vertex.id:
|
||||
continue
|
||||
new_edges.append(edge)
|
||||
new_edges += other_vertex.edges
|
||||
self.edges = new_edges
|
||||
|
||||
def vertex_data_is_identical(self, vertex: Vertex, other_vertex: Vertex) -> bool:
|
||||
return vertex.__repr__() == other_vertex.__repr__()
|
||||
|
||||
|
|
@ -173,10 +180,6 @@ class Graph:
|
|||
# Find vertices that are in self but not in other (removed vertices)
|
||||
removed_vertex_ids = existing_vertex_ids - other_vertex_ids
|
||||
|
||||
# Create a set for new edges
|
||||
edges_to_add = set()
|
||||
edges_to_remove = set()
|
||||
|
||||
# Update existing vertices that have changed
|
||||
for vertex_id in existing_vertex_ids.intersection(other_vertex_ids):
|
||||
self_vertex = self.get_vertex(vertex_id)
|
||||
|
|
@ -184,6 +187,8 @@ class Graph:
|
|||
if not self.vertex_data_is_identical(self_vertex, other_vertex):
|
||||
self_vertex._data = other_vertex._data
|
||||
self_vertex._parse_data()
|
||||
# Now we update the edges of the vertex
|
||||
self.update_edges_from_vertex(self_vertex, other_vertex)
|
||||
self_vertex.params = {}
|
||||
self_vertex._build_params()
|
||||
self_vertex.graph = self
|
||||
|
|
@ -195,25 +200,6 @@ class Graph:
|
|||
self_vertex.artifacts = None
|
||||
self_vertex.set_top_level(self.top_level_vertices)
|
||||
self.reset_all_edges_of_vertex(self_vertex)
|
||||
if not self.vertex_edges_are_identical(self_vertex, other_vertex):
|
||||
# New edges are the edges of the other vertex and not the self vertex
|
||||
# If there are more edges in the other vertex than in the self vertex
|
||||
# then we need to add the new edges to the self vertex
|
||||
# if there are less edges in the other vertex than in the self vertex
|
||||
# then we need to remove the edges that are not in the other vertex
|
||||
|
||||
if len(self_vertex.edges) < len(other_vertex.edges):
|
||||
edges_to_add.update(edge for edge in other_vertex.edges if edge not in self_vertex.edges)
|
||||
elif len(self_vertex.edges) > len(other_vertex.edges):
|
||||
edges_to_remove.update(edge for edge in self_vertex.edges if edge not in other_vertex.edges)
|
||||
|
||||
# Add new edges
|
||||
# to self.edges if they are not already in self.edges
|
||||
for edge in edges_to_add:
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
for edge in edges_to_remove:
|
||||
self.edges.remove(edge)
|
||||
|
||||
# Remove vertices
|
||||
for vertex_id in removed_vertex_ids:
|
||||
|
|
@ -290,7 +276,11 @@ class Graph:
|
|||
return
|
||||
self.vertices.remove(vertex)
|
||||
self.vertex_map.pop(vertex_id)
|
||||
self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id]
|
||||
self.edges = [
|
||||
edge
|
||||
for edge in self.edges
|
||||
if edge.source_id != vertex_id and edge.target_id != vertex_id
|
||||
]
|
||||
|
||||
def _build_vertex_params(self) -> None:
|
||||
"""Identifies and handles the LLM vertex within the graph."""
|
||||
|
|
@ -311,7 +301,9 @@ class Graph:
|
|||
return
|
||||
for vertex in self.vertices:
|
||||
if not self._validate_vertex(vertex):
|
||||
raise ValueError(f"{vertex.vertex_type} is not connected to any other components")
|
||||
raise ValueError(
|
||||
f"{vertex.display_name} is not connected to any other components"
|
||||
)
|
||||
|
||||
def _validate_vertex(self, vertex: Vertex) -> bool:
|
||||
"""Validates a vertex."""
|
||||
|
|
@ -327,7 +319,11 @@ class Graph:
|
|||
|
||||
def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]:
|
||||
"""Returns a list of edges for a given vertex."""
|
||||
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
|
||||
return [
|
||||
edge
|
||||
for edge in self.edges
|
||||
if edge.source_id == vertex_id or edge.target_id == vertex_id
|
||||
]
|
||||
|
||||
def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]:
|
||||
"""Returns the vertices connected to a vertex."""
|
||||
|
|
@ -365,7 +361,9 @@ class Graph:
|
|||
def dfs(vertex):
|
||||
if state[vertex] == 1:
|
||||
# We have a cycle
|
||||
raise ValueError("Graph contains a cycle, cannot perform topological sort")
|
||||
raise ValueError(
|
||||
"Graph contains a cycle, cannot perform topological sort"
|
||||
)
|
||||
if state[vertex] == 0:
|
||||
state[vertex] = 1
|
||||
for edge in vertex.edges:
|
||||
|
|
@ -389,11 +387,17 @@ class Graph:
|
|||
|
||||
def get_predecessors(self, vertex):
|
||||
"""Returns the predecessors of a vertex."""
|
||||
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
|
||||
return [
|
||||
self.get_vertex(source_id)
|
||||
for source_id in self.predecessor_map.get(vertex.id, [])
|
||||
]
|
||||
|
||||
def get_successors(self, vertex):
|
||||
"""Returns the successors of a vertex."""
|
||||
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
|
||||
return [
|
||||
self.get_vertex(target_id)
|
||||
for target_id in self.successor_map.get(vertex.id, [])
|
||||
]
|
||||
|
||||
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
|
||||
"""Returns the neighbors of a vertex."""
|
||||
|
|
@ -432,7 +436,9 @@ class Graph:
|
|||
edges.append(ContractEdge(source, target, edge))
|
||||
return edges
|
||||
|
||||
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
|
||||
def _get_vertex_class(
|
||||
self, node_type: str, node_base_type: str, node_id: str
|
||||
) -> Type[Vertex]:
|
||||
"""Returns the node class based on the node type."""
|
||||
# First we check for the node_base_type
|
||||
node_name = node_id.split("-")[0]
|
||||
|
|
@ -463,14 +469,18 @@ class Graph:
|
|||
vertex_type: str = vertex_data["type"] # type: ignore
|
||||
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
|
||||
|
||||
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
|
||||
VertexClass = self._get_vertex_class(
|
||||
vertex_type, vertex_base_type, vertex_data["id"]
|
||||
)
|
||||
vertex_instance = VertexClass(vertex, graph=self)
|
||||
vertex_instance.set_top_level(self.top_level_vertices)
|
||||
vertices.append(vertex_instance)
|
||||
|
||||
return vertices
|
||||
|
||||
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
|
||||
def get_children_by_vertex_type(
|
||||
self, vertex: Vertex, vertex_type: str
|
||||
) -> List[Vertex]:
|
||||
"""Returns the children of a vertex based on the vertex type."""
|
||||
children = []
|
||||
vertex_types = [vertex.data["type"]]
|
||||
|
|
@ -482,7 +492,9 @@ class Graph:
|
|||
|
||||
def __repr__(self):
|
||||
vertex_ids = [vertex.id for vertex in self.vertices]
|
||||
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
|
||||
edges_repr = "\n".join(
|
||||
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
|
||||
)
|
||||
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
|
||||
|
||||
def sort_up_to_vertex(self, vertex_id: str) -> "Graph":
|
||||
|
|
@ -513,7 +525,9 @@ class Graph:
|
|||
"""Performs a layered topological sort of the vertices in the graph."""
|
||||
|
||||
# Queue for vertices with no incoming edges
|
||||
queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0)
|
||||
queue = deque(
|
||||
vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0
|
||||
)
|
||||
layers = []
|
||||
|
||||
current_layer = 0
|
||||
|
|
@ -569,7 +583,9 @@ class Graph:
|
|||
|
||||
return refined_layers
|
||||
|
||||
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
|
||||
def sort_chat_inputs_first(
|
||||
self, vertices_layers: List[List[str]]
|
||||
) -> List[List[str]]:
|
||||
chat_inputs_first = []
|
||||
for layer in vertices_layers:
|
||||
for vertex_id in layer:
|
||||
|
|
@ -597,11 +613,15 @@ class Graph:
|
|||
self.increment_run_count()
|
||||
return vertices_layers
|
||||
|
||||
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
|
||||
def sort_interface_components_first(
|
||||
self, vertices_layers: List[List[str]]
|
||||
) -> List[List[str]]:
|
||||
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
|
||||
|
||||
def contains_interface_component(vertex):
|
||||
return any(component.value in vertex for component in InterfaceComponentTypes)
|
||||
return any(
|
||||
component.value in vertex for component in InterfaceComponentTypes
|
||||
)
|
||||
|
||||
# Sort each inner list so that vertices containing ChatInput or ChatOutput come first
|
||||
sorted_vertices = [
|
||||
|
|
@ -620,9 +640,13 @@ class Graph:
|
|||
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
|
||||
if len(vertices_ids) == 1:
|
||||
return vertices_ids
|
||||
vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time)
|
||||
vertices_ids.sort(
|
||||
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
|
||||
)
|
||||
|
||||
return vertices_ids
|
||||
|
||||
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
|
||||
sorted_vertices = [
|
||||
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
|
||||
]
|
||||
return sorted_vertices
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import ast
|
|||
import inspect
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Dict, List,
|
||||
Optional)
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -72,11 +73,17 @@ class Vertex:
|
|||
|
||||
def set_state(self, state: str):
|
||||
self.state = VertexStates[state]
|
||||
if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] < 2:
|
||||
if (
|
||||
self.state == VertexStates.INACTIVE
|
||||
and self.graph.in_degree_map[self.id] < 2
|
||||
):
|
||||
# If the vertex is inactive and has only one in degree
|
||||
# it means that it is not a merge point in the graph
|
||||
self.graph.inactive_vertices.add(self.id)
|
||||
elif self.state == VertexStates.ACTIVE and self.id in self.graph.inactive_vertices:
|
||||
elif (
|
||||
self.state == VertexStates.ACTIVE
|
||||
and self.id in self.graph.inactive_vertices
|
||||
):
|
||||
self.graph.inactive_vertices.remove(self.id)
|
||||
|
||||
@property
|
||||
|
|
@ -104,7 +111,9 @@ class Vertex:
|
|||
):
|
||||
if edge.target_id not in edge_results:
|
||||
edge_results[edge.target_id] = {}
|
||||
edge_results[edge.target_id][edge.target_param] = await edge.get_result(source=self, target=target)
|
||||
edge_results[edge.target_id][edge.target_param] = await edge.get_result(
|
||||
source=self, target=target
|
||||
)
|
||||
return edge_results
|
||||
|
||||
def set_result(self, result: "ResultData") -> None:
|
||||
|
|
@ -114,7 +123,9 @@ class Vertex:
|
|||
# If the Vertex.type is a power component
|
||||
# then we need to return the built object
|
||||
# instead of the result dict
|
||||
if self.is_interface_component and not isinstance(self._built_object, UnbuiltObject):
|
||||
if self.is_interface_component and not isinstance(
|
||||
self._built_object, UnbuiltObject
|
||||
):
|
||||
result = self._built_object
|
||||
# if it is not a dict or a string and hasattr model_dump then
|
||||
# return the model_dump
|
||||
|
|
@ -124,7 +135,11 @@ class Vertex:
|
|||
|
||||
if isinstance(self._built_result, UnbuiltResult):
|
||||
return {}
|
||||
return self._built_result if isinstance(self._built_result, dict) else {"result": self._built_result}
|
||||
return (
|
||||
self._built_result
|
||||
if isinstance(self._built_result, dict)
|
||||
else {"result": self._built_result}
|
||||
)
|
||||
|
||||
def set_artifacts(self) -> None:
|
||||
pass
|
||||
|
|
@ -187,17 +202,29 @@ class Vertex:
|
|||
self.output = self.data["node"]["base_classes"]
|
||||
self.display_name = self.data["node"]["display_name"]
|
||||
self.pinned = self.data["node"].get("pinned", False)
|
||||
template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
|
||||
template_dicts = {
|
||||
key: value
|
||||
for key, value in self.data["node"]["template"].items()
|
||||
if isinstance(value, dict)
|
||||
}
|
||||
|
||||
self.required_inputs = [
|
||||
template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"]
|
||||
template_dicts[key]["type"]
|
||||
for key, value in template_dicts.items()
|
||||
if value["required"]
|
||||
]
|
||||
self.optional_inputs = [
|
||||
template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"]
|
||||
template_dicts[key]["type"]
|
||||
for key, value in template_dicts.items()
|
||||
if not value["required"]
|
||||
]
|
||||
# Add the template_dicts[key]["input_types"] to the optional_inputs
|
||||
self.optional_inputs.extend(
|
||||
[input_type for value in template_dicts.values() for input_type in value.get("input_types", [])]
|
||||
[
|
||||
input_type
|
||||
for value in template_dicts.values()
|
||||
for input_type in value.get("input_types", [])
|
||||
]
|
||||
)
|
||||
|
||||
template_dict = self.data["node"]["template"]
|
||||
|
|
@ -240,7 +267,11 @@ class Vertex:
|
|||
if self.graph is None:
|
||||
raise ValueError("Graph not found")
|
||||
|
||||
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
|
||||
template_dict = {
|
||||
key: value
|
||||
for key, value in self.data["node"]["template"].items()
|
||||
if isinstance(value, dict)
|
||||
}
|
||||
params = {}
|
||||
|
||||
for edge in self.edges:
|
||||
|
|
@ -278,7 +309,7 @@ class Vertex:
|
|||
full_path = storage_service.build_full_path(flow_id, file_name)
|
||||
params[key] = full_path
|
||||
else:
|
||||
raise ValueError(f"File path not found for {self.vertex_type}")
|
||||
raise ValueError(f"File path not found for {self.display_name}")
|
||||
elif value.get("type") in DIRECT_TYPES and params.get(key) is None:
|
||||
val = value.get("value")
|
||||
if value.get("type") == "code":
|
||||
|
|
@ -291,7 +322,11 @@ class Vertex:
|
|||
# list of dicts, so we need to convert it to a dict
|
||||
# before passing it to the build method
|
||||
if isinstance(val, list):
|
||||
params[key] = {k: v for item in value.get("value", []) for k, v in item.items()}
|
||||
params[key] = {
|
||||
k: v
|
||||
for item in value.get("value", [])
|
||||
for k, v in item.items()
|
||||
}
|
||||
elif isinstance(val, dict):
|
||||
params[key] = val
|
||||
elif value.get("type") == "int" and val is not None:
|
||||
|
|
@ -327,7 +362,7 @@ class Vertex:
|
|||
"""
|
||||
Initiate the build process.
|
||||
"""
|
||||
logger.debug(f"Building {self.vertex_type}")
|
||||
logger.debug(f"Building {self.display_name}")
|
||||
await self._build_each_node_in_params_dict(user_id)
|
||||
await self._get_and_instantiate_class(user_id)
|
||||
self._validate_built_object()
|
||||
|
|
@ -354,7 +389,9 @@ class Vertex:
|
|||
if isinstance(self._built_object, str):
|
||||
self._built_result = self._built_object
|
||||
|
||||
result = await generate_result(self._built_object, inputs, self.has_external_output, session_id)
|
||||
result = await generate_result(
|
||||
self._built_object, inputs, self.has_external_output, session_id
|
||||
)
|
||||
self._built_result = result
|
||||
|
||||
async def _build_each_node_in_params_dict(self, user_id=None):
|
||||
|
|
@ -382,7 +419,9 @@ class Vertex:
|
|||
"""
|
||||
return all(self._is_node(node) for node in value)
|
||||
|
||||
async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any:
|
||||
async def get_result(
|
||||
self, requester: Optional["Vertex"] = None, user_id=None, timeout=None
|
||||
) -> Any:
|
||||
# PLEASE REVIEW THIS IF STATEMENT
|
||||
# Check if the Vertex was built already
|
||||
if self._built:
|
||||
|
|
@ -416,7 +455,9 @@ class Vertex:
|
|||
self._extend_params_list_with_result(key, result)
|
||||
self.params[key] = result
|
||||
|
||||
async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
|
||||
async def _build_list_of_nodes_and_update_params(
|
||||
self, key, nodes: List["Vertex"], user_id=None
|
||||
):
|
||||
"""
|
||||
Iterates over a list of nodes, builds each and updates the params dictionary.
|
||||
"""
|
||||
|
|
@ -457,7 +498,7 @@ class Vertex:
|
|||
Gets the class from a dictionary and instantiates it with the params.
|
||||
"""
|
||||
if self.base_type is None:
|
||||
raise ValueError(f"Base type for node {self.vertex_type} not found")
|
||||
raise ValueError(f"Base type for node {self.display_name} not found")
|
||||
try:
|
||||
result = await loading.instantiate_class(
|
||||
node_type=self.vertex_type,
|
||||
|
|
@ -468,7 +509,9 @@ class Vertex:
|
|||
self._update_built_object_and_artifacts(result)
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise ValueError(f"Error building node {self.display_name}: {str(exc)}") from exc
|
||||
raise ValueError(
|
||||
f"Error building node {self.display_name}: {str(exc)}"
|
||||
) from exc
|
||||
|
||||
def _update_built_object_and_artifacts(self, result):
|
||||
"""
|
||||
|
|
@ -484,9 +527,9 @@ class Vertex:
|
|||
Checks if the built object is None and raises a ValueError if so.
|
||||
"""
|
||||
if isinstance(self._built_object, UnbuiltObject):
|
||||
raise ValueError(f"{self.vertex_type}: {self._built_object_repr()}")
|
||||
raise ValueError(f"{self.display_name}: {self._built_object_repr()}")
|
||||
elif self._built_object is None:
|
||||
message = f"{self.vertex_type} returned None."
|
||||
message = f"{self.display_name} returned None."
|
||||
if self.base_type == "custom_components":
|
||||
message += " Make sure your build method returns a component."
|
||||
|
||||
|
|
@ -498,6 +541,7 @@ class Vertex:
|
|||
self._built_result = UnbuiltResult()
|
||||
self.artifacts = {}
|
||||
self.steps_ran = []
|
||||
self._build_params()
|
||||
|
||||
def build_inactive(self):
|
||||
# Just set the results to None
|
||||
|
|
@ -538,16 +582,24 @@ class Vertex:
|
|||
return self._built_object
|
||||
|
||||
# Get the requester edge
|
||||
requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None)
|
||||
requester_edge = next(
|
||||
(edge for edge in self.edges if edge.target_id == requester.id), None
|
||||
)
|
||||
# Return the result of the requester edge
|
||||
return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester)
|
||||
return (
|
||||
None
|
||||
if requester_edge is None
|
||||
else await requester_edge.get_result(source=self, target=requester)
|
||||
)
|
||||
|
||||
def add_edge(self, edge: "ContractEdge") -> None:
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})"
|
||||
return (
|
||||
f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})"
|
||||
)
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
try:
|
||||
|
|
@ -560,7 +612,11 @@ class Vertex:
|
|||
|
||||
def _built_object_repr(self):
|
||||
# Add a message with an emoji, stars for sucess,
|
||||
return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵💫"
|
||||
return (
|
||||
"Built sucessfully ✨"
|
||||
if self._built_object is not None
|
||||
else "Failed to build 😵💫"
|
||||
)
|
||||
|
||||
|
||||
class StatefulVertex(Vertex):
|
||||
|
|
|
|||
|
|
@ -122,10 +122,12 @@ class DocumentLoaderVertex(StatefulVertex):
|
|||
# show how many documents are in the list?
|
||||
|
||||
if not isinstance(self._built_object, UnbuiltObject):
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len(
|
||||
self._built_object
|
||||
)
|
||||
return f"""{self.vertex_type}({len(self._built_object)} documents)
|
||||
avg_length = sum(
|
||||
len(doc.page_content)
|
||||
for doc in self._built_object
|
||||
if hasattr(doc, "page_content")
|
||||
) / len(self._built_object)
|
||||
return f"""{self.display_name}({len(self._built_object)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
Documents: {self._built_object[:3]}..."""
|
||||
return f"{self.vertex_type}()"
|
||||
|
|
@ -197,7 +199,9 @@ class TextSplitterVertex(StatefulVertex):
|
|||
# show how many documents are in the list?
|
||||
|
||||
if not isinstance(self._built_object, UnbuiltObject):
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
|
||||
self._built_object
|
||||
)
|
||||
return f"""{self.vertex_type}({len(self._built_object)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
\nDocuments: {self._built_object[:3]}..."""
|
||||
|
|
@ -244,18 +248,27 @@ class PromptVertex(StatelessVertex):
|
|||
user_id = kwargs.get("user_id", None)
|
||||
tools = kwargs.get("tools", [])
|
||||
if not self._built or force:
|
||||
if "input_variables" not in self.params or self.params["input_variables"] is None:
|
||||
if (
|
||||
"input_variables" not in self.params
|
||||
or self.params["input_variables"] is None
|
||||
):
|
||||
self.params["input_variables"] = []
|
||||
# Check if it is a ZeroShotPrompt and needs a tool
|
||||
if "ShotPrompt" in self.vertex_type:
|
||||
tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
|
||||
tools = (
|
||||
[tool_node.build(user_id=user_id) for tool_node in tools]
|
||||
if tools is not None
|
||||
else []
|
||||
)
|
||||
# flatten the list of tools if it is a list of lists
|
||||
# first check if it is a list
|
||||
if tools and isinstance(tools, list) and isinstance(tools[0], list):
|
||||
tools = flatten_list(tools)
|
||||
self.params["tools"] = tools
|
||||
prompt_params = [
|
||||
key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions"
|
||||
key
|
||||
for key, value in self.params.items()
|
||||
if isinstance(value, str) and key != "format_instructions"
|
||||
]
|
||||
else:
|
||||
prompt_params = ["template"]
|
||||
|
|
@ -265,14 +278,20 @@ class PromptVertex(StatelessVertex):
|
|||
prompt_text = self.params[param]
|
||||
variables = extract_input_variables_from_prompt(prompt_text)
|
||||
self.params["input_variables"].extend(variables)
|
||||
self.params["input_variables"] = list(set(self.params["input_variables"]))
|
||||
self.params["input_variables"] = list(
|
||||
set(self.params["input_variables"])
|
||||
)
|
||||
elif isinstance(self.params, dict):
|
||||
self.params.pop("input_variables", None)
|
||||
|
||||
await self._build(user_id=user_id)
|
||||
|
||||
def _built_object_repr(self):
|
||||
if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"):
|
||||
if (
|
||||
not self.artifacts
|
||||
or self._built_object is None
|
||||
or not hasattr(self._built_object, "format")
|
||||
):
|
||||
return super()._built_object_repr()
|
||||
elif isinstance(self._built_object, UnbuiltObject):
|
||||
return super()._built_object_repr()
|
||||
|
|
@ -284,7 +303,9 @@ class PromptVertex(StatelessVertex):
|
|||
# so the prompt format doesn't break
|
||||
artifacts.pop("handle_keys", None)
|
||||
try:
|
||||
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
|
||||
if not hasattr(self._built_object, "template") and hasattr(
|
||||
self._built_object, "prompt"
|
||||
):
|
||||
template = self._built_object.prompt.template
|
||||
else:
|
||||
template = self._built_object.template
|
||||
|
|
@ -292,7 +313,11 @@ class PromptVertex(StatelessVertex):
|
|||
if value:
|
||||
replace_key = "{" + key + "}"
|
||||
template = template.replace(replace_key, value)
|
||||
return template if isinstance(template, str) else f"{self.vertex_type}({template})"
|
||||
return (
|
||||
template
|
||||
if isinstance(template, str)
|
||||
else f"{self.vertex_type}({template})"
|
||||
)
|
||||
except KeyError:
|
||||
return str(self._built_object)
|
||||
|
||||
|
|
|
|||
|
|
@ -477,11 +477,11 @@ export default function GenericNode({
|
|||
) : (
|
||||
<div className="max-h-96 overflow-auto">
|
||||
{typeof validationStatus.params === "string"
|
||||
? `${durationString}\n${validationStatus.params}`
|
||||
? (`${durationString}\n${validationStatus.params}`
|
||||
.split("\n")
|
||||
.map((line, index) => (
|
||||
<div key={index}>{line}</div>
|
||||
))
|
||||
)))
|
||||
: durationString}
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ export default function AccordionComponent({
|
|||
>
|
||||
{trigger}
|
||||
</AccordionTrigger>
|
||||
<AccordionContent className="AccordionContent">
|
||||
<AccordionContent className="AccordionContent flex flex-col">
|
||||
{children}
|
||||
</AccordionContent>
|
||||
</AccordionItem>
|
||||
|
|
|
|||
|
|
@ -12,12 +12,12 @@ export default function IOInputField({
|
|||
const setNode = useFlowStore((state) => state.setNode);
|
||||
const node = nodes.find((node) => node.id === inputId);
|
||||
function handleInputType() {
|
||||
if (!node) return "no node found";
|
||||
if (!node) return <>"No node found!"</>;
|
||||
switch (inputType) {
|
||||
case "TextInput":
|
||||
return (
|
||||
<Textarea
|
||||
className="h-full w-full custom-scroll"
|
||||
className="w-full"
|
||||
placeholder={"Enter text..."}
|
||||
value={node.data.node!.template["value"].value}
|
||||
onChange={(e) => {
|
||||
|
|
@ -47,7 +47,7 @@ export default function IOInputField({
|
|||
default:
|
||||
return (
|
||||
<Textarea
|
||||
className="h-full w-full custom-scroll"
|
||||
className="w-full custom-scroll"
|
||||
placeholder={"Enter text..."}
|
||||
value={node.data.node!.template["value"]}
|
||||
onChange={(e) => {
|
||||
|
|
@ -62,10 +62,5 @@ export default function IOInputField({
|
|||
);
|
||||
}
|
||||
}
|
||||
return (
|
||||
<div className="font-xl flex h-full w-full flex-col items-start gap-4 p-4 font-semibold">
|
||||
{inputType}
|
||||
{handleInputType()}
|
||||
</div>
|
||||
);
|
||||
return handleInputType();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,15 +12,15 @@ export default function IOOutputView({
|
|||
const flowPool = useFlowStore((state) => state.flowPool);
|
||||
const node = nodes.find((node) => node.id === outputId);
|
||||
function handleOutputType() {
|
||||
if (!node) return "no node found";
|
||||
if (!node) return <>"No node found!"</>;
|
||||
switch (outputType) {
|
||||
case "TextOutput":
|
||||
return (
|
||||
<Textarea
|
||||
className="h-full w-full custom-scroll"
|
||||
placeholder={"Enter text..."}
|
||||
className="w-full custom-scroll"
|
||||
placeholder={"Empty"}
|
||||
// update to real value on flowPool
|
||||
value={flowPool[node.id][flowPool[node.id].length - 1].data.results}
|
||||
value={((flowPool[node.id] ?? [])[(flowPool[node.id]?.length ?? 1) - 1])?.params ?? ""}
|
||||
readOnly
|
||||
/>
|
||||
);
|
||||
|
|
@ -28,7 +28,7 @@ export default function IOOutputView({
|
|||
default:
|
||||
return (
|
||||
<Textarea
|
||||
className="h-full w-full custom-scroll"
|
||||
className="w-full custom-scroll"
|
||||
placeholder={"Enter text..."}
|
||||
value={node.data.node!.template["value"]}
|
||||
onChange={(e) => {
|
||||
|
|
@ -43,10 +43,5 @@ export default function IOOutputView({
|
|||
);
|
||||
}
|
||||
}
|
||||
return (
|
||||
<div className="font-xl flex h-full w-full flex-col items-start gap-4 p-4 font-semibold">
|
||||
{outputType}
|
||||
{handleOutputType()}
|
||||
</div>
|
||||
);
|
||||
return handleOutputType();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import { CHAT_FORM_DIALOG_SUBTITLE } from "../../constants/constants";
|
|||
import BaseModal from "../../modals/baseModal";
|
||||
import useAlertStore from "../../stores/alertStore";
|
||||
import useFlowStore from "../../stores/flowStore";
|
||||
import { validateNodes } from "../../utils/reactflowUtils";
|
||||
import { cn } from "../../utils/utils";
|
||||
import AccordionComponent from "../AccordionComponent";
|
||||
import IOInputField from "../IOInputField";
|
||||
|
|
@ -51,33 +50,22 @@ export default function IOView({ children, open, setOpen }): JSX.Element {
|
|||
async function sendMessage(count = 1): Promise<void> {
|
||||
if (isBuilding) return;
|
||||
const { nodes, edges } = getFlow();
|
||||
let nodeValidationErrors = validateNodes(nodes, edges);
|
||||
if (nodeValidationErrors.length === 0) {
|
||||
setIsBuilding(true);
|
||||
setLockChat(true);
|
||||
setChatValue("");
|
||||
const chatInputNode = nodes.find((node) => node.id === chatInput?.id);
|
||||
if (chatInputNode) {
|
||||
let newNode = cloneDeep(chatInputNode);
|
||||
newNode.data.node!.template["message"].value = chatValue;
|
||||
setNode(chatInput!.id, newNode);
|
||||
}
|
||||
for (let i = 0; i < count; i++) {
|
||||
await buildFlow().catch((err) => {
|
||||
console.error(err);
|
||||
setLockChat(false);
|
||||
});
|
||||
}
|
||||
setLockChat(false);
|
||||
|
||||
//set chat message in the flow and run build
|
||||
//@ts-ignore
|
||||
} else {
|
||||
setErrorData({
|
||||
title: "Oops! Looks like you missed some required information:",
|
||||
list: nodeValidationErrors,
|
||||
setIsBuilding(true);
|
||||
setLockChat(true);
|
||||
setChatValue("");
|
||||
const chatInputNode = nodes.find((node) => node.id === chatInput?.id);
|
||||
if (chatInputNode) {
|
||||
let newNode = cloneDeep(chatInputNode);
|
||||
newNode.data.node!.template["message"].value = chatValue;
|
||||
setNode(chatInput!.id, newNode);
|
||||
}
|
||||
for (let i = 0; i < count; i++) {
|
||||
await buildFlow().catch((err) => {
|
||||
console.error(err);
|
||||
setLockChat(false);
|
||||
});
|
||||
}
|
||||
setLockChat(false);
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
|
|
@ -260,27 +248,52 @@ export default function IOView({ children, open, setOpen }): JSX.Element {
|
|||
)}
|
||||
|
||||
{haveChat ? (
|
||||
selectedViewField ? (
|
||||
inputs.some((input) => input.id === selectedViewField.id) ? (
|
||||
<IOInputField
|
||||
inputType={selectedViewField.type!}
|
||||
inputId={selectedViewField.id!}
|
||||
<div className="flex h-full w-full">
|
||||
{selectedViewField && (
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-full w-full flex-col items-start gap-4 p-4",
|
||||
!selectedViewField ? "hidden" : ""
|
||||
)}
|
||||
>
|
||||
<div className="font-xl flex items-center justify-center gap-3 font-semibold">
|
||||
<button onClick={() => setSelectedViewField(undefined)}>
|
||||
<IconComponent
|
||||
name={"ArrowLeft"}
|
||||
className="h-6 w-6"
|
||||
></IconComponent>
|
||||
</button>
|
||||
{selectedViewField.type}
|
||||
</div>
|
||||
<div className="h-full">
|
||||
{inputs.some(
|
||||
(input) => input.id === selectedViewField.id
|
||||
) ? (
|
||||
<IOInputField
|
||||
inputType={selectedViewField.type!}
|
||||
inputId={selectedViewField.id!}
|
||||
/>
|
||||
) : (
|
||||
<IOOutputView
|
||||
outputType={selectedViewField.type!}
|
||||
outputId={selectedViewField.id!}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
className={cn("flex w-full h-full",selectedViewField ? "hidden" : "")}
|
||||
>
|
||||
<NewChatView
|
||||
sendMessage={sendMessage}
|
||||
chatValue={chatValue}
|
||||
setChatValue={setChatValue}
|
||||
lockChat={lockChat}
|
||||
setLockChat={setLockChat}
|
||||
/>
|
||||
) : (
|
||||
<IOOutputView
|
||||
outputType={selectedViewField.type!}
|
||||
outputId={selectedViewField.id!}
|
||||
/>
|
||||
)
|
||||
) : (
|
||||
<NewChatView
|
||||
sendMessage={sendMessage}
|
||||
chatValue={chatValue}
|
||||
setChatValue={setChatValue}
|
||||
lockChat={lockChat}
|
||||
setLockChat={setLockChat}
|
||||
/>
|
||||
)
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="absolute bottom-8 right-8"></div>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import { ShadToolTipType } from "../../types/components";
|
||||
import { cn } from "../../utils/utils";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip";
|
||||
|
||||
export default function ShadTooltip({
|
||||
|
|
@ -14,7 +15,7 @@ export default function ShadTooltip({
|
|||
<TooltipTrigger asChild={asChild}>{children}</TooltipTrigger>
|
||||
|
||||
<TooltipContent
|
||||
className={styleClasses}
|
||||
className={cn(styleClasses, "max-w-96") }
|
||||
side={side}
|
||||
avoidCollisions={false}
|
||||
sticky="always"
|
||||
|
|
|
|||
|
|
@ -20,48 +20,42 @@ export default function TextAreaComponent({
|
|||
}, [disabled]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={
|
||||
"flex w-full items-center " + (disabled ? "pointer-events-none" : "")
|
||||
}
|
||||
>
|
||||
<GenericModal
|
||||
type={TypeModal.TEXT}
|
||||
buttonText="Finishing Editing"
|
||||
modalTitle="Edit Text"
|
||||
value={value}
|
||||
setValue={(value: string) => {
|
||||
onChange(value);
|
||||
}}
|
||||
>
|
||||
<div className="flex w-full items-center" data-testid={"div-" + id}>
|
||||
<Input
|
||||
id={id}
|
||||
data-testid={id}
|
||||
<div className={"flex w-full items-center " + (disabled ? "" : "")}>
|
||||
<div className="flex w-full items-center" data-testid={"div-" + id}>
|
||||
<Input
|
||||
id={id}
|
||||
data-testid={id}
|
||||
value={value}
|
||||
disabled={disabled}
|
||||
className={editNode ? "input-edit-node w-full" : " w-full"}
|
||||
placeholder={"Type something..."}
|
||||
onChange={(event) => {
|
||||
onChange(event.target.value);
|
||||
}}
|
||||
/>
|
||||
<div>
|
||||
<GenericModal
|
||||
type={TypeModal.TEXT}
|
||||
buttonText="Finish Editing"
|
||||
modalTitle="Edit Text"
|
||||
value={value}
|
||||
disabled={disabled}
|
||||
className={
|
||||
editNode
|
||||
? "input-edit-node pointer-events-none "
|
||||
: " pointer-events-none"
|
||||
}
|
||||
placeholder={"Type something..."}
|
||||
onChange={(event) => {
|
||||
onChange(event.target.value);
|
||||
setValue={(value: string) => {
|
||||
onChange(value);
|
||||
}}
|
||||
/>
|
||||
{!editNode && (
|
||||
<IconComponent
|
||||
id={id}
|
||||
name="ExternalLink"
|
||||
className={
|
||||
"icons-parameters-comp" +
|
||||
(disabled ? " text-ring" : " hover:text-accent-foreground")
|
||||
}
|
||||
/>
|
||||
)}
|
||||
>
|
||||
{!editNode && (
|
||||
<IconComponent
|
||||
id={id}
|
||||
name="ExternalLink"
|
||||
className={
|
||||
"icons-parameters-comp" +
|
||||
(disabled ? " text-ring" : " hover:text-accent-foreground")
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</GenericModal>
|
||||
</div>
|
||||
</GenericModal>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -681,4 +681,4 @@ export const LANGFLOW_SUPPORTED_TYPES = new Set([
|
|||
export const priorityFields = new Set(["code", "template"]);
|
||||
|
||||
export const INPUT_TYPES = new Set(["ChatInput", "TextInput"]);
|
||||
export const OUTPUT_TYPES = new Set(["ChatOutput", "PromptTemplate"]);
|
||||
export const OUTPUT_TYPES = new Set(["ChatOutput", "TextOutput"]);
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ type TriggerProps = {
|
|||
};
|
||||
|
||||
const Content: React.FC<ContentProps> = ({ children }) => {
|
||||
return <div className="h-full w-full">{children}</div>;
|
||||
return <div className="h-full w-full flex flex-col">{children}</div>;
|
||||
};
|
||||
const Trigger: React.FC<TriggerProps> = ({ children, asChild, disable }) => {
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import {
|
|||
applyNodeChanges,
|
||||
} from "reactflow";
|
||||
import { create } from "zustand";
|
||||
import { FLOW_BUILD_SUCCESS_ALERT } from "../alerts_constants";
|
||||
import { BuildStatus } from "../constants/enums";
|
||||
import { getFlowPool, updateFlowInDatabase } from "../controllers/API";
|
||||
import { VertexBuildTypeAPI } from "../types/api";
|
||||
|
|
@ -26,12 +27,12 @@ import {
|
|||
getNodeId,
|
||||
scapeJSONParse,
|
||||
scapedJSONStringfy,
|
||||
validateNodes,
|
||||
} from "../utils/reactflowUtils";
|
||||
import { getInputsAndOutputs } from "../utils/storeUtils";
|
||||
import useAlertStore from "./alertStore";
|
||||
import { useDarkStore } from "./darkStore";
|
||||
import useFlowsManagerStore from "./flowsManagerStore";
|
||||
import { FLOW_BUILD_SUCCESS_ALERT } from "../alerts_constants";
|
||||
|
||||
// this is our useStore hook that we can use in our components to get parts of the store and call actions
|
||||
const useFlowStore = create<FlowStoreType>((set, get) => ({
|
||||
|
|
@ -377,6 +378,20 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
|
|||
const setSuccessData = useAlertStore.getState().setSuccessData;
|
||||
const setErrorData = useAlertStore.getState().setErrorData;
|
||||
const setNoticeData = useAlertStore.getState().setNoticeData;
|
||||
function validateSubgraph(nodes: string[]) {
|
||||
const errors = validateNodes(
|
||||
get().nodes.filter((node) => nodes.includes(node.id)),
|
||||
get().edges
|
||||
);
|
||||
if (errors.length > 0) {
|
||||
setErrorData({
|
||||
title: "Oops! Looks like you missed something",
|
||||
list: errors,
|
||||
});
|
||||
get().setIsBuilding(false);
|
||||
throw new Error("Invalid nodes");
|
||||
}
|
||||
}
|
||||
function handleBuildUpdate(
|
||||
vertexBuildData: VertexBuildTypeAPI,
|
||||
status: BuildStatus
|
||||
|
|
@ -397,10 +412,12 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
|
|||
name: currentFlow!.name,
|
||||
description: currentFlow!.description,
|
||||
});
|
||||
setNoticeData({ title: "Running components" });
|
||||
await buildVertices({
|
||||
flowId: currentFlow!.id,
|
||||
nodeId,
|
||||
onGetOrderSuccess: () => {
|
||||
setNoticeData({ title: "Running components" });
|
||||
},
|
||||
onBuildComplete: () => {
|
||||
if (nodeId) {
|
||||
setSuccessData({
|
||||
|
|
@ -422,6 +439,7 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
|
|||
onBuildStart: (idList) => {
|
||||
useFlowStore.getState().updateBuildStatus(idList, BuildStatus.BUILDING);
|
||||
},
|
||||
validateNodes: validateSubgraph,
|
||||
});
|
||||
get().revertBuiltStatusFromBuilding();
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,17 +1,19 @@
|
|||
import { AxiosError } from "axios";
|
||||
import { BuildStatus } from "../constants/enums";
|
||||
import { getVerticesOrder, postBuildVertex } from "../controllers/API";
|
||||
import useAlertStore from "../stores/alertStore";
|
||||
import useFlowStore from "../stores/flowStore";
|
||||
import { VertexBuildTypeAPI } from "../types/api";
|
||||
|
||||
type BuildVerticesParams = {
|
||||
flowId: string; // Assuming FlowType is the type for your flow
|
||||
nodeId?: string | null; // Assuming nodeId is of type string, and it's optional
|
||||
onProgressUpdate?: (progress: number) => void; // Replace number with the actual type if it's not a number
|
||||
onGetOrderSuccess?: () => void;
|
||||
onBuildUpdate?: (data: VertexBuildTypeAPI, status: BuildStatus) => void; // Replace any with the actual type if it's not any
|
||||
onBuildComplete?: (allNodesValid: boolean) => void;
|
||||
onBuildError?: (title, list, idList: string[]) => void;
|
||||
onBuildStart?: (idList: string[]) => void;
|
||||
validateNodes?: (nodes: string[]) => void;
|
||||
};
|
||||
|
||||
function getInactiveVertexData(vertexId: string): VertexBuildTypeAPI {
|
||||
|
|
@ -35,17 +37,37 @@ function getInactiveVertexData(vertexId: string): VertexBuildTypeAPI {
|
|||
export async function buildVertices({
|
||||
flowId,
|
||||
nodeId = null,
|
||||
onProgressUpdate,
|
||||
onGetOrderSuccess,
|
||||
onBuildUpdate,
|
||||
onBuildComplete,
|
||||
onBuildError,
|
||||
onBuildStart,
|
||||
validateNodes,
|
||||
}: BuildVerticesParams) {
|
||||
let orderResponse = await getVerticesOrder(flowId, nodeId);
|
||||
const setErrorData = useAlertStore.getState().setErrorData;
|
||||
let orderResponse;
|
||||
try {
|
||||
orderResponse = await getVerticesOrder(flowId, nodeId);
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
setErrorData({
|
||||
title: "Oops! Looks like you missed something",
|
||||
list: [error.response?.data?.detail ?? "Unknown Error"],
|
||||
});
|
||||
useFlowStore.getState().setIsBuilding(false);
|
||||
throw new Error("Invalid nodes");
|
||||
}
|
||||
if (onGetOrderSuccess) onGetOrderSuccess();
|
||||
let verticesOrder: Array<Array<string>> = orderResponse.data.ids;
|
||||
let vertices_layers: Array<Array<string>> = [];
|
||||
let stop = false;
|
||||
|
||||
if (validateNodes) {
|
||||
try {
|
||||
validateNodes(verticesOrder.flatMap((id) => id));
|
||||
} catch (e) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (nodeId) {
|
||||
for (let i = 0; i < verticesOrder.length; i += 1) {
|
||||
const innerArray = verticesOrder[i];
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import {
|
||||
AlertCircle,
|
||||
ArrowLeft,
|
||||
ArrowUpToLine,
|
||||
Bell,
|
||||
BookMarked,
|
||||
|
|
@ -340,6 +341,7 @@ export const nodeIconsLucide: iconsType = {
|
|||
Bell,
|
||||
ChevronLeft,
|
||||
ChevronDown,
|
||||
ArrowLeft,
|
||||
Shield,
|
||||
Plus,
|
||||
Redo,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue