Refactor vertex class and add custom component functionality
This commit is contained in:
parent
311dcc812d
commit
38ed38d64c
3 changed files with 150 additions and 34 deletions
|
|
@ -2,8 +2,7 @@ import ast
|
|||
import inspect
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Dict, List,
|
||||
Optional)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -41,6 +40,7 @@ class Vertex:
|
|||
) -> None:
|
||||
# is_external means that the Vertex send or receives data from
|
||||
# an external source (e.g the chat)
|
||||
self._custom_component = None
|
||||
self.has_external_input = False
|
||||
self.has_external_output = False
|
||||
self.graph = graph
|
||||
|
|
@ -202,6 +202,7 @@ class Vertex:
|
|||
self.output = self.data["node"]["base_classes"]
|
||||
self.display_name = self.data["node"]["display_name"]
|
||||
self.pinned = self.data["node"].get("pinned", False)
|
||||
self.selected_output_type = self.data["node"].get("selected_output_type")
|
||||
template_dicts = {
|
||||
key: value
|
||||
for key, value in self.data["node"]["template"].items()
|
||||
|
|
@ -500,11 +501,17 @@ class Vertex:
|
|||
if self.base_type is None:
|
||||
raise ValueError(f"Base type for node {self.display_name} not found")
|
||||
try:
|
||||
outgoing_edges = self.graph.get_vertex_edges(
|
||||
self.id, is_source=True, is_target=False
|
||||
)
|
||||
|
||||
result = await loading.instantiate_class(
|
||||
node_type=self.vertex_type,
|
||||
base_type=self.base_type,
|
||||
params=self.params,
|
||||
user_id=user_id,
|
||||
outgoing_edges=outgoing_edges,
|
||||
selected_output_type=self.selected_output_type,
|
||||
)
|
||||
self._update_built_object_and_artifacts(result)
|
||||
except Exception as exc:
|
||||
|
|
@ -518,7 +525,10 @@ class Vertex:
|
|||
Updates the built object and its artifacts.
|
||||
"""
|
||||
if isinstance(result, tuple):
|
||||
self._built_object, self.artifacts = result
|
||||
if len(result) == 2:
|
||||
self._built_object, self.artifacts = result
|
||||
elif len(result) == 3:
|
||||
self._custom_component, self._built_object, self.artifacts = result
|
||||
else:
|
||||
self._built_object = result
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,15 @@
|
|||
import operator
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ClassVar, List, Optional, Sequence, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
import yaml
|
||||
|
|
@ -24,6 +33,9 @@ from langflow.services.deps import (
|
|||
from langflow.services.storage.service import StorageService
|
||||
from langflow.utils import validate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.edge.base import ContractEdge
|
||||
|
||||
|
||||
class CustomComponent(Component):
|
||||
display_name: Optional[str] = None
|
||||
|
|
@ -40,6 +52,12 @@ class CustomComponent(Component):
|
|||
"""The field order of the component. Defaults to an empty list."""
|
||||
pinned: Optional[bool] = False
|
||||
"""The default pinned state of the component. Defaults to False."""
|
||||
build_parameters: Optional[dict] = None
|
||||
"""The build parameters of the component. Defaults to None."""
|
||||
selected_output_type: Optional[str] = None
|
||||
"""The selected output type of the component. Defaults to None."""
|
||||
outgoing_edges: Optional[List["ContractEdge"]] = None
|
||||
"""The edge target parameter of the component. Defaults to None."""
|
||||
code_class_base_inheritance: ClassVar[str] = "CustomComponent"
|
||||
function_entrypoint_name: ClassVar[str] = "build"
|
||||
function: Optional[Callable] = None
|
||||
|
|
@ -88,7 +106,9 @@ class CustomComponent(Component):
|
|||
def tree(self):
|
||||
return self.get_code_tree(self.code or "")
|
||||
|
||||
def to_records(self, data: Any, text_key: str = "text", data_key: str = "data") -> List[dict]:
|
||||
def to_records(
|
||||
self, data: Any, text_key: str = "text", data_key: str = "data"
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Convert data into a list of records.
|
||||
|
||||
|
|
@ -115,7 +135,9 @@ class CustomComponent(Component):
|
|||
|
||||
return records
|
||||
|
||||
def create_references_from_records(self, records: List[dict], include_data: bool = False) -> str:
|
||||
def create_references_from_records(
|
||||
self, records: List[dict], include_data: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Create references from a list of records.
|
||||
|
||||
|
|
@ -150,7 +172,8 @@ class CustomComponent(Component):
|
|||
detail={
|
||||
"error": "Type hint Error",
|
||||
"traceback": (
|
||||
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
|
||||
"Prompt type is not supported in the build method."
|
||||
" Try using PromptTemplate instead."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
@ -164,14 +187,20 @@ class CustomComponent(Component):
|
|||
if not self.code:
|
||||
return {}
|
||||
|
||||
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
|
||||
component_classes = [
|
||||
cls
|
||||
for cls in self.tree["classes"]
|
||||
if self.code_class_base_inheritance in cls["bases"]
|
||||
]
|
||||
if not component_classes:
|
||||
return {}
|
||||
|
||||
# Assume the first Component class is the one we're interested in
|
||||
component_class = component_classes[0]
|
||||
build_methods = [
|
||||
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
|
||||
method
|
||||
for method in component_class["methods"]
|
||||
if method["name"] == self.function_entrypoint_name
|
||||
]
|
||||
|
||||
return build_methods[0] if build_methods else {}
|
||||
|
|
@ -228,7 +257,9 @@ class CustomComponent(Component):
|
|||
# Retrieve and decrypt the credential by name for the current user
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
|
||||
return credential_service.get_credential(
|
||||
user_id=self._user_id or "", name=name, session=session
|
||||
)
|
||||
|
||||
return get_credential
|
||||
|
||||
|
|
@ -238,7 +269,9 @@ class CustomComponent(Component):
|
|||
credential_service = get_credential_service()
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
return credential_service.list_credentials(user_id=self._user_id, session=session)
|
||||
return credential_service.list_credentials(
|
||||
user_id=self._user_id, session=session
|
||||
)
|
||||
|
||||
def index(self, value: int = 0):
|
||||
"""Returns a function that returns the value at the given index in the iterable."""
|
||||
|
|
@ -289,7 +322,11 @@ class CustomComponent(Component):
|
|||
if flow_id:
|
||||
flow = session.query(Flow).get(flow_id)
|
||||
elif flow_name:
|
||||
flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first()
|
||||
flow = (
|
||||
session.query(Flow)
|
||||
.filter(Flow.name == flow_name)
|
||||
.filter(Flow.user_id == self.user_id)
|
||||
).first()
|
||||
else:
|
||||
raise ValueError("Either flow_name or flow_id must be provided")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Type
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Type
|
||||
|
||||
import orjson
|
||||
from langchain.agents import agent as agent_module
|
||||
|
|
@ -34,16 +34,27 @@ from langflow.utils import validate
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from langflow import CustomComponent
|
||||
from langflow.graph.edge.base import ContractEdge
|
||||
|
||||
|
||||
def build_vertex_in_params(params: Dict) -> Dict:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
# If any of the values in params is a Vertex, we will build it
|
||||
return {key: value.build() if isinstance(value, Vertex) else value for key, value in params.items()}
|
||||
return {
|
||||
key: value.build() if isinstance(value, Vertex) else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
|
||||
|
||||
async def instantiate_class(node_type: str, base_type: str, params: Dict, user_id=None) -> Any:
|
||||
async def instantiate_class(
|
||||
node_type: str,
|
||||
base_type: str,
|
||||
params: Dict,
|
||||
user_id=None,
|
||||
outgoing_edges: Optional[List["ContractEdge"]] = None,
|
||||
selected_output_type: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Instantiate class from module type and key, and params"""
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
|
|
@ -55,7 +66,15 @@ async def instantiate_class(node_type: str, base_type: str, params: Dict, user_i
|
|||
return custom_node(**params)
|
||||
logger.debug(f"Instantiating {node_type} of type {base_type}")
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
return await instantiate_based_on_type(class_object, base_type, node_type, params, user_id=user_id)
|
||||
return await instantiate_based_on_type(
|
||||
class_object,
|
||||
base_type,
|
||||
node_type,
|
||||
params,
|
||||
user_id=user_id,
|
||||
outgoing_edges=outgoing_edges,
|
||||
selected_output_type=selected_output_type,
|
||||
)
|
||||
|
||||
|
||||
def convert_params_to_sets(params):
|
||||
|
|
@ -82,7 +101,15 @@ def convert_kwargs(params):
|
|||
return params
|
||||
|
||||
|
||||
async def instantiate_based_on_type(class_object, base_type, node_type, params, user_id):
|
||||
async def instantiate_based_on_type(
|
||||
class_object,
|
||||
base_type,
|
||||
node_type,
|
||||
params,
|
||||
user_id,
|
||||
outgoing_edges,
|
||||
selected_output_type,
|
||||
):
|
||||
if base_type == "agents":
|
||||
return instantiate_agent(node_type, class_object, params)
|
||||
elif base_type == "prompts":
|
||||
|
|
@ -116,17 +143,33 @@ async def instantiate_based_on_type(class_object, base_type, node_type, params,
|
|||
elif base_type == "memory":
|
||||
return instantiate_memory(node_type, class_object, params)
|
||||
elif base_type == "custom_components":
|
||||
return await instantiate_custom_component(node_type, class_object, params, user_id)
|
||||
return await instantiate_custom_component(
|
||||
node_type,
|
||||
class_object,
|
||||
params,
|
||||
user_id,
|
||||
outgoing_edges,
|
||||
selected_output_type,
|
||||
)
|
||||
elif base_type == "wrappers":
|
||||
return instantiate_wrapper(node_type, class_object, params)
|
||||
else:
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
async def instantiate_custom_component(node_type, class_object, params, user_id):
|
||||
async def instantiate_custom_component(
|
||||
node_type, class_object, params, user_id, outgoing_edges, selected_output_type
|
||||
):
|
||||
params_copy = params.copy()
|
||||
class_object: Type["CustomComponent"] = eval_custom_component_code(params_copy.pop("code"))
|
||||
custom_component: "CustomComponent" = class_object(user_id=user_id)
|
||||
class_object: Type["CustomComponent"] = eval_custom_component_code(
|
||||
params_copy.pop("code")
|
||||
)
|
||||
custom_component: "CustomComponent" = class_object(
|
||||
user_id=user_id,
|
||||
parameters=params_copy,
|
||||
outgoing_edges=outgoing_edges,
|
||||
selected_output_type=selected_output_type,
|
||||
)
|
||||
|
||||
if "retriever" in params_copy and hasattr(params_copy["retriever"], "as_retriever"):
|
||||
params_copy["retriever"] = params_copy["retriever"].as_retriever()
|
||||
|
|
@ -141,7 +184,7 @@ async def instantiate_custom_component(node_type, class_object, params, user_id)
|
|||
# Call the build method directly if it's sync
|
||||
build_result = custom_component.build(**params_copy)
|
||||
|
||||
return build_result, {"repr": custom_component.custom_repr()}
|
||||
return custom_component, build_result, {"repr": custom_component.custom_repr()}
|
||||
|
||||
|
||||
def instantiate_wrapper(node_type, class_object, params):
|
||||
|
|
@ -194,7 +237,9 @@ def instantiate_memory(node_type, class_object, params):
|
|||
# I want to catch a specific attribute error that happens
|
||||
# when the object does not have a cursor attribute
|
||||
except Exception as exc:
|
||||
if "object has no attribute 'cursor'" in str(exc) or 'object has no field "conn"' in str(exc):
|
||||
if "object has no attribute 'cursor'" in str(
|
||||
exc
|
||||
) or 'object has no field "conn"' in str(exc):
|
||||
raise AttributeError(
|
||||
(
|
||||
"Failed to build connection to database."
|
||||
|
|
@ -237,7 +282,9 @@ def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params:
|
|||
if class_method := getattr(class_object, method, None):
|
||||
agent = class_method(**params)
|
||||
tools = params.get("tools", [])
|
||||
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, handle_parsing_errors=True)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, handle_parsing_errors=True
|
||||
)
|
||||
return load_agent_executor(class_object, params)
|
||||
|
||||
|
||||
|
|
@ -293,7 +340,11 @@ def instantiate_embedding(node_type, class_object, params: Dict):
|
|||
try:
|
||||
return class_object(**params)
|
||||
except ValidationError:
|
||||
params = {key: value for key, value in params.items() if key in class_object.model_fields}
|
||||
params = {
|
||||
key: value
|
||||
for key, value in params.items()
|
||||
if key in class_object.model_fields
|
||||
}
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
|
|
@ -305,7 +356,9 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
|
|||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
if "documents" in params:
|
||||
params["documents"] = [doc for doc in params["documents"] if isinstance(doc, Document)]
|
||||
params["documents"] = [
|
||||
doc for doc in params["documents"] if isinstance(doc, Document)
|
||||
]
|
||||
if initializer := vecstore_initializer.get(class_object.__name__):
|
||||
vecstore = initializer(class_object, params)
|
||||
else:
|
||||
|
|
@ -320,7 +373,9 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
|
|||
return vecstore
|
||||
|
||||
|
||||
def instantiate_documentloader(node_type: str, class_object: Type[BaseLoader], params: Dict):
|
||||
def instantiate_documentloader(
|
||||
node_type: str, class_object: Type[BaseLoader], params: Dict
|
||||
):
|
||||
if "file_filter" in params:
|
||||
# file_filter will be a string but we need a function
|
||||
# that will be used to filter the files using file_filter
|
||||
|
|
@ -329,13 +384,17 @@ def instantiate_documentloader(node_type: str, class_object: Type[BaseLoader], p
|
|||
# in x and if it is, we will return True
|
||||
file_filter = params.pop("file_filter")
|
||||
extensions = file_filter.split(",")
|
||||
params["file_filter"] = lambda x: any(extension.strip() in x for extension in extensions)
|
||||
params["file_filter"] = lambda x: any(
|
||||
extension.strip() in x for extension in extensions
|
||||
)
|
||||
metadata = params.pop("metadata", None)
|
||||
if metadata and isinstance(metadata, str):
|
||||
try:
|
||||
metadata = orjson.loads(metadata)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError("The metadata you provided is not a valid JSON string.") from exc
|
||||
raise ValueError(
|
||||
"The metadata you provided is not a valid JSON string."
|
||||
) from exc
|
||||
|
||||
if node_type == "WebBaseLoader":
|
||||
if web_path := params.pop("web_path", None):
|
||||
|
|
@ -368,12 +427,16 @@ def instantiate_textsplitter(
|
|||
"Try changing the chunk_size of the Text Splitter."
|
||||
) from exc
|
||||
|
||||
if ("separator_type" in params and params["separator_type"] == "Text") or "separator_type" not in params:
|
||||
if (
|
||||
"separator_type" in params and params["separator_type"] == "Text"
|
||||
) or "separator_type" not in params:
|
||||
params.pop("separator_type", None)
|
||||
# separators might come in as an escaped string like \\n
|
||||
# so we need to convert it to a string
|
||||
if "separators" in params:
|
||||
params["separators"] = params["separators"].encode().decode("unicode-escape")
|
||||
params["separators"] = (
|
||||
params["separators"].encode().decode("unicode-escape")
|
||||
)
|
||||
text_splitter = class_object(**params)
|
||||
else:
|
||||
from langchain.text_splitter import Language
|
||||
|
|
@ -400,7 +463,8 @@ def replace_zero_shot_prompt_with_prompt_template(nodes):
|
|||
tools = [
|
||||
tool
|
||||
for tool in nodes
|
||||
if tool["type"] != "chatOutputNode" and "Tool" in tool["data"]["node"]["base_classes"]
|
||||
if tool["type"] != "chatOutputNode"
|
||||
and "Tool" in tool["data"]["node"]["base_classes"]
|
||||
]
|
||||
node["data"] = build_prompt_template(prompt=node["data"], tools=tools)
|
||||
break
|
||||
|
|
@ -414,7 +478,9 @@ def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs)
|
|||
# agent has hidden args for memory. might need to be support
|
||||
# memory = params["memory"]
|
||||
# if allowed_tools is not a list or set, make it a list
|
||||
if not isinstance(allowed_tools, (list, set)) and isinstance(allowed_tools, BaseTool):
|
||||
if not isinstance(allowed_tools, (list, set)) and isinstance(
|
||||
allowed_tools, BaseTool
|
||||
):
|
||||
allowed_tools = [allowed_tools]
|
||||
tool_names = [tool.name for tool in allowed_tools]
|
||||
# Agent class requires an output_parser but Agent classes
|
||||
|
|
@ -442,7 +508,10 @@ def build_prompt_template(prompt, tools):
|
|||
format_instructions = prompt["node"]["template"]["format_instructions"]["value"]
|
||||
|
||||
tool_strings = "\n".join(
|
||||
[f"{tool['data']['node']['name']}: {tool['data']['node']['description']}" for tool in tools]
|
||||
[
|
||||
f"{tool['data']['node']['name']}: {tool['data']['node']['description']}"
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
tool_names = ", ".join([tool["data"]["node"]["name"] for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue