Fix merge conflicts in backend code

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-18 18:30:55 -03:00
commit c57aadbbc6
6 changed files with 805 additions and 194 deletions

View file

@ -2,14 +2,17 @@ from datetime import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from langflow.services.auth import utils as auth_utils
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.credential import Credential, CredentialCreate, CredentialRead, CredentialUpdate
from langflow.services.database.models.user.model import User
from langflow.services.database.models.credential import (
Credential,
CredentialCreate,
CredentialRead,
CredentialUpdate,
)
from langflow.services.database.models.user.model import User
from langflow.services.deps import get_session, get_settings_service
from sqlmodel import Session, select
router = APIRouter(prefix="/credentials", tags=["Credentials"])
@ -26,7 +29,10 @@ def create_credential(
try:
# check if credential name already exists
credential_exists = session.exec(
select(Credential).where(Credential.name == credential.name, Credential.user_id == current_user.id)
select(Credential).where(
Credential.name == credential.name,
Credential.user_id == current_user.id,
)
).first()
if credential_exists:
raise HTTPException(status_code=400, detail="Credential name already exists")
@ -105,7 +111,6 @@ def delete_credential(
).first()
if not db_credential:
raise HTTPException(status_code=404, detail="Credential not found")
session.delete(db_credential)
session.commit()
return db_credential

View file

@ -1,20 +1,49 @@
import ast
import asyncio
import inspect
import types
from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Optional
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
)
from langflow.graph.utils import UnbuiltObject
from langflow.graph.schema import (
INPUT_COMPONENTS,
INPUT_FIELD_NAME,
OUTPUT_COMPONENTS,
InterfaceComponentTypes,
ResultData,
)
from langflow.graph.utils import UnbuiltObject, UnbuiltResult
from langflow.graph.vertex.utils import generate_result
from langflow.interface.initialize import loading
from langflow.interface.listing import lazy_load_dict
from langflow.services.deps import get_storage_service
from langflow.utils.constants import DIRECT_TYPES
from langflow.utils.util import sync_to_async
from langflow.utils.schemas import ChatOutputResponse
from langflow.utils.util import sync_to_async, unescape_string
from loguru import logger
if TYPE_CHECKING:
from langflow.graph.edge.base import Edge
from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.base import Graph
class VertexStates(str, Enum):
"""Vertex are related to it being active, inactive, or in an error state."""
ACTIVE = "active"
INACTIVE = "inactive"
ERROR = "error"
class Vertex:
def __init__(
self,
@ -24,25 +53,109 @@ class Vertex:
is_task: bool = False,
params: Optional[Dict] = None,
) -> None:
self.graph = graph
# is_external means that the Vertex send or receives data from
# an external source (e.g the chat)
self._lock = asyncio.Lock()
self.will_stream = False
self.updated_raw_params = False
self.id: str = data["id"]
self.is_state = False
self.is_input = any(input_component_name in self.id for input_component_name in INPUT_COMPONENTS)
self.is_output = any(output_component_name in self.id for output_component_name in OUTPUT_COMPONENTS)
self.has_session_id = None
self._custom_component = None
self.has_external_input = False
self.has_external_output = False
self.graph = graph
self._data = data
self.base_type: Optional[str] = base_type
self._parse_data()
self._built_object = UnbuiltObject()
self._built_result = None
self._built = False
self.artifacts: Dict[str, Any] = {}
self.steps: List[Callable] = [self._build]
self.steps_ran: List[Callable] = []
self.task_id: Optional[str] = None
self.is_task = is_task
self.params = params or {}
self.parent_node_id: Optional[str] = self._data.get("parent_node_id")
self.load_from_db_fields: List[str] = []
self.parent_is_top_level = False
self.layer = None
self.should_run = True
self.result: Optional[ResultData] = None
try:
self.is_interface_component = self.vertex_type in InterfaceComponentTypes
except ValueError:
self.is_interface_component = False
self.use_result = False
self.build_times: List[float] = []
self.state = VertexStates.ACTIVE
def update_graph_state(self, key, new_state, append: bool):
if append:
self.graph.append_state(key, new_state, caller=self.id)
else:
self.graph.update_state(key, new_state, caller=self.id)
def set_state(self, state: str):
self.state = VertexStates[state]
if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] < 2:
# If the vertex is inactive and has only one in degree
# it means that it is not a merge point in the graph
self.graph.inactivated_vertices.add(self.id)
elif self.state == VertexStates.ACTIVE and self.id in self.graph.inactivated_vertices:
self.graph.inactivated_vertices.remove(self.id)
@property
def edges(self) -> List["Edge"]:
def avg_build_time(self):
return sum(self.build_times) / len(self.build_times) if self.build_times else 0
def add_build_time(self, time):
self.build_times.append(time)
def set_result(self, result: ResultData) -> None:
self.result = result
def get_built_result(self):
# If the Vertex.type is a power component
# then we need to return the built object
# instead of the result dict
if self.is_interface_component and not isinstance(self._built_object, UnbuiltObject):
result = self._built_object
# if it is not a dict or a string and hasattr model_dump then
# return the model_dump
if not isinstance(result, (dict, str)) and hasattr(result, "content"):
return result.content
return result
if isinstance(self._built_object, str):
self._built_result = self._built_object
if isinstance(self._built_result, UnbuiltResult):
return {}
return self._built_result if isinstance(self._built_result, dict) else {"result": self._built_result}
def set_artifacts(self) -> None:
pass
@property
def edges(self) -> List["ContractEdge"]:
return self.graph.get_vertex_edges(self.id)
@property
def predecessors(self) -> List["Vertex"]:
return self.graph.get_predecessors(self)
@property
def successors(self) -> List["Vertex"]:
return self.graph.get_successors(self)
@property
def successors_ids(self) -> List[str]:
return self.graph.successor_map.get(self.id, [])
def __getstate__(self):
return {
"_data": self._data,
@ -55,14 +168,20 @@ class Vertex:
"parent_node_id": self.parent_node_id,
"parent_is_top_level": self.parent_is_top_level,
"load_from_db_fields": self.load_from_db_fields,
"is_input": self.is_input,
"is_output": self.is_output,
}
def __setstate__(self, state):
self._lock = asyncio.Lock()
self._data = state["_data"]
self.params = state["params"]
self.base_type = state["base_type"]
self.is_task = state["is_task"]
self.id = state["id"]
self.frozen = state.get("frozen", False)
self.is_input = state.get("is_input", False)
self.is_output = state.get("is_output", False)
self._parse_data()
if "_built_object" in state:
self._built_object = state["_built_object"]
@ -70,11 +189,17 @@ class Vertex:
else:
self._built_object = UnbuiltObject()
self._built = False
if "_built_result" in state:
self._built_result = state["_built_result"]
else:
self._built_result = UnbuiltResult()
self.artifacts: Dict[str, Any] = {}
self.task_id: Optional[str] = None
self.parent_node_id = state["parent_node_id"]
self.parent_is_top_level = state["parent_is_top_level"]
self.load_from_db_fields = state["load_from_db_fields"]
self.layer = state.get("layer")
self.steps = state.get("steps", [self._build])
def set_top_level(self, top_level_vertices: List[str]) -> None:
self.parent_is_top_level = self.parent_node_id in top_level_vertices
@ -82,8 +207,15 @@ class Vertex:
def _parse_data(self) -> None:
self.data = self._data["data"]
self.output = self.data["node"]["base_classes"]
self.display_name = self.data["node"].get("display_name", self.id.split("-")[0])
self.frozen = self.data["node"].get("frozen", False)
self.selected_output_type = self.data["node"].get("selected_output_type")
self.is_input = self.data["node"].get("is_input") or self.is_input
self.is_output = self.data["node"].get("is_output") or self.is_output
template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
self.has_session_id = "session_id" in template_dicts
self.required_inputs = [
template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"]
]
@ -135,8 +267,12 @@ class Vertex:
if self.graph is None:
raise ValueError("Graph not found")
if self.updated_raw_params:
self.updated_raw_params = False
return
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
params = self.params.copy() if self.params else {}
params = {}
for edge in self.edges:
if not hasattr(edge, "target_param"):
@ -152,7 +288,13 @@ class Vertex:
params[param_key] = []
params[param_key].append(self.graph.get_vertex(edge.source_id))
elif edge.target_id == self.id:
params[param_key] = self.graph.get_vertex(edge.source_id)
if isinstance(template_dict[param_key].get("value"), dict):
# we don't know the key of the dict but we need to set the value
# to the vertex that is the source of the edge
param_dict = template_dict[param_key]["value"]
params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()}
else:
params[param_key] = self.graph.get_vertex(edge.source_id)
load_from_db_fields = []
for field_name, field in template_dict.items():
@ -169,16 +311,18 @@ class Vertex:
# what is inside value.get('content')
# value.get('value') is the file name
if file_path := field.get("file_path"):
params[field_name] = file_path
storage_service = get_storage_service()
flow_id, file_name = file_path.split("/")
full_path = storage_service.build_full_path(flow_id, file_name)
params[field_name] = full_path
else:
raise ValueError(f"File path not found for {self.vertex_type}")
raise ValueError(f"File path not found for {self.display_name}")
elif field.get("type") in DIRECT_TYPES and params.get(field_name) is None:
val = field.get("value")
if field.get("type") == "code":
try:
params[field_name] = ast.literal_eval(val) if val else None
except Exception as exc:
logger.debug(f"Error parsing code: {exc}")
except Exception:
params[field_name] = val
elif field.get("type") in ["dict", "NestedDict"]:
# When dict comes from the frontend it comes as a
@ -198,6 +342,16 @@ class Vertex:
params[field_name] = float(val)
except ValueError:
params[field_name] = val
params[field_name] = val
elif field.get("type") == "str" and val is not None:
# val may contain escaped \n, \t, etc.
# so we need to unescape it
if isinstance(val, list):
params[field_name] = [unescape_string(v) for v in val]
elif isinstance(val, str):
params[field_name] = unescape_string(val)
elif val is not None and val != "":
params[field_name] = val
elif val is not None and val != "":
params[field_name] = val
@ -210,92 +364,226 @@ class Vertex:
else:
params.pop(field_name, None)
# Add _type to params
self._raw_params = params
self.params = params
self.load_from_db_fields = load_from_db_fields
self._raw_params = params.copy()
def update_raw_params(self, new_params: Dict[str, str], overwrite: bool = False):
"""
Update the raw parameters of the vertex with the given new parameters.
Args:
new_params (Dict[str, Any]): The new parameters to update.
Raises:
ValueError: If any key in new_params is not found in self._raw_params.
"""
# First check if the input_value in _raw_params is not a vertex
if not new_params:
return
if any(isinstance(self._raw_params.get(key), Vertex) for key in new_params):
return
if not overwrite:
for key in new_params.copy():
if key not in self._raw_params:
new_params.pop(key)
self._raw_params.update(new_params)
self.updated_raw_params = True
async def _build(self, user_id=None):
"""
Initiate the build process.
"""
logger.debug(f"Building {self.vertex_type}")
await self._build_each_node_in_params_dict(user_id)
logger.debug(f"Building {self.display_name}")
await self._build_each_vertex_in_params_dict(user_id)
await self._get_and_instantiate_class(user_id)
self._validate_built_object()
self._built = True
async def _build_each_node_in_params_dict(self, user_id=None):
def extract_messages_from_artifacts(self, artifacts: Dict[str, Any]) -> List[dict]:
"""
Iterates over each node in the params dictionary and builds it.
Extracts messages from the artifacts.
Args:
artifacts (Dict[str, Any]): The artifacts to extract messages from.
Returns:
List[str]: The extracted messages.
"""
for key, value in self.params.copy().items():
if self._is_node(value):
messages = []
for key, artifact in artifacts.items():
if not isinstance(artifact, dict):
continue
if "message" in artifact:
chat_output_response = ChatOutputResponse(
message=artifact["message"],
sender=artifact.get("sender"),
sender_name=artifact.get("sender_name"),
session_id=artifact.get("session_id"),
component_id=self.id,
)
messages.append(chat_output_response.model_dump(exclude_none=True))
return messages
def _finalize_build(self):
result_dict = self.get_built_result()
# We need to set the artifacts to pass information
# to the frontend
self.set_artifacts()
artifacts = self.artifacts
messages = self.extract_messages_from_artifacts(artifacts)
result_dict = ResultData(
results=result_dict,
artifacts=artifacts,
messages=messages,
component_display_name=self.display_name,
component_id=self.id,
)
self.set_result(result_dict)
async def _run(
self,
user_id: str,
inputs: Optional[dict] = None,
session_id: Optional[str] = None,
):
# user_id is just for compatibility with the other build methods
inputs = inputs or {}
# inputs = {key: value or "" for key, value in inputs.items()}
# if hasattr(self._built_object, "input_keys"):
# # test if all keys are in inputs
# # and if not add them with empty string
# # for key in self._built_object.input_keys:
# # if key not in inputs:
# # inputs[key] = ""
# if inputs == {} and hasattr(self._built_object, "prompt"):
# inputs = self._built_object.prompt.partial_variables
if isinstance(self._built_object, str):
self._built_result = self._built_object
result = await generate_result(self._built_object, inputs, self.has_external_output, session_id)
self._built_result = result
async def _build_each_vertex_in_params_dict(self, user_id=None):
"""
Iterates over each vertex in the params dictionary and builds it.
"""
for key, value in self._raw_params.items():
if self._is_vertex(value):
if value == self:
del self.params[key]
continue
await self._build_node_and_update_params(key, value, user_id)
elif isinstance(value, list) and self._is_list_of_nodes(value):
await self._build_list_of_nodes_and_update_params(key, value, user_id)
await self._build_vertex_and_update_params(
key,
value,
)
elif isinstance(value, list) and self._is_list_of_vertices(value):
await self._build_list_of_vertices_and_update_params(key, value)
elif isinstance(value, dict):
await self._build_dict_and_update_params(
key,
value,
)
elif key not in self.params or self.updated_raw_params:
self.params[key] = value
def _is_node(self, value):
async def _build_dict_and_update_params(
self,
key,
vertices_dict: Dict[str, "Vertex"],
):
"""
Iterates over a dictionary of vertices, builds each and updates the params dictionary.
"""
for sub_key, value in vertices_dict.items():
if not self._is_vertex(value):
self.params[key][sub_key] = value
else:
result = await value.get_result()
self.params[key][sub_key] = result
def _is_vertex(self, value):
"""
Checks if the provided value is an instance of Vertex.
"""
return isinstance(value, Vertex)
def _is_list_of_nodes(self, value):
def _is_list_of_vertices(self, value):
"""
Checks if the provided value is a list of Vertex instances.
"""
return all(self._is_node(node) for node in value)
return all(self._is_vertex(vertex) for vertex in value)
async def get_result(self, user_id=None, timeout=None) -> Any:
# Check if the Vertex was built already
if self._built:
return self._built_object
if self.is_task and self.task_id is not None:
task = self.get_task()
result = task.get(timeout=timeout)
if isinstance(result, Coroutine):
result = await result
if result is not None: # If result is ready
self._update_built_object_and_artifacts(result)
return self._built_object
else:
# Handle the case when the result is not ready (retry, throw exception, etc.)
pass
# If there's no task_id, build the vertex locally
await self.build(user_id=user_id)
return self._built_object
async def _build_node_and_update_params(self, key, node, user_id=None):
async def get_result(
self,
) -> Any:
"""
Builds a given node and updates the params dictionary accordingly.
Retrieves the result of the vertex.
This is a read-only method so it raises an error if the vertex has not been built yet.
Returns:
The result of the vertex.
"""
async with self._lock:
return await self._get_result()
async def _get_result(self) -> Any:
"""
Retrieves the result of the built component.
If the component has not been built yet, a ValueError is raised.
Returns:
The built result if use_result is True, else the built object.
"""
if not self._built:
raise ValueError(f"Component {self.display_name} has not been built yet")
return self._built_result if self.use_result else self._built_object
async def _build_vertex_and_update_params(self, key, vertex: "Vertex"):
"""
Builds a given vertex and updates the params dictionary accordingly.
"""
result = await node.get_result(user_id)
result = await vertex.get_result()
self._handle_func(key, result)
if isinstance(result, list):
self._extend_params_list_with_result(key, result)
self.params[key] = result
async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
async def _build_list_of_vertices_and_update_params(
self,
key,
vertices: List["Vertex"],
):
"""
Iterates over a list of nodes, builds each and updates the params dictionary.
Iterates over a list of vertices, builds each and updates the params dictionary.
"""
self.params[key] = []
for node in nodes:
built = await node.get_result(user_id)
if isinstance(built, list):
if key not in self.params:
self.params[key] = []
self.params[key].extend(built)
for vertex in vertices:
result = await vertex.get_result()
# Weird check to see if the params[key] is a list
# because sometimes it is a Record and breaks the code
if not isinstance(self.params[key], list):
self.params[key] = [self.params[key]]
if isinstance(result, list):
self.params[key].extend(result)
else:
self.params[key].append(built)
try:
if self.params[key] == result:
continue
self.params[key].append(result)
except AttributeError as e:
logger.exception(e)
raise ValueError(
f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {result}"
f"Error building vertex {self.display_name}: {str(e)}"
) from e
def _handle_func(self, key, result):
"""
@ -324,26 +612,27 @@ class Vertex:
Gets the class from a dictionary and instantiates it with the params.
"""
if self.base_type is None:
raise ValueError(f"Base type for node {self.vertex_type} not found")
raise ValueError(f"Base type for vertex {self.display_name} not found")
try:
result = await loading.instantiate_class(
node_type=self.vertex_type,
base_type=self.base_type,
load_from_db_fields=self.load_from_db_fields,
params=self.params,
user_id=user_id,
vertex=self,
)
self._update_built_object_and_artifacts(result)
except Exception as exc:
logger.exception(exc)
raise ValueError(f"Error building node {self.vertex_type}(ID:{self.id}): {str(exc)}") from exc
raise ValueError(f"Error building vertex {self.display_name}: {str(exc)}") from exc
def _update_built_object_and_artifacts(self, result):
"""
Updates the built object and its artifacts.
"""
if isinstance(result, tuple):
self._built_object, self.artifacts = result
if len(result) == 2:
self._built_object, self.artifacts = result
elif len(result) == 3:
self._custom_component, self._built_object, self.artifacts = result
else:
self._built_object = result
@ -352,30 +641,105 @@ class Vertex:
Checks if the built object is None and raises a ValueError if so.
"""
if isinstance(self._built_object, UnbuiltObject):
raise ValueError(f"{self.vertex_type}: {self._built_object_repr()}")
raise ValueError(f"{self.display_name}: {self._built_object_repr()}")
elif self._built_object is None:
message = f"{self.vertex_type} returned None."
message = f"{self.display_name} returned None."
if self.base_type == "custom_components":
message += " Make sure your build method returns a component."
logger.warning(message)
elif isinstance(self._built_object, (Iterator, AsyncIterator)):
if self.display_name in ["Text Output"]:
raise ValueError(f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead.")
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
if not self._built or force:
await self._build(user_id, *args, **kwargs)
def _reset(self, params_update: Optional[Dict[str, Any]] = None):
self._built = False
self._built_object = UnbuiltObject()
self._built_result = UnbuiltResult()
self.artifacts = {}
self.steps_ran = []
self._build_params()
return self._built_object
def _is_chat_input(self):
return False
def add_edge(self, edge: "Edge") -> None:
def build_inactive(self):
# Just set the results to None
self._built = True
self._built_object = None
self._built_result = None
async def build(
self,
user_id=None,
inputs: Optional[Dict[str, Any]] = None,
requester: Optional["Vertex"] = None,
**kwargs,
) -> Any:
async with self._lock:
if self.state == VertexStates.INACTIVE:
# If the vertex is inactive, return None
self.build_inactive()
return
if self.frozen and self._built:
return self.get_requester_result(requester)
elif self._built and requester is not None:
# This means that the vertex has already been built
# and we are just getting the result for the requester
return await self.get_requester_result(requester)
self._reset()
if self._is_chat_input() and inputs:
inputs = {"input_value": inputs.get(INPUT_FIELD_NAME, "")}
self.update_raw_params(inputs, overwrite=True)
# Run steps
for step in self.steps:
if step not in self.steps_ran:
if inspect.iscoroutinefunction(step):
await step(user_id=user_id, **kwargs)
else:
step(user_id=user_id, **kwargs)
self.steps_ran.append(step)
self._finalize_build()
return await self.get_requester_result(requester)
async def get_requester_result(self, requester: Optional["Vertex"]):
# If the requester is None, this means that
# the Vertex is the root of the graph
if requester is None:
return self._built_object
# Get the requester edge
requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None)
# Return the result of the requester edge
return (
None
if requester_edge is None
else await requester_edge.get_result_from_source(source=self, target=requester)
)
def add_edge(self, edge: "ContractEdge") -> None:
if edge not in self.edges:
self.edges.append(edge)
def __repr__(self) -> str:
return f"Vertex(id={self.id}, data={self.data})"
return f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})"
def __eq__(self, __o: object) -> bool:
try:
return self.id == __o.id if isinstance(__o, Vertex) else False
if not isinstance(__o, Vertex):
return False
# We should create a more robust comparison
# for the Vertex class
ids_are_equal = self.id == __o.id
# self._data is a dict and we need to compare them
# to check if they are equal
data_are_equal = self.data == __o.data
return ids_are_equal and data_are_equal
except AttributeError:
return False

View file

@ -1,20 +1,42 @@
import operator
from typing import Any, Callable, List, Optional, Union
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
List,
Optional,
Sequence,
Union,
)
from uuid import UUID
import yaml
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
from langchain_core.documents import Document
from langflow.interface.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
from langflow.interface.custom.custom_component.component import Component
from langflow.schema import Record
from langflow.schema.dotdict import dotdict
from langflow.services.database.models.flow import Flow
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_credential_service, get_db_service
from langflow.services.deps import (
get_credential_service,
get_db_service,
get_storage_service,
)
from langflow.services.storage.service import StorageService
from langflow.utils import validate
from pydantic import BaseModel
from sqlmodel import select
from .component import Component
if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
class CustomComponent(Component):
@ -35,6 +57,67 @@ class CustomComponent(Component):
_tree (Optional[dict]): The code tree of the custom component.
"""
display_name: Optional[str] = None
"""The display name of the component. Defaults to None."""
description: Optional[str] = None
"""The description of the component. Defaults to None."""
icon: Optional[str] = None
"""The icon of the component. It should be an emoji. Defaults to None."""
is_input: Optional[bool] = None
"""The input state of the component. Defaults to None.
If True, the component must have a field named 'input_value'."""
is_output: Optional[bool] = None
"""The output state of the component. Defaults to None.
If True, the component must have a field named 'input_value'."""
code: Optional[str] = None
"""The code of the component. Defaults to None."""
field_config: dict = {}
"""The field configuration of the component. Defaults to an empty dictionary."""
field_order: Optional[List[str]] = None
"""The field order of the component. Defaults to an empty list."""
frozen: Optional[bool] = False
"""The default frozen state of the component. Defaults to False."""
build_parameters: Optional[dict] = None
"""The build parameters of the component. Defaults to None."""
selected_output_type: Optional[str] = None
"""The selected output type of the component. Defaults to None."""
vertex: Optional["Vertex"] = None
"""The edge target parameter of the component. Defaults to None."""
code_class_base_inheritance: ClassVar[str] = "CustomComponent"
function_entrypoint_name: ClassVar[str] = "build"
function: Optional[Callable] = None
repr_value: Optional[Any] = ""
user_id: Optional[Union[UUID, str]] = None
status: Optional[Any] = None
"""The status of the component. This is displayed on the frontend. Defaults to None."""
_flows_records: Optional[List[Record]] = None
def update_state(self, name: str, value: Any):
if not self.vertex:
raise ValueError("Vertex is not set")
try:
self.vertex.graph.update_state(name=name, record=value, caller=self.vertex.id)
except Exception as e:
raise ValueError(f"Error updating state: {e}")
def append_state(self, name: str, value: Any):
if not self.vertex:
raise ValueError("Vertex is not set")
try:
self.vertex.graph.append_state(name=name, record=value, caller=self.vertex.id)
except Exception as e:
raise ValueError(f"Error appending state: {e}")
def get_state(self, name: str):
if not self.vertex:
raise ValueError("Vertex is not set")
try:
return self.vertex.graph.get_state(name=name)
except Exception as e:
raise ValueError(f"Error getting state: {e}")
_tree: Optional[dict] = None
def __init__(self, **data):
"""
Initializes a new instance of the CustomComponent class.
@ -45,6 +128,29 @@ class CustomComponent(Component):
self.cache = TTLCache(maxsize=1024, ttl=60)
super().__init__(**data)
@staticmethod
def resolve_path(path: str) -> str:
"""Resolves the path to an absolute path."""
path_object = Path(path)
if path_object.parts[0] == "~":
path_object = path_object.expanduser()
elif path_object.is_relative_to("."):
path_object = path_object.resolve()
return str(path_object)
def get_full_path(self, path: str) -> str:
storage_svc: "StorageService" = get_storage_service()
flow_id, file_name = path.split("/", 1)
return storage_svc.build_full_path(flow_id, file_name)
@property
def graph(self):
return self.vertex.graph
def _get_field_order(self):
return self.field_order or list(self.field_config.keys())
def custom_repr(self):
"""
Returns the custom representation of the custom component.
@ -58,7 +164,7 @@ class CustomComponent(Component):
return yaml.dump(self.repr_value)
if isinstance(self.repr_value, str):
return self.repr_value
return str(self.repr_value)
return self.repr_value
def build_config(self):
"""
@ -69,6 +175,15 @@ class CustomComponent(Component):
"""
return self.field_config
def update_build_config(
self,
build_config: dotdict,
field_value: Any,
field_name: Optional[str] = None,
):
build_config[field_name] = field_value
return build_config
@property
def tree(self):
"""
@ -79,6 +194,78 @@ class CustomComponent(Component):
"""
return self.get_code_tree(self.code or "")
def to_records(self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False) -> List[Record]:
"""
Converts input data into a list of Record objects.
Args:
data (Any): The input data to be converted. It can be a single item or a sequence of items.
If the input data is a Langchain Document, text_key and data_key are ignored.
keys (List[str], optional): The keys to access the text and data values in each item.
It should be a list of strings where the first element is the text key and the second element is the data key.
Defaults to None, in which case the default keys "text" and "data" are used.
Returns:
List[Record]: A list of Record objects.
Raises:
ValueError: If the input data is not of a valid type or if the specified keys are not found in the data.
"""
if not keys:
keys = []
records = []
if not isinstance(data, Sequence):
data = [data]
for item in data:
data_dict = {}
if isinstance(item, Document):
data_dict = item.metadata
data_dict["text"] = item.page_content
elif isinstance(item, BaseModel):
model_dump = item.model_dump()
for key in keys:
if silent_errors:
data_dict[key] = model_dump.get(key, "")
else:
try:
data_dict[key] = model_dump[key]
except KeyError:
raise ValueError(f"Key {key} not found in {item}")
elif isinstance(item, str):
data_dict = {"text": item}
elif isinstance(item, dict):
data_dict = item.copy()
else:
raise ValueError(f"Invalid data type: {type(item)}")
records.append(Record(data=data_dict))
return records
def create_references_from_records(self, records: List[Record], include_data: bool = False) -> str:
"""
Create references from a list of records.
Args:
records (List[dict]): A list of records, where each record is a dictionary.
include_data (bool, optional): Whether to include data in the references. Defaults to False.
Returns:
str: A string containing the references in markdown format.
"""
if not records:
return ""
markdown_string = "---\n"
for record in records:
markdown_string += f"- Text: {record.text}"
if include_data:
markdown_string += f" Data: {record.data}"
markdown_string += "\n"
return markdown_string
@property
def get_function_entrypoint_args(self) -> list:
"""
@ -93,17 +280,7 @@ class CustomComponent(Component):
args = build_method["args"]
for arg in args:
if arg.get("type") == "prompt":
raise HTTPException(
status_code=400,
detail={
"error": "Type hint Error",
"traceback": (
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
),
},
)
elif not arg.get("type") and arg.get("name") != "self":
if not arg.get("type") and arg.get("name") != "self":
# Set the type to Data
arg["type"] = "Data"
return args
@ -145,7 +322,10 @@ class CustomComponent(Component):
return_type = build_method["return_type"]
# If list or List is in the return type, then we remove it and return the inner type
if hasattr(return_type, "__origin__") and return_type.__origin__ in [list, List]:
if hasattr(return_type, "__origin__") and return_type.__origin__ in [
list,
List,
]:
return_type = extract_inner_type_from_generic_alias(return_type)
# If the return type is not a Union, then we just return it as a list
@ -265,7 +445,6 @@ class CustomComponent(Component):
return get_index
@property
def get_function(self):
"""
Gets the function associated with the custom component.
@ -275,18 +454,9 @@ class CustomComponent(Component):
"""
return validate.create_function(self.code, self.function_entrypoint_name)
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
"""
Loads a flow with the specified ID and applies tweaks if provided.
Args:
flow_id (str): The ID of the flow to load.
tweaks (Optional[dict]): The tweaks to apply to the flow.
Returns:
Any: The loaded flow.
"""
from langflow.processing.process import build_sorted_vertices, process_tweaks
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph":
from langflow.graph.graph.base import Graph
from langflow.processing.process import process_tweaks
db_service = get_db_service()
with session_getter(db_service) as session:
@ -295,62 +465,52 @@ class CustomComponent(Component):
raise ValueError(f"Flow {flow_id} not found")
if tweaks:
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
return await build_sorted_vertices(graph_data, self.user_id)
graph = Graph.from_payload(graph_data, flow_id=flow_id)
return graph
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
"""
Lists the flows associated with the custom component.
async def run_flow(
self,
input_value: Union[str, list[str]],
flow_id: Optional[str] = None,
flow_name: Optional[str] = None,
tweaks: Optional[dict] = None,
) -> Any:
if not flow_id and not flow_name:
raise ValueError("Flow ID or Flow Name is required")
if not self._flows_records:
self.list_flows()
if not flow_id and self._flows_records:
flow_ids = [flow.data["id"] for flow in self._flows_records if flow.data["name"] == flow_name]
if not flow_ids:
raise ValueError(f"Flow {flow_name} not found")
elif len(flow_ids) > 1:
raise ValueError(f"Multiple flows found with the name {flow_name}")
flow_id = flow_ids[0]
Args:
get_session (Optional[Callable]): The function to get the session.
if not flow_id:
raise ValueError(f"Flow {flow_name} not found")
if isinstance(input_value, str):
input_value = [input_value]
graph = await self.load_flow(flow_id, tweaks)
input_value_dict = [{"input_value": input_val} for input_val in input_value]
return await graph.run(input_value_dict, stream=False)
Returns:
List[Flow]: The list of flows associated with the custom component.
"""
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Record]:
if not self._user_id:
raise ValueError("Session is invalid")
try:
get_session = get_session or session_getter
db_service = get_db_service()
with get_session(db_service) as session:
flows = session.query(Flow).filter(Flow.user_id == self.user_id).all()
return flows
flows = session.exec(
select(Flow).where(Flow.user_id == self._user_id).where(Flow.is_component == False) # noqa
).all()
flows_records = [flow.to_record() for flow in flows]
self._flows_records = flows_records
return flows_records
except Exception as e:
raise ValueError("Session is invalid") from e
async def get_flow(
self,
*,
flow_name: Optional[str] = None,
flow_id: Optional[str] = None,
tweaks: Optional[dict] = None,
get_session: Optional[Callable] = None,
) -> Flow:
"""
Gets a flow with the specified name or ID and applies tweaks if provided.
Args:
flow_name (Optional[str]): The name of the flow to get.
flow_id (Optional[str]): The ID of the flow to get.
tweaks (Optional[dict]): The tweaks to apply to the flow.
get_session (Optional[Callable]): The function to get the session.
Returns:
Flow: The flow with the specified name or ID.
"""
get_session = get_session or session_getter
db_service = get_db_service()
with get_session(db_service) as session:
if flow_id:
flow = session.query(Flow).get(flow_id)
elif flow_name:
flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first()
else:
raise ValueError("Either flow_name or flow_id must be provided")
if not flow:
raise ValueError(f"Flow {flow_name or flow_id} not found")
return await self.load_flow(flow.id, tweaks)
raise ValueError(f"Error listing flows: {e}")
def build(self, *args: Any, **kwargs: Any) -> Any:
"""

