diff --git a/src/backend/langflow/api/v1/flows.py b/src/backend/langflow/api/v1/flows.py index e58a44e70..149206fd0 100644 --- a/src/backend/langflow/api/v1/flows.py +++ b/src/backend/langflow/api/v1/flows.py @@ -211,7 +211,5 @@ async def download_file( current_user: User = Depends(get_current_active_user), ): """Download all flows as a file.""" - flows = read_flows( - current_user=current_user, session=session, settings_service=settings_service - ) + flows = read_flows(current_user=current_user, session=session, settings_service=settings_service) return FlowListRead(flows=flows) diff --git a/src/backend/langflow/api/v1/login.py b/src/backend/langflow/api/v1/login.py index 30387b92a..d2e031a63 100644 --- a/src/backend/langflow/api/v1/login.py +++ b/src/backend/langflow/api/v1/login.py @@ -40,7 +40,7 @@ async def login_to_get_access_token( httponly=auth_settings.REFRESH_HTTPONLY, samesite=auth_settings.REFRESH_SAME_SITE, secure=auth_settings.REFRESH_SECURE, - expires=auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES*60, + expires=auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES * 60, ) response.set_cookie( "access_token_lf", @@ -48,7 +48,7 @@ async def login_to_get_access_token( httponly=auth_settings.ACCESS_HTTPONLY, samesite=auth_settings.ACCESS_SAME_SITE, secure=auth_settings.ACCESS_SECURE, - expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES*60, + expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, ) return tokens else: @@ -74,7 +74,7 @@ async def auto_login( httponly=auth_settings.ACCESS_HTTPONLY, samesite=auth_settings.ACCESS_SAME_SITE, secure=auth_settings.ACCESS_SECURE, - expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES*60, + expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, ) return tokens @@ -101,7 +101,7 @@ async def refresh_token(request: Request, response: Response, settings_service=D httponly=auth_settings.REFRESH_TOKEN_HTTPONLY, samesite=auth_settings.REFRESH_SAME_SITE, secure=auth_settings.REFRESH_SECURE, - expires=auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES*60, + expires=auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES * 60, ) response.set_cookie( "access_token_lf", @@ -109,7 +109,7 @@ async def refresh_token(request: Request, response: Response, settings_service=D httponly=auth_settings.ACCESS_HTTPONLY, samesite=auth_settings.ACCESS_SAME_SITE, secure=auth_settings.ACCESS_SECURE, - expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES*60, + expires=auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, ) return tokens else: diff --git a/src/backend/langflow/base/data/utils.py b/src/backend/langflow/base/data/utils.py index a9aaed45b..b2c11a270 100644 --- a/src/backend/langflow/base/data/utils.py +++ b/src/backend/langflow/base/data/utils.py @@ -54,9 +54,7 @@ def retrieve_file_paths( glob = "**/*" if recursive else "*" paths = walk_level(path_obj, depth) if depth else path_obj.glob(glob) - file_paths = [ - Text(p) for p in paths if p.is_file() and match_types(p) and is_not_hidden(p) - ] + file_paths = [Text(p) for p in paths if p.is_file() and match_types(p) and is_not_hidden(p)] return file_paths @@ -90,13 +88,10 @@ def parse_text_file_to_record(file_path: str, silent_errors: bool) -> Optional[R text = read_text_file(file_path) # if file is json, yaml, or xml, we can parse it if file_path.endswith(".json"): - text = json.loads(text) elif file_path.endswith(".yaml") or file_path.endswith(".yml"): - text = yaml.safe_load(text) elif file_path.endswith(".xml"): - text = ET.fromstring(text) except Exception as e: if not silent_errors: @@ -116,10 +111,7 @@ def get_elements( if use_multithreading: records = parallel_load_records(file_paths, silent_errors, max_concurrency) else: - records = [ - partition_file_to_record(file_path, silent_errors) - for file_path in file_paths - ] + records = [partition_file_to_record(file_path, silent_errors) for file_path in file_paths] records = list(filter(None, records)) return records diff --git a/src/backend/langflow/components/data/APIRequest.py b/src/backend/langflow/components/data/APIRequest.py index a7cc2ac37..6199d541b 100644 --- a/src/backend/langflow/components/data/APIRequest.py +++ b/src/backend/langflow/components/data/APIRequest.py @@ -56,9 +56,7 @@ class APIRequest(CustomComponent): data = body if body else None payload = json.dumps(data) try: - response = await client.request( - method, url, headers=headers, content=payload, timeout=timeout - ) + response = await client.request(method, url, headers=headers, content=payload, timeout=timeout) try: result = response.json() except Exception: @@ -111,10 +109,7 @@ class APIRequest(CustomComponent): bodies += [None] * (len(urls) - len(bodies)) async with httpx.AsyncClient() as client: results = await asyncio.gather( - *[ - self.make_request(client, method, u, headers, rec, timeout) - for u, rec in zip(urls, bodies) - ] + *[self.make_request(client, method, u, headers, rec, timeout) for u, rec in zip(urls, bodies)] ) self.status = results return results diff --git a/src/backend/langflow/components/data/Directory.py b/src/backend/langflow/components/data/Directory.py index bc949343a..d0428009d 100644 --- a/src/backend/langflow/components/data/Directory.py +++ b/src/backend/langflow/components/data/Directory.py @@ -53,20 +53,14 @@ class DirectoryComponent(CustomComponent): silent_errors: bool = False, use_multithreading: bool = True, ) -> List[Optional[Record]]: - resolved_path = self.resolve_path(path) file_paths = retrieve_file_paths(resolved_path, load_hidden, recursive, depth) loaded_records = [] if use_multithreading: - loaded_records = parallel_load_records( - file_paths, silent_errors, max_concurrency - ) + loaded_records = parallel_load_records(file_paths, silent_errors, max_concurrency) else: - loaded_records = [ - parse_text_file_to_record(file_path, silent_errors) - for file_path in file_paths - ] + loaded_records = [parse_text_file_to_record(file_path, silent_errors) for file_path in file_paths] loaded_records = list(filter(None, loaded_records)) self.status = loaded_records return loaded_records diff --git a/src/backend/langflow/components/experimental/ExtractDataFromRecord.py b/src/backend/langflow/components/experimental/ExtractDataFromRecord.py index 55a48f6c1..6618ad213 100644 --- a/src/backend/langflow/components/experimental/ExtractDataFromRecord.py +++ b/src/backend/langflow/components/experimental/ExtractDataFromRecord.py @@ -21,9 +21,7 @@ class ExtractKeyFromRecordComponent(CustomComponent): }, } - def build( - self, record: Record, keys: list[str], silent_error: bool = True - ) -> Record: + def build(self, record: Record, keys: list[str], silent_error: bool = True) -> Record: """ Extracts the keys from a record. diff --git a/src/backend/langflow/components/inputs/Prompt.py b/src/backend/langflow/components/inputs/Prompt.py index a0b03e369..17055aec3 100644 --- a/src/backend/langflow/components/inputs/Prompt.py +++ b/src/backend/langflow/components/inputs/Prompt.py @@ -24,9 +24,7 @@ class PromptComponent(CustomComponent): prompt_template = PromptTemplate.from_template(Text(template)) kwargs = dict_values_to_string(kwargs) - kwargs = { - k: "\n".join(v) if isinstance(v, list) else v for k, v in kwargs.items() - } + kwargs = {k: "\n".join(v) if isinstance(v, list) else v for k, v in kwargs.items()} try: formated_prompt = prompt_template.format(**kwargs) except Exception as exc: diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index db7602551..06d761533 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -60,13 +60,8 @@ class Vertex: self.updated_raw_params = False self.id: str = data["id"] self.is_state = False - self.is_input = any( - input_component_name in self.id for input_component_name in INPUT_COMPONENTS - ) - self.is_output = any( - output_component_name in self.id - for output_component_name in OUTPUT_COMPONENTS - ) + self.is_input = any(input_component_name in self.id for input_component_name in INPUT_COMPONENTS) + self.is_output = any(output_component_name in self.id for output_component_name in OUTPUT_COMPONENTS) self.has_session_id = None self._custom_component = None self.has_external_input = False @@ -106,17 +101,11 @@ class Vertex: def set_state(self, state: str): self.state = VertexStates[state] - if ( - self.state == VertexStates.INACTIVE - and self.graph.in_degree_map[self.id] < 2 - ): + if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] < 2: # If the vertex is inactive and has only one in degree # it means that it is not a merge point in the graph self.graph.inactivated_vertices.add(self.id) - elif ( - self.state == VertexStates.ACTIVE - and self.id in self.graph.inactivated_vertices - ): + elif self.state == VertexStates.ACTIVE and self.id in self.graph.inactivated_vertices: self.graph.inactivated_vertices.remove(self.id) @property @@ -133,9 +122,7 @@ class Vertex: # If the Vertex.type is a power component # then we need to return the built object # instead of the result dict - if self.is_interface_component and not isinstance( - self._built_object, UnbuiltObject - ): + if self.is_interface_component and not isinstance(self._built_object, UnbuiltObject): result = self._built_object # if it is not a dict or a string and hasattr model_dump then # return the model_dump @@ -147,11 +134,7 @@ class Vertex: if isinstance(self._built_result, UnbuiltResult): return {} - return ( - self._built_result - if isinstance(self._built_result, dict) - else {"result": self._built_result} - ) + return self._built_result if isinstance(self._built_result, dict) else {"result": self._built_result} def set_artifacts(self) -> None: pass @@ -225,31 +208,19 @@ class Vertex: self.selected_output_type = self.data["node"].get("selected_output_type") self.is_input = self.data["node"].get("is_input") or self.is_input self.is_output = self.data["node"].get("is_output") or self.is_output - template_dicts = { - key: value - for key, value in self.data["node"]["template"].items() - if isinstance(value, dict) - } + template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} self.has_session_id = "session_id" in template_dicts self.required_inputs = [ - template_dicts[key]["type"] - for key, value in template_dicts.items() - if value["required"] + template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"] ] self.optional_inputs = [ - template_dicts[key]["type"] - for key, value in template_dicts.items() - if not value["required"] + template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"] ] # Add the template_dicts[key]["input_types"] to the optional_inputs self.optional_inputs.extend( - [ - input_type - for value in template_dicts.values() - for input_type in value.get("input_types", []) - ] + [input_type for value in template_dicts.values() for input_type in value.get("input_types", [])] ) template_dict = self.data["node"]["template"] @@ -296,11 +267,7 @@ class Vertex: self.updated_raw_params = False return - template_dict = { - key: value - for key, value in self.data["node"]["template"].items() - if isinstance(value, dict) - } + template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} params = {} for edge in self.edges: @@ -321,10 +288,7 @@ class Vertex: # we don't know the key of the dict but we need to set the value # to the vertex that is the source of the edge param_dict = template_dict[param_key]["value"] - params[param_key] = { - key: self.graph.get_vertex(edge.source_id) - for key in param_dict.keys() - } + params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()} else: params[param_key] = self.graph.get_vertex(edge.source_id) @@ -360,11 +324,7 @@ class Vertex: # list of dicts, so we need to convert it to a dict # before passing it to the build method if isinstance(val, list): - params[key] = { - k: v - for item in value.get("value", []) - for k, v in item.items() - } + params[key] = {k: v for item in value.get("value", []) for k, v in item.items()} elif isinstance(val, dict): params[key] = val elif value.get("type") == "int" and val is not None: @@ -489,9 +449,7 @@ class Vertex: if isinstance(self._built_object, str): self._built_result = self._built_object - result = await generate_result( - self._built_object, inputs, self.has_external_output, session_id - ) + result = await generate_result(self._built_object, inputs, self.has_external_output, session_id) self._built_result = result async def _build_each_node_in_params_dict(self, user_id=None): @@ -511,9 +469,7 @@ class Vertex: elif key not in self.params or self.updated_raw_params: self.params[key] = value - async def _build_dict_and_update_params( - self, key, nodes_dict: Dict[str, "Vertex"], user_id=None - ): + async def _build_dict_and_update_params(self, key, nodes_dict: Dict[str, "Vertex"], user_id=None): """ Iterates over a dictionary of nodes, builds each and updates the params dictionary. """ @@ -536,9 +492,7 @@ class Vertex: """ return all(self._is_node(node) for node in value) - async def get_result( - self, requester: Optional["Vertex"] = None, user_id=None, timeout=None - ) -> Any: + async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any: # PLEASE REVIEW THIS IF STATEMENT # Check if the Vertex was built already if self._built: @@ -572,9 +526,7 @@ class Vertex: self._extend_params_list_with_result(key, result) self.params[key] = result - async def _build_list_of_nodes_and_update_params( - self, key, nodes: List["Vertex"], user_id=None - ): + async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None): """ Iterates over a list of nodes, builds each and updates the params dictionary. """ @@ -641,9 +593,7 @@ class Vertex: except Exception as exc: logger.exception(exc) - raise ValueError( - f"Error building node {self.display_name}: {str(exc)}" - ) from exc + raise ValueError(f"Error building node {self.display_name}: {str(exc)}") from exc def _update_built_object_and_artifacts(self, result): """ @@ -671,9 +621,7 @@ class Vertex: logger.warning(message) elif isinstance(self._built_object, (Iterator, AsyncIterator)): if self.display_name in ["Text Output"]: - raise ValueError( - f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead." - ) + raise ValueError(f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead.") def _reset(self, params_update: Optional[Dict[str, Any]] = None): self._built = False @@ -736,24 +684,16 @@ class Vertex: return self._built_object # Get the requester edge - requester_edge = next( - (edge for edge in self.edges if edge.target_id == requester.id), None - ) + requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None) # Return the result of the requester edge - return ( - None - if requester_edge is None - else await requester_edge.get_result(source=self, target=requester) - ) + return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester) def add_edge(self, edge: "ContractEdge") -> None: if edge not in self.edges: self.edges.append(edge) def __repr__(self) -> str: - return ( - f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})" - ) + return f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})" def __eq__(self, __o: object) -> bool: try: @@ -774,8 +714,4 @@ class Vertex: def _built_object_repr(self): # Add a message with an emoji, stars for sucess, - return ( - "Built sucessfully ✨" - if self._built_object is not None - else "Failed to build 😵‍💫" - ) + return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵‍💫" diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index 8e1c194c5..1467919e4 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -124,11 +124,9 @@ class DocumentLoaderVertex(Vertex): # show how many documents are in the list? if not isinstance(self._built_object, UnbuiltObject): - avg_length = sum( - len(record.text) - for record in self._built_object - if hasattr(record, "text") - ) / len(self._built_object) + avg_length = sum(len(record.text) for record in self._built_object if hasattr(record, "text")) / len( + self._built_object + ) return f"""{self.display_name}({len(self._built_object)} records) \nAvg. Record Length (characters): {int(avg_length)} Records: {self._built_object[:3]}...""" @@ -201,9 +199,7 @@ class TextSplitterVertex(Vertex): # show how many documents are in the list? if not isinstance(self._built_object, UnbuiltObject): - avg_length = sum(len(doc.page_content) for doc in self._built_object) / len( - self._built_object - ) + avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object) return f"""{self.vertex_type}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} \nDocuments: {self._built_object[:3]}...""" @@ -250,27 +246,18 @@ class PromptVertex(Vertex): user_id = kwargs.get("user_id", None) tools = kwargs.get("tools", []) if not self._built or force: - if ( - "input_variables" not in self.params - or self.params["input_variables"] is None - ): + if "input_variables" not in self.params or self.params["input_variables"] is None: self.params["input_variables"] = [] # Check if it is a ZeroShotPrompt and needs a tool if "ShotPrompt" in self.vertex_type: - tools = ( - [tool_node.build(user_id=user_id) for tool_node in tools] - if tools is not None - else [] - ) + tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else [] # flatten the list of tools if it is a list of lists # first check if it is a list if tools and isinstance(tools, list) and isinstance(tools[0], list): tools = flatten_list(tools) self.params["tools"] = tools prompt_params = [ - key - for key, value in self.params.items() - if isinstance(value, str) and key != "format_instructions" + key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions" ] else: prompt_params = ["template"] @@ -280,20 +267,14 @@ class PromptVertex(Vertex): prompt_text = self.params[param] variables = extract_input_variables_from_prompt(prompt_text) self.params["input_variables"].extend(variables) - self.params["input_variables"] = list( - set(self.params["input_variables"]) - ) + self.params["input_variables"] = list(set(self.params["input_variables"])) elif isinstance(self.params, dict): self.params.pop("input_variables", None) await self._build(user_id=user_id) def _built_object_repr(self): - if ( - not self.artifacts - or self._built_object is None - or not hasattr(self._built_object, "format") - ): + if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"): return super()._built_object_repr() elif isinstance(self._built_object, UnbuiltObject): return super()._built_object_repr() @@ -305,9 +286,7 @@ class PromptVertex(Vertex): # so the prompt format doesn't break artifacts.pop("handle_keys", None) try: - if not hasattr(self._built_object, "template") and hasattr( - self._built_object, "prompt" - ): + if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"): template = self._built_object.prompt.template else: template = self._built_object.template @@ -315,11 +294,7 @@ class PromptVertex(Vertex): if value: replace_key = "{" + key + "}" template = template.replace(replace_key, value) - return ( - template - if isinstance(template, str) - else f"{self.vertex_type}({template})" - ) + return template if isinstance(template, str) else f"{self.vertex_type}({template})" except KeyError: return str(self._built_object) @@ -354,14 +329,8 @@ class ChatVertex(Vertex): return f"Task {self.task_id} is not running" if self.artifacts: # dump as a yaml string - artifacts = { - k.title().replace("_", " "): v - for k, v in self.artifacts.items() - if v is not None - } - yaml_str = yaml.dump( - artifacts, default_flow_style=False, allow_unicode=True - ) + artifacts = {k.title().replace("_", " "): v for k, v in self.artifacts.items() if v is not None} + yaml_str = yaml.dump(artifacts, default_flow_style=False, allow_unicode=True) return yaml_str return super()._built_object_repr() diff --git a/src/backend/langflow/interface/tools/constants.py b/src/backend/langflow/interface/tools/constants.py index 2561d9996..194bea2e9 100644 --- a/src/backend/langflow/interface/tools/constants.py +++ b/src/backend/langflow/interface/tools/constants.py @@ -17,9 +17,7 @@ CUSTOM_TOOLS = { "PythonFunctionTool": PythonFunctionTool, } -OTHER_TOOLS = { - tool: import_class(f"langchain_community.tools.{tool}") for tool in tools.__all__ -} +OTHER_TOOLS = {tool: import_class(f"langchain_community.tools.{tool}") for tool in tools.__all__} ALL_TOOLS_NAMES = { **_BASE_TOOLS, diff --git a/src/backend/langflow/schema/schema.py b/src/backend/langflow/schema/schema.py index d937be247..5ee4be748 100644 --- a/src/backend/langflow/schema/schema.py +++ b/src/backend/langflow/schema/schema.py @@ -73,9 +73,7 @@ class Record(BaseModel): return self.data.get(key, self._default_value) except KeyError: # Fallback to default behavior to raise AttributeError for undefined attributes - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{key}'" - ) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") def __setattr__(self, key, value): """ diff --git a/src/backend/langflow/services/settings/auth.py b/src/backend/langflow/services/settings/auth.py index 074558bc3..8463d0781 100644 --- a/src/backend/langflow/services/settings/auth.py +++ b/src/backend/langflow/services/settings/auth.py @@ -1,4 +1,3 @@ -import datetime import secrets from pathlib import Path from typing import Optional @@ -37,7 +36,7 @@ class AuthSettings(BaseSettings): NEW_USER_IS_ACTIVE: bool = False SUPERUSER: str = DEFAULT_SUPERUSER SUPERUSER_PASSWORD: str = DEFAULT_SUPERUSER_PASSWORD - + REFRESH_SAME_SITE: str = "none" """The SameSite attribute of the refresh token cookie.""" REFRESH_SECURE: bool = True diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index 9bc2d2030..4ec2c526c 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -20,12 +20,8 @@ def remove_ansi_escape_codes(text): return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text) -def build_template_from_function( - name: str, type_to_loader_dict: Dict, add_function: bool = False -): - classes = [ - item.__annotations__["return"].__name__ for item in type_to_loader_dict.values() - ] +def build_template_from_function(name: str, type_to_loader_dict: Dict, add_function: bool = False): + classes = [item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()] # Raise error if name is not in chains if name not in classes: @@ -46,10 +42,8 @@ def build_template_from_function( for name_, value_ in value.__repr_args__(): if name_ == "default_factory": try: - variables[class_field_items]["default"] = ( - get_default_factory( - module=_class.__base__.__module__, function=value_ - ) + variables[class_field_items]["default"] = get_default_factory( + module=_class.__base__.__module__, function=value_ ) except Exception: variables[class_field_items]["default"] = None @@ -57,9 +51,7 @@ def build_template_from_function( variables[class_field_items][name_] = value_ variables[class_field_items]["placeholder"] = ( - docs.params[class_field_items] - if class_field_items in docs.params - else "" + docs.params[class_field_items] if class_field_items in docs.params else "" ) # Adding function to base classes to allow # the output to be a function @@ -74,9 +66,7 @@ def build_template_from_function( } -def build_template_from_class( - name: str, type_to_cls_dict: Dict, add_function: bool = False -): +def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False): classes = [item.__name__ for item in type_to_cls_dict.values()] # Raise error if name is not in chains @@ -100,11 +90,9 @@ def build_template_from_class( for name_, value_ in value.__repr_args__(): if name_ == "default_factory": try: - variables[class_field_items]["default"] = ( - get_default_factory( - module=_class.__base__.__module__, - function=value_, - ) + variables[class_field_items]["default"] = get_default_factory( + module=_class.__base__.__module__, + function=value_, ) except Exception: variables[class_field_items]["default"] = None @@ -112,9 +100,7 @@ def build_template_from_class( variables[class_field_items][name_] = value_ variables[class_field_items]["placeholder"] = ( - docs.params[class_field_items] - if class_field_items in docs.params - else "" + docs.params[class_field_items] if class_field_items in docs.params else "" ) base_classes = get_base_classes(_class) # Adding function to base classes to allow @@ -146,9 +132,7 @@ def build_template_from_method( # Check if the method exists in this class if not hasattr(_class, method_name): - raise ValueError( - f"Method {method_name} not found in class {class_name}" - ) + raise ValueError(f"Method {method_name} not found in class {class_name}") # Get the method method = getattr(_class, method_name) @@ -167,14 +151,8 @@ def build_template_from_method( "_type": _type, **{ name: { - "default": ( - param.default if param.default != param.empty else None - ), - "type": ( - param.annotation - if param.annotation != param.empty - else None - ), + "default": (param.default if param.default != param.empty else None), + "type": (param.annotation if param.annotation != param.empty else None), "required": param.default == param.empty, } for name, param in params.items() @@ -261,9 +239,7 @@ def sync_to_async(func): return async_wrapper -def format_dict( - dictionary: Dict[str, Any], class_name: Optional[str] = None -) -> Dict[str, Any]: +def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) -> Dict[str, Any]: """ Formats a dictionary by removing certain keys and modifying the values of other keys. @@ -349,9 +325,7 @@ def check_list_type(_type: str, value: Dict[str, Any]) -> str: The modified type string. """ if any(list_type in _type for list_type in ["List", "Sequence", "Set"]): - _type = ( - _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1] - ) + _type = _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1] value["list"] = True else: value["list"] = False @@ -454,9 +428,7 @@ def set_headers_value(value: Dict[str, Any]) -> None: value["value"] = """{"Authorization": "Bearer "}""" -def add_options_to_field( - value: Dict[str, Any], class_name: Optional[str], key: str -) -> None: +def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: str) -> None: """ Adds options to the field based on the class name and key. """ diff --git a/tests/conftest.py b/tests/conftest.py index cf4e72245..01c4379de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,35 +28,19 @@ if TYPE_CHECKING: def pytest_configure(): - pytest.BASIC_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "basic_example.json" - ) - pytest.COMPLEX_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "complex_example.json" - ) - pytest.OPENAPI_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "Openapi.json" - ) - pytest.GROUPED_CHAT_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "grouped_chat.json" - ) - pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "one_group_chat.json" - ) - pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json" - ) + pytest.BASIC_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "basic_example.json" + pytest.COMPLEX_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "complex_example.json" + pytest.OPENAPI_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "Openapi.json" + pytest.GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "grouped_chat.json" + pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "one_group_chat.json" + pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json" pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = ( Path(__file__).parent.absolute() / "data" / "BasicChatWithPromptAndHistory.json" ) pytest.CHAT_INPUT = Path(__file__).parent.absolute() / "data" / "ChatInputTest.json" - pytest.TWO_OUTPUTS = ( - Path(__file__).parent.absolute() / "data" / "TwoOutputsTest.json" - ) - pytest.VECTOR_STORE_PATH = ( - Path(__file__).parent.absolute() / "data" / "Vector_store.json" - ) + pytest.TWO_OUTPUTS = Path(__file__).parent.absolute() / "data" / "TwoOutputsTest.json" + pytest.VECTOR_STORE_PATH = Path(__file__).parent.absolute() / "data" / "Vector_store.json" pytest.CODE_WITH_SYNTAX_ERROR = """ def get_text(): retun "Hello World" @@ -67,9 +51,7 @@ def get_text(): def check_openai_api_key_in_environment_variables(): import os - assert ( - os.environ.get("OPENAI_API_KEY") is not None - ), "OPENAI_API_KEY is not set in environment variables" + assert os.environ.get("OPENAI_API_KEY") is not None, "OPENAI_API_KEY is not set in environment variables" @pytest.fixture() @@ -83,9 +65,7 @@ async def async_client() -> AsyncGenerator: @pytest.fixture(name="session") def session_fixture(): - engine = create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ) + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) SQLModel.metadata.create_all(engine) with Session(engine) as session: yield session @@ -121,9 +101,7 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env): monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false") # monkeypatch langflow.services.task.manager.USE_CELERY to True # monkeypatch.setattr(manager, "USE_CELERY", True) - monkeypatch.setattr( - celery_app, "celery_app", celery_app.make_celery("langflow", Config) - ) + monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config)) # def get_session_override(): # return session @@ -274,11 +252,7 @@ def active_user(client): is_superuser=False, ) # check if user exists - if ( - active_user := session.query(User) - .filter(User.username == user.username) - .first() - ): + if active_user := session.query(User).filter(User.username == user.username).first(): return active_user session.add(user) session.commit() @@ -301,9 +275,7 @@ def flow(client, json_flow: str, active_user): from langflow.services.database.models.flow.model import FlowCreate loaded_json = json.loads(json_flow) - flow_data = FlowCreate( - name="test_flow", data=loaded_json.get("data"), user_id=active_user.id - ) + flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id) flow = Flow.model_validate(flow_data) with session_getter(get_db_service()) as session: @@ -327,9 +299,7 @@ def json_two_outputs(): @pytest.fixture -def added_flow_with_prompt_and_history( - client, json_flow_with_prompt_and_history, logged_in_headers -): +def added_flow_with_prompt_and_history(client, json_flow_with_prompt_and_history, logged_in_headers): flow = orjson.loads(json_flow_with_prompt_and_history) data = flow["data"] flow = FlowCreate(name="Basic Chat", description="description", data=data) @@ -369,9 +339,7 @@ def added_vector_store(client, json_vector_store, logged_in_headers): vector_store = orjson.loads(json_vector_store) data = vector_store["data"] vector_store = FlowCreate(name="Vector Store", description="description", data=data) - response = client.post( - "api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers - ) + response = client.post("api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers) assert response.status_code == 201 assert response.json()["name"] == vector_store.name assert response.json()["data"] == vector_store.data @@ -389,11 +357,7 @@ def created_api_key(active_user): ) db_manager = get_db_service() with session_getter(db_manager) as session: - if ( - existing_api_key := session.query(ApiKey) - .filter(ApiKey.api_key == api_key.api_key) - .first() - ): + if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first(): return existing_api_key session.add(api_key) session.commit() @@ -405,9 +369,7 @@ def created_api_key(active_user): def get_starter_project(active_user): # once the client is created, we can get the starter project with session_getter(get_db_service()) as session: - flow = session.exec( - select(Flow).where(Flow.folder == STARTER_FOLDER_NAME) - ).first() + flow = session.exec(select(Flow).where(Flow.folder == STARTER_FOLDER_NAME)).first() if not flow: raise ValueError("No starter project found") diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index 1a06348eb..725d35564 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -104,9 +104,7 @@ def test_custom_component_init(): """ function_entrypoint_name = "build" - custom_component = CustomComponent( - code=code_default, function_entrypoint_name=function_entrypoint_name - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name) assert custom_component.code == code_default assert custom_component.function_entrypoint_name == function_entrypoint_name @@ -115,9 +113,7 @@ def test_custom_component_build_template_config(): """ Test the build_template_config property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") config = custom_component.build_template_config() assert isinstance(config, dict) @@ -126,9 +122,7 @@ def test_custom_component_get_function(): """ Test the get_function property of the CustomComponent class. """ - custom_component = CustomComponent( - code="def build(): pass", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") my_function = custom_component.get_function() assert isinstance(my_function, types.FunctionType) @@ -213,9 +207,7 @@ def test_custom_component_get_function_entrypoint_args(): Test the get_function_entrypoint_args property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") args = custom_component.get_function_entrypoint_args assert len(args) == 4 assert args[0]["name"] == "self" @@ -229,9 +221,7 @@ def test_custom_component_get_function_entrypoint_return_type(): property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") return_type = custom_component.get_function_entrypoint_return_type assert return_type == [Document] @@ -240,9 +230,7 @@ def test_custom_component_get_main_class_name(): """ Test the get_main_class_name property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") class_name = custom_component.get_main_class_name assert class_name == "YourComponent" @@ -252,9 +240,7 @@ def test_custom_component_get_function_valid(): Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name. """ - custom_component = CustomComponent( - code="def build(): pass", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") my_function = custom_component.get_function assert callable(my_function) @@ -289,9 +275,7 @@ def test_code_parser_parse_callable_details_no_args(): parser = CodeParser("") node = ast.FunctionDef( name="test", - args=ast.arguments( - args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] - ), + args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[], decorator_list=[], returns=None, @@ -337,9 +321,7 @@ def test_code_parser_parse_function_def_not_init(): parser = CodeParser("") stmt = ast.FunctionDef( name="test", - args=ast.arguments( - args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] - ), + args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[], decorator_list=[], returns=None, @@ -357,9 +339,7 @@ def test_code_parser_parse_function_def_init(): parser = CodeParser("") stmt = ast.FunctionDef( name="__init__", - args=ast.arguments( - args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] - ), + args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[], decorator_list=[], returns=None, @@ -394,9 +374,7 @@ def test_custom_component_get_code_tree_syntax_error(): Test the get_code_tree method of the CustomComponent class raises the CodeSyntaxError when given incorrect syntax. """ - custom_component = CustomComponent( - code="import os as", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="import os as", function_entrypoint_name="build") with pytest.raises(CodeSyntaxError): custom_component.get_code_tree(custom_component.code) @@ -450,9 +428,7 @@ def test_custom_component_build_not_implemented(): Test the build method of the CustomComponent class raises the NotImplementedError. """ - custom_component = CustomComponent( - code="def build(): pass", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") with pytest.raises(NotImplementedError): custom_component.build() @@ -486,9 +462,7 @@ def test_flow(db): } # Create flow - flow = FlowCreate( - id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data - ) + flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data) # Add to database db.add(flow) diff --git a/tests/test_data_components.py b/tests/test_data_components.py index 0a7cee8b6..cbaa39733 100644 --- a/tests/test_data_components.py +++ b/tests/test_data_components.py @@ -28,9 +28,7 @@ async def test_successful_get_request(api_request): respx.get(url).mock(return_value=Response(200, json=mock_response)) # Making the request - result = await api_request.make_request( - client=httpx.AsyncClient(), method=method, url=url - ) + result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url) # Assertions assert result.data["status_code"] == 200 @@ -46,9 +44,7 @@ async def test_failed_request(api_request): respx.get(url).mock(return_value=Response(404)) # Making the request - result = await api_request.make_request( - client=httpx.AsyncClient(), method=method, url=url - ) + result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url) # Assertions assert result.data["status_code"] == 404 @@ -60,14 +56,10 @@ async def test_timeout(api_request): # Mocking a timeout url = "https://example.com/api/timeout" method = "GET" - respx.get(url).mock( - side_effect=httpx.TimeoutException(message="Timeout", request=None) - ) + respx.get(url).mock(side_effect=httpx.TimeoutException(message="Timeout", request=None)) # Making the request - result = await api_request.make_request( - client=httpx.AsyncClient(), method=method, url=url, timeout=1 - ) + result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url, timeout=1) # Assertions assert result.data["status_code"] == 408 @@ -121,7 +113,7 @@ def test_directory_component_build_with_multithreading( mock_parallel_load_records.return_value = [Mock()] # Act - result = directory_component.build( + directory_component.build( path, types, depth, @@ -134,9 +126,7 @@ def test_directory_component_build_with_multithreading( # Assert mock_resolve_path.assert_called_once_with(path) - mock_retrieve_file_paths.assert_called_once_with( - path, types, load_hidden, recursive, depth - ) + mock_retrieve_file_paths.assert_called_once_with(path, types, load_hidden, recursive, depth) mock_parallel_load_records.assert_called_once_with( mock_retrieve_file_paths.return_value, silent_errors, max_concurrency ) diff --git a/tests/test_database.py b/tests/test_database.py index f04aaf9ff..554a1fc4f 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -27,9 +27,7 @@ def json_style(): ) -def test_create_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_create_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] flow = FlowCreate(name="Test Flow", description="description", data=data) @@ -39,9 +37,7 @@ def test_create_flow( assert response.json()["data"] == flow.data # flow is optional so we can create a flow without a flow flow = FlowCreate(name="Test Flow") - response = client.post( - "api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers - ) + response = client.post("api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers) assert response.status_code == 201 assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data @@ -82,9 +78,7 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he assert response.json()["data"] == flow.data -def test_update_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] @@ -97,9 +91,7 @@ def test_update_flow( description="updated description", data=data, ) - response = client.patch( - f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers - ) + response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers) assert response.status_code == 200 assert response.json()["name"] == updated_flow.name @@ -107,9 +99,7 @@ def test_update_flow( # assert response.json()["data"] == updated_flow.data -def test_delete_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_delete_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] flow = FlowCreate(name="Test Flow", description="description", data=data) @@ -120,9 +110,7 @@ def test_delete_flow( assert response.json()["message"] == "Flow deleted successfully" -def test_create_flows( - client: TestClient, session: Session, json_flow: str, logged_in_headers -): +def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] # Create test data @@ -133,9 +121,7 @@ def test_create_flows( ] ) # Make request to endpoint - response = client.post( - "api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers - ) + response = client.post("api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers) # Check response status code assert response.status_code == 201 # Check response data @@ -149,9 +135,7 @@ def test_create_flows( assert response_data[1]["data"] == data -def test_upload_file( - client: TestClient, session: Session, json_flow: str, logged_in_headers -): +def test_upload_file(client: TestClient, session: Session, json_flow: str, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] # Create test data @@ -220,9 +204,7 @@ def test_download_file( assert response_data[1]["data"] == data -def test_create_flow_with_invalid_data( - client: TestClient, active_user, logged_in_headers -): +def test_create_flow_with_invalid_data(client: TestClient, active_user, logged_in_headers): flow = {"name": "a" * 256, "data": "Invalid flow data"} response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers) assert response.status_code == 422 @@ -234,29 +216,19 @@ def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers assert response.status_code == 404 -def test_update_flow_idempotency( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_update_flow_idempotency(client: TestClient, json_flow: str, active_user, logged_in_headers): flow_data = orjson.loads(json_flow) data = flow_data["data"] flow_data = FlowCreate(name="Test Flow", description="description", data=data) - response = client.post( - "api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers - ) + response = client.post("api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers) flow_id = response.json()["id"] updated_flow = FlowCreate(name="Updated Flow", description="description", data=data) - response1 = client.put( - f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers - ) - response2 = client.put( - f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers - ) + response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers) + response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers) assert response1.json() == response2.json() -def test_update_nonexistent_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow_data = orjson.loads(json_flow) data = flow_data["data"] uuid = uuid4() @@ -265,9 +237,7 @@ def test_update_nonexistent_flow( description="description", data=data, ) - response = client.patch( - f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers - ) + response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers) assert response.status_code == 404 diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 8aec03eac..4494934fd 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -29,10 +29,7 @@ def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1): href, headers=headers, ) - if ( - task_status_response.status_code == 200 - and task_status_response.json()["status"] == "SUCCESS" - ): + if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS": return task_status_response.json() time.sleep(sleep_time) return None # Return None if task did not complete in time @@ -126,11 +123,7 @@ def created_api_key(active_user): ) db_manager = get_db_service() with session_getter(db_manager) as session: - if ( - existing_api_key := session.query(ApiKey) - .filter(ApiKey.api_key == api_key.api_key) - .first() - ): + if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first(): return existing_api_key session.add(api_key) session.commit() @@ -296,11 +289,7 @@ def test_get_all(client: TestClient, logged_in_headers): dir_reader = DirectoryReader(settings.COMPONENTS_PATH[0]) files = dir_reader.get_files() # json_response is a dict of dicts - all_names = [ - component_name - for _, components in response.json().items() - for component_name in components - ] + all_names = [component_name for _, components in response.json().items() for component_name in components] json_response = response.json() # We need to test the custom nodes assert len(all_names) > len(files) @@ -425,19 +414,13 @@ def test_various_prompts(client, prompt, expected_input_variables): def test_get_vertices_flow_not_found(client, logged_in_headers): - response = client.get( - "/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers - ) - assert ( - response.status_code == 500 - ) # Or whatever status code you've set for invalid ID + response = client.get("/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers) + assert response.status_code == 500 # Or whatever status code you've set for invalid ID def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers): flow_id = added_flow_with_prompt_and_history["id"] - response = client.get( - f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers - ) + response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers) assert response.status_code == 200 assert "ids" in response.json() # The response should contain the list in this order @@ -452,19 +435,13 @@ def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_head def test_build_vertex_invalid_flow_id(client, logged_in_headers): - response = client.post( - "/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers - ) + response = client.post("/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers) assert response.status_code == 500 -def test_build_vertex_invalid_vertex_id( - client, added_flow_with_prompt_and_history, logged_in_headers -): +def test_build_vertex_invalid_vertex_id(client, added_flow_with_prompt_and_history, logged_in_headers): flow_id = added_flow_with_prompt_and_history["id"] - response = client.post( - f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers - ) + response = client.post(f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers) assert response.status_code == 500 diff --git a/tests/test_files.py b/tests/test_files.py index 0b3086b6f..b3a85cf03 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -4,7 +4,6 @@ import pytest from langflow.services.deps import get_storage_service from langflow.services.storage.service import StorageService -from langflow.services.storage.utils import build_content_type_from_extension @pytest.fixture @@ -87,10 +86,8 @@ def test_file_operations(client, created_api_key, flow): assert file_name in response.json()["files"] # Step 3: Download the file and verify its content - mime_type = build_content_type_from_extension(file_name.split(".")[-1]) - response = client.get( - f"api/v1/files/download/{flow_id}/{file_name}", headers=headers - ) + + response = client.get(f"api/v1/files/download/{flow_id}/{file_name}", headers=headers) assert response.status_code == 200 assert response.content == file_content # the headers are application/octet-stream @@ -98,9 +95,7 @@ def test_file_operations(client, created_api_key, flow): # mime_type is inside media_type # Step 4: Delete the file - response = client.delete( - f"api/v1/files/delete/{flow_id}/{file_name}", headers=headers - ) + response = client.delete(f"api/v1/files/delete/{flow_id}/{file_name}", headers=headers) assert response.status_code == 200 assert response.json() == {"message": f"File {file_name} deleted successfully"} diff --git a/tests/test_graph.py b/tests/test_graph.py index 3605d4c2a..e645a1c82 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -39,13 +39,7 @@ def sample_nodes(): return [ { "id": "node1", - "data": { - "node": { - "template": { - "some_field": {"show": True, "advanced": False, "name": "Name1"} - } - } - }, + "data": {"node": {"template": {"some_field": {"show": True, "advanced": False, "name": "Name1"}}}}, }, { "id": "node2", @@ -63,11 +57,7 @@ def sample_nodes(): }, { "id": "node3", - "data": { - "node": { - "template": {"unrelated_field": {"show": True, "advanced": True}} - } - }, + "data": {"node": {"template": {"unrelated_field": {"show": True, "advanced": True}}}}, }, ] @@ -152,15 +142,9 @@ def test_get_node_neighbors_basic(basic_graph): # Root Node is an Agent, it requires an LLMChain and tools # We need to check if there is a Chain in the one of the neighbors' # data attribute in the type key - assert any( - "ConversationBufferMemory" in neighbor.data["type"] - for neighbor, val in neighbors.items() - if val - ) + assert any("ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() if val) - assert any( - "OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val - ) + assert any("OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val) def test_get_node(basic_graph): @@ -259,9 +243,7 @@ def test_find_last_node(grouped_chat_json_flow): def test_ungroup_node(grouped_chat_json_flow): grouped_chat_data = json.loads(grouped_chat_json_flow).get("data") - group_node = grouped_chat_data["nodes"][ - 2 - ] # Assuming the first node is a group node + group_node = grouped_chat_data["nodes"][2] # Assuming the first node is a group node base_flow = copy.deepcopy(grouped_chat_data) ungroup_node(group_node["data"], base_flow) # after ungroup_node is called, the base_flow and grouped_chat_data should be different @@ -313,14 +295,9 @@ def test_process_flow_one_group(one_grouped_chat_json_flow): assert "edges" in processed_flow # Now get the node that has ChatOpenAI in its id - chat_openai_node = next( - (node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None - ) + chat_openai_node = next((node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None) assert chat_openai_node is not None - assert ( - chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] - == "test" - ) + assert chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] == "test" def test_process_flow_vector_store_grouped(vector_store_grouped_json_flow): @@ -369,17 +346,11 @@ def test_update_template(sample_template, sample_nodes): assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False - assert ( - node1_updated["data"]["node"]["template"]["some_field"]["display_name"] - == "Name1" - ) + assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1" assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True - assert ( - node2_updated["data"]["node"]["template"]["other_field"]["display_name"] - == "DisplayName2" - ) + assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2" # Ensure node3 remains unchanged assert node3_updated == sample_nodes[2] @@ -410,9 +381,7 @@ def test_set_new_target_handle(): "data": { "node": { "flow": True, - "template": { - "field_1": {"proxy": {"field": "new_field", "id": "new_id"}} - }, + "template": {"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}}, } } } @@ -432,9 +401,7 @@ def test_update_source_handle(): "nodes": [{"id": "some_node"}, {"id": "last_node"}], "edges": [{"source": "some_node"}], } - updated_edge = update_source_handle( - new_edge, flow_data["nodes"], flow_data["edges"] - ) + updated_edge = update_source_handle(new_edge, flow_data["nodes"], flow_data["edges"]) assert updated_edge["source"] == "last_node" assert updated_edge["data"]["sourceHandle"]["id"] == "last_node" diff --git a/tests/test_process.py b/tests/test_process.py index f5bae0569..2548e9215 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -268,13 +268,9 @@ async def test_load_langchain_object_with_cached_session(client, basic_graph_dat # Provide a non-existent session_id session_service = get_session_service() session_id1 = "non-existent-session-id" - graph1, artifacts1 = await session_service.load_session( - session_id1, basic_graph_data - ) + graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data) # Use the new session_id to get the langchain_object again - graph2, artifacts2 = await session_service.load_session( - session_id1, basic_graph_data - ) + graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data) assert graph1 == graph2 assert artifacts1 == artifacts2 @@ -286,15 +282,11 @@ async def test_load_langchain_object_with_no_cached_session(client, basic_graph_ session_service = get_session_service() session_id1 = "non-existent-session-id" session_id = session_service.build_key(session_id1, basic_graph_data) - graph1, artifacts1 = await session_service.load_session( - session_id, data_graph=basic_graph_data, flow_id="flow_id" - ) + graph1, artifacts1 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id") # Clear the cache session_service.clear_session(session_id) # Use the new session_id to get the langchain_object again - graph2, artifacts2 = await session_service.load_session( - session_id, data_graph=basic_graph_data, flow_id="flow_id" - ) + graph2, artifacts2 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id") assert id(graph1) != id(graph2) # Since the cache was cleared, objects should be different @@ -305,12 +297,8 @@ async def test_load_langchain_object_without_session_id(client, basic_graph_data # Provide a non-existent session_id session_service = get_session_service() session_id1 = None - graph1, artifacts1 = await session_service.load_session( - session_id1, data_graph=basic_graph_data, flow_id="flow_id" - ) + graph1, artifacts1 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id") # Use the new session_id to get the langchain_object again - graph2, artifacts2 = await session_service.load_session( - session_id1, data_graph=basic_graph_data, flow_id="flow_id" - ) + graph2, artifacts2 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id") assert graph1 == graph2 diff --git a/tests/test_setup_superuser.py b/tests/test_setup_superuser.py index c7d343818..c2172429b 100644 --- a/tests/test_setup_superuser.py +++ b/tests/test_setup_superuser.py @@ -91,9 +91,7 @@ from langflow.services.utils import teardown_superuser @patch("langflow.services.deps.get_settings_service") @patch("langflow.services.deps.get_session") -def test_teardown_superuser_default_superuser( - mock_get_session, mock_get_settings_service -): +def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service): mock_settings_service = MagicMock() mock_settings_service.auth_settings.AUTO_LOGIN = True mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER @@ -113,9 +111,7 @@ def test_teardown_superuser_default_superuser( @patch("langflow.services.deps.get_settings_service") @patch("langflow.services.deps.get_session") -def test_teardown_superuser_no_default_superuser( - mock_get_session, mock_get_settings_service -): +def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service): ADMIN_USER_NAME = "admin_user" mock_settings_service = MagicMock() mock_settings_service.auth_settings.AUTO_LOGIN = False