Fix styleUtils import and remove unnecessary lines

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-22 17:07:36 -03:00
commit 9cfb03fbc9
47 changed files with 219 additions and 574 deletions

View file

@ -3,8 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
from langchain.schema import AgentAction, AgentFinish
from langchain_core.callbacks.base import (AsyncCallbackHandler,
BaseCallbackHandler)
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langflow.api.v1.schemas import ChatResponse, PromptResponse
from langflow.services.deps import get_chat_service
from langflow.utils.util import remove_ansi_escape_codes

View file

@ -43,13 +43,9 @@ async def chat(
user = await get_current_user_for_websocket(websocket, db)
await websocket.accept()
if not user:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
elif not user.is_active:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
if client_id in chat_service.cache_service:
await chat_service.handle_websocket(client_id, websocket)
@ -65,9 +61,7 @@ async def chat(
logger.error(f"Error in chat websocket: {exc}")
messsage = exc.detail if isinstance(exc, HTTPException) else str(exc)
if "Could not validate credentials" in str(exc):
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
else:
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=messsage)
@ -137,12 +131,8 @@ async def build_vertex(
cache = chat_service.get_cache(flow_id)
if not cache:
# If there's no cache
logger.warning(
f"No cache found for {flow_id}. Building graph starting at {vertex_id}"
)
graph = build_and_cache_graph(
flow_id=flow_id, session=next(get_session()), chat_service=chat_service
)
logger.warning(f"No cache found for {flow_id}. Building graph starting at {vertex_id}")
graph = build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service)
else:
graph = cache.get("result")
result_dict = {}

View file

@ -66,8 +66,6 @@ async def get_transactions(
monitor_service: MonitorService = Depends(get_monitor_service),
):
try:
return monitor_service.get_transactions(
source=source, target=target, status=status, order_by=order_by
)
return monitor_service.get_transactions(source=source, target=target, status=status, order_by=order_by)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View file

@ -161,9 +161,7 @@ class StreamData(BaseModel):
data: dict
def __str__(self) -> str:
return (
f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n"
)
return f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n"
class CustomComponentCode(BaseModel):

View file

@ -40,9 +40,7 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
add_new_variables_to_template(input_variables, prompt_request)
remove_old_variables_from_template(
old_custom_fields, input_variables, prompt_request
)
remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request)
update_input_variables_field(input_variables, prompt_request)
@ -57,19 +55,12 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
def get_old_custom_fields(prompt_request):
try:
if (
len(prompt_request.frontend_node.custom_fields) == 1
and prompt_request.name == ""
):
if len(prompt_request.frontend_node.custom_fields) == 1 and prompt_request.name == "":
# If there is only one custom field and the name is empty string
# then we are dealing with the first prompt request after the node was created
prompt_request.name = list(
prompt_request.frontend_node.custom_fields.keys()
)[0]
prompt_request.name = list(prompt_request.frontend_node.custom_fields.keys())[0]
old_custom_fields = prompt_request.frontend_node.custom_fields[
prompt_request.name
]
old_custom_fields = prompt_request.frontend_node.custom_fields[prompt_request.name]
if old_custom_fields is None:
old_custom_fields = []
@ -95,40 +86,26 @@ def add_new_variables_to_template(input_variables, prompt_request):
)
if variable in prompt_request.frontend_node.template:
# Set the new field with the old value
template_field.value = prompt_request.frontend_node.template[variable][
"value"
]
template_field.value = prompt_request.frontend_node.template[variable]["value"]
prompt_request.frontend_node.template[variable] = template_field.to_dict()
# Check if variable is not already in the list before appending
if (
variable
not in prompt_request.frontend_node.custom_fields[prompt_request.name]
):
prompt_request.frontend_node.custom_fields[prompt_request.name].append(
variable
)
if variable not in prompt_request.frontend_node.custom_fields[prompt_request.name]:
prompt_request.frontend_node.custom_fields[prompt_request.name].append(variable)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def remove_old_variables_from_template(
old_custom_fields, input_variables, prompt_request
):
def remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request):
for variable in old_custom_fields:
if variable not in input_variables:
try:
# Remove the variable from custom_fields associated with the given name
if (
variable
in prompt_request.frontend_node.custom_fields[prompt_request.name]
):
prompt_request.frontend_node.custom_fields[
prompt_request.name
].remove(variable)
if variable in prompt_request.frontend_node.custom_fields[prompt_request.name]:
prompt_request.frontend_node.custom_fields[prompt_request.name].remove(variable)
# Remove the variable from the template
prompt_request.frontend_node.template.pop(variable, None)
@ -140,6 +117,4 @@ def remove_old_variables_from_template(
def update_input_variables_field(input_variables, prompt_request):
if "input_variables" in prompt_request.frontend_node.template:
prompt_request.frontend_node.template["input_variables"][
"value"
] = input_variables
prompt_request.frontend_node.template["input_variables"]["value"] = input_variables

View file

@ -52,9 +52,7 @@ class StoreMessages(CustomComponent):
if not records:
records = []
if not session_id or not sender or not sender_name:
raise ValueError(
"If passing texts, session_id, sender, and sender_name must be provided."
)
raise ValueError("If passing texts, session_id, sender, and sender_name must be provided.")
for text in texts:
record = Record(
text=text,

View file

@ -126,8 +126,7 @@ class ChatLiteLLMComponent(CustomComponent):
litellm.set_verbose = verbose
except ImportError:
raise ChatLiteLLMException(
"Could not import litellm python package. "
"Please install it with `pip install litellm`"
"Could not import litellm python package. " "Please install it with `pip install litellm`"
)
provider_map = {
"OpenAI": "openai_api_key",

View file

@ -63,4 +63,4 @@ class AmazonBedrockComponent(CustomComponent):
message = output.invoke(inputs)
result = message.content if hasattr(message, "content") else message
self.status = result
return result
return result

View file

@ -54,7 +54,7 @@ class AnthropicLLM(CustomComponent):
def build(
self,
model: str,
inputs:str,
inputs: str,
anthropic_api_key: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
@ -78,4 +78,3 @@ class AnthropicLLM(CustomComponent):
result = message.content if hasattr(message, "content") else message
self.status = result
return result

View file

@ -31,7 +31,7 @@ class CTransformersComponent(CustomComponent):
"inputs": {"display_name": "Input"},
}
def build(self, model: str, model_file: str,inputs:str, model_type: str, config: Optional[Dict] = None) -> Text:
def build(self, model: str, model_file: str, inputs: str, model_type: str, config: Optional[Dict] = None) -> Text:
output = CTransformers(model=model, model_file=model_file, model_type=model_type, config=config)
message = output.invoke(inputs)
result = message.content if hasattr(message, "content") else message

View file

@ -14,41 +14,41 @@ class GoogleGenerativeAIComponent(CustomComponent):
def build_config(self):
return {
"google_api_key":
{ "display_name":"Google API Key",
"info":"The Google API Key to use for the Google Generative AI.",
} ,
"max_output_tokens":{
"display_name":"Max Output Tokens",
"info":"The maximum number of tokens to generate.",
"google_api_key": {
"display_name": "Google API Key",
"info": "The Google API Key to use for the Google Generative AI.",
},
"max_output_tokens": {
"display_name": "Max Output Tokens",
"info": "The maximum number of tokens to generate.",
},
"temperature": {
"display_name":"Temperature",
"info":"Run inference with this temperature. Must by in the closed interval [0.0, 1.0].",
"display_name": "Temperature",
"info": "Run inference with this temperature. Must by in the closed interval [0.0, 1.0].",
},
"top_k": {
"display_name":"Top K",
"info":"Decode using top-k sampling: consider the set of top_k most probable tokens. Must be positive.",
"range_spec":RangeSpec(min=0, max=2, step=0.1),
"advanced":True,
"display_name": "Top K",
"info": "Decode using top-k sampling: consider the set of top_k most probable tokens. Must be positive.",
"range_spec": RangeSpec(min=0, max=2, step=0.1),
"advanced": True,
},
"top_p": {
"display_name":"Top P",
"info":"The maximum cumulative probability of tokens to consider when sampling.",
"advanced":True,
"display_name": "Top P",
"info": "The maximum cumulative probability of tokens to consider when sampling.",
"advanced": True,
},
"n": {
"display_name":"N",
"info":"Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated.",
"advanced":True,
"display_name": "N",
"info": "Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated.",
"advanced": True,
},
"model": {
"display_name":"Model",
"info":"The name of the model to use. Supported examples: gemini-pro",
"options":["gemini-pro", "gemini-pro-vision"],
"display_name": "Model",
"info": "The name of the model to use. Supported examples: gemini-pro",
"options": ["gemini-pro", "gemini-pro-vision"],
},
"code": {
"advanced":True,
"advanced": True,
},
"inputs": {"display_name": "Input"},
}
@ -57,14 +57,14 @@ class GoogleGenerativeAIComponent(CustomComponent):
self,
google_api_key: str,
model: str,
inputs:str,
inputs: str,
max_output_tokens: Optional[int] = None,
temperature: float = 0.1,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
n: Optional[int] = 1,
) -> Text:
output = ChatGoogleGenerativeAI(
output = ChatGoogleGenerativeAI(
model=model,
max_output_tokens=max_output_tokens or None, # type: ignore
temperature=temperature,

View file

@ -47,4 +47,3 @@ class HuggingFaceEndpointsComponent(CustomComponent):
result = message.content if hasattr(message, "content") else message
self.status = result
return result

View file

@ -57,7 +57,7 @@ class LlamaCppComponent(CustomComponent):
def build(
self,
model_path: str,
inputs:str,
inputs: str,
grammar: Optional[str] = None,
cache: Optional[bool] = None,
client: Optional[Any] = None,

View file

@ -171,7 +171,7 @@ class ChatOllamaComponent(CustomComponent):
self,
base_url: Optional[str],
model: str,
inputs:str,
inputs: str,
mirostat: Optional[str],
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,

View file

@ -19,7 +19,6 @@ class PromptComponent(CustomComponent):
template: Prompt,
**kwargs,
) -> Text:
prompt_template = PromptTemplate.from_template(template)
attributes_to_check = ["text", "page_content"]

View file

@ -27,9 +27,7 @@ class RecordsAsTextComponent(CustomComponent):
if isinstance(records, Record):
records = [records]
formated_records = [
template.format(text=record.text, **record.data) for record in records
]
formated_records = [template.format(text=record.text, **record.data) for record in records]
result_string = "\n".join(formated_records)
self.status = result_string
return result_string

View file

@ -84,8 +84,7 @@ class ChromaComponent(CustomComponent):
if chroma_server_host is not None:
chroma_settings = chromadb.config.Settings(
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
or None,
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
chroma_server_host=chroma_server_host,
chroma_server_port=chroma_server_port or None,
chroma_server_grpc_port=chroma_server_grpc_port or None,
@ -100,9 +99,7 @@ class ChromaComponent(CustomComponent):
if documents is not None and embedding is not None:
if len(documents) == 0:
raise ValueError(
"If documents are provided, there must be at least one document."
)
raise ValueError("If documents are provided, there must be at least one document.")
chroma = Chroma.from_documents(
documents=documents, # type: ignore
persist_directory=index_directory,
@ -111,7 +108,5 @@ class ChromaComponent(CustomComponent):
client_settings=chroma_settings,
)
else:
chroma = Chroma(
persist_directory=index_directory, client_settings=chroma_settings
)
chroma = Chroma(persist_directory=index_directory, client_settings=chroma_settings)
return chroma

View file

@ -92,8 +92,7 @@ class ChromaSearchComponent(CustomComponent):
if chroma_server_host is not None:
chroma_settings = chromadb.config.Settings(
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
or None,
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
chroma_server_host=chroma_server_host,
chroma_server_port=chroma_server_port or None,
chroma_server_grpc_port=chroma_server_grpc_port or None,

View file

@ -11,9 +11,7 @@ if TYPE_CHECKING:
class SourceHandle(BaseModel):
baseClasses: List[str] = Field(
..., description="List of base classes for the source handle."
)
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
@ -21,9 +19,7 @@ class SourceHandle(BaseModel):
class TargetHandle(BaseModel):
fieldName: str = Field(..., description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
inputTypes: Optional[List[str]] = Field(
None, description="List of input types for the target handle."
)
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
type: str = Field(..., description="Type of the target handle.")
@ -52,24 +48,16 @@ class Edge:
def validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = (
self.target_handle.type in self.source_handle.baseClasses
)
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
else:
self.valid_handles = (
any(
baseClass in self.target_handle.inputTypes
for baseClass in self.source_handle.baseClasses
)
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
or self.target_handle.type in self.source_handle.baseClasses
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has invalid handles"
)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
def __setstate__(self, state):
self.source_id = state["source_id"]
@ -86,11 +74,7 @@ class Edge:
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(
output in target_req
for output in self.source_types
for target_req in self.target_reqs
)
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
# Get what type of input the target node is expecting
self.matched_type = next(
@ -101,10 +85,7 @@ class Edge:
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has no matched type"
)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
def __repr__(self) -> str:
return (
@ -116,11 +97,7 @@ class Edge:
return hash(self.__repr__())
def __eq__(self, __value: object) -> bool:
return (
self.__repr__() == __value.__repr__()
if isinstance(__value, Edge)
else False
)
return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False
class ContractEdge(Edge):
@ -176,9 +153,7 @@ class ContractEdge(Edge):
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
def log_transaction(
edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None
):
def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None):
try:
monitor_service = get_monitor_service()
clean_params = build_clean_params(target)

View file

@ -225,11 +225,7 @@ class Graph:
return
self.vertices.remove(vertex)
self.vertex_map.pop(vertex_id)
self.edges = [
edge
for edge in self.edges
if edge.source_id != vertex_id and edge.target_id != vertex_id
]
self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id]
def _build_vertex_params(self) -> None:
"""Identifies and handles the LLM vertex within the graph."""
@ -250,9 +246,7 @@ class Graph:
return
for vertex in self.vertices:
if not self._validate_vertex(vertex):
raise ValueError(
f"{vertex.vertex_type} is not connected to any other components"
)
raise ValueError(f"{vertex.vertex_type} is not connected to any other components")
def _validate_vertex(self, vertex: Vertex) -> bool:
"""Validates a vertex."""
@ -268,11 +262,7 @@ class Graph:
def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]:
"""Returns a list of edges for a given vertex."""
return [
edge
for edge in self.edges
if edge.source_id == vertex_id or edge.target_id == vertex_id
]
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]:
"""Returns the vertices connected to a vertex."""
@ -310,9 +300,7 @@ class Graph:
def dfs(vertex):
if state[vertex] == 1:
# We have a cycle
raise ValueError(
"Graph contains a cycle, cannot perform topological sort"
)
raise ValueError("Graph contains a cycle, cannot perform topological sort")
if state[vertex] == 0:
state[vertex] = 1
for edge in vertex.edges:
@ -336,17 +324,11 @@ class Graph:
def get_predecessors(self, vertex):
"""Returns the predecessors of a vertex."""
return [
self.get_vertex(source_id)
for source_id in self.predecessor_map.get(vertex.id, [])
]
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
def get_successors(self, vertex):
"""Returns the successors of a vertex."""
return [
self.get_vertex(target_id)
for target_id in self.successor_map.get(vertex.id, [])
]
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
"""Returns the neighbors of a vertex."""
@ -385,9 +367,7 @@ class Graph:
edges.append(ContractEdge(source, target, edge))
return edges
def _get_vertex_class(
self, node_type: str, node_base_type: str, node_id: str
) -> Type[Vertex]:
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
@ -417,18 +397,14 @@ class Graph:
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(
vertex_type, vertex_base_type, vertex_data["id"]
)
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)
return vertices
def get_children_by_vertex_type(
self, vertex: Vertex, vertex_type: str
) -> List[Vertex]:
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
vertex_types = [vertex.data["type"]]
@ -440,9 +416,7 @@ class Graph:
def __repr__(self):
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join(
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
)
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def sort_up_to_vertex(self, vertex_id: str) -> "Graph":
@ -473,9 +447,7 @@ class Graph:
"""Performs a layered topological sort of the vertices in the graph."""
# Queue for vertices with no incoming edges
queue = deque(
vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0
)
queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0)
layers = []
current_layer = 0
@ -531,10 +503,7 @@ class Graph:
return refined_layers
def sort_chat_inputs_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
chat_inputs_first = []
for layer in vertices_layers:
for vertex_id in layer:
@ -561,15 +530,11 @@ class Graph:
self.increment_run_count()
return vertices_layers
def sort_interface_components_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
def contains_interface_component(vertex):
return any(
component.value in vertex for component in InterfaceComponentTypes
)
return any(component.value in vertex for component in InterfaceComponentTypes)
# Sort each inner list so that vertices containing ChatInput or ChatOutput come first
sorted_vertices = [
@ -588,13 +553,9 @@ class Graph:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
if len(vertices_ids) == 1:
return vertices_ids
vertices_ids.sort(
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
)
vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time)
return vertices_ids
sorted_vertices = [
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
]
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
return sorted_vertices

View file

@ -84,9 +84,7 @@ class Vertex:
):
if edge.target_id not in edge_results:
edge_results[edge.target_id] = {}
edge_results[edge.target_id][edge.target_param] = await edge.get_result(
source=self, target=target
)
edge_results[edge.target_id][edge.target_param] = await edge.get_result(source=self, target=target)
return edge_results
def set_result(self, result: "ResultData") -> None:
@ -96,9 +94,7 @@ class Vertex:
# If the Vertex.type is a power component
# then we need to return the built object
# instead of the result dict
if self.is_interface_component and not isinstance(
self._built_object, UnbuiltObject
):
if self.is_interface_component and not isinstance(self._built_object, UnbuiltObject):
result = self._built_object
# if it is not a dict or a string and hasattr model_dump then
# return the model_dump
@ -108,11 +104,7 @@ class Vertex:
if isinstance(self._built_result, UnbuiltResult):
return {}
return (
self._built_result
if isinstance(self._built_result, dict)
else {"result": self._built_result}
)
return self._built_result if isinstance(self._built_result, dict) else {"result": self._built_result}
def set_artifacts(self) -> None:
pass
@ -174,29 +166,17 @@ class Vertex:
self.data = self._data["data"]
self.output = self.data["node"]["base_classes"]
self.pinned = self.data["node"].get("pinned", False)
template_dicts = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
self.required_inputs = [
template_dicts[key]["type"]
for key, value in template_dicts.items()
if value["required"]
template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"]
]
self.optional_inputs = [
template_dicts[key]["type"]
for key, value in template_dicts.items()
if not value["required"]
template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"]
]
# Add the template_dicts[key]["input_types"] to the optional_inputs
self.optional_inputs.extend(
[
input_type
for value in template_dicts.values()
for input_type in value.get("input_types", [])
]
[input_type for value in template_dicts.values() for input_type in value.get("input_types", [])]
)
template_dict = self.data["node"]["template"]
@ -239,11 +219,7 @@ class Vertex:
if self.graph is None:
raise ValueError("Graph not found")
template_dict = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
params = {}
for edge in self.edges:
@ -294,11 +270,7 @@ class Vertex:
# list of dicts, so we need to convert it to a dict
# before passing it to the build method
if isinstance(val, list):
params[key] = {
k: v
for item in value.get("value", [])
for k, v in item.items()
}
params[key] = {k: v for item in value.get("value", []) for k, v in item.items()}
elif isinstance(val, dict):
params[key] = val
elif value.get("type") == "int" and val is not None:
@ -358,9 +330,7 @@ class Vertex:
if isinstance(self._built_object, str):
self._built_result = self._built_object
result = await generate_result(
self._built_object, inputs, self.has_external_output, session_id
)
result = await generate_result(self._built_object, inputs, self.has_external_output, session_id)
self._built_result = result
async def _build_each_node_in_params_dict(self, user_id=None):
@ -388,9 +358,7 @@ class Vertex:
"""
return all(self._is_node(node) for node in value)
async def get_result(
self, requester: Optional["Vertex"] = None, user_id=None, timeout=None
) -> Any:
async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any:
# PLEASE REVIEW THIS IF STATEMENT
# Check if the Vertex was built already
if self._built:
@ -424,9 +392,7 @@ class Vertex:
self._extend_params_list_with_result(key, result)
self.params[key] = result
async def _build_list_of_nodes_and_update_params(
self, key, nodes: List["Vertex"], user_id=None
):
async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
"""
Iterates over a list of nodes, builds each and updates the params dictionary.
"""
@ -478,9 +444,7 @@ class Vertex:
self._update_built_object_and_artifacts(result)
except Exception as exc:
logger.exception(exc)
raise ValueError(
f"Error building node {self.vertex_type}(ID:{self.id}): {str(exc)}"
) from exc
raise ValueError(f"Error building node {self.vertex_type}(ID:{self.id}): {str(exc)}") from exc
def _update_built_object_and_artifacts(self, result):
"""
@ -539,15 +503,9 @@ class Vertex:
return self._built_object
# Get the requester edge
requester_edge = next(
(edge for edge in self.edges if edge.target_id == requester.id), None
)
requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None)
# Return the result of the requester edge
return (
None
if requester_edge is None
else await requester_edge.get_result(source=self, target=requester)
)
return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester)
def add_edge(self, edge: "ContractEdge") -> None:
if edge not in self.edges:
@ -567,11 +525,7 @@ class Vertex:
def _built_object_repr(self):
# Add a message with an emoji, stars for sucess,
return (
"Built sucessfully ✨"
if self._built_object is not None
else "Failed to build 😵‍💫"
)
return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵‍💫"
class StatefulVertex(Vertex):

