refactor: Update ChatInput record_response method to use "text" key instead of "message"

This commit is contained in:
ogabrielluiz 2024-06-03 18:26:01 -03:00
commit a91e97ca30
17 changed files with 189 additions and 61 deletions

View file

@ -71,9 +71,7 @@ async def login_to_get_access_token(
@router.get("/auto_login")
async def auto_login(
response: Response,
db: Session = Depends(get_session),
settings_service=Depends(get_settings_service)
response: Response, db: Session = Depends(get_session), settings_service=Depends(get_settings_service)
):
auth_settings = settings_service.auth_settings
if settings_service.auth_settings.AUTO_LOGIN:

View file

@ -30,4 +30,4 @@ class TextInput(TextComponent):
]
def text_response(self) -> Text:
return self.input_value if self.input_value else ""
return self.build(input_value=self.input_value, record_template=self.record_template)

View file

@ -30,4 +30,4 @@ class TextOutput(TextComponent):
]
def text_response(self) -> Text:
return self.input_value if self.input_value else ""
return self.build(input_value=self.input_value, record_template=self.record_template)

View file

@ -65,8 +65,6 @@ class CustomComponent(BaseComponent):
"""The default frozen state of the component. Defaults to False."""
build_parameters: Optional[dict] = None
"""The build parameters of the component. Defaults to None."""
selected_output_type: Optional[str] = None
"""The selected output type of the component. Defaults to None."""
vertex: Optional["Vertex"] = None
"""The edge target parameter of the component. Defaults to None."""
code_class_base_inheritance: ClassVar[str] = "CustomComponent"

View file

@ -220,8 +220,6 @@ class DirectoryReader:
return False, "Empty file"
elif not self.validate_code(file_content):
return False, "Syntax error"
elif not self.validate_build(file_content):
return False, "Missing build function"
elif self._is_type_hint_used_in_args("Optional", file_content) and not self._is_type_hint_imported(
"Optional", file_content
):

View file

@ -14,8 +14,10 @@ class SourceHandle(BaseModel):
baseClasses: Optional[List[str]] = Field(None, description="List of base classes for the source handle.")
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
name: str = Field(..., description="Name of the source handle.")
output_types: List[str] = Field(..., description="List of output types for the source handle.")
name: Optional[str] = Field(None, description="Name of the source handle.")
output_types: Optional[List[str]] = Field(
default_factory=list, description="List of output types for the source handle."
)
class TargetHandle(BaseModel):
@ -49,6 +51,12 @@ class Edge:
self.validate_edge(source, target)
def validate_handles(self, source, target) -> None:
if isinstance(self._source_handle, str) or self.source_handle.baseClasses:
self._legacy_validate_handles(source, target)
else:
self._validate_handles(source, target)
def _validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = self.target_handle.type in self.source_handle.output_types
else:
@ -61,6 +69,19 @@ class Edge:
logger.debug(self.target_handle)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
def _legacy_validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
else:
self.valid_handles = (
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
or self.target_handle.type in self.source_handle.baseClasses
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
def __setstate__(self, state):
self.source_id = state["source_id"]
self.target_id = state["target_id"]
@ -69,6 +90,14 @@ class Edge:
self.target_handle = state.get("target_handle")
def validate_edge(self, source, target) -> None:
# If the self.source_handle has baseClasses, then we are using the legacy
# way of defining the source and target handles
if isinstance(self._source_handle, str) or self.source_handle.baseClasses:
self._legacy_validate_edge(source, target)
else:
self._validate_edge(source, target)
def _validate_edge(self, source, target) -> None:
# Validate that the outputs of the source node are valid inputs
# for the target node
# .outputs is a list of Output objects as dictionaries
@ -97,6 +126,27 @@ class Edge:
None,
)
no_matched_type = self.matched_type is None
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type. ")
def _legacy_validate_edge(self, source, target) -> None:
# Validate that the outputs of the source node are valid inputs
# for the target node
self.source_types = source.output
self.target_reqs = target.required_inputs + target.optional_inputs
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
# Get what type of input the target node is expecting
self.matched_type = next(
(output for output in self.source_types if output in self.target_reqs),
None,
)
no_matched_type = self.matched_type is None
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)

