This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-08 17:42:22 -03:00
commit 522bd304e8
22 changed files with 151 additions and 485 deletions

View file

@ -211,7 +211,5 @@ async def download_file(
current_user: User = Depends(get_current_active_user),
):
"""Download all flows as a file."""
flows = read_flows(
current_user=current_user, session=session, settings_service=settings_service
)
flows = read_flows(current_user=current_user, session=session, settings_service=settings_service)
return FlowListRead(flows=flows)

View file

@ -40,7 +40,7 @@ async def login_to_get_access_token(
httponly=auth_settings.REFRESH_HTTPONLY,
samesite=auth_settings.REFRESH_SAME_SITE,
secure=auth_settings.REFRESH_SECURE,
expires=auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES*60,
expires=auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES * 60,
)
response.set_cookie(
"access_token_lf",
@ -48,7 +48,7 @@ async def login_to_get_access_token(
httponly=auth_settings.ACCESS_HTTPONLY,
samesite=auth_settings.ACCESS_SAME_SITE,
secure=auth_settings.ACCESS_SECURE,
expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES*60,
expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
return tokens
else:
@ -74,7 +74,7 @@ async def auto_login(
httponly=auth_settings.ACCESS_HTTPONLY,
samesite=auth_settings.ACCESS_SAME_SITE,
secure=auth_settings.ACCESS_SECURE,
expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES*60,
expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
return tokens
@ -101,7 +101,7 @@ async def refresh_token(request: Request, response: Response, settings_service=D
httponly=auth_settings.REFRESH_TOKEN_HTTPONLY,
samesite=auth_settings.REFRESH_SAME_SITE,
secure=auth_settings.REFRESH_SECURE,
expires=auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES*60,
expires=auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES * 60,
)
response.set_cookie(
"access_token_lf",
@ -109,7 +109,7 @@ async def refresh_token(request: Request, response: Response, settings_service=D
httponly=auth_settings.ACCESS_HTTPONLY,
samesite=auth_settings.ACCESS_SAME_SITE,
secure=auth_settings.ACCESS_SECURE,
expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES*60,
expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
return tokens
else:

View file

@ -54,9 +54,7 @@ def retrieve_file_paths(
glob = "**/*" if recursive else "*"
paths = walk_level(path_obj, depth) if depth else path_obj.glob(glob)
file_paths = [
Text(p) for p in paths if p.is_file() and match_types(p) and is_not_hidden(p)
]
file_paths = [Text(p) for p in paths if p.is_file() and match_types(p) and is_not_hidden(p)]
return file_paths
@ -90,13 +88,10 @@ def parse_text_file_to_record(file_path: str, silent_errors: bool) -> Optional[R
text = read_text_file(file_path)
# if file is json, yaml, or xml, we can parse it
if file_path.endswith(".json"):
text = json.loads(text)
elif file_path.endswith(".yaml") or file_path.endswith(".yml"):
text = yaml.safe_load(text)
elif file_path.endswith(".xml"):
text = ET.fromstring(text)
except Exception as e:
if not silent_errors:
@ -116,10 +111,7 @@ def get_elements(
if use_multithreading:
records = parallel_load_records(file_paths, silent_errors, max_concurrency)
else:
records = [
partition_file_to_record(file_path, silent_errors)
for file_path in file_paths
]
records = [partition_file_to_record(file_path, silent_errors) for file_path in file_paths]
records = list(filter(None, records))
return records

View file

@ -56,9 +56,7 @@ class APIRequest(CustomComponent):
data = body if body else None
payload = json.dumps(data)
try:
response = await client.request(
method, url, headers=headers, content=payload, timeout=timeout
)
response = await client.request(method, url, headers=headers, content=payload, timeout=timeout)
try:
result = response.json()
except Exception:
@ -111,10 +109,7 @@ class APIRequest(CustomComponent):
bodies += [None] * (len(urls) - len(bodies))
async with httpx.AsyncClient() as client:
results = await asyncio.gather(
*[
self.make_request(client, method, u, headers, rec, timeout)
for u, rec in zip(urls, bodies)
]
*[self.make_request(client, method, u, headers, rec, timeout) for u, rec in zip(urls, bodies)]
)
self.status = results
return results

View file

@ -53,20 +53,14 @@ class DirectoryComponent(CustomComponent):
silent_errors: bool = False,
use_multithreading: bool = True,
) -> List[Optional[Record]]:
resolved_path = self.resolve_path(path)
file_paths = retrieve_file_paths(resolved_path, load_hidden, recursive, depth)
loaded_records = []
if use_multithreading:
loaded_records = parallel_load_records(
file_paths, silent_errors, max_concurrency
)
loaded_records = parallel_load_records(file_paths, silent_errors, max_concurrency)
else:
loaded_records = [
parse_text_file_to_record(file_path, silent_errors)
for file_path in file_paths
]
loaded_records = [parse_text_file_to_record(file_path, silent_errors) for file_path in file_paths]
loaded_records = list(filter(None, loaded_records))
self.status = loaded_records
return loaded_records

View file

@ -21,9 +21,7 @@ class ExtractKeyFromRecordComponent(CustomComponent):
},
}
def build(
self, record: Record, keys: list[str], silent_error: bool = True
) -> Record:
def build(self, record: Record, keys: list[str], silent_error: bool = True) -> Record:
"""
Extracts the keys from a record.

View file

@ -24,9 +24,7 @@ class PromptComponent(CustomComponent):
prompt_template = PromptTemplate.from_template(Text(template))
kwargs = dict_values_to_string(kwargs)
kwargs = {
k: "\n".join(v) if isinstance(v, list) else v for k, v in kwargs.items()
}
kwargs = {k: "\n".join(v) if isinstance(v, list) else v for k, v in kwargs.items()}
try:
formated_prompt = prompt_template.format(**kwargs)
except Exception as exc:

View file

@ -60,13 +60,8 @@ class Vertex:
self.updated_raw_params = False
self.id: str = data["id"]
self.is_state = False
self.is_input = any(
input_component_name in self.id for input_component_name in INPUT_COMPONENTS
)
self.is_output = any(
output_component_name in self.id
for output_component_name in OUTPUT_COMPONENTS
)
self.is_input = any(input_component_name in self.id for input_component_name in INPUT_COMPONENTS)
self.is_output = any(output_component_name in self.id for output_component_name in OUTPUT_COMPONENTS)
self.has_session_id = None
self._custom_component = None
self.has_external_input = False
@ -106,17 +101,11 @@ class Vertex:
def set_state(self, state: str):
self.state = VertexStates[state]
if (
self.state == VertexStates.INACTIVE
and self.graph.in_degree_map[self.id] < 2
):
if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] < 2:
# If the vertex is inactive and has only one in degree
# it means that it is not a merge point in the graph
self.graph.inactivated_vertices.add(self.id)
elif (
self.state == VertexStates.ACTIVE
and self.id in self.graph.inactivated_vertices
):
elif self.state == VertexStates.ACTIVE and self.id in self.graph.inactivated_vertices:
self.graph.inactivated_vertices.remove(self.id)
@property
@ -133,9 +122,7 @@ class Vertex:
# If the Vertex.type is a power component
# then we need to return the built object
# instead of the result dict
if self.is_interface_component and not isinstance(
self._built_object, UnbuiltObject
):
if self.is_interface_component and not isinstance(self._built_object, UnbuiltObject):
result = self._built_object
# if it is not a dict or a string and hasattr model_dump then
# return the model_dump
@ -147,11 +134,7 @@ class Vertex:
if isinstance(self._built_result, UnbuiltResult):
return {}
return (
self._built_result
if isinstance(self._built_result, dict)
else {"result": self._built_result}
)
return self._built_result if isinstance(self._built_result, dict) else {"result": self._built_result}
def set_artifacts(self) -> None:
pass
@ -225,31 +208,19 @@ class Vertex:
self.selected_output_type = self.data["node"].get("selected_output_type")
self.is_input = self.data["node"].get("is_input") or self.is_input
self.is_output = self.data["node"].get("is_output") or self.is_output
template_dicts = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
self.has_session_id = "session_id" in template_dicts
self.required_inputs = [
template_dicts[key]["type"]
for key, value in template_dicts.items()
if value["required"]
template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"]
]
self.optional_inputs = [
template_dicts[key]["type"]
for key, value in template_dicts.items()
if not value["required"]
template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"]
]
# Add the template_dicts[key]["input_types"] to the optional_inputs
self.optional_inputs.extend(
[
input_type
for value in template_dicts.values()
for input_type in value.get("input_types", [])
]
[input_type for value in template_dicts.values() for input_type in value.get("input_types", [])]
)
template_dict = self.data["node"]["template"]
@ -296,11 +267,7 @@ class Vertex:
self.updated_raw_params = False
return
template_dict = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
params = {}
for edge in self.edges:
@ -321,10 +288,7 @@ class Vertex:
# we don't know the key of the dict but we need to set the value
# to the vertex that is the source of the edge
param_dict = template_dict[param_key]["value"]
params[param_key] = {
key: self.graph.get_vertex(edge.source_id)
for key in param_dict.keys()
}
params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()}
else:
params[param_key] = self.graph.get_vertex(edge.source_id)
@ -360,11 +324,7 @@ class Vertex:
# list of dicts, so we need to convert it to a dict
# before passing it to the build method
if isinstance(val, list):
params[key] = {
k: v
for item in value.get("value", [])
for k, v in item.items()
}
params[key] = {k: v for item in value.get("value", []) for k, v in item.items()}
elif isinstance(val, dict):
params[key] = val
elif value.get("type") == "int" and val is not None:
@ -489,9 +449,7 @@ class Vertex:
if isinstance(self._built_object, str):
self._built_result = self._built_object
result = await generate_result(
self._built_object, inputs, self.has_external_output, session_id
)
result = await generate_result(self._built_object, inputs, self.has_external_output, session_id)
self._built_result = result
async def _build_each_node_in_params_dict(self, user_id=None):
@ -511,9 +469,7 @@ class Vertex:
elif key not in self.params or self.updated_raw_params:
self.params[key] = value
async def _build_dict_and_update_params(
self, key, nodes_dict: Dict[str, "Vertex"], user_id=None
):
async def _build_dict_and_update_params(self, key, nodes_dict: Dict[str, "Vertex"], user_id=None):
"""
Iterates over a dictionary of nodes, builds each and updates the params dictionary.
"""
@ -536,9 +492,7 @@ class Vertex:
"""
return all(self._is_node(node) for node in value)
async def get_result(
self, requester: Optional["Vertex"] = None, user_id=None, timeout=None
) -> Any:
async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any:
# PLEASE REVIEW THIS IF STATEMENT
# Check if the Vertex was built already
if self._built:
@ -572,9 +526,7 @@ class Vertex:
self._extend_params_list_with_result(key, result)
self.params[key] = result
async def _build_list_of_nodes_and_update_params(
self, key, nodes: List["Vertex"], user_id=None
):
async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
"""
Iterates over a list of nodes, builds each and updates the params dictionary.
"""
@ -641,9 +593,7 @@ class Vertex:
except Exception as exc:
logger.exception(exc)
raise ValueError(
f"Error building node {self.display_name}: {str(exc)}"
) from exc
raise ValueError(f"Error building node {self.display_name}: {str(exc)}") from exc
def _update_built_object_and_artifacts(self, result):
"""
@ -671,9 +621,7 @@ class Vertex:
logger.warning(message)
elif isinstance(self._built_object, (Iterator, AsyncIterator)):
if self.display_name in ["Text Output"]:
raise ValueError(
f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead."
)
raise ValueError(f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead.")
def _reset(self, params_update: Optional[Dict[str, Any]] = None):
self._built = False
@ -736,24 +684,16 @@ class Vertex:
return self._built_object
# Get the requester edge
requester_edge = next(
(edge for edge in self.edges if edge.target_id == requester.id), None
)
requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None)
# Return the result of the requester edge
return (
None
if requester_edge is None
else await requester_edge.get_result(source=self, target=requester)
)
return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester)
def add_edge(self, edge: "ContractEdge") -> None:
if edge not in self.edges:
self.edges.append(edge)
def __repr__(self) -> str:
return (
f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})"
)
return f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})"
def __eq__(self, __o: object) -> bool:
try:
@ -774,8 +714,4 @@ class Vertex:
def _built_object_repr(self):
# Add a message with an emoji, stars for sucess,
return (
"Built sucessfully ✨"
if self._built_object is not None
else "Failed to build 😵‍💫"
)
return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵‍💫"

View file

@ -124,11 +124,9 @@ class DocumentLoaderVertex(Vertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(
len(record.text)
for record in self._built_object
if hasattr(record, "text")
) / len(self._built_object)
avg_length = sum(len(record.text) for record in self._built_object if hasattr(record, "text")) / len(
self._built_object
)
return f"""{self.display_name}({len(self._built_object)} records)
\nAvg. Record Length (characters): {int(avg_length)}
Records: {self._built_object[:3]}..."""
@ -201,9 +199,7 @@ class TextSplitterVertex(Vertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
self._built_object
)
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
\nDocuments: {self._built_object[:3]}..."""
@ -250,27 +246,18 @@ class PromptVertex(Vertex):
user_id = kwargs.get("user_id", None)
tools = kwargs.get("tools", [])
if not self._built or force:
if (
"input_variables" not in self.params
or self.params["input_variables"] is None
):
if "input_variables" not in self.params or self.params["input_variables"] is None:
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = (
[tool_node.build(user_id=user_id) for tool_node in tools]
if tools is not None
else []
)
tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
tools = flatten_list(tools)
self.params["tools"] = tools
prompt_params = [
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]
@ -280,20 +267,14 @@ class PromptVertex(Vertex):
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(
set(self.params["input_variables"])
)
self.params["input_variables"] = list(set(self.params["input_variables"]))
elif isinstance(self.params, dict):
self.params.pop("input_variables", None)
await self._build(user_id=user_id)
def _built_object_repr(self):
if (
not self.artifacts
or self._built_object is None
or not hasattr(self._built_object, "format")
):
if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"):
return super()._built_object_repr()
elif isinstance(self._built_object, UnbuiltObject):
return super()._built_object_repr()
@ -305,9 +286,7 @@ class PromptVertex(Vertex):
# so the prompt format doesn't break
artifacts.pop("handle_keys", None)
try:
if not hasattr(self._built_object, "template") and hasattr(
self._built_object, "prompt"
):
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
template = self._built_object.prompt.template
else:
template = self._built_object.template
@ -315,11 +294,7 @@ class PromptVertex(Vertex):
if value:
replace_key = "{" + key + "}"
template = template.replace(replace_key, value)
return (
template
if isinstance(template, str)
else f"{self.vertex_type}({template})"
)
return template if isinstance(template, str) else f"{self.vertex_type}({template})"
except KeyError:
return str(self._built_object)
@ -354,14 +329,8 @@ class ChatVertex(Vertex):
return f"Task {self.task_id} is not running"
if self.artifacts:
# dump as a yaml string
artifacts = {
k.title().replace("_", " "): v
for k, v in self.artifacts.items()
if v is not None
}
yaml_str = yaml.dump(
artifacts, default_flow_style=False, allow_unicode=True
)
artifacts = {k.title().replace("_", " "): v for k, v in self.artifacts.items() if v is not None}
yaml_str = yaml.dump(artifacts, default_flow_style=False, allow_unicode=True)
return yaml_str
return super()._built_object_repr()

View file

@ -17,9 +17,7 @@ CUSTOM_TOOLS = {
"PythonFunctionTool": PythonFunctionTool,
}
OTHER_TOOLS = {
tool: import_class(f"langchain_community.tools.{tool}") for tool in tools.__all__
}
OTHER_TOOLS = {tool: import_class(f"langchain_community.tools.{tool}") for tool in tools.__all__}
ALL_TOOLS_NAMES = {
**_BASE_TOOLS,

View file

@ -73,9 +73,7 @@ class Record(BaseModel):
return self.data.get(key, self._default_value)
except KeyError:
# Fallback to default behavior to raise AttributeError for undefined attributes
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{key}'"
)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
def __setattr__(self, key, value):
"""

View file

@ -1,4 +1,3 @@
import datetime
import secrets
from pathlib import Path
from typing import Optional
@ -37,7 +36,7 @@ class AuthSettings(BaseSettings):
NEW_USER_IS_ACTIVE: bool = False
SUPERUSER: str = DEFAULT_SUPERUSER
SUPERUSER_PASSWORD: str = DEFAULT_SUPERUSER_PASSWORD
REFRESH_SAME_SITE: str = "none"
"""The SameSite attribute of the refresh token cookie."""
REFRESH_SECURE: bool = True

View file

@ -20,12 +20,8 @@ def remove_ansi_escape_codes(text):
return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text)
def build_template_from_function(
name: str, type_to_loader_dict: Dict, add_function: bool = False
):
classes = [
item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()
]
def build_template_from_function(name: str, type_to_loader_dict: Dict, add_function: bool = False):
classes = [item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()]
# Raise error if name is not in chains
if name not in classes:
@ -46,10 +42,8 @@ def build_template_from_function(
for name_, value_ in value.__repr_args__():
if name_ == "default_factory":
try:
variables[class_field_items]["default"] = (
get_default_factory(
module=_class.__base__.__module__, function=value_
)
variables[class_field_items]["default"] = get_default_factory(
module=_class.__base__.__module__, function=value_
)
except Exception:
variables[class_field_items]["default"] = None
@ -57,9 +51,7 @@ def build_template_from_function(
variables[class_field_items][name_] = value_
variables[class_field_items]["placeholder"] = (
docs.params[class_field_items]
if class_field_items in docs.params
else ""
docs.params[class_field_items] if class_field_items in docs.params else ""
)
# Adding function to base classes to allow
# the output to be a function
@ -74,9 +66,7 @@ def build_template_from_function(
}
def build_template_from_class(
name: str, type_to_cls_dict: Dict, add_function: bool = False
):
def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False):
classes = [item.__name__ for item in type_to_cls_dict.values()]
# Raise error if name is not in chains
@ -100,11 +90,9 @@ def build_template_from_class(
for name_, value_ in value.__repr_args__():
if name_ == "default_factory":
try:
variables[class_field_items]["default"] = (
get_default_factory(
module=_class.__base__.__module__,
function=value_,
)
variables[class_field_items]["default"] = get_default_factory(
module=_class.__base__.__module__,
function=value_,
)
except Exception:
variables[class_field_items]["default"] = None
@ -112,9 +100,7 @@ def build_template_from_class(
variables[class_field_items][name_] = value_
variables[class_field_items]["placeholder"] = (
docs.params[class_field_items]
if class_field_items in docs.params
else ""
docs.params[class_field_items] if class_field_items in docs.params else ""
)
base_classes = get_base_classes(_class)
# Adding function to base classes to allow
@ -146,9 +132,7 @@ def build_template_from_method(
# Check if the method exists in this class
if not hasattr(_class, method_name):
raise ValueError(
f"Method {method_name} not found in class {class_name}"
)
raise ValueError(f"Method {method_name} not found in class {class_name}")
# Get the method
method = getattr(_class, method_name)
@ -167,14 +151,8 @@ def build_template_from_method(
"_type": _type,
**{
name: {
"default": (
param.default if param.default != param.empty else None
),
"type": (
param.annotation
if param.annotation != param.empty
else None
),
"default": (param.default if param.default != param.empty else None),
"type": (param.annotation if param.annotation != param.empty else None),
"required": param.default == param.empty,
}
for name, param in params.items()
@ -261,9 +239,7 @@ def sync_to_async(func):
return async_wrapper
def format_dict(
dictionary: Dict[str, Any], class_name: Optional[str] = None
) -> Dict[str, Any]:
def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) -> Dict[str, Any]:
"""
Formats a dictionary by removing certain keys and modifying the
values of other keys.
@ -349,9 +325,7 @@ def check_list_type(_type: str, value: Dict[str, Any]) -> str:
The modified type string.
"""
if any(list_type in _type for list_type in ["List", "Sequence", "Set"]):
_type = (
_type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
)
_type = _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
value["list"] = True
else:
value["list"] = False
@ -454,9 +428,7 @@ def set_headers_value(value: Dict[str, Any]) -> None:
value["value"] = """{"Authorization": "Bearer <token>"}"""
def add_options_to_field(
value: Dict[str, Any], class_name: Optional[str], key: str
) -> None:
def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: str) -> None:
"""
Adds options to the field based on the class name and key.
"""

View file

@ -28,35 +28,19 @@ if TYPE_CHECKING:
def pytest_configure():
pytest.BASIC_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "basic_example.json"
)
pytest.COMPLEX_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "complex_example.json"
)
pytest.OPENAPI_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "Openapi.json"
)
pytest.GROUPED_CHAT_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "grouped_chat.json"
)
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "one_group_chat.json"
)
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json"
)
pytest.BASIC_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "basic_example.json"
pytest.COMPLEX_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "complex_example.json"
pytest.OPENAPI_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "Openapi.json"
pytest.GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "grouped_chat.json"
pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "one_group_chat.json"
pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json"
pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = (
Path(__file__).parent.absolute() / "data" / "BasicChatWithPromptAndHistory.json"
)
pytest.CHAT_INPUT = Path(__file__).parent.absolute() / "data" / "ChatInputTest.json"
pytest.TWO_OUTPUTS = (
Path(__file__).parent.absolute() / "data" / "TwoOutputsTest.json"
)
pytest.VECTOR_STORE_PATH = (
Path(__file__).parent.absolute() / "data" / "Vector_store.json"
)
pytest.TWO_OUTPUTS = Path(__file__).parent.absolute() / "data" / "TwoOutputsTest.json"
pytest.VECTOR_STORE_PATH = Path(__file__).parent.absolute() / "data" / "Vector_store.json"
pytest.CODE_WITH_SYNTAX_ERROR = """
def get_text():
retun "Hello World"
@ -67,9 +51,7 @@ def get_text():
def check_openai_api_key_in_environment_variables():
import os
assert (
os.environ.get("OPENAI_API_KEY") is not None
), "OPENAI_API_KEY is not set in environment variables"
assert os.environ.get("OPENAI_API_KEY") is not None, "OPENAI_API_KEY is not set in environment variables"
@pytest.fixture()
@ -83,9 +65,7 @@ async def async_client() -> AsyncGenerator:
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@ -121,9 +101,7 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
# monkeypatch langflow.services.task.manager.USE_CELERY to True
# monkeypatch.setattr(manager, "USE_CELERY", True)
monkeypatch.setattr(
celery_app, "celery_app", celery_app.make_celery("langflow", Config)
)
monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config))
# def get_session_override():
# return session
@ -274,11 +252,7 @@ def active_user(client):
is_superuser=False,
)
# check if user exists
if (
active_user := session.query(User)
.filter(User.username == user.username)
.first()
):
if active_user := session.query(User).filter(User.username == user.username).first():
return active_user
session.add(user)
session.commit()
@ -301,9 +275,7 @@ def flow(client, json_flow: str, active_user):
from langflow.services.database.models.flow.model import FlowCreate
loaded_json = json.loads(json_flow)
flow_data = FlowCreate(
name="test_flow", data=loaded_json.get("data"), user_id=active_user.id
)
flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id)
flow = Flow.model_validate(flow_data)
with session_getter(get_db_service()) as session:
@ -327,9 +299,7 @@ def json_two_outputs():
@pytest.fixture
def added_flow_with_prompt_and_history(
client, json_flow_with_prompt_and_history, logged_in_headers
):
def added_flow_with_prompt_and_history(client, json_flow_with_prompt_and_history, logged_in_headers):
flow = orjson.loads(json_flow_with_prompt_and_history)
data = flow["data"]
flow = FlowCreate(name="Basic Chat", description="description", data=data)
@ -369,9 +339,7 @@ def added_vector_store(client, json_vector_store, logged_in_headers):
vector_store = orjson.loads(json_vector_store)
data = vector_store["data"]
vector_store = FlowCreate(name="Vector Store", description="description", data=data)
response = client.post(
"api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers
)
response = client.post("api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == vector_store.name
assert response.json()["data"] == vector_store.data
@ -389,11 +357,7 @@ def created_api_key(active_user):
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
if (
existing_api_key := session.query(ApiKey)
.filter(ApiKey.api_key == api_key.api_key)
.first()
):
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
return existing_api_key
session.add(api_key)
session.commit()
@ -405,9 +369,7 @@ def created_api_key(active_user):
def get_starter_project(active_user):
# once the client is created, we can get the starter project
with session_getter(get_db_service()) as session:
flow = session.exec(
select(Flow).where(Flow.folder == STARTER_FOLDER_NAME)
).first()
flow = session.exec(select(Flow).where(Flow.folder == STARTER_FOLDER_NAME)).first()
if not flow:
raise ValueError("No starter project found")

View file

@ -104,9 +104,7 @@ def test_custom_component_init():
"""
function_entrypoint_name = "build"
custom_component = CustomComponent(
code=code_default, function_entrypoint_name=function_entrypoint_name
)
custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name)
assert custom_component.code == code_default
assert custom_component.function_entrypoint_name == function_entrypoint_name
@ -115,9 +113,7 @@ def test_custom_component_build_template_config():
"""
Test the build_template_config property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
config = custom_component.build_template_config()
assert isinstance(config, dict)
@ -126,9 +122,7 @@ def test_custom_component_get_function():
"""
Test the get_function property of the CustomComponent class.
"""
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
my_function = custom_component.get_function()
assert isinstance(my_function, types.FunctionType)
@ -213,9 +207,7 @@ def test_custom_component_get_function_entrypoint_args():
Test the get_function_entrypoint_args
property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
args = custom_component.get_function_entrypoint_args
assert len(args) == 4
assert args[0]["name"] == "self"
@ -229,9 +221,7 @@ def test_custom_component_get_function_entrypoint_return_type():
property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == [Document]
@ -240,9 +230,7 @@ def test_custom_component_get_main_class_name():
"""
Test the get_main_class_name property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
class_name = custom_component.get_main_class_name
assert class_name == "YourComponent"
@ -252,9 +240,7 @@ def test_custom_component_get_function_valid():
Test the get_function property of the CustomComponent
class with valid code and function_entrypoint_name.
"""
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
my_function = custom_component.get_function
assert callable(my_function)
@ -289,9 +275,7 @@ def test_code_parser_parse_callable_details_no_args():
parser = CodeParser("")
node = ast.FunctionDef(
name="test",
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
@ -337,9 +321,7 @@ def test_code_parser_parse_function_def_not_init():
parser = CodeParser("")
stmt = ast.FunctionDef(
name="test",
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
@ -357,9 +339,7 @@ def test_code_parser_parse_function_def_init():
parser = CodeParser("")
stmt = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
@ -394,9 +374,7 @@ def test_custom_component_get_code_tree_syntax_error():
Test the get_code_tree method of the CustomComponent class
raises the CodeSyntaxError when given incorrect syntax.
"""
custom_component = CustomComponent(
code="import os as", function_entrypoint_name="build"
)
custom_component = CustomComponent(code="import os as", function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
custom_component.get_code_tree(custom_component.code)
@ -450,9 +428,7 @@ def test_custom_component_build_not_implemented():
Test the build method of the CustomComponent
class raises the NotImplementedError.
"""
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
with pytest.raises(NotImplementedError):
custom_component.build()
@ -486,9 +462,7 @@ def test_flow(db):
}
# Create flow
flow = FlowCreate(
id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data
)
flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data)
# Add to database
db.add(flow)

View file

@ -28,9 +28,7 @@ async def test_successful_get_request(api_request):
respx.get(url).mock(return_value=Response(200, json=mock_response))
# Making the request
result = await api_request.make_request(
client=httpx.AsyncClient(), method=method, url=url
)
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url)
# Assertions
assert result.data["status_code"] == 200
@ -46,9 +44,7 @@ async def test_failed_request(api_request):
respx.get(url).mock(return_value=Response(404))
# Making the request
result = await api_request.make_request(
client=httpx.AsyncClient(), method=method, url=url
)
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url)
# Assertions
assert result.data["status_code"] == 404
@ -60,14 +56,10 @@ async def test_timeout(api_request):
# Mocking a timeout
url = "https://example.com/api/timeout"
method = "GET"
respx.get(url).mock(
side_effect=httpx.TimeoutException(message="Timeout", request=None)
)
respx.get(url).mock(side_effect=httpx.TimeoutException(message="Timeout", request=None))
# Making the request
result = await api_request.make_request(
client=httpx.AsyncClient(), method=method, url=url, timeout=1
)
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url, timeout=1)
# Assertions
assert result.data["status_code"] == 408
@ -121,7 +113,7 @@ def test_directory_component_build_with_multithreading(
mock_parallel_load_records.return_value = [Mock()]
# Act
result = directory_component.build(
directory_component.build(
path,
types,
depth,
@ -134,9 +126,7 @@ def test_directory_component_build_with_multithreading(
# Assert
mock_resolve_path.assert_called_once_with(path)
mock_retrieve_file_paths.assert_called_once_with(
path, types, load_hidden, recursive, depth
)
mock_retrieve_file_paths.assert_called_once_with(path, types, load_hidden, recursive, depth)
mock_parallel_load_records.assert_called_once_with(
mock_retrieve_file_paths.return_value, silent_errors, max_concurrency
)

View file

@ -27,9 +27,7 @@ def json_style():
)
def test_create_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
def test_create_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
@ -39,9 +37,7 @@ def test_create_flow(
assert response.json()["data"] == flow.data
# flow is optional so we can create a flow without a flow
flow = FlowCreate(name="Test Flow")
response = client.post(
"api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers
)
response = client.post("api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
@ -82,9 +78,7 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he
assert response.json()["data"] == flow.data
def test_update_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
@ -97,9 +91,7 @@ def test_update_flow(
description="updated description",
data=data,
)
response = client.patch(
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
)
response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers)
assert response.status_code == 200
assert response.json()["name"] == updated_flow.name
@ -107,9 +99,7 @@ def test_update_flow(
# assert response.json()["data"] == updated_flow.data
def test_delete_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
def test_delete_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
@ -120,9 +110,7 @@ def test_delete_flow(
assert response.json()["message"] == "Flow deleted successfully"
def test_create_flows(
client: TestClient, session: Session, json_flow: str, logged_in_headers
):
def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
# Create test data
@ -133,9 +121,7 @@ def test_create_flows(
]
)
# Make request to endpoint
response = client.post(
"api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers
)
response = client.post("api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers)
# Check response status code
assert response.status_code == 201
# Check response data
@ -149,9 +135,7 @@ def test_create_flows(
assert response_data[1]["data"] == data
def test_upload_file(
client: TestClient, session: Session, json_flow: str, logged_in_headers
):
def test_upload_file(client: TestClient, session: Session, json_flow: str, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
# Create test data
@ -220,9 +204,7 @@ def test_download_file(
assert response_data[1]["data"] == data
def test_create_flow_with_invalid_data(
client: TestClient, active_user, logged_in_headers
):
def test_create_flow_with_invalid_data(client: TestClient, active_user, logged_in_headers):
flow = {"name": "a" * 256, "data": "Invalid flow data"}
response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers)
assert response.status_code == 422
@ -234,29 +216,19 @@ def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers
assert response.status_code == 404
def test_update_flow_idempotency(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
def test_update_flow_idempotency(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow_data = orjson.loads(json_flow)
data = flow_data["data"]
flow_data = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post(
"api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers
)
response = client.post("api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers)
flow_id = response.json()["id"]
updated_flow = FlowCreate(name="Updated Flow", description="description", data=data)
response1 = client.put(
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
)
response2 = client.put(
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
)
response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers)
response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers)
assert response1.json() == response2.json()
def test_update_nonexistent_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow_data = orjson.loads(json_flow)
data = flow_data["data"]
uuid = uuid4()
@ -265,9 +237,7 @@ def test_update_nonexistent_flow(
description="description",
data=data,
)
response = client.patch(
f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers
)
response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers)
assert response.status_code == 404

View file

@ -29,10 +29,7 @@ def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1):
href,
headers=headers,
)
if (
task_status_response.status_code == 200
and task_status_response.json()["status"] == "SUCCESS"
):
if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS":
return task_status_response.json()
time.sleep(sleep_time)
return None # Return None if task did not complete in time
@ -126,11 +123,7 @@ def created_api_key(active_user):
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
if (
existing_api_key := session.query(ApiKey)
.filter(ApiKey.api_key == api_key.api_key)
.first()
):
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
return existing_api_key
session.add(api_key)
session.commit()
@ -296,11 +289,7 @@ def test_get_all(client: TestClient, logged_in_headers):
dir_reader = DirectoryReader(settings.COMPONENTS_PATH[0])
files = dir_reader.get_files()
# json_response is a dict of dicts
all_names = [
component_name
for _, components in response.json().items()
for component_name in components
]
all_names = [component_name for _, components in response.json().items() for component_name in components]
json_response = response.json()
# We need to test the custom nodes
assert len(all_names) > len(files)
@ -425,19 +414,13 @@ def test_various_prompts(client, prompt, expected_input_variables):
def test_get_vertices_flow_not_found(client, logged_in_headers):
response = client.get(
"/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers
)
assert (
response.status_code == 500
) # Or whatever status code you've set for invalid ID
response = client.get("/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers)
assert response.status_code == 500 # Or whatever status code you've set for invalid ID
def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers):
flow_id = added_flow_with_prompt_and_history["id"]
response = client.get(
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
)
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
assert response.status_code == 200
assert "ids" in response.json()
# The response should contain the list in this order
@ -452,19 +435,13 @@ def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_head
def test_build_vertex_invalid_flow_id(client, logged_in_headers):
response = client.post(
"/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers
)
response = client.post("/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers)
assert response.status_code == 500
def test_build_vertex_invalid_vertex_id(
client, added_flow_with_prompt_and_history, logged_in_headers
):
def test_build_vertex_invalid_vertex_id(client, added_flow_with_prompt_and_history, logged_in_headers):
flow_id = added_flow_with_prompt_and_history["id"]
response = client.post(
f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers
)
response = client.post(f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers)
assert response.status_code == 500

View file

@ -4,7 +4,6 @@ import pytest
from langflow.services.deps import get_storage_service
from langflow.services.storage.service import StorageService
from langflow.services.storage.utils import build_content_type_from_extension
@pytest.fixture
@ -87,10 +86,8 @@ def test_file_operations(client, created_api_key, flow):
assert file_name in response.json()["files"]
# Step 3: Download the file and verify its content
mime_type = build_content_type_from_extension(file_name.split(".")[-1])
response = client.get(
f"api/v1/files/download/{flow_id}/{file_name}", headers=headers
)
response = client.get(f"api/v1/files/download/{flow_id}/{file_name}", headers=headers)
assert response.status_code == 200
assert response.content == file_content
# the headers are application/octet-stream
@ -98,9 +95,7 @@ def test_file_operations(client, created_api_key, flow):
# mime_type is inside media_type
# Step 4: Delete the file
response = client.delete(
f"api/v1/files/delete/{flow_id}/{file_name}", headers=headers
)
response = client.delete(f"api/v1/files/delete/{flow_id}/{file_name}", headers=headers)
assert response.status_code == 200
assert response.json() == {"message": f"File {file_name} deleted successfully"}

View file

@ -39,13 +39,7 @@ def sample_nodes():
return [
{
"id": "node1",
"data": {
"node": {
"template": {
"some_field": {"show": True, "advanced": False, "name": "Name1"}
}
}
},
"data": {"node": {"template": {"some_field": {"show": True, "advanced": False, "name": "Name1"}}}},
},
{
"id": "node2",
@ -63,11 +57,7 @@ def sample_nodes():
},
{
"id": "node3",
"data": {
"node": {
"template": {"unrelated_field": {"show": True, "advanced": True}}
}
},
"data": {"node": {"template": {"unrelated_field": {"show": True, "advanced": True}}}},
},
]
@ -152,15 +142,9 @@ def test_get_node_neighbors_basic(basic_graph):
# Root Node is an Agent, it requires an LLMChain and tools
# We need to check if there is a Chain in the one of the neighbors'
# data attribute in the type key
assert any(
"ConversationBufferMemory" in neighbor.data["type"]
for neighbor, val in neighbors.items()
if val
)
assert any("ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
assert any(
"OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
)
assert any("OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
def test_get_node(basic_graph):
@ -259,9 +243,7 @@ def test_find_last_node(grouped_chat_json_flow):
def test_ungroup_node(grouped_chat_json_flow):
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
group_node = grouped_chat_data["nodes"][
2
] # Assuming the first node is a group node
group_node = grouped_chat_data["nodes"][2] # Assuming the first node is a group node
base_flow = copy.deepcopy(grouped_chat_data)
ungroup_node(group_node["data"], base_flow)
# after ungroup_node is called, the base_flow and grouped_chat_data should be different
@ -313,14 +295,9 @@ def test_process_flow_one_group(one_grouped_chat_json_flow):
assert "edges" in processed_flow
# Now get the node that has ChatOpenAI in its id
chat_openai_node = next(
(node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None
)
chat_openai_node = next((node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None)
assert chat_openai_node is not None
assert (
chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"]
== "test"
)
assert chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] == "test"
def test_process_flow_vector_store_grouped(vector_store_grouped_json_flow):
@ -369,17 +346,11 @@ def test_update_template(sample_template, sample_nodes):
assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True
assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False
assert (
node1_updated["data"]["node"]["template"]["some_field"]["display_name"]
== "Name1"
)
assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1"
assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False
assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True
assert (
node2_updated["data"]["node"]["template"]["other_field"]["display_name"]
== "DisplayName2"
)
assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2"
# Ensure node3 remains unchanged
assert node3_updated == sample_nodes[2]
@ -410,9 +381,7 @@ def test_set_new_target_handle():
"data": {
"node": {
"flow": True,
"template": {
"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}
},
"template": {"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}},
}
}
}
@ -432,9 +401,7 @@ def test_update_source_handle():
"nodes": [{"id": "some_node"}, {"id": "last_node"}],
"edges": [{"source": "some_node"}],
}
updated_edge = update_source_handle(
new_edge, flow_data["nodes"], flow_data["edges"]
)
updated_edge = update_source_handle(new_edge, flow_data["nodes"], flow_data["edges"])
assert updated_edge["source"] == "last_node"
assert updated_edge["data"]["sourceHandle"]["id"] == "last_node"

View file

@ -268,13 +268,9 @@ async def test_load_langchain_object_with_cached_session(client, basic_graph_dat
# Provide a non-existent session_id
session_service = get_session_service()
session_id1 = "non-existent-session-id"
graph1, artifacts1 = await session_service.load_session(
session_id1, basic_graph_data
)
graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data)
# Use the new session_id to get the langchain_object again
graph2, artifacts2 = await session_service.load_session(
session_id1, basic_graph_data
)
graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data)
assert graph1 == graph2
assert artifacts1 == artifacts2
@ -286,15 +282,11 @@ async def test_load_langchain_object_with_no_cached_session(client, basic_graph_
session_service = get_session_service()
session_id1 = "non-existent-session-id"
session_id = session_service.build_key(session_id1, basic_graph_data)
graph1, artifacts1 = await session_service.load_session(
session_id, data_graph=basic_graph_data, flow_id="flow_id"
)
graph1, artifacts1 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id")
# Clear the cache
session_service.clear_session(session_id)
# Use the new session_id to get the langchain_object again
graph2, artifacts2 = await session_service.load_session(
session_id, data_graph=basic_graph_data, flow_id="flow_id"
)
graph2, artifacts2 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id")
assert id(graph1) != id(graph2)
# Since the cache was cleared, objects should be different
@ -305,12 +297,8 @@ async def test_load_langchain_object_without_session_id(client, basic_graph_data
# Provide a non-existent session_id
session_service = get_session_service()
session_id1 = None
graph1, artifacts1 = await session_service.load_session(
session_id1, data_graph=basic_graph_data, flow_id="flow_id"
)
graph1, artifacts1 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id")
# Use the new session_id to get the langchain_object again
graph2, artifacts2 = await session_service.load_session(
session_id1, data_graph=basic_graph_data, flow_id="flow_id"
)
graph2, artifacts2 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id")
assert graph1 == graph2

View file

@ -91,9 +91,7 @@ from langflow.services.utils import teardown_superuser
@patch("langflow.services.deps.get_settings_service")
@patch("langflow.services.deps.get_session")
def test_teardown_superuser_default_superuser(
mock_get_session, mock_get_settings_service
):
def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
mock_settings_service = MagicMock()
mock_settings_service.auth_settings.AUTO_LOGIN = True
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
@ -113,9 +111,7 @@ def test_teardown_superuser_default_superuser(
@patch("langflow.services.deps.get_settings_service")
@patch("langflow.services.deps.get_session")
def test_teardown_superuser_no_default_superuser(
mock_get_session, mock_get_settings_service
):
def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service):
ADMIN_USER_NAME = "admin_user"
mock_settings_service = MagicMock()
mock_settings_service.auth_settings.AUTO_LOGIN = False