View file

@ -119,11 +119,9 @@ class DocumentLoaderVertex(StatefulVertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(
len(doc.page_content)
for doc in self._built_object
if hasattr(doc, "page_content")
) / len(self._built_object)
avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len(
self._built_object
)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
Documents: {self._built_object[:3]}..."""
@ -196,9 +194,7 @@ class TextSplitterVertex(StatefulVertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
self._built_object
)
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
\nDocuments: {self._built_object[:3]}..."""
@ -245,27 +241,18 @@ class PromptVertex(StatelessVertex):
user_id = kwargs.get("user_id", None)
tools = kwargs.get("tools", [])
if not self._built or force:
if (
"input_variables" not in self.params
or self.params["input_variables"] is None
):
if "input_variables" not in self.params or self.params["input_variables"] is None:
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = (
[tool_node.build(user_id=user_id) for tool_node in tools]
if tools is not None
else []
)
tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
tools = flatten_list(tools)
self.params["tools"] = tools
prompt_params = [
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]
@ -275,20 +262,14 @@ class PromptVertex(StatelessVertex):
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(
set(self.params["input_variables"])
)
self.params["input_variables"] = list(set(self.params["input_variables"]))
elif isinstance(self.params, dict):
self.params.pop("input_variables", None)
await self._build(user_id=user_id)
def _built_object_repr(self):
if (
not self.artifacts
or self._built_object is None
or not hasattr(self._built_object, "format")
):
if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"):
return super()._built_object_repr()
elif isinstance(self._built_object, UnbuiltObject):
return super()._built_object_repr()
@ -300,9 +281,7 @@ class PromptVertex(StatelessVertex):
# so the prompt format doesn't break
artifacts.pop("handle_keys", None)
try:
if not hasattr(self._built_object, "template") and hasattr(
self._built_object, "prompt"
):
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
template = self._built_object.prompt.template
else:
template = self._built_object.template
@ -310,11 +289,7 @@ class PromptVertex(StatelessVertex):
if value:
replace_key = "{" + key + "}"
template = template.replace(replace_key, value)
return (
template
if isinstance(template, str)
else f"{self.vertex_type}({template})"
)
return template if isinstance(template, str) else f"{self.vertex_type}({template})"
except KeyError:
return str(self._built_object)
@ -422,15 +397,11 @@ class RoutingVertex(StatelessVertex):
else:
target_vertex.should_run = False
else:
raise ValueError(
f"RoutingVertex {self.id} must have a condition in the _built_object"
)
raise ValueError(f"RoutingVertex {self.id} must have a condition in the _built_object")
self._built_result = result
else:
raise ValueError(
f"RoutingVertex {self.id} must have a _built_object with a condition and a result"
)
raise ValueError(f"RoutingVertex {self.id} must have a _built_object with a condition and a result")
def dict_to_codeblock(d: dict) -> str:

View file

@ -21,9 +21,7 @@ class ComponentFunctionEntrypointNameNullError(HTTPException):
class Component:
ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided."
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = (
"The name of the entrypoint function must be provided."
)
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = "The name of the entrypoint function must be provided."
code: Optional[str] = None
_function_entrypoint_name: str = "build"

View file

@ -100,8 +100,7 @@ class CustomComponent(Component):
detail={
"error": "Type hint Error",
"traceback": (
"Prompt type is not supported in the build method."
" Try using PromptTemplate instead."
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
),
},
)
@ -115,20 +114,14 @@ class CustomComponent(Component):
if not self.code:
return {}
component_classes = [
cls
for cls in self.tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
if not component_classes:
return {}
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
]
return build_methods[0] if build_methods else {}
@ -185,9 +178,7 @@ class CustomComponent(Component):
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.get_credential(
user_id=self._user_id or "", name=name, session=session
)
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return get_credential
@ -197,9 +188,7 @@ class CustomComponent(Component):
credential_service = get_credential_service()
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.list_credentials(
user_id=self._user_id, session=session
)
return credential_service.list_credentials(user_id=self._user_id, session=session)
def index(self, value: int = 0):
"""Returns a function that returns the value at the given index in the iterable."""
@ -250,11 +239,7 @@ class CustomComponent(Component):
if flow_id:
flow = session.query(Flow).get(flow_id)
elif flow_name:
flow = (
session.query(Flow)
.filter(Flow.name == flow_name)
.filter(Flow.user_id == self.user_id)
).first()
flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first()
else:
raise ValueError("Either flow_name or flow_id must be provided")