View file

@ -20,7 +20,6 @@ from langflow.schema.schema import INPUT_FIELD_NAME, InputType
from langflow.services.cache.utils import CacheMiss
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service
from langflow.services.monitor.utils import log_transaction
if TYPE_CHECKING:
from langflow.graph.schema import ResultData
@ -526,6 +525,7 @@ class Graph:
raise ValueError(
f"Invalid payload. Expected keys 'nodes' and 'edges'. Found {list(payload.keys())}"
) from exc
raise ValueError(f"Error while creating graph from payload: {exc}") from exc
def __eq__(self, other: object) -> bool:
@ -764,11 +764,9 @@ class Graph:
next_runnable_vertices, top_level_vertices = await self.get_next_and_top_level_vertices(
lock, set_cache_coro, vertex
)
log_transaction(vertex, status="success")
return next_runnable_vertices, top_level_vertices, result_dict, params, valid, artifacts, vertex
except Exception as exc:
logger.exception(f"Error building vertex: {exc}")
log_transaction(vertex, status="failure", error=str(exc))
raise exc
async def get_next_and_top_level_vertices(

View file

@ -212,14 +212,18 @@ class Vertex:
def _parse_data(self) -> None:
self.data = self._data["data"]
self.outputs = self.data["node"]["outputs"]
if self.data["node"]["template"]["_type"] == "Component":
if "outputs" not in self.data["node"]:
raise ValueError(f"Outputs not found for {self.display_name}")
self.outputs = self.data["node"]["outputs"]
else:
self.outputs = self.data["node"]["outputs"]
self.output = self.data["node"]["base_classes"]
self.display_name = self.data["node"].get("display_name", self.id.split("-")[0])
self.description = self.data["node"].get("description", "")
self.frozen = self.data["node"].get("frozen", False)
self.selected_output_type = (
str(self.data.get("selected_output_type")).strip() if self.data.get("selected_output_type") else None
)
self.is_input = self.data["node"].get("is_input") or self.is_input
self.is_output = self.data["node"].get("is_output") or self.is_output
template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}

View file

@ -86,7 +86,8 @@ class ComponentVertex(Vertex):
edge = self.get_edge_with_target(requester.id)
if edge is None:
raise ValueError(f"Edge not found between {self.display_name} and {requester.display_name}")
if edge.source_handle.name not in self.results:
raise ValueError(f"Result not found for {edge.source_handle.name}. Results: {self.results}")
result = self.results[edge.source_handle.name]
log_transaction(source=self, target=requester, flow_id=self.graph.flow_id, status="success")

View file