View file

@ -11,49 +11,60 @@ from langchain.chains.base import Chain
from langchain.document_loaders.base import BaseLoader
from langchain_community.vectorstores import VectorStore
from langchain_core.documents import Document
from loguru import logger
from pydantic import ValidationError
from langflow.interface.custom.eval import eval_custom_component_code
from langflow.interface.custom.utils import get_function
from langflow.interface.custom_lists import CUSTOM_NODES
from langflow.interface.importing.utils import eval_custom_component_code, get_function, import_by_type
from langflow.interface.importing.utils import import_by_type
from langflow.interface.initialize.llm import initialize_vertexai
from langflow.interface.initialize.utils import handle_format_kwargs, handle_node_type, handle_partial_variables
from langflow.interface.initialize.utils import (
handle_format_kwargs,
handle_node_type,
handle_partial_variables,
)
from langflow.interface.initialize.vector_store import vecstore_initializer
from langflow.interface.output_parsers.base import output_parser_creator
from langflow.interface.retrievers.base import retriever_creator
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.utils import load_file_into_dict
from langflow.interface.wrappers.base import wrapper_creator
from langflow.schema.schema import Record
from langflow.utils import validate
from langflow.utils.util import unescape_string
from loguru import logger
from pydantic import ValidationError
if TYPE_CHECKING:
from langflow import CustomComponent
def build_vertex_in_params(params: Dict) -> Dict:
from langflow.graph.vertex.base import Vertex
# If any of the values in params is a Vertex, we will build it
return {key: value.build() if isinstance(value, Vertex) else value for key, value in params.items()}
async def instantiate_class(
node_type: str, base_type: str, load_from_db_fields: list[str], params: Dict, user_id=None
vertex: "Vertex",
user_id=None,
) -> Any:
"""Instantiate class from module type and key, and params"""
vertex_type = vertex.vertex_type
base_type = vertex.base_type
params = vertex.params
params = convert_params_to_sets(params)
params = convert_kwargs(params)
if node_type in CUSTOM_NODES:
if custom_node := CUSTOM_NODES.get(node_type):
if vertex_type in CUSTOM_NODES:
if custom_node := CUSTOM_NODES.get(vertex_type):
if hasattr(custom_node, "initialize"):
return custom_node.initialize(**params)
return custom_node(**params)
logger.debug(f"Instantiating {node_type} of type {base_type}")
class_object = import_by_type(_type=base_type, name=node_type)
logger.debug(f"Instantiating {vertex_type} of type {base_type}")
if not base_type:
raise ValueError("No base type provided for vertex")
class_object = import_by_type(_type=base_type, name=vertex_type)
return await instantiate_based_on_type(
class_object, base_type, node_type, load_from_db_fields, params, user_id=user_id
class_object=class_object,
base_type=base_type,
node_type=vertex_type,
params=params,
user_id=user_id,
vertex=vertex,
)
@ -81,7 +92,14 @@ def convert_kwargs(params):
return params
async def instantiate_based_on_type(class_object, base_type, node_type, load_from_db_fields, params, user_id):
async def instantiate_based_on_type(
class_object,
base_type,
node_type,
params,
user_id,
vertex,
):
if base_type == "agents":
return instantiate_agent(node_type, class_object, params)
elif base_type == "prompts":
@ -108,14 +126,20 @@ async def instantiate_based_on_type(class_object, base_type, node_type, load_fro
return instantiate_chains(node_type, class_object, params)
elif base_type == "output_parsers":
return instantiate_output_parser(node_type, class_object, params)
elif base_type == "llms":
elif base_type == "models":
return instantiate_llm(node_type, class_object, params)
elif base_type == "retrievers":
return instantiate_retriever(node_type, class_object, params)
elif base_type == "memory":
return instantiate_memory(node_type, class_object, params)
elif base_type == "custom_components":
return await instantiate_custom_component(node_type, class_object, load_from_db_fields, params, user_id)
return await instantiate_custom_component(
node_type,
class_object,
params,
user_id,
vertex,
)
elif base_type == "wrappers":
return instantiate_wrapper(node_type, class_object, params)
else:
@ -137,11 +161,17 @@ def update_params_with_load_from_db_fields(custom_component: "CustomComponent",
return params
async def instantiate_custom_component(node_type, class_object, load_from_db_fields, params, user_id):
async def instantiate_custom_component(node_type, class_object, params, user_id, vertex):
params_copy = params.copy()
class_object: "CustomComponent" = eval_custom_component_code(params_copy.pop("code"))
custom_component = class_object(user_id=user_id)
params_copy = update_params_with_load_from_db_fields(custom_component, params_copy, load_from_db_fields)
class_object: Type["CustomComponent"] = eval_custom_component_code(params_copy.pop("code"))
custom_component: "CustomComponent" = class_object(
user_id=user_id,
parameters=params_copy,
vertex=vertex,
selected_output_type=vertex.selected_output_type,
)
params_copy = update_params_with_load_from_db_fields(custom_component, params_copy, vertex.load_from_db_fields)
if "retriever" in params_copy and hasattr(params_copy["retriever"], "as_retriever"):
params_copy["retriever"] = params_copy["retriever"].as_retriever()
@ -150,12 +180,14 @@ async def instantiate_custom_component(node_type, class_object, load_from_db_fie
if is_async:
# Await the build method directly if it's async
built_object = await custom_component.build(**params_copy)
build_result = await custom_component.build(**params_copy)
else:
# Call the build method directly if it's sync
built_object = custom_component.build(**params_copy)
return built_object, {"repr": custom_component.custom_repr()}
build_result = custom_component.build(**params_copy)
custom_repr = custom_component.custom_repr()
if not custom_repr and isinstance(build_result, (dict, Record, str)):
custom_repr = build_result
return custom_component, build_result, {"repr": custom_repr}
def instantiate_wrapper(node_type, class_object, params):
@ -387,7 +419,10 @@ def instantiate_textsplitter(
# separators might come in as an escaped string like \\n
# so we need to convert it to a string
if "separators" in params:
params["separators"] = params["separators"].encode().decode("unicode-escape")
if isinstance(params["separators"], str):
params["separators"] = unescape_string(params["separators"])
elif isinstance(params["separators"], list):
params["separators"] = [unescape_string(separator) for separator in params["separators"]]
text_splitter = class_object(**params)
else:
from langchain.text_splitter import Language

View file

@ -3,11 +3,10 @@ from typing import Any
from langchain.agents import AgentExecutor
from langchain.chains.base import Chain
from langchain_core.runnables import Runnable
from loguru import logger
from langflow.api.v1.schemas import ChatMessage
from langflow.interface.utils import try_setting_streaming_options
from langflow.processing.base import get_result_and_steps
from loguru import logger
LANGCHAIN_RUNNABLES = (Chain, Runnable, AgentExecutor)

View file

@ -1,7 +1,14 @@
from typing import Any, Callable, Optional, Union
from langflow.field_typing.range_spec import RangeSpec
from pydantic import BaseModel, ConfigDict, Field, field_serializer
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_serializer,
field_validator,
model_serializer,
)
class TemplateField(BaseModel):
@ -29,7 +36,7 @@ class TemplateField(BaseModel):
"""The value of the field. Default is None."""
file_types: list[str] = Field(default=[], serialization_alias="fileTypes")
"""List of file types associated with the field. Default is an empty list. (duplicate)"""
"""List of file types associated with the field . Default is an empty list."""
file_path: Optional[str] = ""
"""The file path of the field if it is a file. Defaults to None."""
@ -58,18 +65,38 @@ class TemplateField(BaseModel):
info: Optional[str] = ""
"""Additional information about the field to be shown in the tooltip. Defaults to an empty string."""
refresh: Optional[bool] = None
"""Specifies if the field should be refreshed. Defaults to False."""
real_time_refresh: Optional[bool] = None
"""Specifies if the field should have real time refresh. `refresh_button` must be False. Defaults to None."""
refresh_button: Optional[bool] = None
"""Specifies if the field should have a refresh button. Defaults to False."""
refresh_button_text: Optional[str] = None
"""Specifies the text for the refresh button. Defaults to None."""
range_spec: Optional[RangeSpec] = Field(default=None, serialization_alias="rangeSpec")
"""Range specification for the field. Defaults to None."""
load_from_db: bool = False
"""Specifies if the field should be loaded from the database. Defaults to False."""
title_case: bool = False
"""Specifies if the field should be displayed in title case. Defaults to True."""
def to_dict(self):
return self.model_dump(by_alias=True, exclude_none=True)
@model_serializer(mode="wrap")
def serialize_model(self, handler):
result = handler(self)
# If the field is str, we add the Text input type
if self.field_type in ["str", "Text"]:
if "input_types" not in result:
result["input_types"] = ["Text"]
if self.field_type == "Text":
result["type"] = "str"
else:
result["type"] = self.field_type
return result
@field_serializer("file_path")
def serialize_file_path(self, value):
return value if self.field_type == "file" else ""
@ -79,3 +106,24 @@ class TemplateField(BaseModel):
if value == "float" and self.range_spec is None:
self.range_spec = RangeSpec()
return value
@field_serializer("display_name")
def serialize_display_name(self, value, _info):
# If display_name is not set, use name and convert to title case
# if title_case is True
if value is None:
# name is probably a snake_case string
# Ex: "file_path" -> "File Path"
value = self.name.replace("_", " ")
if self.title_case:
value = value.title()
return value
@field_validator("file_types")
def validate_file_types(cls, value):
if not isinstance(value, list):
raise ValueError("file_types must be a list")
return [
(f".{file_type}" if isinstance(file_type, str) and not file_type.startswith(".") else file_type)
for file_type in value
]