View file

@ -79,13 +79,9 @@ class DirectoryReader:
except Exception as e:
logger.error(f"Error while loading component: {e}")
continue
items.append(
{"name": menu["name"], "path": menu["path"], "components": components}
)
items.append({"name": menu["name"], "path": menu["path"], "components": components})
filtered = [menu for menu in items if menu["components"]]
logger.debug(
f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}'
)
logger.debug(f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}')
return {"menu": filtered}
def validate_code(self, file_content):
@ -118,9 +114,7 @@ class DirectoryReader:
Walk through the directory path and return a list of all .py files.
"""
if not (safe_path := self.get_safe_path()):
raise CustomComponentPathValueError(
f"The path needs to start with '{self.base_path}'."
)
raise CustomComponentPathValueError(f"The path needs to start with '{self.base_path}'.")
file_list = []
for root, _, files in os.walk(safe_path):
@ -165,9 +159,7 @@ class DirectoryReader:
for node in ast.walk(module):
if isinstance(node, ast.FunctionDef):
for arg in node.args.args:
if self._is_type_hint_in_arg_annotation(
arg.annotation, type_hint_name
):
if self._is_type_hint_in_arg_annotation(arg.annotation, type_hint_name):
return True
except SyntaxError:
# Returns False if the code is not valid Python
@ -185,16 +177,14 @@ class DirectoryReader:
and annotation.value.id == type_hint_name
)
def is_type_hint_used_but_not_imported(
self, type_hint_name: str, code: str
) -> bool:
def is_type_hint_used_but_not_imported(self, type_hint_name: str, code: str) -> bool:
"""
Check if a type hint is used but not imported in the given code.
"""
try:
return self._is_type_hint_used_in_args(
return self._is_type_hint_used_in_args(type_hint_name, code) and not self._is_type_hint_imported(
type_hint_name, code
) and not self._is_type_hint_imported(type_hint_name, code)
)
except SyntaxError:
# Returns True if there's something wrong with the code
# TODO : Find a better way to handle this
@ -215,9 +205,9 @@ class DirectoryReader:
return False, "Syntax error"
elif not self.validate_build(file_content):
return False, "Missing build function"
elif self._is_type_hint_used_in_args(
elif self._is_type_hint_used_in_args("Optional", file_content) and not self._is_type_hint_imported(
"Optional", file_content
) and not self._is_type_hint_imported("Optional", file_content):
):
return (
False,
"Type hint 'Optional' is used but not imported in the code.",
@ -233,9 +223,7 @@ class DirectoryReader:
from the .py files in the directory.
"""
response = {"menu": []}
logger.debug(
"-------------------- Building component menu list --------------------"
)
logger.debug("-------------------- Building component menu list --------------------")
for file_path in file_paths:
menu_name = os.path.basename(os.path.dirname(file_path))
@ -255,9 +243,7 @@ class DirectoryReader:
# first check if it's already CamelCase
if "_" in component_name:
component_name_camelcase = " ".join(
word.title() for word in component_name.split("_")
)
component_name_camelcase = " ".join(word.title() for word in component_name.split("_"))
else:
component_name_camelcase = component_name
@ -265,9 +251,7 @@ class DirectoryReader:
try:
output_types = self.get_output_types_from_code(result_content)
except Exception as exc:
logger.exception(
f"Error while getting output types from code: {str(exc)}"
)
logger.exception(f"Error while getting output types from code: {str(exc)}")
output_types = [component_name_camelcase]
else:
output_types = [component_name_camelcase]
@ -283,9 +267,7 @@ class DirectoryReader:
if menu_result not in response["menu"]:
response["menu"].append(menu_result)
logger.debug(
"-------------------- Component menu list built --------------------"
)
logger.debug("-------------------- Component menu list built --------------------")
return response
@staticmethod