@ -1,3 +1,5 @@
import copy
import json
import logging
import os
from collections import defaultdict
@ -42,39 +44,129 @@ def update_projects_components_with_latest_component_versions(project_data, all_
latest_template = latest_node.get("template")
node_data["template"]["code"] = latest_template["code"]
for attr in NODE_FORMAT_ATTRIBUTES:
if attr in latest_node:
# Check if it needs to be updated
if latest_node[attr] != node_data.get(attr):
node_changes_log[node_data["display_name"]].append(
{
"attr": attr,
"old_value": node_data.get(attr),
"new_value": latest_node[attr],
}
)
node_data[attr] = latest_node[attr]
for field_name, field_dict in latest_template.items():
if field_name not in node_data["template"]:
continue
# The idea here is to update some attributes of the field
for attr in FIELD_FORMAT_ATTRIBUTES:
if attr in field_dict and attr in node_data["template"].get(field_name):
if "outputs" in latest_node:
node_data["outputs"] = latest_node["outputs"]
if node_data["template"]["_type"] != latest_template["_type"]:
node_data["template"] = latest_template
else:
for attr in NODE_FORMAT_ATTRIBUTES:
if attr in latest_node:
# Check if it needs to be updated
if field_dict[attr] != node_data["template"][field_name][attr]:
if latest_node[attr] != node_data.get(attr):
node_changes_log[node_data["display_name"]].append(
{
"attr": f"{field_name}.{attr}",
"old_value": node_data["template"][field_name][attr],
"new_value": field_dict[attr],
"attr": attr,
"old_value": node_data.get(attr),
"new_value": latest_node[attr],
}
)
node_data["template"][field_name][attr] = field_dict[attr]
node_data[attr] = latest_node[attr]
for field_name, field_dict in latest_template.items():
if field_name not in node_data["template"]:
continue
# The idea here is to update some attributes of the field
for attr in FIELD_FORMAT_ATTRIBUTES:
if attr in field_dict and attr in node_data["template"].get(field_name):
# Check if it needs to be updated
if field_dict[attr] != node_data["template"][field_name][attr]:
node_changes_log[node_data["display_name"]].append(
{
"attr": f"{field_name}.{attr}",
"old_value": node_data["template"][field_name][attr],
"new_value": field_dict[attr],
}
)
node_data["template"][field_name][attr] = field_dict[attr]
project_data_copy = update_new_output(project_data_copy)
log_node_changes(node_changes_log)
return project_data_copy
def scape_json_parse(json_string: str) -> dict:
parsed_string = json_string.replace("œ", '"')
return json.loads(parsed_string)
def update_new_output(data):
nodes = copy.deepcopy(data["nodes"])
edges = copy.deepcopy(data["edges"])
for edge in edges:
if "sourceHandle" in edge and "targetHandle" in edge:
new_source_handle = scape_json_parse(edge["sourceHandle"])
new_target_handle = scape_json_parse(edge["targetHandle"])
_id = new_source_handle["id"]
source_node_index = next((index for (index, d) in enumerate(nodes) if d["id"] == _id), -1)
source_node = nodes[source_node_index] if source_node_index != -1 else None
if "baseClasses" in new_source_handle:
if "output_types" not in new_source_handle:
if source_node and "node" in source_node["data"] and "output_types" in source_node["data"]["node"]:
new_source_handle["output_types"] = source_node["data"]["node"]["output_types"]
else:
new_source_handle["output_types"] = new_source_handle["baseClasses"]
del new_source_handle["baseClasses"]
if "inputTypes" in new_target_handle and new_target_handle["inputTypes"]:
intersection = [
type_ for type_ in new_source_handle["output_types"] if type_ in new_target_handle["inputTypes"]
]
else:
intersection = [
type_ for type_ in new_source_handle["output_types"] if type_ == new_target_handle["type"]
]
selected = intersection[0] if intersection else None
if "name" not in new_source_handle:
new_source_handle["name"] = " | ".join(new_source_handle["output_types"])
new_source_handle["output_types"] = [selected] if selected else []
if source_node and not source_node["data"]["node"].get("outputs"):
if "outputs" not in source_node["data"]["node"]:
source_node["data"]["node"]["outputs"] = []
types = source_node["data"]["node"].get(
"output_types", source_node["data"]["node"].get("base_classes", [])
)
if not any(output.get("selected") == selected for output in source_node["data"]["node"]["outputs"]):
source_node["data"]["node"]["outputs"].append(
{
"types": types,
"selected": selected,
"name": " | ".join(types),
}
)
deduplicated_outputs = []
for output in source_node["data"]["node"]["outputs"]:
if output["name"] not in [d["name"] for d in deduplicated_outputs]:
deduplicated_outputs.append(output)
source_node["data"]["node"]["outputs"] = deduplicated_outputs
edge["sourceHandle"] = json.dumps(new_source_handle)
edge["data"]["sourceHandle"] = new_source_handle
edge["data"]["targetHandle"] = new_target_handle
# The above sets the edges but some of the sourceHandles do not have valid name
# which can be found in the nodes. We need to update the sourceHandle with the
# name from node['data']['node']['outputs']
for node in nodes:
if "outputs" in node["data"]["node"]:
for output in node["data"]["node"]["outputs"]:
for edge in edges:
if node["id"] != edge["source"] or output.get("method") is None:
continue
source_handle = scape_json_parse(edge["sourceHandle"])
if source_handle["output_types"] == output.get("types") and source_handle["name"] != output["name"]:
source_handle["name"] = output["name"]
edge["sourceHandle"] = json.dumps(source_handle)
edge["data"]["sourceHandle"] = source_handle
data_copy = copy.deepcopy(data)
data_copy["nodes"] = nodes
data_copy["edges"] = edges
return data_copy
def log_node_changes(node_changes_log):
# The idea here is to log the changes that were made to the nodes in debug
# Something like:

View file

@ -40,7 +40,6 @@ async def instantiate_class(
user_id=user_id,
parameters=params_copy,
vertex=vertex,
selected_output_type=vertex.selected_output_type,
)
params_copy = update_params_with_load_from_db_fields(
custom_component, params_copy, vertex.load_from_db_fields, fallback_to_env_vars

View file

@ -215,10 +215,7 @@ def create_user_longterm_token(db: Session = Depends(get_session)) -> tuple[UUID
username = settings_service.auth_settings.SUPERUSER
super_user = get_user_by_username(db, username)
if not super_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Super user hasn't been created"
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Super user hasn't been created")
access_token_expires_longterm = timedelta(days=365)
access_token = create_token(
data={"sub": str(super_user.id)},

View file

@ -1,5 +1,4 @@
import os
from typing import Optional
import yaml
from loguru import logger
@ -8,6 +7,7 @@ from langflow.services.base import Service
from langflow.services.settings.auth import AuthSettings
from langflow.services.settings.base import Settings
class SettingsService(Service):
name = "settings_service"

View file

@ -72,8 +72,7 @@ class FrontendNode(BaseModel):
def serialize_model(self, handler):
result = handler(self)
if hasattr(self, "template") and hasattr(self.template, "to_dict"):
format_func = self.format_field if self._format_template else None
result["template"] = self.template.to_dict(format_func)
result["template"] = self.template.to_dict()
name = result.pop("name")
# Migrate base classes to outputs

View file

@ -253,9 +253,7 @@ def build_class_constructor(compiled_class, exec_globals, class_name):
globals()[module_name] = module
instance = exec_globals[class_name](*args, **kwargs)
# Get selected type from global scope
if instance.selected_output_type in exec_globals:
instance.selected_output_type = exec_globals[instance.selected_output_type]
return instance
build_custom_class.__globals__.update(exec_globals)

View file

@ -227,8 +227,9 @@ def client_fixture(session: Session, monkeypatch, request, load_flows_dir):
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
if "load_flows" in request.keywords:
shutil.copyfile(pytest.BASIC_EXAMPLE_PATH,
os.path.join(load_flows_dir, "c54f9130-f2fa-4a3e-b22a-3856d946351b.json"))
shutil.copyfile(
pytest.BASIC_EXAMPLE_PATH, os.path.join(load_flows_dir, "c54f9130-f2fa-4a3e-b22a-3856d946351b.json")
)
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")

View file

@ -1,5 +1,3 @@
import os
from typing import Optional, List
from uuid import UUID, uuid4
import orjson
@ -13,7 +11,6 @@ from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from langflow.services.settings.base import Settings
@pytest.fixture(scope="module")
@ -263,5 +260,3 @@ def test_load_flows(client: TestClient, load_flows_dir):
response = client.get("api/v1/flows/c54f9130-f2fa-4a3e-b22a-3856d946351b")
assert response.status_code == 200
assert response.json()["name"] == "BasicExample"