View file

@ -27,18 +27,14 @@ from langflow.utils import validate
from langflow.utils.util import get_base_classes
def add_output_types(
frontend_node: CustomComponentFrontendNode, return_types: List[str]
):
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
"""Add output types to the frontend node"""
for return_type in return_types:
if return_type is None:
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid return type. Please check your code and try again."
),
"error": ("Invalid return type. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
)
@ -67,18 +63,14 @@ def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: List
frontend_node.template.fields = reordered_fields
def add_base_classes(
frontend_node: CustomComponentFrontendNode, return_types: List[str]
):
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
"""Add base classes to the frontend node"""
for return_type_instance in return_types:
if return_type_instance is None:
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid return type. Please check your code and try again."
),
"error": ("Invalid return type. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
)
@ -153,14 +145,10 @@ def add_new_custom_field(
# If options is a list, then it's a dropdown
# If options is None, then it's a list of strings
is_list = isinstance(field_config.get("options"), list)
field_config["is_list"] = (
is_list or field_config.get("is_list", False) or field_contains_list
)
field_config["is_list"] = is_list or field_config.get("is_list", False) or field_contains_list
if "name" in field_config:
warnings.warn(
"The 'name' key in field_config is used to build the object and can't be changed."
)
warnings.warn("The 'name' key in field_config is used to build the object and can't be changed.")
required = field_config.pop("required", field_required)
placeholder = field_config.pop("placeholder", "")
@ -191,9 +179,7 @@ def add_extra_fields(frontend_node, field_config, function_args):
if "name" not in extra_field or extra_field["name"] == "self":
continue
field_name, field_type, field_value, field_required = get_field_properties(
extra_field
)
field_name, field_type, field_value, field_required = get_field_properties(extra_field)
config = field_config.get(field_name, {})
frontend_node = add_new_custom_field(
frontend_node,
@ -231,9 +217,7 @@ def run_build_config(
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid type convertion. Please check your code and try again."
),
"error": ("Invalid type convertion. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
) from exc
@ -261,9 +245,7 @@ def run_build_config(
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid type convertion. Please check your code and try again."
),
"error": ("Invalid type convertion. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
) from exc
@ -318,24 +300,16 @@ def build_custom_component_template(
frontend_node = build_frontend_node(custom_component.template_config)
logger.debug("Updated attributes")
field_config, custom_instance = run_build_config(
custom_component, user_id=user_id, update_field=update_field
)
field_config, custom_instance = run_build_config(custom_component, user_id=user_id, update_field=update_field)
logger.debug("Built field config")
entrypoint_args = custom_component.get_function_entrypoint_args
add_extra_fields(frontend_node, field_config, entrypoint_args)
frontend_node = add_code_field(
frontend_node, custom_component.code, field_config.get("code", {})
)
frontend_node = add_code_field(frontend_node, custom_component.code, field_config.get("code", {}))
add_base_classes(
frontend_node, custom_component.get_function_entrypoint_return_type
)
add_output_types(
frontend_node, custom_component.get_function_entrypoint_return_type
)
add_base_classes(frontend_node, custom_component.get_function_entrypoint_return_type)
add_output_types(frontend_node, custom_component.get_function_entrypoint_return_type)
logger.debug("Added base classes")
reorder_fields(frontend_node, custom_instance._get_field_order())
@ -347,9 +321,7 @@ def build_custom_component_template(
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid type convertion. Please check your code and try again."
),
"error": ("Invalid type convertion. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
) from exc
@ -373,9 +345,7 @@ def build_custom_components(settings_service):
if not settings_service.settings.COMPONENTS_PATH:
return {}
logger.info(
f"Building custom components from {settings_service.settings.COMPONENTS_PATH}"
)
logger.info(f"Building custom components from {settings_service.settings.COMPONENTS_PATH}")
custom_components_from_file = {}
processed_paths = set()
for path in settings_service.settings.COMPONENTS_PATH:
@ -386,9 +356,7 @@ def build_custom_components(settings_service):
custom_component_dict = build_custom_component_list_from_path(path_str)
if custom_component_dict:
category = next(iter(custom_component_dict))
logger.info(
f"Loading {len(custom_component_dict[category])} component(s) from category {category}"
)
logger.info(f"Loading {len(custom_component_dict[category])} component(s) from category {category}")
custom_components_from_file = merge_nested_dicts_with_renaming(
custom_components_from_file, custom_component_dict
)

View file

@ -146,9 +146,7 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"):
result = await runnable.ainvoke(inputs)
else:
raise ValueError(
f"Runnable {runnable} does not support inputs of type {type(inputs)}"
)
raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}")
# Check if the result is a list of AIMessages
if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result):
result = [r.content for r in result]
@ -157,9 +155,7 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
return result
async def process_inputs_dict(
built_object: Union[Chain, VectorStore, Runnable], inputs: dict
):
async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict):
if isinstance(built_object, Chain):
if inputs is None:
raise ValueError("Inputs must be provided for a Chain")
@ -194,9 +190,7 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]):
return await process_runnable(built_object, inputs)
async def generate_result(
built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]
):
async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
if isinstance(inputs, dict):
result = await process_inputs_dict(built_object, inputs)
elif isinstance(inputs, List) and isinstance(built_object, Runnable):
@ -228,9 +222,7 @@ async def process_graph_cached(
if clear_cache:
session_service.clear_session(session_id)
if session_id is None:
session_id = session_service.generate_key(
session_id=session_id, data_graph=data_graph
)
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
# Load the graph using SessionService
session = await session_service.load_session(session_id, data_graph)
graph, artifacts = session if session else (None, None)
@ -266,18 +258,14 @@ async def build_graph_and_generate_result(
return Result(result=result, session_id=session_id)
def validate_input(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
raise ValueError("graph_data and tweaks should be dictionaries")
nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes")
if not isinstance(nodes, list):
raise ValueError(
"graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key"
)
raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key")
return nodes
@ -286,9 +274,7 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
template_data = node.get("data", {}).get("node", {}).get("template")
if not isinstance(template_data, dict):
logger.warning(
f"Template data for node {node.get('id')} should be a dictionary"
)
logger.warning(f"Template data for node {node.get('id')} should be a dictionary")
return
for tweak_name, tweak_value in node_tweaks.items():
@ -303,9 +289,7 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None:
vertex.params[tweak_name] = tweak_value
def process_tweaks(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
) -> Dict[str, Any]:
def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
"""
This function is used to tweak the graph data using the node id and the tweaks dict.
@ -326,9 +310,7 @@ def process_tweaks(
if node_tweaks := tweaks.get(node_id):
apply_tweaks(node, node_tweaks)
else:
logger.warning(
"Each node should be a dictionary with an 'id' key of type str"
)
logger.warning("Each node should be a dictionary with an 'id' key of type str")
return graph_data
@ -340,8 +322,6 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]):
if node_tweaks := tweaks.get(node_id):
apply_tweaks_on_vertex(vertex, node_tweaks)
else:
logger.warning(
"Each node should be a Vertex with an 'id' attribute of type str"
)
logger.warning("Each node should be a Vertex with an 'id' attribute of type str")
return graph

View file

@ -10,9 +10,7 @@ if TYPE_CHECKING:
class TransactionModel(BaseModel):
id: Optional[int] = Field(default=None, alias="id")
timestamp: Optional[datetime] = Field(
default_factory=datetime.now, alias="timestamp"
)
timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp")
source: str
target: str
target_args: dict
@ -54,9 +52,7 @@ class MessageModel(BaseModel):
def from_record(cls, record: "Record"):
# first check if the record has all the required fields
if "sender" not in record.data and "sender_name" not in record.data:
raise ValueError(
"The record does not have the required fields 'sender' and 'sender_name' in the data."
)
raise ValueError("The record does not have the required fields 'sender' and 'sender_name' in the data.")
return cls(
sender=record.data["sender"],
sender_name=record.data["sender_name"],
@ -110,7 +106,6 @@ class VertexBuildModel(BaseModel):
class VertexBuildResponseModel(VertexBuildModel):
@field_serializer("data", "artifacts")
def serialize_dict(v):
return v

View file

@ -43,9 +43,7 @@ class MonitorService(Service):
def ensure_tables_exist(self):
for table_name, model in self.table_map.items():
drop_and_create_table_if_schema_mismatch(
str(self.db_path), table_name, model
)
drop_and_create_table_if_schema_mismatch(str(self.db_path), table_name, model)
def add_row(
self,

View file

@ -45,9 +45,7 @@ def model_to_sql_column_definitions(model: Type[BaseModel]) -> dict:
return columns
def drop_and_create_table_if_schema_mismatch(
db_path: str, table_name: str, model: Type[BaseModel]
):
def drop_and_create_table_if_schema_mismatch(db_path: str, table_name: str, model: Type[BaseModel]):
with duckdb.connect(db_path) as conn:
# Get the current schema from the database
try:
@ -68,12 +66,8 @@ def drop_and_create_table_if_schema_mismatch(
conn.execute(f"CREATE SEQUENCE seq_{table_name} START 1;")
except duckdb.CatalogException:
pass
desired_schema[INDEX_KEY] = (
f"INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_{table_name}')"
)
columns_sql = ", ".join(
f"{name} {data_type}" for name, data_type in desired_schema.items()
)
desired_schema[INDEX_KEY] = f"INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_{table_name}')"
columns_sql = ", ".join(f"{name} {data_type}" for name, data_type in desired_schema.items())
create_table_sql = f"CREATE TABLE {table_name} ({columns_sql})"
conn.execute(create_table_sql)

View file

@ -31,9 +31,7 @@ class SettingsService(Service):
for key in settings_dict:
if key not in Settings.model_fields.keys():
raise KeyError(f"Key {key} not found in settings")
logger.debug(
f"Loading {len(settings_dict[key])} {key} from {file_path}"
)
logger.debug(f"Loading {len(settings_dict[key])} {key} from {file_path}")
settings = Settings(**settings_dict)
if not settings.CONFIG_DIR:

View file

@ -95,9 +95,7 @@ async def build_vertex(
)
# Emit the vertex build response
response = VertexBuildResponse(
valid=valid, params=params, id=vertex.id, data=result_dict
)
response = VertexBuildResponse(valid=valid, params=params, id=vertex.id, data=result_dict)
await sio.emit("vertex_build", data=response.model_dump(), to=sid)
except Exception as exc:

View file

@ -88,9 +88,7 @@ class LocalStorageService(StorageService):
file_path.unlink()
logger.info(f"File {file_name} deleted successfully from flow {flow_id}.")
else:
logger.warning(
f"Attempted to delete non-existent file {file_name} in flow {flow_id}."
)
logger.warning(f"Attempted to delete non-existent file {file_name} in flow {flow_id}.")
def teardown(self):
"""Perform any cleanup operations when the service is being torn down."""

View file

@ -60,9 +60,7 @@ class TemplateField(BaseModel):
refresh: Optional[bool] = None
"""Specifies if the field should be refreshed. Defaults to False."""
range_spec: Optional[RangeSpec] = Field(
default=None, serialization_alias="rangeSpec"
)
range_spec: Optional[RangeSpec] = Field(default=None, serialization_alias="rangeSpec")
"""Range specification for the field. Defaults to None."""
title_case: bool = True

View file

@ -88,11 +88,7 @@ class FrontendNode(BaseModel):
def process_base_classes(self, base_classes: List[str]) -> List[str]:
"""Removes unwanted base classes from the list of base classes."""
return [
base_class
for base_class in base_classes
if base_class not in CLASSES_TO_REMOVE
]
return [base_class for base_class in base_classes if base_class not in CLASSES_TO_REMOVE]
@field_serializer("display_name")
def process_display_name(self, display_name: str) -> str:
@ -172,9 +168,7 @@ class FrontendNode(BaseModel):
return _type
@staticmethod
def handle_special_field(
field, key: str, _type: str, SPECIAL_FIELD_HANDLERS
) -> str:
def handle_special_field(field, key: str, _type: str, SPECIAL_FIELD_HANDLERS) -> str:
"""Handles special field by using the respective handler if present."""
handler = SPECIAL_FIELD_HANDLERS.get(key)
return handler(field) if handler else _type
@ -185,11 +179,7 @@ class FrontendNode(BaseModel):
if "dict" in _type.lower() and field.name == "dict_":
field.field_type = "file"
field.file_types = [".json", ".yaml", ".yml"]
elif (
_type.startswith("Dict")
or _type.startswith("Mapping")
or _type.startswith("dict")
):
elif _type.startswith("Dict") or _type.startswith("Mapping") or _type.startswith("dict"):
field.field_type = "dict"
return _type
@ -200,9 +190,7 @@ class FrontendNode(BaseModel):
field.value = value["default"]
@staticmethod
def handle_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
def handle_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values for certain fields."""
if key == "headers":
field.value = """{"Authorization": "Bearer <token>"}"""
@ -210,9 +198,7 @@ class FrontendNode(BaseModel):
FrontendNode._handle_api_key_specific_field_values(field, key, name)
@staticmethod
def _handle_model_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
def _handle_model_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values related to models."""
model_dict = {
"OpenAI": constants.OPENAI_MODELS,
@ -225,9 +211,7 @@ class FrontendNode(BaseModel):
field.is_list = True
@staticmethod
def _handle_api_key_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
def _handle_api_key_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values related to API keys."""
if "api_key" in key and "OpenAI" in str(name):
field.display_name = "OpenAI API Key"
@ -267,10 +251,7 @@ class FrontendNode(BaseModel):
@staticmethod
def should_be_password(key: str, show: bool) -> bool:
"""Determines whether the field should be a password field."""
return (
any(text in key.lower() for text in {"password", "token", "api", "key"})
and show
)
return any(text in key.lower() for text in {"password", "token", "api", "key"}) and show
@staticmethod
def should_be_multiline(key: str) -> bool:

View file

@ -15,7 +15,6 @@ from langflow.template.template.base import Template
class MemoryFrontendNode(FrontendNode):
pinned: bool = True
def add_extra_fields(self) -> None:
@ -81,9 +80,7 @@ class MemoryFrontendNode(FrontendNode):
field.show = True
field.advanced = False
field.value = ""
field.info = (
INPUT_KEY_INFO if field.name == "input_key" else OUTPUT_KEY_INFO
)
field.info = INPUT_KEY_INFO if field.name == "input_key" else OUTPUT_KEY_INFO
if field.name == "memory_key":
field.value = "chat_history"

View file

@ -46,9 +46,7 @@ def validate_code(code):
# Evaluate the function definition
for node in tree.body:
if isinstance(node, ast.FunctionDef):
code_obj = compile(
ast.Module(body=[node], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[node], type_ignores=[]), "<string>", "exec")
try:
exec(code_obj)
except Exception as e:
@ -92,23 +90,15 @@ def execute_function(code, function_name, *args, **kwargs):
exec_globals,
locals(),
)
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
function_code = next(
node
for node in module.body
if isinstance(node, ast.FunctionDef) and node.name == function_name
node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name
)
function_code.parent = None
code_obj = compile(
ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec")
try:
exec(code_obj, exec_globals, locals())
except Exception as exc:
@ -135,23 +125,15 @@ def create_function(code, function_name):
if isinstance(node, ast.Import):
for alias in node.names:
try:
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
function_code = next(
node
for node in module.body
if isinstance(node, ast.FunctionDef) and node.name == function_name
node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name
)
function_code.parent = None
code_obj = compile(
ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec")
with contextlib.suppress(Exception):
exec(code_obj, exec_globals, locals())
exec_globals[function_name] = locals()[function_name]
@ -213,22 +195,16 @@ def prepare_global_scope(code, module):
if isinstance(node, ast.Import):
for alias in node.names:
try:
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
elif isinstance(node, ast.ImportFrom) and node.module is not None:
try:
imported_module = importlib.import_module(node.module)
for alias in node.names:
exec_globals[alias.name] = getattr(imported_module, alias.name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {node.module} not found. Please install it and try again."
) from e
raise ModuleNotFoundError(f"Module {node.module} not found. Please install it and try again.") from e
return exec_globals
@ -240,11 +216,7 @@ def extract_class_code(module, class_name):
:param class_name: Name of the class to extract
:return: AST node of the specified class
"""
class_code = next(
node
for node in module.body
if isinstance(node, ast.ClassDef) and node.name == class_name
)
class_code = next(node for node in module.body if isinstance(node, ast.ClassDef) and node.name == class_name)
class_code.parent = None
return class_code
@ -257,9 +229,7 @@ def compile_class_code(class_code):
:param class_code: AST node of the class
:return: Compiled code object of the class
"""
code_obj = compile(
ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec")
return code_obj
@ -303,9 +273,7 @@ def get_default_imports(code_string):
langflow_imports = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
necessary_imports = find_names_in_code(code_string, langflow_imports)
langflow_module = importlib.import_module("langflow.field_typing")
default_imports.update(
{name: getattr(langflow_module, name) for name in necessary_imports}
)
default_imports.update({name: getattr(langflow_module, name) for name in necessary_imports})
return default_imports

View file

@ -5,6 +5,7 @@ import Tooltip from "../../components/TooltipComponent";
import IconComponent from "../../components/genericIconComponent";
import InputComponent from "../../components/inputComponent";
import { Button } from "../../components/ui/button";
import Loading from "../../components/ui/loading";
import { Textarea } from "../../components/ui/textarea";
import { priorityFields } from "../../constants/constants";
import { BuildStatus } from "../../constants/enums";
@ -18,7 +19,6 @@ import { handleKeyDown, scapedJSONStringfy } from "../../utils/reactflowUtils";
import { nodeColors, nodeIconsLucide } from "../../utils/styleUtils";
import { classNames, cn, getFieldTitle } from "../../utils/utils";
import ParameterComponent from "./components/parameterComponent";
import Loading from "../../components/ui/loading";
export default function GenericNode({
data,
@ -166,7 +166,7 @@ export default function GenericNode({
);
const getStatusClassName = (
validationStatus: validationStatusType | null,
validationStatus: validationStatusType | null
) => {
if (validationStatus && validationStatus.valid) {
return "green-status";
@ -181,10 +181,10 @@ export default function GenericNode({
const renderIconPlayOrPauseComponents = (
buildStatus: BuildStatus | undefined,
validationStatus: validationStatusType | null,
validationStatus: validationStatusType | null
) => {
if (buildStatus === BuildStatus.BUILDING) {
return <Loading/>
return <Loading />;
} else {
const className = getStatusClassName(validationStatus);
return <>{getIconPlayOrPauseComponent("Play", className)}</>;
@ -446,7 +446,9 @@ export default function GenericNode({
}));
}}
>
<Tooltip title={<span>{pinned ? "Pin Output" : "Unpin Output"}</span>}>
<Tooltip
title={<span>{pinned ? "Pin Output" : "Unpin Output"}</span>}
>
<div className="generic-node-status-position flex items-center">
<IconComponent
name={"Pin"}
@ -461,12 +463,12 @@ export default function GenericNode({
)}
{showNode && (
<Button
variant="outline"
className={"h-9 px-1.5"}
onClick={() => {
if(data?.build_status === BuildStatus.BUILDING || isBuilding) return;
buildFlow(data.id)
if (data?.build_status === BuildStatus.BUILDING || isBuilding)
return;
buildFlow(data.id);
}}
>
<div>
@ -499,7 +501,8 @@ export default function GenericNode({
<div className="generic-node-status-position flex items-center justify-center">
{renderIconPlayOrPauseComponents(
data?.build_status,
validationStatus)}
validationStatus
)}
</div>
</Tooltip>
</div>

View file

@ -35,7 +35,7 @@ export default function Chat({ flow }: ChatType): JSX.Element {
{/* <BuildTrigger open={open} flow={flow} /> */}
{hasIO && (
<IOView open={open} setOpen={setOpen}>
<ChatTrigger />
<ChatTrigger />
</IOView>
)}
</div>

View file

@ -3,9 +3,6 @@ import IconComponent from "../../../components/genericIconComponent";
import { Textarea } from "../../../components/ui/textarea";
import { chatInputType } from "../../../types/components";
import { classNames } from "../../../utils/utils";
import { Button } from "../../ui/button";
import { Input } from "../../ui/input";
import { Popover, PopoverContent, PopoverTrigger } from "../../ui/popover";
export default function ChatInput({
lockChat,
@ -113,7 +110,8 @@ export default function ChatInput({
)}
</button>
</div>
</div>{/*
</div>
{/*
<Popover>
<PopoverTrigger asChild>
<Button variant="primary" className="h-13 px-4">

View file

@ -106,7 +106,7 @@ export default function newChatView(): JSX.Element {
}, []);
async function sendMessage(count = 1): Promise<void> {
if(isBuilding) return;
if (isBuilding) return;
const { nodes, edges } = getFlow();
let nodeValidationErrors = validateNodes(nodes, edges);
if (nodeValidationErrors.length === 0) {
@ -114,7 +114,7 @@ export default function newChatView(): JSX.Element {
setLockChat(true);
setChatValue("");
const chatInputId = inputIds.find((inputId) =>
inputId.includes("ChatInput")
inputId.includes("ChatInput")
);
const chatInput: NodeType = getNode(chatInputId!) as NodeType;
if (chatInput) {

View file

@ -35,7 +35,7 @@ const nodeTypes = {
genericNode: GenericNode,
};
export default function Page({
export default function Page({
flow,
view,
}: {
@ -377,7 +377,7 @@ export default function Page({
zoomOnPinch={!view}
panOnDrag={!view}
proOptions={{ hideAttribution: true }}
onPaneClick={onPaneClick}
onPaneClick={onPaneClick}
>
<Background className="" />
{!view && (

View file

@ -80,7 +80,6 @@ export default function NodeToolbarComponent({
window.open(url, "_blank", "noreferrer");
};
useEffect(() => {
if (!showModalAdvanced) {
onCloseAdvancedModal!(false);

View file

@ -2,7 +2,6 @@ import { cloneDeep } from "lodash";
import {
Edge,
EdgeChange,
MarkerType,
Node,
NodeChange,
addEdge,
@ -10,7 +9,6 @@ import {
applyNodeChanges,
} from "reactflow";
import { create } from "zustand";
import { INPUT_TYPES, OUTPUT_TYPES } from "../constants/constants";
import { BuildStatus } from "../constants/enums";
import { getFlowPool, updateFlowInDatabase } from "../controllers/API";
import {
@ -123,7 +121,7 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
});
const flowsManager = useFlowsManagerStore.getState();
if(!(get().isBuilding)){
if (!get().isBuilding) {
flowsManager.autoSaveCurrentFlow(
newChange,
newEdges,
@ -139,7 +137,7 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
});
const flowsManager = useFlowsManagerStore.getState();
if(!(get().isBuilding)){
if (!get().isBuilding) {
flowsManager.autoSaveCurrentFlow(
get().nodes,
newChange,
@ -345,7 +343,7 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
sourceHandle: scapeJSONParse(connection.sourceHandle!),
},
style: { stroke: "#555" },
className:"stroke-foreground stroke-connection",
className: "stroke-foreground stroke-connection",
},
oldEdges
);

View file

@ -1,5 +1,5 @@
import { create } from "zustand";
import { checkHasApiKey, checkHasStore } from "../controllers/API";
import { checkHasStore } from "../controllers/API";
import { StoreStoreType } from "../types/zustand/store";
export const useStoreStore = create<StoreStoreType>((set) => ({

View file

@ -63,13 +63,15 @@ export async function buildVertices({
onBuildError,
verticesIds,
buildResults,
stopBuild:()=>{stop=true}
stopBuild: () => {
stop = true;
},
});
if(stop){
if (stop) {
break;
}
}
if(stop){
if (stop) {
break;
}
}
@ -96,7 +98,7 @@ async function buildVertex({
onBuildError?: (title, list, idList: string[]) => void;
verticesIds: string[];
buildResults: boolean[];
stopBuild:()=>void;
stopBuild: () => void;
}) {
try {
const buildRes = await postBuildVertex(flowId, id);

View file

@ -4,8 +4,8 @@ import {
Bell,
BookMarked,
BookmarkPlus,
Boxes,
Bot,
Boxes,
Cable,
Check,
CheckCircle2,