diff --git a/src/backend/langflow/__main__.py b/src/backend/langflow/__main__.py index 4d3f76498..dab8402bd 100644 --- a/src/backend/langflow/__main__.py +++ b/src/backend/langflow/__main__.py @@ -98,12 +98,8 @@ def update_settings( @app.command() def run( - host: str = typer.Option( - "127.0.0.1", help="Host to bind the server to.", envvar="LANGFLOW_HOST" - ), - workers: int = typer.Option( - 1, help="Number of worker processes.", envvar="LANGFLOW_WORKERS" - ), + host: str = typer.Option("127.0.0.1", help="Host to bind the server to.", envvar="LANGFLOW_HOST"), + workers: int = typer.Option(1, help="Number of worker processes.", envvar="LANGFLOW_WORKERS"), timeout: int = typer.Option(300, help="Worker timeout in seconds."), port: int = typer.Option(7860, help="Port to listen on.", envvar="LANGFLOW_PORT"), components_path: Optional[Path] = typer.Option( @@ -111,19 +107,11 @@ def run( help="Path to the directory containing custom components.", envvar="LANGFLOW_COMPONENTS_PATH", ), - config: str = typer.Option( - Path(__file__).parent / "config.yaml", help="Path to the configuration file." - ), + config: str = typer.Option(Path(__file__).parent / "config.yaml", help="Path to the configuration file."), # .env file param - env_file: Path = typer.Option( - None, help="Path to the .env file containing environment variables." - ), - log_level: str = typer.Option( - "critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL" - ), - log_file: Path = typer.Option( - "logs/langflow.log", help="Path to the log file.", envvar="LANGFLOW_LOG_FILE" - ), + env_file: Path = typer.Option(None, help="Path to the .env file containing environment variables."), + log_level: str = typer.Option("critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"), + log_file: Path = typer.Option("logs/langflow.log", help="Path to the log file.", envvar="LANGFLOW_LOG_FILE"), cache: Optional[str] = typer.Option( envvar="LANGFLOW_LANGCHAIN_CACHE", help="Type of cache to use. (InMemoryCache, SQLiteCache)", @@ -202,9 +190,7 @@ def run( def run_on_mac_or_linux(host, port, log_level, options, app, open_browser=True): - webapp_process = Process( - target=run_langflow, args=(host, port, log_level, options, app) - ) + webapp_process = Process(target=run_langflow, args=(host, port, log_level, options, app)) webapp_process.start() status_code = 0 while status_code != 200: @@ -280,9 +266,7 @@ def print_banner(host, port): ) # Create a panel with the title and the info text, and a border around it - panel = Panel( - f"{title}\n{info_text}", box=box.ROUNDED, border_style="blue", expand=False - ) + panel = Panel(f"{title}\n{info_text}", box=box.ROUNDED, border_style="blue", expand=False) # Print the banner with a separator line before and after rprint(panel) @@ -314,12 +298,8 @@ def run_langflow(host, port, log_level, options, app): @app.command() def superuser( username: str = typer.Option(..., prompt=True, help="Username for the superuser."), - password: str = typer.Option( - ..., prompt=True, hide_input=True, help="Password for the superuser." - ), - log_level: str = typer.Option( - "critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL" - ), + password: str = typer.Option(..., prompt=True, hide_input=True, help="Password for the superuser."), + log_level: str = typer.Option("critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"), ): """ Create a superuser. diff --git a/src/backend/langflow/api/utils.py b/src/backend/langflow/api/utils.py index c519fffed..79a6b74e2 100644 --- a/src/backend/langflow/api/utils.py +++ b/src/backend/langflow/api/utils.py @@ -2,9 +2,7 @@ API_WORDS = ["api", "key", "token"] def has_api_terms(word: str): - return "api" in word and ( - "key" in word or ("token" in word and "tokens" not in word) - ) + return "api" in word and ("key" in word or ("token" in word and "tokens" not in word)) def remove_api_keys(flow: dict): @@ -14,11 +12,7 @@ def remove_api_keys(flow: dict): node_data = node.get("data").get("node") template = node_data.get("template") for value in template.values(): - if ( - isinstance(value, dict) - and has_api_terms(value["name"]) - and value.get("password") - ): + if isinstance(value, dict) and has_api_terms(value["name"]) and value.get("password"): value["value"] = None return flow @@ -39,9 +33,7 @@ def build_input_keys_response(langchain_object, artifacts): input_keys_response["input_keys"][key] = value # If the object has memory, that memory will have a memory_variables attribute # memory variables should be removed from the input keys - if hasattr(langchain_object, "memory") and hasattr( - langchain_object.memory, "memory_variables" - ): + if hasattr(langchain_object, "memory") and hasattr(langchain_object.memory, "memory_variables"): # Remove memory variables from input keys input_keys_response["input_keys"] = { key: value @@ -51,9 +43,7 @@ def build_input_keys_response(langchain_object, artifacts): # Add memory variables to memory_keys input_keys_response["memory_keys"] = langchain_object.memory.memory_variables - if hasattr(langchain_object, "prompt") and hasattr( - langchain_object.prompt, "template" - ): + if hasattr(langchain_object, "prompt") and hasattr(langchain_object.prompt, "template"): input_keys_response["template"] = langchain_object.prompt.template return input_keys_response diff --git a/src/backend/langflow/api/v1/api_key.py b/src/backend/langflow/api/v1/api_key.py index b05f7a5f5..fcd9c934f 100644 --- a/src/backend/langflow/api/v1/api_key.py +++ b/src/backend/langflow/api/v1/api_key.py @@ -79,9 +79,7 @@ def save_store_api_key( try: api_key = api_key_request.api_key # Encrypt the API key - encrypted = auth_utils.encrypt_api_key( - api_key, settings_service=settings_service - ) + encrypted = auth_utils.encrypt_api_key(api_key, settings_service=settings_service) current_user.store_api_key = encrypted db.commit() return {"detail": "API Key saved"} diff --git a/src/backend/langflow/api/v1/base.py b/src/backend/langflow/api/v1/base.py index f2c2f3f59..701c953cd 100644 --- a/src/backend/langflow/api/v1/base.py +++ b/src/backend/langflow/api/v1/base.py @@ -79,9 +79,7 @@ def validate_prompt(template: str): # Check if there are invalid characters in the input_variables input_variables = check_input_variables(input_variables) if any(var in INVALID_NAMES for var in input_variables): - raise ValueError( - f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. " - ) + raise ValueError(f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. ") try: PromptTemplate(template=template, input_variables=input_variables) @@ -132,9 +130,7 @@ def check_input_variables(input_variables: list): return input_variables -def build_error_message( - input_variables, invalid_chars, wrong_variables, fixed_variables, empty_variables -): +def build_error_message(input_variables, invalid_chars, wrong_variables, fixed_variables, empty_variables): input_variables_str = ", ".join([f"'{var}'" for var in input_variables]) error_string = f"Invalid input variables: {input_variables_str}. " diff --git a/src/backend/langflow/api/v1/callback.py b/src/backend/langflow/api/v1/callback.py index 50a242d2d..da9005bd0 100644 --- a/src/backend/langflow/api/v1/callback.py +++ b/src/backend/langflow/api/v1/callback.py @@ -28,9 +28,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): resp = ChatResponse(message=token, type="stream", intermediate_steps="") await self.websocket.send_json(resp.dict()) - async def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> Any: + async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any: """Run when tool starts running.""" resp = ChatResponse( message="", diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 090457e36..c8be9cca2 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -37,13 +37,9 @@ async def chat( await websocket.accept() user = await get_current_user(token, db) if not user: - await websocket.close( - code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" - ) + await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized") if not user.is_active: - await websocket.close( - code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" - ) + await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized") if client_id in chat_service.cache_service: await chat_service.handle_websocket(client_id, websocket) @@ -59,9 +55,7 @@ async def chat( logger.error(f"Error in chat websocket: {exc}") messsage = exc.detail if isinstance(exc, HTTPException) else str(exc) if "Could not validate credentials" in str(exc): - await websocket.close( - code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" - ) + await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized") else: await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=messsage) @@ -103,15 +97,10 @@ async def init_build( @router.get("/build/{flow_id}/status", response_model=BuiltResponse) -async def build_status( - flow_id: str, cache_service: "BaseCacheService" = Depends(get_cache_service) -): +async def build_status(flow_id: str, cache_service: "BaseCacheService" = Depends(get_cache_service)): """Check the flow_id is in the cache_service.""" try: - built = ( - flow_id in cache_service - and cache_service[flow_id]["status"] == BuildStatus.SUCCESS - ) + built = flow_id in cache_service and cache_service[flow_id]["status"] == BuildStatus.SUCCESS return BuiltResponse( built=built, @@ -173,9 +162,7 @@ async def stream_build( params = vertex._built_object_repr() valid = True logger.debug(f"Building node {str(vertex.vertex_type)}") - logger.debug( - f"Output: {params[:100]}{'...' if len(params) > 100 else ''}" - ) + logger.debug(f"Output: {params[:100]}{'...' if len(params) > 100 else ''}") if vertex.artifacts: # The artifacts will be prompt variables # passed to build_input_keys_response @@ -187,9 +174,7 @@ async def stream_build( valid = False update_build_status(cache_service, flow_id, BuildStatus.FAILURE) - vertex_id = ( - vertex.parent_node_id if vertex.parent_is_top_level else vertex.id - ) + vertex_id = vertex.parent_node_id if vertex.parent_is_top_level else vertex.id if vertex_id in graph.top_level_nodes: response = { "valid": valid, @@ -203,9 +188,7 @@ async def stream_build( langchain_object = graph.build() # Now we need to check the input_keys to send them to the client if hasattr(langchain_object, "input_keys"): - input_keys_response = build_input_keys_response( - langchain_object, artifacts - ) + input_keys_response = build_input_keys_response(langchain_object, artifacts) else: input_keys_response = { "input_keys": None, diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 770172534..65ff08b51 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -92,12 +92,7 @@ async def process( ) # Get the flow that matches the flow_id and belongs to the user - flow = ( - session.query(Flow) - .filter(Flow.id == flow_id) - .filter(Flow.user_id == api_key_user.id) - .first() - ) + flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() if flow is None: raise ValueError(f"Flow {flow_id} not found") @@ -111,9 +106,7 @@ async def process( logger.error(f"Error processing tweaks: {exc}") if sync: task_id, result = await task_service.launch_and_await_task( - process_graph_cached_task - if task_service.use_celery - else process_graph_cached, + process_graph_cached_task if task_service.use_celery else process_graph_cached, graph_data, inputs, clear_cache, @@ -133,13 +126,9 @@ async def process( ) if session_id is None: # Generate a session ID - session_id = get_session_service().generate_key( - session_id=session_id, data_graph=graph_data - ) + session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data) task_id, task = await task_service.launch_task( - process_graph_cached_task - if task_service.use_celery - else process_graph_cached, + process_graph_cached_task if task_service.use_celery else process_graph_cached, graph_data, inputs, clear_cache, @@ -162,18 +151,12 @@ async def process( # StatementError('(builtins.ValueError) badly formed hexadecimal UUID string') if "badly formed hexadecimal UUID string" in str(exc): # This means the Flow ID is not a valid UUID which means it can't find the flow - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc except ValueError as exc: if f"Flow {flow_id} not found" in str(exc): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc else: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc) - ) from exc + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc except Exception as e: # Log stack trace logger.exception(e) diff --git a/src/backend/langflow/api/v1/flows.py b/src/backend/langflow/api/v1/flows.py index 6f06a3fbb..c3f894732 100644 --- a/src/backend/langflow/api/v1/flows.py +++ b/src/backend/langflow/api/v1/flows.py @@ -64,12 +64,7 @@ def read_flow( current_user: User = Depends(get_current_active_user), ): """Read a flow.""" - if user_flow := ( - session.query(Flow) - .filter(Flow.id == flow_id) - .filter(Flow.user_id == current_user.id) - .first() - ): + if user_flow := (session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == current_user.id).first()): return user_flow else: raise HTTPException(status_code=404, detail="Flow not found") diff --git a/src/backend/langflow/api/v1/login.py b/src/backend/langflow/api/v1/login.py index ba4debdba..6b22a7990 100644 --- a/src/backend/langflow/api/v1/login.py +++ b/src/backend/langflow/api/v1/login.py @@ -44,9 +44,7 @@ async def login_to_get_access_token( @router.get("/auto_login") -async def auto_login( - db: Session = Depends(get_session), settings_service=Depends(get_settings_service) -): +async def auto_login(db: Session = Depends(get_session), settings_service=Depends(get_settings_service)): if settings_service.auth_settings.AUTO_LOGIN: return create_user_longterm_token(db) @@ -60,9 +58,7 @@ async def auto_login( @router.post("/refresh") -async def refresh_token( - token: str, current_user: Session = Depends(get_current_active_user) -): +async def refresh_token(token: str, current_user: Session = Depends(get_current_active_user)): if token: return create_refresh_token(token) else: diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index 73e45d673..e84d8295c 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -149,9 +149,7 @@ class StreamData(BaseModel): data: dict def __str__(self) -> str: - return ( - f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n" - ) + return f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n" class CustomComponentCode(BaseModel): diff --git a/src/backend/langflow/api/v1/store.py b/src/backend/langflow/api/v1/store.py index ea508aa51..abdb1eaf1 100644 --- a/src/backend/langflow/api/v1/store.py +++ b/src/backend/langflow/api/v1/store.py @@ -28,9 +28,7 @@ def get_user_store_api_key( settings_service=Depends(get_settings_service), ): if not user.store_api_key: - raise HTTPException( - status_code=400, detail="You must have a store API key set." - ) + raise HTTPException(status_code=400, detail="You must have a store API key set.") decrypted = auth_utils.decrypt_api_key(user.store_api_key, settings_service) return decrypted @@ -107,9 +105,7 @@ def get_components( ) except HTTPStatusError as exc: if exc.response.status_code == 403: - raise ValueError( - "You are not authorized to access this public resource" - ) + raise ValueError("You are not authorized to access this public resource") try: if result: if len(result) >= limit: @@ -124,25 +120,19 @@ def get_components( comp_count = 0 except HTTPStatusError as exc: if exc.response.status_code == 403: - raise ValueError( - "You are not authorized to access this public resource" - ) + raise ValueError("You are not authorized to access this public resource") if store_api_Key and result: # Now, from the result, we need to get the components # the user likes and set the liked_by_user to True try: - updated_result = update_components_with_user_data( - result, store_service, store_api_Key - ) + updated_result = update_components_with_user_data(result, store_service, store_api_Key) authorized = True result = updated_result except Exception: # If we get an error here, it means the user is not authorized authorized = False - return ListComponentResponseModel( - results=result, authorized=authorized, count=comp_count - ) + return ListComponentResponseModel(results=result, authorized=authorized, count=comp_count) except Exception as exc: if isinstance(exc, HTTPStatusError): if exc.response.status_code == 403: @@ -237,9 +227,7 @@ def like_component( ): try: result = store_service.like_component(store_api_Key, component_id) - likes_count = store_service.get_component_likes_count( - store_api_Key, component_id - ) + likes_count = store_service.get_component_likes_count(store_api_Key, component_id) return UsersLikesResponse(likes_count=likes_count, liked_by_user=result) except Exception as exc: diff --git a/src/backend/langflow/api/v1/users.py b/src/backend/langflow/api/v1/users.py index 4d74229ea..22bab69de 100644 --- a/src/backend/langflow/api/v1/users.py +++ b/src/backend/langflow/api/v1/users.py @@ -46,9 +46,7 @@ def add_user( session.refresh(new_user) except IntegrityError as e: session.rollback() - raise HTTPException( - status_code=400, detail="This username is unavailable." - ) from e + raise HTTPException(status_code=400, detail="This username is unavailable.") from e return new_user @@ -96,14 +94,10 @@ def patch_user( Update an existing user's data. """ if not user.is_superuser and user.id != user_id: - raise HTTPException( - status_code=403, detail="You don't have the permission to update this user" - ) + raise HTTPException(status_code=403, detail="You don't have the permission to update this user") if user_update.password: if not user.is_superuser: - raise HTTPException( - status_code=400, detail="You can't change your password here" - ) + raise HTTPException(status_code=400, detail="You can't change your password here") user_update.password = get_password_hash(user_update.password) if user_db := get_user_by_id(session, user_id): @@ -123,16 +117,12 @@ def reset_password( Reset a user's password. """ if user_id != user.id: - raise HTTPException( - status_code=400, detail="You can't change another user's password" - ) + raise HTTPException(status_code=400, detail="You can't change another user's password") if not user: raise HTTPException(status_code=404, detail="User not found") if verify_password(user_update.password, user.password): - raise HTTPException( - status_code=400, detail="You can't use your current password" - ) + raise HTTPException(status_code=400, detail="You can't use your current password") new_password = get_password_hash(user_update.password) user.password = new_password session.commit() @@ -151,13 +141,9 @@ def delete_user( Delete a user from the database. """ if current_user.id == user_id: - raise HTTPException( - status_code=400, detail="You can't delete your own user account" - ) + raise HTTPException(status_code=400, detail="You can't delete your own user account") elif not current_user.is_superuser: - raise HTTPException( - status_code=403, detail="You don't have the permission to delete this user" - ) + raise HTTPException(status_code=403, detail="You don't have the permission to delete this user") user_db = session.query(User).filter(User.id == user_id).first() if not user_db: diff --git a/src/backend/langflow/api/v1/validate.py b/src/backend/langflow/api/v1/validate.py index 65fb66bd2..c1ef04e54 100644 --- a/src/backend/langflow/api/v1/validate.py +++ b/src/backend/langflow/api/v1/validate.py @@ -41,9 +41,7 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest): add_new_variables_to_template(input_variables, prompt_request) - remove_old_variables_from_template( - old_custom_fields, input_variables, prompt_request - ) + remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request) update_input_variables_field(input_variables, prompt_request) @@ -58,19 +56,12 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest): def get_old_custom_fields(prompt_request): try: - if ( - len(prompt_request.frontend_node.custom_fields) == 1 - and prompt_request.name == "" - ): + if len(prompt_request.frontend_node.custom_fields) == 1 and prompt_request.name == "": # If there is only one custom field and the name is empty string # then we are dealing with the first prompt request after the node was created - prompt_request.name = list( - prompt_request.frontend_node.custom_fields.keys() - )[0] + prompt_request.name = list(prompt_request.frontend_node.custom_fields.keys())[0] - old_custom_fields = prompt_request.frontend_node.custom_fields[ - prompt_request.name - ].copy() + old_custom_fields = prompt_request.frontend_node.custom_fields[prompt_request.name].copy() except KeyError: old_custom_fields = [] prompt_request.frontend_node.custom_fields[prompt_request.name] = [] @@ -92,40 +83,26 @@ def add_new_variables_to_template(input_variables, prompt_request): ) if variable in prompt_request.frontend_node.template: # Set the new field with the old value - template_field.value = prompt_request.frontend_node.template[variable][ - "value" - ] + template_field.value = prompt_request.frontend_node.template[variable]["value"] prompt_request.frontend_node.template[variable] = template_field.to_dict() # Check if variable is not already in the list before appending - if ( - variable - not in prompt_request.frontend_node.custom_fields[prompt_request.name] - ): - prompt_request.frontend_node.custom_fields[prompt_request.name].append( - variable - ) + if variable not in prompt_request.frontend_node.custom_fields[prompt_request.name]: + prompt_request.frontend_node.custom_fields[prompt_request.name].append(variable) except Exception as exc: logger.exception(exc) raise HTTPException(status_code=500, detail=str(exc)) from exc -def remove_old_variables_from_template( - old_custom_fields, input_variables, prompt_request -): +def remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request): for variable in old_custom_fields: if variable not in input_variables: try: # Remove the variable from custom_fields associated with the given name - if ( - variable - in prompt_request.frontend_node.custom_fields[prompt_request.name] - ): - prompt_request.frontend_node.custom_fields[ - prompt_request.name - ].remove(variable) + if variable in prompt_request.frontend_node.custom_fields[prompt_request.name]: + prompt_request.frontend_node.custom_fields[prompt_request.name].remove(variable) # Remove the variable from the template prompt_request.frontend_node.template.pop(variable, None) @@ -137,6 +114,4 @@ def remove_old_variables_from_template( def update_input_variables_field(input_variables, prompt_request): if "input_variables" in prompt_request.frontend_node.template: - prompt_request.frontend_node.template["input_variables"][ - "value" - ] = input_variables + prompt_request.frontend_node.template["input_variables"]["value"] = input_variables diff --git a/src/backend/langflow/components/agents/OpenAIConversationalAgent.py b/src/backend/langflow/components/agents/OpenAIConversationalAgent.py index 364abf962..5dc042aed 100644 --- a/src/backend/langflow/components/agents/OpenAIConversationalAgent.py +++ b/src/backend/langflow/components/agents/OpenAIConversationalAgent.py @@ -71,7 +71,9 @@ class ConversationalAgent(CustomComponent): extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)], ) agent = OpenAIFunctionsAgent( - llm=llm, tools=tools, prompt=prompt # type: ignore + llm=llm, + tools=tools, + prompt=prompt, # type: ignore ) return AgentExecutor( agent=agent, diff --git a/src/backend/langflow/components/chains/PromptRunner.py b/src/backend/langflow/components/chains/PromptRunner.py index 96a64f6fc..ba2fd7e34 100644 --- a/src/backend/langflow/components/chains/PromptRunner.py +++ b/src/backend/langflow/components/chains/PromptRunner.py @@ -18,9 +18,7 @@ class PromptRunner(CustomComponent): "code": {"show": False}, } - def build( - self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {} - ) -> Document: + def build(self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}) -> Document: chain = prompt | llm # The input is an empty dict because the prompt is already filled result = chain.invoke(input=inputs) diff --git a/src/backend/langflow/components/retrievers/MetalRetriever.py b/src/backend/langflow/components/retrievers/MetalRetriever.py index b105cd24f..88393f26d 100644 --- a/src/backend/langflow/components/retrievers/MetalRetriever.py +++ b/src/backend/langflow/components/retrievers/MetalRetriever.py @@ -18,9 +18,7 @@ class MetalRetrieverComponent(CustomComponent): "code": {"show": False}, } - def build( - self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None - ) -> BaseRetriever: + def build(self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None) -> BaseRetriever: try: metal = Metal(api_key=api_key, client_id=client_id, index_id=index_id) except Exception as e: diff --git a/src/backend/langflow/components/toolkits/Metaphor.py b/src/backend/langflow/components/toolkits/Metaphor.py index 6f43d24b4..a66da9bca 100644 --- a/src/backend/langflow/components/toolkits/Metaphor.py +++ b/src/backend/langflow/components/toolkits/Metaphor.py @@ -10,9 +10,7 @@ from langchain.agents.agent_toolkits.base import BaseToolkit class MetaphorToolkit(CustomComponent): display_name: str = "Metaphor" description: str = "Metaphor Toolkit" - documentation = ( - "https://python.langchain.com/docs/integrations/tools/metaphor_search" - ) + documentation = "https://python.langchain.com/docs/integrations/tools/metaphor_search" beta = True # api key should be password = True field_config = { @@ -33,9 +31,7 @@ class MetaphorToolkit(CustomComponent): @tool def search(query: str): """Call search engine with a query.""" - return client.search( - query, use_autoprompt=use_autoprompt, num_results=search_num_results - ) + return client.search(query, use_autoprompt=use_autoprompt, num_results=search_num_results) @tool def get_contents(ids: List[str]): diff --git a/src/backend/langflow/components/utilities/GetRequest.py b/src/backend/langflow/components/utilities/GetRequest.py index d5df32cca..546b3b10b 100644 --- a/src/backend/langflow/components/utilities/GetRequest.py +++ b/src/backend/langflow/components/utilities/GetRequest.py @@ -30,9 +30,7 @@ class GetRequest(CustomComponent): }, } - def get_document( - self, session: requests.Session, url: str, headers: Optional[dict], timeout: int - ) -> Document: + def get_document(self, session: requests.Session, url: str, headers: Optional[dict], timeout: int) -> Document: try: response = session.get(url, headers=headers, timeout=int(timeout)) try: diff --git a/src/backend/langflow/components/utilities/JSONDocumentBuilder.py b/src/backend/langflow/components/utilities/JSONDocumentBuilder.py index 26a2afd94..fc637c5a5 100644 --- a/src/backend/langflow/components/utilities/JSONDocumentBuilder.py +++ b/src/backend/langflow/components/utilities/JSONDocumentBuilder.py @@ -21,9 +21,7 @@ class JSONDocumentBuilder(CustomComponent): description: str = "Build a Document containing a JSON object using a key and another Document page content." output_types: list[str] = ["Document"] beta = True - documentation: str = ( - "https://docs.langflow.org/components/utilities#json-document-builder" - ) + documentation: str = "https://docs.langflow.org/components/utilities#json-document-builder" field_config = { "key": {"display_name": "Key"}, @@ -38,18 +36,11 @@ class JSONDocumentBuilder(CustomComponent): documents = None if isinstance(document, list): documents = [ - Document( - page_content=orjson_dumps({key: doc.page_content}, indent_2=False) - ) - for doc in document + Document(page_content=orjson_dumps({key: doc.page_content}, indent_2=False)) for doc in document ] elif isinstance(document, Document): - documents = Document( - page_content=orjson_dumps({key: document.page_content}, indent_2=False) - ) + documents = Document(page_content=orjson_dumps({key: document.page_content}, indent_2=False)) else: - raise TypeError( - f"Expected Document or list of Documents, got {type(document)}" - ) + raise TypeError(f"Expected Document or list of Documents, got {type(document)}") self.repr_value = documents return documents diff --git a/src/backend/langflow/components/utilities/PostRequest.py b/src/backend/langflow/components/utilities/PostRequest.py index 6857f4866..81c54660a 100644 --- a/src/backend/langflow/components/utilities/PostRequest.py +++ b/src/backend/langflow/components/utilities/PostRequest.py @@ -65,16 +65,12 @@ class PostRequest(CustomComponent): if not isinstance(document, list) and isinstance(document, Document): documents: list[Document] = [document] - elif isinstance(document, list) and all( - isinstance(doc, Document) for doc in document - ): + elif isinstance(document, list) and all(isinstance(doc, Document) for doc in document): documents = document else: raise ValueError("document must be a Document or a list of Documents") with requests.Session() as session: - documents = [ - self.post_document(session, doc, url, headers) for doc in documents - ] + documents = [self.post_document(session, doc, url, headers) for doc in documents] self.repr_value = documents return documents diff --git a/src/backend/langflow/components/utilities/UpdateRequest.py b/src/backend/langflow/components/utilities/UpdateRequest.py index d18c94a56..6f9ef91a5 100644 --- a/src/backend/langflow/components/utilities/UpdateRequest.py +++ b/src/backend/langflow/components/utilities/UpdateRequest.py @@ -39,9 +39,7 @@ class UpdateRequest(CustomComponent): ) -> Document: try: if method == "PATCH": - response = session.patch( - url, headers=headers, data=document.page_content - ) + response = session.patch(url, headers=headers, data=document.page_content) elif method == "PUT": response = session.put(url, headers=headers, data=document.page_content) else: @@ -78,17 +76,12 @@ class UpdateRequest(CustomComponent): if not isinstance(document, list) and isinstance(document, Document): documents: list[Document] = [document] - elif isinstance(document, list) and all( - isinstance(doc, Document) for doc in document - ): + elif isinstance(document, list) and all(isinstance(doc, Document) for doc in document): documents = document else: raise ValueError("document must be a Document or a list of Documents") with requests.Session() as session: - documents = [ - self.update_document(session, doc, url, headers, method) - for doc in documents - ] + documents = [self.update_document(session, doc, url, headers, method) for doc in documents] self.repr_value = documents return documents diff --git a/src/backend/langflow/components/vectorstores/Chroma.py b/src/backend/langflow/components/vectorstores/Chroma.py index 3cfd4771e..c798ab83b 100644 --- a/src/backend/langflow/components/vectorstores/Chroma.py +++ b/src/backend/langflow/components/vectorstores/Chroma.py @@ -86,8 +86,7 @@ class ChromaComponent(CustomComponent): if chroma_server_host is not None: chroma_settings = chromadb.config.Settings( - chroma_server_cors_allow_origins=chroma_server_cors_allow_origins - or None, + chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None, chroma_server_host=chroma_server_host, chroma_server_port=chroma_server_port or None, chroma_server_grpc_port=chroma_server_grpc_port or None, @@ -104,6 +103,4 @@ class ChromaComponent(CustomComponent): client_settings=chroma_settings, ) - return Chroma( - persist_directory=persist_directory, client_settings=chroma_settings - ) + return Chroma(persist_directory=persist_directory, client_settings=chroma_settings) diff --git a/src/backend/langflow/components/vectorstores/Vectara.py b/src/backend/langflow/components/vectorstores/Vectara.py index 6f7fcc0bb..eeee290d1 100644 --- a/src/backend/langflow/components/vectorstores/Vectara.py +++ b/src/backend/langflow/components/vectorstores/Vectara.py @@ -10,9 +10,7 @@ from langchain.schema import BaseRetriever class VectaraComponent(CustomComponent): display_name: str = "Vectara" description: str = "Implementation of Vector Store using Vectara" - documentation = ( - "https://python.langchain.com/docs/integrations/vectorstores/vectara" - ) + documentation = "https://python.langchain.com/docs/integrations/vectorstores/vectara" beta = True # api key should be password = True field_config = { diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index 82714e395..f9d77741b 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -8,9 +8,7 @@ if TYPE_CHECKING: class SourceHandle(BaseModel): - baseClasses: List[str] = Field( - ..., description="List of base classes for the source handle." - ) + baseClasses: List[str] = Field(..., description="List of base classes for the source handle.") dataType: str = Field(..., description="Data type for the source handle.") id: str = Field(..., description="Unique identifier for the source handle.") @@ -18,9 +16,7 @@ class SourceHandle(BaseModel): class TargetHandle(BaseModel): fieldName: str = Field(..., description="Field name for the target handle.") id: str = Field(..., description="Unique identifier for the target handle.") - inputTypes: Optional[List[str]] = Field( - None, description="List of input types for the target handle." - ) + inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.") type: str = Field(..., description="Type of the target handle.") @@ -49,23 +45,17 @@ class Edge: def validate_handles(self) -> None: if self.target_handle.inputTypes is None: - self.valid_handles = ( - self.target_handle.type in self.source_handle.baseClasses - ) + self.valid_handles = self.target_handle.type in self.source_handle.baseClasses else: self.valid_handles = ( - any( - baseClass in self.target_handle.inputTypes - for baseClass in self.source_handle.baseClasses - ) + any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses) or self.target_handle.type in self.source_handle.baseClasses ) if not self.valid_handles: logger.debug(self.source_handle) logger.debug(self.target_handle) raise ValueError( - f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " - f"has invalid handles" + f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " f"has invalid handles" ) def __setstate__(self, state): @@ -87,11 +77,7 @@ class Edge: # Both lists contain strings and sometimes a string contains the value we are # looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] # so we need to check if any of the strings in source_types is in target_reqs - self.valid = any( - output in target_req - for output in self.source_types - for target_req in self.target_reqs - ) + self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs) # Get what type of input the target node is expecting self.matched_type = next( @@ -103,8 +89,7 @@ class Edge: logger.debug(self.source_types) logger.debug(self.target_reqs) raise ValueError( - f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " - f"has no matched type" + f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " f"has no matched type" ) def __repr__(self) -> str: @@ -117,8 +102,4 @@ class Edge: return hash(self.__repr__()) def __eq__(self, __value: object) -> bool: - return ( - self.__repr__() == __value.__repr__() - if isinstance(__value, Edge) - else False - ) + return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 83976ae49..aeb7ec91d 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -104,9 +104,7 @@ class Graph: return for node in self.nodes: if not self._validate_node(node): - raise ValueError( - f"{node.vertex_type} is not connected to any other components" - ) + raise ValueError(f"{node.vertex_type} is not connected to any other components") def _validate_node(self, node: Vertex) -> bool: """Validates a node.""" @@ -119,9 +117,7 @@ class Graph: def get_nodes_with_target(self, node: Vertex) -> List[Vertex]: """Returns the nodes connected to a node.""" - connected_nodes: List[Vertex] = [ - edge.source for edge in self.edges if edge.target == node - ] + connected_nodes: List[Vertex] = [edge.source for edge in self.edges if edge.target == node] return connected_nodes def build(self) -> Chain: @@ -149,9 +145,7 @@ class Graph: def dfs(node): if state[node] == 1: # We have a cycle - raise ValueError( - "Graph contains a cycle, cannot perform topological sort" - ) + raise ValueError("Graph contains a cycle, cannot perform topological sort") if state[node] == 0: state[node] = 1 for edge in node.edges: @@ -245,7 +239,5 @@ class Graph: def __repr__(self): node_ids = [node.id for node in self.nodes] - edges_repr = "\n".join( - [f"{edge.source.id} --> {edge.target.id}" for edge in self.edges] - ) + edges_repr = "\n".join([f"{edge.source.id} --> {edge.target.id}" for edge in self.edges]) return f"Graph:\nNodes: {node_ids}\nConnections:\n{edges_repr}" diff --git a/src/backend/langflow/graph/graph/constants.py b/src/backend/langflow/graph/graph/constants.py index c9fea48b5..abfc2970f 100644 --- a/src/backend/langflow/graph/graph/constants.py +++ b/src/backend/langflow/graph/graph/constants.py @@ -47,10 +47,7 @@ class VertexTypesDict(LazyLoadDictBase): **{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()}, **{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()}, **{t: types.OutputParserVertex for t in output_parser_creator.to_list()}, - **{ - t: types.CustomComponentVertex - for t in custom_component_creator.to_list() - }, + **{t: types.CustomComponentVertex for t in custom_component_creator.to_list()}, **{t: types.RetrieverVertex for t in retriever_creator.to_list()}, } diff --git a/src/backend/langflow/graph/graph/utils.py b/src/backend/langflow/graph/graph/utils.py index 71b81fea1..a3299739e 100644 --- a/src/backend/langflow/graph/graph/utils.py +++ b/src/backend/langflow/graph/graph/utils.py @@ -28,23 +28,14 @@ def ungroup_node(group_node_data, base_flow): g_edges = flow["data"]["edges"] # Redirect edges to the correct proxy node - updated_edges = get_updated_edges( - base_flow, g_nodes, g_edges, group_node_data["id"] - ) + updated_edges = get_updated_edges(base_flow, g_nodes, g_edges, group_node_data["id"]) # Update template values update_template(template, g_nodes) - nodes = [ - n for n in base_flow["nodes"] if n["id"] != group_node_data["id"] - ] + g_nodes + nodes = [n for n in base_flow["nodes"] if n["id"] != group_node_data["id"]] + g_nodes edges = ( - [ - e - for e in base_flow["edges"] - if e["target"] != group_node_data["id"] - and e["source"] != group_node_data["id"] - ] + [e for e in base_flow["edges"] if e["target"] != group_node_data["id"] and e["source"] != group_node_data["id"]] + g_edges + updated_edges ) @@ -66,11 +57,7 @@ def process_flow(flow_object): if node_id in processed_nodes: return - if ( - node.get("data") - and node["data"].get("node") - and node["data"]["node"].get("flow") - ): + if node.get("data") and node["data"].get("node") and node["data"]["node"].get("flow"): process_flow(node["data"]["node"]["flow"]["data"]) new_nodes = ungroup_node(node["data"], cloned_flow) # Add new nodes to the queue for future processing @@ -108,26 +95,16 @@ def update_template(template, g_nodes): if node_index != -1: display_name = None show = g_nodes[node_index]["data"]["node"]["template"][field]["show"] - advanced = g_nodes[node_index]["data"]["node"]["template"][field][ - "advanced" - ] + advanced = g_nodes[node_index]["data"]["node"]["template"][field]["advanced"] if "display_name" in g_nodes[node_index]["data"]["node"]["template"][field]: - display_name = g_nodes[node_index]["data"]["node"]["template"][field][ - "display_name" - ] + display_name = g_nodes[node_index]["data"]["node"]["template"][field]["display_name"] else: - display_name = g_nodes[node_index]["data"]["node"]["template"][field][ - "name" - ] + display_name = g_nodes[node_index]["data"]["node"]["template"][field]["name"] g_nodes[node_index]["data"]["node"]["template"][field] = value g_nodes[node_index]["data"]["node"]["template"][field]["show"] = show - g_nodes[node_index]["data"]["node"]["template"][field][ - "advanced" - ] = advanced - g_nodes[node_index]["data"]["node"]["template"][field][ - "display_name" - ] = display_name + g_nodes[node_index]["data"]["node"]["template"][field]["advanced"] = advanced + g_nodes[node_index]["data"]["node"]["template"][field]["display_name"] = display_name def update_target_handle(new_edge, g_nodes, group_node_id): diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 88a4db26e..5862ca3ae 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -51,9 +51,7 @@ class Vertex: self.params.pop(target_param, None) continue - if target_param in self.params and not is_basic_type( - self.params[target_param] - ): + if target_param in self.params and not is_basic_type(self.params[target_param]): # edge.source.params = {} edge.source._build_params() edge.source._built_object = UnbuiltObject() @@ -99,29 +97,17 @@ class Vertex: def _parse_data(self) -> None: self.data = self._data["data"] self.output = self.data["node"]["base_classes"] - template_dicts = { - key: value - for key, value in self.data["node"]["template"].items() - if isinstance(value, dict) - } + template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} self.required_inputs = [ - template_dicts[key]["type"] - for key, value in template_dicts.items() - if value["required"] + template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"] ] self.optional_inputs = [ - template_dicts[key]["type"] - for key, value in template_dicts.items() - if not value["required"] + template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"] ] # Add the template_dicts[key]["input_types"] to the optional_inputs self.optional_inputs.extend( - [ - input_type - for value in template_dicts.values() - for input_type in value.get("input_types", []) - ] + [input_type for value in template_dicts.values() for input_type in value.get("input_types", [])] ) template_dict = self.data["node"]["template"] @@ -160,11 +146,7 @@ class Vertex: # and use that as the value for the param # If the type is "str", then we need to get the value of the "value" key # and use that as the value for the param - template_dict = { - key: value - for key, value in self.data["node"]["template"].items() - if isinstance(value, dict) - } + template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} params = self.params.copy() if self.params else {} for edge in self.edges: @@ -209,11 +191,7 @@ class Vertex: # before passing it to the build method _value = value.get("value") if isinstance(_value, list): - params[key] = { - k: v - for item in value.get("value", []) - for k, v in item.items() - } + params[key] = {k: v for item in value.get("value", []) for k, v in item.items()} elif isinstance(_value, dict): params[key] = _value elif value.get("type") == "int" and value.get("value") is not None: @@ -304,9 +282,7 @@ class Vertex: self._extend_params_list_with_result(key, result) self.params[key] = result - def _build_list_of_nodes_and_update_params( - self, key, nodes: List["Vertex"], user_id=None - ): + def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None): """ Iterates over a list of nodes, builds each and updates the params dictionary. """ @@ -358,9 +334,7 @@ class Vertex: self._update_built_object_and_artifacts(result) except Exception as exc: logger.exception(exc) - raise ValueError( - f"Error building node {self.vertex_type}(ID:{self.id}): {str(exc)}" - ) from exc + raise ValueError(f"Error building node {self.vertex_type}(ID:{self.id}): {str(exc)}") from exc def _update_built_object_and_artifacts(self, result): """ @@ -408,8 +382,4 @@ class Vertex: def _built_object_repr(self): # Add a message with an emoji, stars for sucess, - return ( - "Built sucessfully ✨" - if self._built_object is not None - else "Failed to build 😵‍💫" - ) + return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵‍💫" diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index fdbabe510..5e7dfae5a 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -107,11 +107,9 @@ class DocumentLoaderVertex(Vertex): # show how many documents are in the list? if self._built_object: - avg_length = sum( - len(doc.page_content) - for doc in self._built_object - if hasattr(doc, "page_content") - ) / len(self._built_object) + avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len( + self._built_object + ) return f"""{self.vertex_type}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} Documents: {self._built_object[:3]}...""" @@ -184,9 +182,7 @@ class TextSplitterVertex(Vertex): # show how many documents are in the list? if self._built_object: - avg_length = sum(len(doc.page_content) for doc in self._built_object) / len( - self._built_object - ) + avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object) return f"""{self.vertex_type}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} \nDocuments: {self._built_object[:3]}...""" @@ -232,27 +228,18 @@ class PromptVertex(Vertex): **kwargs, ) -> Any: if not self._built or force: - if ( - "input_variables" not in self.params - or self.params["input_variables"] is None - ): + if "input_variables" not in self.params or self.params["input_variables"] is None: self.params["input_variables"] = [] # Check if it is a ZeroShotPrompt and needs a tool if "ShotPrompt" in self.vertex_type: - tools = ( - [tool_node.build(user_id=user_id) for tool_node in tools] - if tools is not None - else [] - ) + tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else [] # flatten the list of tools if it is a list of lists # first check if it is a list if tools and isinstance(tools, list) and isinstance(tools[0], list): tools = flatten_list(tools) self.params["tools"] = tools prompt_params = [ - key - for key, value in self.params.items() - if isinstance(value, str) and key != "format_instructions" + key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions" ] else: prompt_params = ["template"] @@ -262,9 +249,7 @@ class PromptVertex(Vertex): prompt_text = self.params[param] variables = extract_input_variables_from_prompt(prompt_text) self.params["input_variables"].extend(variables) - self.params["input_variables"] = list( - set(self.params["input_variables"]) - ) + self.params["input_variables"] = list(set(self.params["input_variables"])) elif isinstance(self.params, dict): self.params.pop("input_variables", None) @@ -272,11 +257,7 @@ class PromptVertex(Vertex): return self._built_object def _built_object_repr(self): - if ( - not self.artifacts - or self._built_object is None - or not hasattr(self._built_object, "format") - ): + if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"): return super()._built_object_repr() # We'll build the prompt with the artifacts # to show the user what the prompt looks like @@ -286,9 +267,7 @@ class PromptVertex(Vertex): # so the prompt format doesn't break artifacts.pop("handle_keys", None) try: - if not hasattr(self._built_object, "template") and hasattr( - self._built_object, "prompt" - ): + if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"): template = self._built_object.prompt.template else: template = self._built_object.template @@ -296,11 +275,7 @@ class PromptVertex(Vertex): if value: replace_key = "{" + key + "}" template = template.replace(replace_key, value) - return ( - template - if isinstance(template, str) - else f"{self.vertex_type}({template})" - ) + return template if isinstance(template, str) else f"{self.vertex_type}({template})" except KeyError: return str(self._built_object) diff --git a/src/backend/langflow/interface/agents/base.py b/src/backend/langflow/interface/agents/base.py index 8bc12a8df..68ae7b91a 100644 --- a/src/backend/langflow/interface/agents/base.py +++ b/src/backend/langflow/interface/agents/base.py @@ -42,9 +42,7 @@ class AgentCreator(LangChainTypeCreator): add_function=True, method_name=self.from_method_nodes[name], ) - return build_template_from_class( - name, self.type_to_loader_dict, add_function=True - ) + return build_template_from_class(name, self.type_to_loader_dict, add_function=True) except ValueError as exc: raise ValueError("Agent not found") from exc except AttributeError as exc: @@ -56,15 +54,8 @@ class AgentCreator(LangChainTypeCreator): names = [] settings_service = get_settings_service() for _, agent in self.type_to_loader_dict.items(): - agent_name = ( - agent.function_name() - if hasattr(agent, "function_name") - else agent.__name__ - ) - if ( - agent_name in settings_service.settings.AGENTS - or settings_service.settings.DEV - ): + agent_name = agent.function_name() if hasattr(agent, "function_name") else agent.__name__ + if agent_name in settings_service.settings.AGENTS or settings_service.settings.DEV: names.append(agent_name) return names diff --git a/src/backend/langflow/interface/agents/custom.py b/src/backend/langflow/interface/agents/custom.py index 735b27917..95c81f137 100644 --- a/src/backend/langflow/interface/agents/custom.py +++ b/src/backend/langflow/interface/agents/custom.py @@ -66,7 +66,8 @@ class JsonAgent(CustomAgentExecutor): prompt=prompt, ) agent = ZeroShotAgent( - llm_chain=llm_chain, allowed_tools=tool_names # type: ignore + llm_chain=llm_chain, + allowed_tools=tool_names, # type: ignore ) return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True) @@ -90,11 +91,7 @@ class CSVAgent(CustomAgentExecutor): @classmethod def from_toolkit_and_llm( - cls, - path: str, - llm: BaseLanguageModel, - pandas_kwargs: Optional[dict] = None, - **kwargs: Any + cls, path: str, llm: BaseLanguageModel, pandas_kwargs: Optional[dict] = None, **kwargs: Any ): import pandas as pd # type: ignore @@ -115,7 +112,9 @@ class CSVAgent(CustomAgentExecutor): ) tool_names = {tool.name for tool in tools} agent = ZeroShotAgent( - llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore + llm_chain=llm_chain, + allowed_tools=tool_names, + **kwargs, # type: ignore ) return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True) @@ -139,9 +138,7 @@ class VectorStoreAgent(CustomAgentExecutor): super().__init__(*args, **kwargs) @classmethod - def from_toolkit_and_llm( - cls, llm: BaseLanguageModel, vectorstoreinfo: VectorStoreInfo, **kwargs: Any - ): + def from_toolkit_and_llm(cls, llm: BaseLanguageModel, vectorstoreinfo: VectorStoreInfo, **kwargs: Any): """Construct a vectorstore agent from an LLM and tools.""" toolkit = VectorStoreToolkit(vectorstore_info=vectorstoreinfo, llm=llm) @@ -154,11 +151,11 @@ class VectorStoreAgent(CustomAgentExecutor): ) tool_names = {tool.name for tool in tools} agent = ZeroShotAgent( - llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore - ) - return AgentExecutor.from_agent_and_tools( - agent=agent, tools=tools, verbose=True, handle_parsing_errors=True + llm_chain=llm_chain, + allowed_tools=tool_names, + **kwargs, # type: ignore ) + return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True) def run(self, *args, **kwargs): return super().run(*args, **kwargs) @@ -179,9 +176,7 @@ class SQLAgent(CustomAgentExecutor): super().__init__(*args, **kwargs) @classmethod - def from_toolkit_and_llm( - cls, llm: BaseLanguageModel, database_uri: str, **kwargs: Any - ): + def from_toolkit_and_llm(cls, llm: BaseLanguageModel, database_uri: str, **kwargs: Any): """Construct an SQL agent from an LLM and tools.""" db = SQLDatabase.from_uri(database_uri) toolkit = SQLDatabaseToolkit(db=db, llm=llm) @@ -199,9 +194,7 @@ class SQLAgent(CustomAgentExecutor): llmchain = LLMChain( llm=llm, - prompt=PromptTemplate( - template=QUERY_CHECKER, input_variables=["query", "dialect"] - ), + prompt=PromptTemplate(template=QUERY_CHECKER, input_variables=["query", "dialect"]), ) tools = [ @@ -224,7 +217,9 @@ class SQLAgent(CustomAgentExecutor): ) tool_names = {tool.name for tool in tools} # type: ignore agent = ZeroShotAgent( - llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore + llm_chain=llm_chain, + allowed_tools=tool_names, + **kwargs, # type: ignore ) return AgentExecutor.from_agent_and_tools( agent=agent, @@ -255,10 +250,7 @@ class VectorStoreRouterAgent(CustomAgentExecutor): @classmethod def from_toolkit_and_llm( - cls, - llm: BaseLanguageModel, - vectorstoreroutertoolkit: VectorStoreRouterToolkit, - **kwargs: Any + cls, llm: BaseLanguageModel, vectorstoreroutertoolkit: VectorStoreRouterToolkit, **kwargs: Any ): """Construct a vector store router agent from an LLM and tools.""" @@ -274,11 +266,11 @@ class VectorStoreRouterAgent(CustomAgentExecutor): ) tool_names = {tool.name for tool in tools} agent = ZeroShotAgent( - llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore - ) - return AgentExecutor.from_agent_and_tools( - agent=agent, tools=tools, verbose=True, handle_parsing_errors=True + llm_chain=llm_chain, + allowed_tools=tool_names, + **kwargs, # type: ignore ) + return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True) def run(self, *args, **kwargs): return super().run(*args, **kwargs) diff --git a/src/backend/langflow/interface/base.py b/src/backend/langflow/interface/base.py index 87538b248..bc9db0a93 100644 --- a/src/backend/langflow/interface/base.py +++ b/src/backend/langflow/interface/base.py @@ -30,13 +30,8 @@ class LangChainTypeCreator(BaseModel, ABC): settings_service = get_settings_service() if self.name_docs_dict is None: try: - type_settings = getattr( - settings_service.settings, self.type_name.upper() - ) - self.name_docs_dict = { - name: value_dict["documentation"] - for name, value_dict in type_settings.items() - } + type_settings = getattr(settings_service.settings, self.type_name.upper()) + self.name_docs_dict = {name: value_dict["documentation"] for name, value_dict in type_settings.items()} except AttributeError as exc: logger.error(f"Error getting settings for {self.type_name}: {exc}") diff --git a/src/backend/langflow/interface/chains/base.py b/src/backend/langflow/interface/chains/base.py index 9d84c707c..1ca2e944c 100644 --- a/src/backend/langflow/interface/chains/base.py +++ b/src/backend/langflow/interface/chains/base.py @@ -33,8 +33,7 @@ class ChainCreator(LangChainTypeCreator): if self.type_dict is None: settings_service = get_settings_service() self.type_dict: dict[str, Any] = { - chain_name: import_class(f"langchain.chains.{chain_name}") - for chain_name in chains.__all__ + chain_name: import_class(f"langchain.chains.{chain_name}") for chain_name in chains.__all__ } from langflow.interface.chains.custom import CUSTOM_CHAINS @@ -45,8 +44,7 @@ class ChainCreator(LangChainTypeCreator): self.type_dict = { name: chain for name, chain in self.type_dict.items() - if name in settings_service.settings.CHAINS - or settings_service.settings.DEV + if name in settings_service.settings.CHAINS or settings_service.settings.DEV } return self.type_dict @@ -61,9 +59,7 @@ class ChainCreator(LangChainTypeCreator): method_name=self.from_method_nodes[name], add_function=True, ) - return build_template_from_class( - name, self.type_to_loader_dict, add_function=True - ) + return build_template_from_class(name, self.type_to_loader_dict, add_function=True) except ValueError as exc: raise ValueError(f"Chain {name} not found: {exc}") from exc except AttributeError as exc: @@ -73,11 +69,7 @@ class ChainCreator(LangChainTypeCreator): def to_list(self) -> List[str]: names = [] for _, chain in self.type_to_loader_dict.items(): - chain_name = ( - chain.function_name() - if hasattr(chain, "function_name") - else chain.__name__ - ) + chain_name = chain.function_name() if hasattr(chain, "function_name") else chain.__name__ names.append(chain_name) return names diff --git a/src/backend/langflow/interface/chains/custom.py b/src/backend/langflow/interface/chains/custom.py index 01dd9bab0..dc960f71c 100644 --- a/src/backend/langflow/interface/chains/custom.py +++ b/src/backend/langflow/interface/chains/custom.py @@ -41,9 +41,7 @@ class BaseCustomConversationChain(ConversationChain): values["template"] = values["template"].format(**format_dict) values["template"] = values["template"] - values["input_variables"] = extract_input_variables_from_prompt( - values["template"] - ) + values["input_variables"] = extract_input_variables_from_prompt(values["template"]) values["prompt"].template = values["template"] values["prompt"].input_variables = values["input_variables"] return values @@ -54,9 +52,7 @@ class SeriesCharacterChain(BaseCustomConversationChain): character: str series: str - template: Optional[ - str - ] = """I want you to act like {character} from {series}. + template: Optional[str] = """I want you to act like {character} from {series}. I want you to respond and answer like {character}. do not write any explanations. only answer like {character}. You must know all of the knowledge of {character}. Current conversation: @@ -71,9 +67,7 @@ Human: {input} class MidJourneyPromptChain(BaseCustomConversationChain): """MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts.""" - template: Optional[ - str - ] = """I want you to act as a prompt generator for Midjourney's artificial intelligence program. + template: Optional[str] = """I want you to act as a prompt generator for Midjourney's artificial intelligence program. Your job is to provide detailed and creative descriptions that will inspire unique and interesting images from the AI. Keep in mind that the AI is capable of understanding a wide range of language and can interpret abstract concepts, so feel free to be as imaginative and descriptive as possible. For example, you could describe a scene from a futuristic city, or a surreal landscape filled with strange creatures. @@ -87,9 +81,7 @@ class MidJourneyPromptChain(BaseCustomConversationChain): class TimeTravelGuideChain(BaseCustomConversationChain): - template: Optional[ - str - ] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information. + template: Optional[str] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information. Current conversation: {history} Human: {input} diff --git a/src/backend/langflow/interface/custom/code_parser.py b/src/backend/langflow/interface/custom/code_parser.py index d42f82635..5202ed507 100644 --- a/src/backend/langflow/interface/custom/code_parser.py +++ b/src/backend/langflow/interface/custom/code_parser.py @@ -127,22 +127,14 @@ class CodeParser: num_defaults = len(node.args.defaults) num_missing_defaults = num_args - num_defaults missing_defaults = [None] * num_missing_defaults - default_values = [ - ast.unparse(default).strip("'") if default else None - for default in node.args.defaults - ] + default_values = [ast.unparse(default).strip("'") if default else None for default in node.args.defaults] # Now check all default values to see if there # are any "None" values in the middle - default_values = [ - None if value == "None" else value for value in default_values - ] + default_values = [None if value == "None" else value for value in default_values] defaults = missing_defaults + default_values - args = [ - self.parse_arg(arg, default) - for arg, default in zip(node.args.args, defaults) - ] + args = [self.parse_arg(arg, default) for arg, default in zip(node.args.args, defaults)] return args def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]: @@ -160,17 +152,11 @@ class CodeParser: """ Parses the keyword-only arguments of a function or method node. """ - kw_defaults = [None] * ( - len(node.args.kwonlyargs) - len(node.args.kw_defaults) - ) + [ - ast.unparse(default) if default else None - for default in node.args.kw_defaults + kw_defaults = [None] * (len(node.args.kwonlyargs) - len(node.args.kw_defaults)) + [ + ast.unparse(default) if default else None for default in node.args.kw_defaults ] - args = [ - self.parse_arg(arg, default) - for arg, default in zip(node.args.kwonlyargs, kw_defaults) - ] + args = [self.parse_arg(arg, default) for arg, default in zip(node.args.kwonlyargs, kw_defaults)] return args def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]: @@ -254,9 +240,7 @@ class CodeParser: Extracts global variables from the code. """ global_var = { - "targets": [ - t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets - ], + "targets": [t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets], "value": ast.unparse(node.value), } self.data["global_vars"].append(global_var) diff --git a/src/backend/langflow/interface/custom/component.py b/src/backend/langflow/interface/custom/component.py index 06db5bd46..92df3ef93 100644 --- a/src/backend/langflow/interface/custom/component.py +++ b/src/backend/langflow/interface/custom/component.py @@ -17,9 +17,7 @@ class ComponentFunctionEntrypointNameNullError(HTTPException): class Component(BaseModel): ERROR_CODE_NULL = "Python code must be provided." - ERROR_FUNCTION_ENTRYPOINT_NAME_NULL = ( - "The name of the entrypoint function must be provided." - ) + ERROR_FUNCTION_ENTRYPOINT_NAME_NULL = "The name of the entrypoint function must be provided." code: Optional[str] function_entrypoint_name = "build" diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index 3b4218f45..4b3c3ebac 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -53,9 +53,9 @@ class CustomComponent(Component, extra=Extra.allow): reader = DirectoryReader("", False) for type_hint in TYPE_HINT_LIST: - if reader._is_type_hint_used_in_args( + if reader._is_type_hint_used_in_args(type_hint, code) and not reader._is_type_hint_imported( type_hint, code - ) and not reader._is_type_hint_imported(type_hint, code): + ): error_detail = { "error": "Type hint Error", "traceback": f"Type hint '{type_hint}' is used but not imported in the code.", @@ -74,20 +74,14 @@ class CustomComponent(Component, extra=Extra.allow): return "" tree = self.get_code_tree(self.code) - component_classes = [ - cls - for cls in tree["classes"] - if self.code_class_base_inheritance in cls["bases"] - ] + component_classes = [cls for cls in tree["classes"] if self.code_class_base_inheritance in cls["bases"]] if not component_classes: return "" # Assume the first Component class is the one we're interested in component_class = component_classes[0] build_methods = [ - method - for method in component_class["methods"] - if method["name"] == self.function_entrypoint_name + method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name ] if not build_methods: @@ -103,8 +97,7 @@ class CustomComponent(Component, extra=Extra.allow): detail={ "error": "Type hint Error", "traceback": ( - "Prompt type is not supported in the build method." - " Try using PromptTemplate instead." + "Prompt type is not supported in the build method." " Try using PromptTemplate instead." ), }, ) @@ -119,20 +112,14 @@ class CustomComponent(Component, extra=Extra.allow): return [] tree = self.get_code_tree(self.code) - component_classes = [ - cls - for cls in tree["classes"] - if self.code_class_base_inheritance in cls["bases"] - ] + component_classes = [cls for cls in tree["classes"] if self.code_class_base_inheritance in cls["bases"]] if not component_classes: return [] # Assume the first Component class is the one we're interested in component_class = component_classes[0] build_methods = [ - method - for method in component_class["methods"] - if method["name"] == self.function_entrypoint_name + method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name ] if not build_methods: @@ -230,11 +217,7 @@ class CustomComponent(Component, extra=Extra.allow): if flow_id: flow = session.query(Flow).get(flow_id) elif flow_name: - flow = ( - session.query(Flow) - .filter(Flow.name == flow_name) - .filter(Flow.user_id == self.user_id) - ).first() + flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first() else: raise ValueError("Either flow_name or flow_id must be provided") diff --git a/src/backend/langflow/interface/custom/directory_reader.py b/src/backend/langflow/interface/custom/directory_reader.py index 01b11a4a6..e80f0bd28 100644 --- a/src/backend/langflow/interface/custom/directory_reader.py +++ b/src/backend/langflow/interface/custom/directory_reader.py @@ -76,9 +76,7 @@ class DirectoryReader: for menu in data["menu"] ] filtered = [menu for menu in items if menu["components"]] - logger.debug( - f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}' - ) + logger.debug(f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}') return {"menu": filtered} def validate_code(self, file_content): @@ -111,9 +109,7 @@ class DirectoryReader: Walk through the directory path and return a list of all .py files. """ if not (safe_path := self.get_safe_path()): - raise CustomComponentPathValueError( - f"The path needs to start with '{self.base_path}'." - ) + raise CustomComponentPathValueError(f"The path needs to start with '{self.base_path}'.") file_list = [] for root, _, files in os.walk(safe_path): @@ -158,9 +154,7 @@ class DirectoryReader: for node in ast.walk(module): if isinstance(node, ast.FunctionDef): for arg in node.args.args: - if self._is_type_hint_in_arg_annotation( - arg.annotation, type_hint_name - ): + if self._is_type_hint_in_arg_annotation(arg.annotation, type_hint_name): return True except SyntaxError: # Returns False if the code is not valid Python @@ -178,16 +172,14 @@ class DirectoryReader: and annotation.value.id == type_hint_name ) - def is_type_hint_used_but_not_imported( - self, type_hint_name: str, code: str - ) -> bool: + def is_type_hint_used_but_not_imported(self, type_hint_name: str, code: str) -> bool: """ Check if a type hint is used but not imported in the given code. """ try: - return self._is_type_hint_used_in_args( + return self._is_type_hint_used_in_args(type_hint_name, code) and not self._is_type_hint_imported( type_hint_name, code - ) and not self._is_type_hint_imported(type_hint_name, code) + ) except SyntaxError: # Returns True if there's something wrong with the code # TODO : Find a better way to handle this @@ -208,9 +200,9 @@ class DirectoryReader: return False, "Syntax error" elif not self.validate_build(file_content): return False, "Missing build function" - elif self._is_type_hint_used_in_args( + elif self._is_type_hint_used_in_args("Optional", file_content) and not self._is_type_hint_imported( "Optional", file_content - ) and not self._is_type_hint_imported("Optional", file_content): + ): return ( False, "Type hint 'Optional' is used but not imported in the code.", @@ -226,9 +218,7 @@ class DirectoryReader: from the .py files in the directory. """ response = {"menu": []} - logger.debug( - "-------------------- Building component menu list --------------------" - ) + logger.debug("-------------------- Building component menu list --------------------") for file_path in file_paths: menu_name = os.path.basename(os.path.dirname(file_path)) @@ -248,9 +238,7 @@ class DirectoryReader: # first check if it's already CamelCase if "_" in component_name: - component_name_camelcase = " ".join( - word.title() for word in component_name.split("_") - ) + component_name_camelcase = " ".join(word.title() for word in component_name.split("_")) else: component_name_camelcase = component_name @@ -266,7 +254,5 @@ class DirectoryReader: logger.debug(f"Component info: {component_info}") if menu_result not in response["menu"]: response["menu"].append(menu_result) - logger.debug( - "-------------------- Component menu list built --------------------" - ) + logger.debug("-------------------- Component menu list built --------------------") return response diff --git a/src/backend/langflow/interface/custom_lists.py b/src/backend/langflow/interface/custom_lists.py index 5a22d989f..9ad34edc2 100644 --- a/src/backend/langflow/interface/custom_lists.py +++ b/src/backend/langflow/interface/custom_lists.py @@ -46,34 +46,26 @@ toolkit_type_to_cls_dict: dict[str, Any] = { # Memories memory_type_to_cls_dict: dict[str, Any] = { - memory_name: import_class(f"langchain.memory.{memory_name}") - for memory_name in memory.__all__ + memory_name: import_class(f"langchain.memory.{memory_name}") for memory_name in memory.__all__ } # Wrappers -wrapper_type_to_cls_dict: dict[str, Any] = { - wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper] -} +wrapper_type_to_cls_dict: dict[str, Any] = {wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]} # Embeddings embedding_type_to_cls_dict: dict[str, Any] = { - embedding_name: import_class(f"langchain.embeddings.{embedding_name}") - for embedding_name in embeddings.__all__ + embedding_name: import_class(f"langchain.embeddings.{embedding_name}") for embedding_name in embeddings.__all__ } # Document Loaders documentloaders_type_to_cls_dict: dict[str, Any] = { - documentloader_name: import_class( - f"langchain.document_loaders.{documentloader_name}" - ) + documentloader_name: import_class(f"langchain.document_loaders.{documentloader_name}") for documentloader_name in document_loaders.__all__ } # Text Splitters -textsplitter_type_to_cls_dict: dict[str, Any] = dict( - inspect.getmembers(text_splitter, inspect.isclass) -) +textsplitter_type_to_cls_dict: dict[str, Any] = dict(inspect.getmembers(text_splitter, inspect.isclass)) # merge CUSTOM_AGENTS and CUSTOM_CHAINS CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} # type: ignore diff --git a/src/backend/langflow/interface/document_loaders/base.py b/src/backend/langflow/interface/document_loaders/base.py index 8099390a5..6142d2fa2 100644 --- a/src/backend/langflow/interface/document_loaders/base.py +++ b/src/backend/langflow/interface/document_loaders/base.py @@ -35,8 +35,7 @@ class DocumentLoaderCreator(LangChainTypeCreator): return [ documentloader.__name__ for documentloader in self.type_to_loader_dict.values() - if documentloader.__name__ in settings_service.settings.DOCUMENTLOADERS - or settings_service.settings.DEV + if documentloader.__name__ in settings_service.settings.DOCUMENTLOADERS or settings_service.settings.DEV ] diff --git a/src/backend/langflow/interface/embeddings/base.py b/src/backend/langflow/interface/embeddings/base.py index b253e4dfa..834ea61fa 100644 --- a/src/backend/langflow/interface/embeddings/base.py +++ b/src/backend/langflow/interface/embeddings/base.py @@ -37,8 +37,7 @@ class EmbeddingCreator(LangChainTypeCreator): return [ embedding.__name__ for embedding in self.type_to_loader_dict.values() - if embedding.__name__ in settings_service.settings.EMBEDDINGS - or settings_service.settings.DEV + if embedding.__name__ in settings_service.settings.EMBEDDINGS or settings_service.settings.DEV ] diff --git a/src/backend/langflow/interface/importing/utils.py b/src/backend/langflow/interface/importing/utils.py index 44d72df42..6b28d792e 100644 --- a/src/backend/langflow/interface/importing/utils.py +++ b/src/backend/langflow/interface/importing/utils.py @@ -104,10 +104,7 @@ def import_prompt(prompt: str) -> Type[PromptTemplate]: def import_wrapper(wrapper: str) -> Any: """Import wrapper from wrapper name""" - if ( - isinstance(wrapper_creator.type_dict, dict) - and wrapper in wrapper_creator.type_dict - ): + if isinstance(wrapper_creator.type_dict, dict) and wrapper in wrapper_creator.type_dict: return wrapper_creator.type_dict.get(wrapper) diff --git a/src/backend/langflow/interface/initialize/llm.py b/src/backend/langflow/interface/initialize/llm.py index eaed04b77..05219d7d3 100644 --- a/src/backend/langflow/interface/initialize/llm.py +++ b/src/backend/langflow/interface/initialize/llm.py @@ -2,8 +2,6 @@ def initialize_vertexai(class_object, params): if credentials_path := params.get("credentials"): from google.oauth2 import service_account # type: ignore - credentials_object = service_account.Credentials.from_service_account_file( - filename=credentials_path - ) + credentials_object = service_account.Credentials.from_service_account_file(filename=credentials_path) params["credentials"] = credentials_object return class_object(**params) diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 7b3ad4f6f..bf8fa2187 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -44,15 +44,10 @@ def build_vertex_in_params(params: Dict) -> Dict: from langflow.graph.vertex.base import Vertex # If any of the values in params is a Vertex, we will build it - return { - key: value.build() if isinstance(value, Vertex) else value - for key, value in params.items() - } + return {key: value.build() if isinstance(value, Vertex) else value for key, value in params.items()} -def instantiate_class( - node_type: str, base_type: str, params: Dict, user_id=None -) -> Any: +def instantiate_class(node_type: str, base_type: str, params: Dict, user_id=None) -> Any: """Instantiate class from module type and key, and params""" params = convert_params_to_sets(params) params = convert_kwargs(params) @@ -64,9 +59,7 @@ def instantiate_class( return custom_node(**params) logger.debug(f"Instantiating {node_type} of type {base_type}") class_object = import_by_type(_type=base_type, name=node_type) - return instantiate_based_on_type( - class_object, base_type, node_type, params, user_id=user_id - ) + return instantiate_based_on_type(class_object, base_type, node_type, params, user_id=user_id) def convert_params_to_sets(params): @@ -194,9 +187,7 @@ def instantiate_memory(node_type, class_object, params): # I want to catch a specific attribute error that happens # when the object does not have a cursor attribute except Exception as exc: - if "object has no attribute 'cursor'" in str( - exc - ) or 'object has no field "conn"' in str(exc): + if "object has no attribute 'cursor'" in str(exc) or 'object has no field "conn"' in str(exc): raise AttributeError( ( "Failed to build connection to database." @@ -235,9 +226,7 @@ def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params: if class_method := getattr(class_object, method, None): agent = class_method(**params) tools = params.get("tools", []) - return AgentExecutor.from_agent_and_tools( - agent=agent, tools=tools, handle_parsing_errors=True - ) + return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, handle_parsing_errors=True) return load_agent_executor(class_object, params) @@ -290,11 +279,7 @@ def instantiate_embedding(node_type, class_object, params: Dict): try: return class_object(**params) except ValidationError: - params = { - key: value - for key, value in params.items() - if key in class_object.__fields__ - } + params = {key: value for key, value in params.items() if key in class_object.__fields__} return class_object(**params) @@ -304,9 +289,7 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict): if "texts" in params: params["documents"] = params.pop("texts") if "documents" in params: - params["documents"] = [ - doc for doc in params["documents"] if isinstance(doc, Document) - ] + params["documents"] = [doc for doc in params["documents"] if isinstance(doc, Document)] if initializer := vecstore_initializer.get(class_object.__name__): vecstore = initializer(class_object, params) else: @@ -321,9 +304,7 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict): return vecstore -def instantiate_documentloader( - node_type: str, class_object: Type[BaseLoader], params: Dict -): +def instantiate_documentloader(node_type: str, class_object: Type[BaseLoader], params: Dict): if "file_filter" in params: # file_filter will be a string but we need a function # that will be used to filter the files using file_filter @@ -332,17 +313,13 @@ def instantiate_documentloader( # in x and if it is, we will return True file_filter = params.pop("file_filter") extensions = file_filter.split(",") - params["file_filter"] = lambda x: any( - extension.strip() in x for extension in extensions - ) + params["file_filter"] = lambda x: any(extension.strip() in x for extension in extensions) metadata = params.pop("metadata", None) if metadata and isinstance(metadata, str): try: metadata = orjson.loads(metadata) except json.JSONDecodeError as exc: - raise ValueError( - "The metadata you provided is not a valid JSON string." - ) from exc + raise ValueError("The metadata you provided is not a valid JSON string.") from exc if node_type == "WebBaseLoader": if web_path := params.pop("web_path", None): @@ -375,16 +352,12 @@ def instantiate_textsplitter( "Try changing the chunk_size of the Text Splitter." ) from exc - if ( - "separator_type" in params and params["separator_type"] == "Text" - ) or "separator_type" not in params: + if ("separator_type" in params and params["separator_type"] == "Text") or "separator_type" not in params: params.pop("separator_type", None) # separators might come in as an escaped string like \\n # so we need to convert it to a string if "separators" in params: - params["separators"] = ( - params["separators"].encode().decode("unicode-escape") - ) + params["separators"] = params["separators"].encode().decode("unicode-escape") text_splitter = class_object(**params) else: from langchain.text_splitter import Language @@ -411,8 +384,7 @@ def replace_zero_shot_prompt_with_prompt_template(nodes): tools = [ tool for tool in nodes - if tool["type"] != "chatOutputNode" - and "Tool" in tool["data"]["node"]["base_classes"] + if tool["type"] != "chatOutputNode" and "Tool" in tool["data"]["node"]["base_classes"] ] node["data"] = build_prompt_template(prompt=node["data"], tools=tools) break @@ -426,9 +398,7 @@ def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs) # agent has hidden args for memory. might need to be support # memory = params["memory"] # if allowed_tools is not a list or set, make it a list - if not isinstance(allowed_tools, (list, set)) and isinstance( - allowed_tools, BaseTool - ): + if not isinstance(allowed_tools, (list, set)) and isinstance(allowed_tools, BaseTool): allowed_tools = [allowed_tools] tool_names = [tool.name for tool in allowed_tools] # Agent class requires an output_parser but Agent classes @@ -456,10 +426,7 @@ def build_prompt_template(prompt, tools): format_instructions = prompt["node"]["template"]["format_instructions"]["value"] tool_strings = "\n".join( - [ - f"{tool['data']['node']['name']}: {tool['data']['node']['description']}" - for tool in tools - ] + [f"{tool['data']['node']['name']}: {tool['data']['node']['description']}" for tool in tools] ) tool_names = ", ".join([tool["data"]["node"]["name"] for tool in tools]) format_instructions = format_instructions.format(tool_names=tool_names) diff --git a/src/backend/langflow/interface/initialize/utils.py b/src/backend/langflow/interface/initialize/utils.py index 199626de5..4aa3fc52b 100644 --- a/src/backend/langflow/interface/initialize/utils.py +++ b/src/backend/langflow/interface/initialize/utils.py @@ -30,9 +30,7 @@ def check_tools_in_params(params: Dict): def instantiate_from_template(class_object, params: Dict): - from_template_params = { - "template": params.pop("prompt", params.pop("template", "")) - } + from_template_params = {"template": params.pop("prompt", params.pop("template", ""))} if not from_template_params.get("template"): raise ValueError("Prompt template is required") return class_object.from_template(**from_template_params) @@ -48,9 +46,7 @@ def handle_format_kwargs(prompt, params: Dict): def handle_partial_variables(prompt, format_kwargs: Dict): partial_variables = format_kwargs.copy() - partial_variables = { - key: value for key, value in partial_variables.items() if value - } + partial_variables = {key: value for key, value in partial_variables.items() if value} # Remove handle_keys otherwise LangChain raises an error partial_variables.pop("handle_keys", None) if partial_variables and hasattr(prompt, "partial"): @@ -62,9 +58,7 @@ def handle_variable(params: Dict, input_variable: str, format_kwargs: Dict): variable = params[input_variable] if isinstance(variable, str): format_kwargs[input_variable] = variable - elif isinstance(variable, BaseOutputParser) and hasattr( - variable, "get_format_instructions" - ): + elif isinstance(variable, BaseOutputParser) and hasattr(variable, "get_format_instructions"): format_kwargs[input_variable] = variable.get_format_instructions() elif is_instance_of_list_or_document(variable): format_kwargs = format_document(variable, input_variable, format_kwargs) @@ -107,8 +101,7 @@ def try_to_load_json(content): def needs_handle_keys(variable): return is_instance_of_list_or_document(variable) or ( - isinstance(variable, BaseOutputParser) - and hasattr(variable, "get_format_instructions") + isinstance(variable, BaseOutputParser) and hasattr(variable, "get_format_instructions") ) diff --git a/src/backend/langflow/interface/initialize/vector_store.py b/src/backend/langflow/interface/initialize/vector_store.py index 59e7c0d3b..736d3cc2b 100644 --- a/src/backend/langflow/interface/initialize/vector_store.py +++ b/src/backend/langflow/interface/initialize/vector_store.py @@ -17,9 +17,7 @@ import orjson def docs_in_params(params: dict) -> bool: """Check if params has documents OR texts and one of them is not an empty list, If any of them is not an empty list, return True, else return False""" - return ("documents" in params and params["documents"]) or ( - "texts" in params and params["texts"] - ) + return ("documents" in params and params["documents"]) or ("texts" in params and params["texts"]) def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dict): @@ -31,9 +29,7 @@ def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dic from pymongo import MongoClient import certifi - client: MongoClient = MongoClient( - MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where() - ) + client: MongoClient = MongoClient(MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where()) db_name = params.pop("db_name", None) collection_name = params.pop("collection_name", None) if not db_name or not collection_name: @@ -141,9 +137,7 @@ def initialize_pinecone(class_object: Type[Pinecone], params: dict): pinecone_env = os.getenv("PINECONE_ENV") if pinecone_api_key is None or pinecone_env is None: - raise ValueError( - "Pinecone API key and environment must be provided in the params" - ) + raise ValueError("Pinecone API key and environment must be provided in the params") # initialize pinecone pinecone.init( @@ -177,19 +171,13 @@ def initialize_chroma(class_object: Type[Chroma], params: dict): import chromadb # type: ignore settings_params = { - key: params[key] - for key, value_ in params.items() - if key.startswith("chroma_server_") and value_ + key: params[key] for key, value_ in params.items() if key.startswith("chroma_server_") and value_ } chroma_settings = chromadb.config.Settings(**settings_params) params["client_settings"] = chroma_settings else: # remove all chroma_server_ keys from params - params = { - key: value - for key, value in params.items() - if not key.startswith("chroma_server_") - } + params = {key: value for key, value in params.items() if not key.startswith("chroma_server_")} persist = params.pop("persist", False) if not docs_in_params(params): diff --git a/src/backend/langflow/interface/llms/base.py b/src/backend/langflow/interface/llms/base.py index 5076d459f..4b0654a1a 100644 --- a/src/backend/langflow/interface/llms/base.py +++ b/src/backend/langflow/interface/llms/base.py @@ -38,8 +38,7 @@ class LLMCreator(LangChainTypeCreator): return [ llm.__name__ for llm in self.type_to_loader_dict.values() - if llm.__name__ in settings_service.settings.LLMS - or settings_service.settings.DEV + if llm.__name__ in settings_service.settings.LLMS or settings_service.settings.DEV ] diff --git a/src/backend/langflow/interface/memories/base.py b/src/backend/langflow/interface/memories/base.py index 5f74e1450..a6bff858f 100644 --- a/src/backend/langflow/interface/memories/base.py +++ b/src/backend/langflow/interface/memories/base.py @@ -53,8 +53,7 @@ class MemoryCreator(LangChainTypeCreator): return [ memory.__name__ for memory in self.type_to_loader_dict.values() - if memory.__name__ in settings_service.settings.MEMORIES - or settings_service.settings.DEV + if memory.__name__ in settings_service.settings.MEMORIES or settings_service.settings.DEV ] diff --git a/src/backend/langflow/interface/output_parsers/base.py b/src/backend/langflow/interface/output_parsers/base.py index 8e7cbb05d..926f403ab 100644 --- a/src/backend/langflow/interface/output_parsers/base.py +++ b/src/backend/langflow/interface/output_parsers/base.py @@ -26,17 +26,14 @@ class OutputParserCreator(LangChainTypeCreator): if self.type_dict is None: settings_service = get_settings_service() self.type_dict = { - output_parser_name: import_class( - f"langchain.output_parsers.{output_parser_name}" - ) + output_parser_name: import_class(f"langchain.output_parsers.{output_parser_name}") # if output_parser_name is not lower case it is a class for output_parser_name in output_parsers.__all__ } self.type_dict = { name: output_parser for name, output_parser in self.type_dict.items() - if name in settings_service.settings.OUTPUT_PARSERS - or settings_service.settings.DEV + if name in settings_service.settings.OUTPUT_PARSERS or settings_service.settings.DEV } return self.type_dict diff --git a/src/backend/langflow/interface/prompts/base.py b/src/backend/langflow/interface/prompts/base.py index 040c2a1a4..164ea2b10 100644 --- a/src/backend/langflow/interface/prompts/base.py +++ b/src/backend/langflow/interface/prompts/base.py @@ -36,8 +36,7 @@ class PromptCreator(LangChainTypeCreator): self.type_dict = { name: prompt for name, prompt in self.type_dict.items() - if name in settings_service.settings.PROMPTS - or settings_service.settings.DEV + if name in settings_service.settings.PROMPTS or settings_service.settings.DEV } return self.type_dict diff --git a/src/backend/langflow/interface/prompts/custom.py b/src/backend/langflow/interface/prompts/custom.py index ef16f1474..47b990326 100644 --- a/src/backend/langflow/interface/prompts/custom.py +++ b/src/backend/langflow/interface/prompts/custom.py @@ -42,17 +42,13 @@ class BaseCustomPrompt(PromptTemplate): values["template"] = values["template"].format(**format_dict) values["template"] = values["template"] - values["input_variables"] = extract_input_variables_from_prompt( - values["template"] - ) + values["input_variables"] = extract_input_variables_from_prompt(values["template"]) return values class SeriesCharacterPrompt(BaseCustomPrompt): # Add a very descriptive description for the prompt generator - description: Optional[ - str - ] = "A prompt that asks the AI to act like a character from a series." + description: Optional[str] = "A prompt that asks the AI to act like a character from a series." character: str series: str template: str = """I want you to act like {character} from {series}. @@ -68,6 +64,4 @@ Human: {input} input_variables: List[str] = ["character", "series"] -CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = { - "SeriesCharacterPrompt": SeriesCharacterPrompt -} +CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = {"SeriesCharacterPrompt": SeriesCharacterPrompt} diff --git a/src/backend/langflow/interface/retrievers/base.py b/src/backend/langflow/interface/retrievers/base.py index 2c789fea8..27eda0772 100644 --- a/src/backend/langflow/interface/retrievers/base.py +++ b/src/backend/langflow/interface/retrievers/base.py @@ -39,9 +39,7 @@ class RetrieverCreator(LangChainTypeCreator): method_name=self.from_method_nodes[name], ) else: - return build_template_from_class( - name, type_to_cls_dict=self.type_to_loader_dict - ) + return build_template_from_class(name, type_to_cls_dict=self.type_to_loader_dict) except ValueError as exc: raise ValueError(f"Retriever {name} not found") from exc except AttributeError as exc: @@ -53,8 +51,7 @@ class RetrieverCreator(LangChainTypeCreator): return [ retriever for retriever in self.type_to_loader_dict.keys() - if retriever in settings_service.settings.RETRIEVERS - or settings_service.settings.DEV + if retriever in settings_service.settings.RETRIEVERS or settings_service.settings.DEV ] diff --git a/src/backend/langflow/interface/text_splitters/base.py b/src/backend/langflow/interface/text_splitters/base.py index 3132fb315..29c77a12f 100644 --- a/src/backend/langflow/interface/text_splitters/base.py +++ b/src/backend/langflow/interface/text_splitters/base.py @@ -35,8 +35,7 @@ class TextSplitterCreator(LangChainTypeCreator): return [ textsplitter.__name__ for textsplitter in self.type_to_loader_dict.values() - if textsplitter.__name__ in settings_service.settings.TEXTSPLITTERS - or settings_service.settings.DEV + if textsplitter.__name__ in settings_service.settings.TEXTSPLITTERS or settings_service.settings.DEV ] diff --git a/src/backend/langflow/interface/toolkits/base.py b/src/backend/langflow/interface/toolkits/base.py index cbd681b79..7c57579d1 100644 --- a/src/backend/langflow/interface/toolkits/base.py +++ b/src/backend/langflow/interface/toolkits/base.py @@ -32,13 +32,10 @@ class ToolkitCreator(LangChainTypeCreator): if self.type_dict is None: settings_service = get_settings_service() self.type_dict = { - toolkit_name: import_class( - f"langchain.agents.agent_toolkits.{toolkit_name}" - ) + toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}") # if toolkit_name is not lower case it is a class for toolkit_name in agent_toolkits.__all__ - if not toolkit_name.islower() - and toolkit_name in settings_service.settings.TOOLKITS + if not toolkit_name.islower() and toolkit_name in settings_service.settings.TOOLKITS } return self.type_dict @@ -61,9 +58,7 @@ class ToolkitCreator(LangChainTypeCreator): def get_create_function(self, name: str) -> Callable: if loader_name := self.create_functions.get(name): - return import_module( - f"from langchain.agents.agent_toolkits import {loader_name[0]}" - ) + return import_module(f"from langchain.agents.agent_toolkits import {loader_name[0]}") else: raise ValueError("Toolkit not found") diff --git a/src/backend/langflow/interface/tools/base.py b/src/backend/langflow/interface/tools/base.py index 2b0082fc5..42c9e01ca 100644 --- a/src/backend/langflow/interface/tools/base.py +++ b/src/backend/langflow/interface/tools/base.py @@ -31,9 +31,7 @@ TOOL_INPUTS = { placeholder="", value="", ), - "llm": TemplateField( - field_type="BaseLanguageModel", required=True, is_list=False, show=True - ), + "llm": TemplateField(field_type="BaseLanguageModel", required=True, is_list=False, show=True), "func": TemplateField( field_type="function", required=True, @@ -76,10 +74,7 @@ class ToolCreator(LangChainTypeCreator): tool_name = tool_params.get("name") or tool - if ( - tool_name in settings_service.settings.TOOLS - or settings_service.settings.DEV - ): + if tool_name in settings_service.settings.TOOLS or settings_service.settings.DEV: if tool_name == "JsonSpec": tool_params["path"] = tool_params.pop("dict_") # type: ignore all_tools[tool_name] = { diff --git a/src/backend/langflow/interface/tools/util.py b/src/backend/langflow/interface/tools/util.py index 8e4f582c1..7c8020aa9 100644 --- a/src/backend/langflow/interface/tools/util.py +++ b/src/backend/langflow/interface/tools/util.py @@ -21,16 +21,12 @@ def get_func_tool_params(func, **kwargs) -> Union[Dict, None]: for keyword in tool.keywords: if keyword.arg == "name": try: - tool_params["name"] = ast.literal_eval( - keyword.value - ) + tool_params["name"] = ast.literal_eval(keyword.value) except ValueError: break elif keyword.arg == "description": try: - tool_params["description"] = ast.literal_eval( - keyword.value - ) + tool_params["description"] = ast.literal_eval(keyword.value) except ValueError: continue @@ -43,9 +39,7 @@ def get_func_tool_params(func, **kwargs) -> Union[Dict, None]: else: # get the class object from the return statement try: - class_obj = eval( - compile(ast.Expression(tool), "", "eval") - ) + class_obj = eval(compile(ast.Expression(tool), "", "eval")) except Exception: return None diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 0046c84fb..755bfc435 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -114,14 +114,10 @@ def add_new_custom_field( # If options is a list, then it's a dropdown # If options is None, then it's a list of strings is_list = isinstance(field_config.get("options"), list) - field_config["is_list"] = ( - is_list or field_config.get("is_list", False) or field_contains_list - ) + field_config["is_list"] = is_list or field_config.get("is_list", False) or field_contains_list if "name" in field_config: - warnings.warn( - "The 'name' key in field_config is used to build the object and can't be changed." - ) + warnings.warn("The 'name' key in field_config is used to build the object and can't be changed.") field_config.pop("name", None) required = field_config.pop("required", field_required) @@ -185,9 +181,7 @@ def extract_type_from_optional(field_type): def build_frontend_node(custom_component: CustomComponent): """Build a frontend node for a custom component""" try: - return ( - CustomComponentFrontendNode().to_dict().get(type(custom_component).__name__) - ) + return CustomComponentFrontendNode().to_dict().get(type(custom_component).__name__) except Exception as exc: logger.error(f"Error while building base frontend node: {exc}") @@ -236,9 +230,7 @@ def add_extra_fields(frontend_node, field_config, function_args): if "name" not in extra_field or extra_field["name"] == "self": continue - field_name, field_type, field_value, field_required = get_field_properties( - extra_field - ) + field_name, field_type, field_value, field_required = get_field_properties(extra_field) config = field_config.get(field_name, {}) frontend_node = add_new_custom_field( frontend_node, @@ -273,8 +265,7 @@ def add_base_classes(frontend_node, return_types: List[str]): status_code=400, detail={ "error": ( - "Invalid return type should be one of: " - f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}" + "Invalid return type should be one of: " f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}" ), "traceback": traceback.format_exc(), }, @@ -296,8 +287,7 @@ def add_output_types(frontend_node, return_types: List[str]): status_code=400, detail={ "error": ( - "Invalid return type should be one of: " - f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}" + "Invalid return type should be one of: " f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}" ), "traceback": traceback.format_exc(), }, @@ -325,16 +315,10 @@ def build_langchain_template_custom_component(custom_component: CustomComponent) add_extra_fields(frontend_node, field_config, entrypoint_args) logger.debug("Added extra fields") - frontend_node = add_code_field( - frontend_node, custom_component.code, field_config.get("code", {}) - ) + frontend_node = add_code_field(frontend_node, custom_component.code, field_config.get("code", {})) logger.debug("Added code field") - add_base_classes( - frontend_node, custom_component.get_function_entrypoint_return_type - ) - add_output_types( - frontend_node, custom_component.get_function_entrypoint_return_type - ) + add_base_classes(frontend_node, custom_component.get_function_entrypoint_return_type) + add_output_types(frontend_node, custom_component.get_function_entrypoint_return_type) logger.debug("Added base classes") return frontend_node except Exception as exc: @@ -343,9 +327,7 @@ def build_langchain_template_custom_component(custom_component: CustomComponent) raise HTTPException( status_code=400, detail={ - "error": ( - "Invalid type convertion. Please check your code and try again." - ), + "error": ("Invalid type convertion. Please check your code and try again."), "traceback": traceback.format_exc(), }, ) from exc @@ -377,9 +359,7 @@ def build_valid_menu(valid_components): valid_menu[menu_name] = {} for component in menu_item["components"]: - logger.debug( - f"Building component: {component.get('name'), component.get('output_types')}" - ) + logger.debug(f"Building component: {component.get('name'), component.get('output_types')}") try: component_name = component["name"] component_code = component["code"] @@ -388,9 +368,7 @@ def build_valid_menu(valid_components): component_extractor = CustomComponent(code=component_code) component_extractor.is_check_valid() - component_template = build_langchain_template_custom_component( - component_extractor - ) + component_template = build_langchain_template_custom_component(component_extractor) component_template["output_types"] = component_output_types if len(component_output_types) == 1: component_name = component_output_types[0] @@ -398,9 +376,7 @@ def build_valid_menu(valid_components): file_name = component.get("file").split(".")[0] if "_" in file_name: # turn .py file into camelcase - component_name = "".join( - [word.capitalize() for word in file_name.split("_")] - ) + component_name = "".join([word.capitalize() for word in file_name.split("_")]) else: component_name = file_name @@ -409,9 +385,7 @@ def build_valid_menu(valid_components): except Exception as exc: logger.error(f"Error loading Component: {component['output_types']}") - logger.exception( - f"Error while building custom component {component_output_types}: {exc}" - ) + logger.exception(f"Error while building custom component {component_output_types}: {exc}") return valid_menu @@ -449,20 +423,14 @@ def build_invalid_menu(invalid_components): logger.debug(f"Added {component_name} to invalid menu to {menu_name}") except Exception as exc: - logger.exception( - f"Error while creating custom component [{component_name}]: {str(exc)}" - ) + logger.exception(f"Error while creating custom component [{component_name}]: {str(exc)}") return invalid_menu def merge_nested_dicts_with_renaming(dict1, dict2): for key, value in dict2.items(): - if ( - key in dict1 - and isinstance(value, dict) - and isinstance(dict1.get(key), dict) - ): + if key in dict1 and isinstance(value, dict) and isinstance(dict1.get(key), dict): for sub_key, sub_value in value.items(): if sub_key in dict1[key]: new_key = get_new_key(dict1[key], sub_key) @@ -479,9 +447,7 @@ def build_langchain_custom_component_list_from_path(path: str): file_list = load_files_from_path(path) reader = DirectoryReader(path, False) - valid_components, invalid_components = build_and_validate_all_files( - reader, file_list - ) + valid_components, invalid_components = build_and_validate_all_files(reader, file_list) valid_menu = build_valid_menu(valid_components) invalid_menu = build_invalid_menu(invalid_components) @@ -495,18 +461,14 @@ def get_all_types_dict(settings_service): # need to merge all the keys into one dict custom_components_from_file: dict[str, Any] = {} if settings_service.settings.COMPONENTS_PATH: - logger.info( - f"Building custom components from {settings_service.settings.COMPONENTS_PATH}" - ) + logger.info(f"Building custom components from {settings_service.settings.COMPONENTS_PATH}") custom_component_dicts = [] processed_paths = [] for path in settings_service.settings.COMPONENTS_PATH: if str(path) in processed_paths: continue - custom_component_dict = build_langchain_custom_component_list_from_path( - str(path) - ) + custom_component_dict = build_langchain_custom_component_list_from_path(str(path)) custom_component_dicts.append(custom_component_dict) processed_paths.append(str(path)) @@ -516,16 +478,12 @@ def get_all_types_dict(settings_service): if not custom_component_dict: continue category = list(custom_component_dict.keys())[0] - logger.info( - f"Loading {len(custom_component_dict[category])} component(s) from category {category}" - ) + logger.info(f"Loading {len(custom_component_dict[category])} component(s) from category {category}") custom_components_from_file = merge_nested_dicts_with_renaming( custom_components_from_file, custom_component_dict ) - return merge_nested_dicts_with_renaming( - native_components, custom_components_from_file - ) + return merge_nested_dicts_with_renaming(native_components, custom_components_from_file) def merge_nested_dicts(dict1, dict2): diff --git a/src/backend/langflow/interface/utilities/base.py b/src/backend/langflow/interface/utilities/base.py index 24c8faeab..2f2235586 100644 --- a/src/backend/langflow/interface/utilities/base.py +++ b/src/backend/langflow/interface/utilities/base.py @@ -29,16 +29,14 @@ class UtilityCreator(LangChainTypeCreator): if self.type_dict is None: settings_service = get_settings_service() self.type_dict = { - utility_name: import_class(f"langchain.utilities.{utility_name}") - for utility_name in utilities.__all__ + utility_name: import_class(f"langchain.utilities.{utility_name}") for utility_name in utilities.__all__ } self.type_dict["SQLDatabase"] = utilities.SQLDatabase # Filter according to settings.utilities self.type_dict = { name: utility for name, utility in self.type_dict.items() - if name in settings_service.settings.UTILITIES - or settings_service.settings.DEV + if name in settings_service.settings.UTILITIES or settings_service.settings.DEV } return self.type_dict diff --git a/src/backend/langflow/interface/utils.py b/src/backend/langflow/interface/utils.py index ca5083df7..6f7ec9329 100644 --- a/src/backend/langflow/interface/utils.py +++ b/src/backend/langflow/interface/utils.py @@ -43,9 +43,7 @@ def try_setting_streaming_options(langchain_object): llm = None if hasattr(langchain_object, "llm"): llm = langchain_object.llm - elif hasattr(langchain_object, "llm_chain") and hasattr( - langchain_object.llm_chain, "llm" - ): + elif hasattr(langchain_object, "llm_chain") and hasattr(langchain_object.llm_chain, "llm"): llm = langchain_object.llm_chain.llm if isinstance(llm, BaseLanguageModel): @@ -79,9 +77,7 @@ def set_langchain_cache(settings): if cache_type := os.getenv("LANGFLOW_LANGCHAIN_CACHE"): try: - cache_class = import_class( - f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}" - ) + cache_class = import_class(f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}") logger.debug(f"Setting up LLM caching with {cache_class.__name__}") langchain.llm_cache = cache_class() diff --git a/src/backend/langflow/interface/vector_store/base.py b/src/backend/langflow/interface/vector_store/base.py index bf349ca2b..d04689469 100644 --- a/src/backend/langflow/interface/vector_store/base.py +++ b/src/backend/langflow/interface/vector_store/base.py @@ -22,9 +22,7 @@ class VectorstoreCreator(LangChainTypeCreator): def type_to_loader_dict(self) -> Dict: if self.type_dict is None: self.type_dict: dict[str, Any] = { - vectorstore_name: import_class( - f"langchain.vectorstores.{vectorstore_name}" - ) + vectorstore_name: import_class(f"langchain.vectorstores.{vectorstore_name}") for vectorstore_name in vectorstores.__all__ } return self.type_dict @@ -48,8 +46,7 @@ class VectorstoreCreator(LangChainTypeCreator): return [ vectorstore for vectorstore in self.type_to_loader_dict.keys() - if vectorstore in settings_service.settings.VECTORSTORES - or settings_service.settings.DEV + if vectorstore in settings_service.settings.VECTORSTORES or settings_service.settings.DEV ] diff --git a/src/backend/langflow/interface/wrappers/base.py b/src/backend/langflow/interface/wrappers/base.py index de631101a..66dca7f58 100644 --- a/src/backend/langflow/interface/wrappers/base.py +++ b/src/backend/langflow/interface/wrappers/base.py @@ -16,8 +16,7 @@ class WrapperCreator(LangChainTypeCreator): def type_to_loader_dict(self) -> Dict: if self.type_dict is None: self.type_dict = { - wrapper.__name__: wrapper - for wrapper in [requests.TextRequestsWrapper, sql_database.SQLDatabase] + wrapper.__name__: wrapper for wrapper in [requests.TextRequestsWrapper, sql_database.SQLDatabase] } return self.type_dict diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index 1369d7c4f..d93bd8004 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -84,9 +84,7 @@ def get_static_files_dir(): return frontend_path / "frontend" -def setup_app( - static_files_dir: Optional[Path] = None, backend_only: bool = False -) -> FastAPI: +def setup_app(static_files_dir: Optional[Path] = None, backend_only: bool = False) -> FastAPI: """Setup the FastAPI app.""" # get the directory of the current file if not static_files_dir: diff --git a/src/backend/langflow/processing/base.py b/src/backend/langflow/processing/base.py index 6c56a59dd..d1896398e 100644 --- a/src/backend/langflow/processing/base.py +++ b/src/backend/langflow/processing/base.py @@ -34,9 +34,7 @@ def get_langfuse_callback(trace_id): if langfuse := LangfuseInstance.get(): logger.debug("Langfuse credentials found") try: - trace = langfuse.trace( - CreateTrace(name="langflow-" + trace_id, id=trace_id) - ) + trace = langfuse.trace(CreateTrace(name="langflow-" + trace_id, id=trace_id)) return trace.getNewHandler() except Exception as exc: logger.error(f"Error initializing langfuse callback: {exc}") @@ -44,9 +42,7 @@ def get_langfuse_callback(trace_id): return None -def flush_langfuse_callback_if_present( - callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]] -): +def flush_langfuse_callback_if_present(callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]]): """ If langfuse callback is present, run callback.langfuse.flush() """ @@ -88,15 +84,9 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa # if langfuse callback is present, run callback.langfuse.flush() flush_langfuse_callback_if_present(callbacks) - intermediate_steps = ( - output.get("intermediate_steps", []) if isinstance(output, dict) else [] - ) + intermediate_steps = output.get("intermediate_steps", []) if isinstance(output, dict) else [] - result = ( - output.get(langchain_object.output_keys[0]) - if isinstance(output, dict) - else output - ) + result = output.get(langchain_object.output_keys[0]) if isinstance(output, dict) else output try: thought = format_actions(intermediate_steps) if intermediate_steps else "" except Exception as exc: diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 8d04b4e24..951e196fb 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -112,9 +112,7 @@ def load_langchain_object( logger.debug("Loaded LangChain object") if langchain_object is None: - raise ValueError( - "There was an error loading the langchain_object. Please, check all the nodes and try again." - ) + raise ValueError("There was an error loading the langchain_object. Please, check all the nodes and try again.") return langchain_object, artifacts, session_id @@ -164,9 +162,7 @@ async def process_graph_cached( if clear_cache: session_service.clear_session(session_id) if session_id is None: - session_id = session_service.generate_key( - session_id=session_id, data_graph=data_graph - ) + session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph) # Load the graph using SessionService graph, artifacts = session_service.load_session(session_id, data_graph) built_object = graph.build() @@ -179,9 +175,7 @@ async def process_graph_cached( return Result(result=result, session_id=session_id) -def load_flow_from_json( - flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True -): +def load_flow_from_json(flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True): """ Load flow from a JSON file or a JSON object. @@ -198,9 +192,7 @@ def load_flow_from_json( elif isinstance(flow, dict): flow_graph = flow else: - raise TypeError( - "Input must be either a file path (str) or a JSON object (dict)" - ) + raise TypeError("Input must be either a file path (str) or a JSON object (dict)") graph_data = flow_graph["data"] if tweaks is not None: @@ -226,18 +218,14 @@ def load_flow_from_json( return graph -def validate_input( - graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] -) -> List[Dict[str, Any]]: +def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]: if not isinstance(graph_data, dict) or not isinstance(tweaks, dict): raise ValueError("graph_data and tweaks should be dictionaries") nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes") if not isinstance(nodes, list): - raise ValueError( - "graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key" - ) + raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key") return nodes @@ -246,9 +234,7 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None: template_data = node.get("data", {}).get("node", {}).get("template") if not isinstance(template_data, dict): - logger.warning( - f"Template data for node {node.get('id')} should be a dictionary" - ) + logger.warning(f"Template data for node {node.get('id')} should be a dictionary") return for tweak_name, tweak_value in node_tweaks.items(): @@ -257,9 +243,7 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None: template_data[tweak_name][key] = tweak_value -def process_tweaks( - graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] -) -> Dict[str, Any]: +def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: """ This function is used to tweak the graph data using the node id and the tweaks dict. @@ -280,8 +264,6 @@ def process_tweaks( if node_tweaks := tweaks.get(node_id): apply_tweaks(node, node_tweaks) else: - logger.warning( - "Each node should be a dictionary with an 'id' key of type str" - ) + logger.warning("Each node should be a dictionary with an 'id' key of type str") return graph_data diff --git a/src/backend/langflow/server.py b/src/backend/langflow/server.py index 3a2943444..9fe432744 100644 --- a/src/backend/langflow/server.py +++ b/src/backend/langflow/server.py @@ -10,11 +10,7 @@ class LangflowApplication(BaseApplication): super().__init__() def load_config(self): - config = { - key: value - for key, value in self.options.items() - if key in self.cfg.settings and value is not None - } + config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None} for key, value in config.items(): self.cfg.set(key.lower(), value) diff --git a/src/backend/langflow/services/auth/utils.py b/src/backend/langflow/services/auth/utils.py index f5504c8ce..79bb43faa 100644 --- a/src/backend/langflow/services/auth/utils.py +++ b/src/backend/langflow/services/auth/utils.py @@ -20,12 +20,8 @@ oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login") API_KEY_NAME = "x-api-key" -api_key_query = APIKeyQuery( - name=API_KEY_NAME, scheme_name="API key query", auto_error=False -) -api_key_header = APIKeyHeader( - name=API_KEY_NAME, scheme_name="API key header", auto_error=False -) +api_key_query = APIKeyQuery(name=API_KEY_NAME, scheme_name="API key query", auto_error=False) +api_key_header = APIKeyHeader(name=API_KEY_NAME, scheme_name="API key header", auto_error=False) # Source: https://github.com/mrtolkien/fastapi_simple_security/blob/master/fastapi_simple_security/security_api_key.py @@ -118,23 +114,17 @@ def get_current_active_user(current_user: Annotated[User, Depends(get_current_us return current_user -def get_current_active_superuser( - current_user: Annotated[User, Depends(get_current_user)] -) -> User: +def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User: if not current_user.is_active: raise HTTPException(status_code=401, detail="Inactive user") if not current_user.is_superuser: - raise HTTPException( - status_code=400, detail="The user doesn't have enough privileges" - ) + raise HTTPException(status_code=400, detail="The user doesn't have enough privileges") return current_user def verify_password(plain_password, hashed_password): settings_service = get_settings_service() - return settings_service.auth_settings.pwd_context.verify( - plain_password, hashed_password - ) + return settings_service.auth_settings.pwd_context.verify(plain_password, hashed_password) def get_password_hash(password): @@ -223,22 +213,16 @@ def get_user_id_from_token(token: str) -> UUID: return UUID(int=0) -def create_user_tokens( - user_id: UUID, db: Session = Depends(get_session), update_last_login: bool = False -) -> dict: +def create_user_tokens(user_id: UUID, db: Session = Depends(get_session), update_last_login: bool = False) -> dict: settings_service = get_settings_service() - access_token_expires = timedelta( - minutes=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES - ) + access_token_expires = timedelta(minutes=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_token( data={"sub": str(user_id)}, expires_delta=access_token_expires, ) - refresh_token_expires = timedelta( - minutes=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES - ) + refresh_token_expires = timedelta(minutes=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES) refresh_token = create_token( data={"sub": str(user_id), "type": "rf"}, expires_delta=refresh_token_expires, @@ -268,9 +252,7 @@ def create_refresh_token(refresh_token: str, db: Session = Depends(get_session)) token_type: str = payload.get("type") # type: ignore if user_id is None or token_type is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") return create_user_tokens(user_id, db) @@ -281,9 +263,7 @@ def create_refresh_token(refresh_token: str, db: Session = Depends(get_session)) ) from e -def authenticate_user( - username: str, password: str, db: Session = Depends(get_session) -) -> Optional[User]: +def authenticate_user(username: str, password: str, db: Session = Depends(get_session)) -> Optional[User]: user = get_user_by_username(db, username) if not user: @@ -318,9 +298,7 @@ def encrypt_api_key(api_key: str, settings_service=Depends(get_settings_service) return encrypted_key -def decrypt_api_key( - encrypted_api_key: str, settings_service=Depends(get_settings_service) -): +def decrypt_api_key(encrypted_api_key: str, settings_service=Depends(get_settings_service)): fernet = get_fernet(settings_service) # Two-way decryption if isinstance(encrypted_api_key, str): diff --git a/src/backend/langflow/services/cache/factory.py b/src/backend/langflow/services/cache/factory.py index 3288ca993..10e657bc5 100644 --- a/src/backend/langflow/services/cache/factory.py +++ b/src/backend/langflow/services/cache/factory.py @@ -26,9 +26,7 @@ class CacheServiceFactory(ServiceFactory): if redis_cache.is_connected(): logger.debug("Redis cache is connected") return redis_cache - logger.warning( - "Redis cache is not connected, falling back to in-memory cache" - ) + logger.warning("Redis cache is not connected, falling back to in-memory cache") return InMemoryCache() elif settings_service.settings.CACHE_TYPE == "memory": diff --git a/src/backend/langflow/services/cache/service.py b/src/backend/langflow/services/cache/service.py index da76a2b5c..3ee8d001b 100644 --- a/src/backend/langflow/services/cache/service.py +++ b/src/backend/langflow/services/cache/service.py @@ -68,10 +68,7 @@ class InMemoryCache(BaseCacheService, Service): Retrieve an item from the cache without acquiring the lock. """ if item := self._cache.get(key): - if ( - self.expiration_time is None - or time.time() - item["time"] < self.expiration_time - ): + if self.expiration_time is None or time.time() - item["time"] < self.expiration_time: # Move the key to the end to make it recently used self._cache.move_to_end(key) # Check if the value is pickled @@ -118,11 +115,7 @@ class InMemoryCache(BaseCacheService, Service): """ with self._lock: existing_value = self._get_without_lock(key) - if ( - existing_value is not None - and isinstance(existing_value, dict) - and isinstance(value, dict) - ): + if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict): existing_value.update(value) value = existing_value @@ -276,9 +269,7 @@ class RedisCache(BaseCacheService, Service): if not result: raise ValueError("RedisCache could not set the value.") except TypeError as exc: - raise TypeError( - "RedisCache only accepts values that can be pickled. " - ) from exc + raise TypeError("RedisCache only accepts values that can be pickled. ") from exc def upsert(self, key, value): """ @@ -290,11 +281,7 @@ class RedisCache(BaseCacheService, Service): value: The value to insert or update. """ existing_value = self.get(key) - if ( - existing_value is not None - and isinstance(existing_value, dict) - and isinstance(value, dict) - ): + if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict): existing_value.update(value) value = existing_value diff --git a/src/backend/langflow/services/cache/utils.py b/src/backend/langflow/services/cache/utils.py index bd6b4fb0a..bb8ba9a9e 100644 --- a/src/backend/langflow/services/cache/utils.py +++ b/src/backend/langflow/services/cache/utils.py @@ -83,9 +83,7 @@ def clear_old_cache_files(max_cache_size: int = 3): cache_files = list(cache_dir.glob("*.dill")) if len(cache_files) > max_cache_size: - cache_files_sorted_by_mtime = sorted( - cache_files, key=lambda x: x.stat().st_mtime, reverse=True - ) + cache_files_sorted_by_mtime = sorted(cache_files, key=lambda x: x.stat().st_mtime, reverse=True) for cache_file in cache_files_sorted_by_mtime[max_cache_size:]: with contextlib.suppress(OSError): diff --git a/src/backend/langflow/services/chat/service.py b/src/backend/langflow/services/chat/service.py index 59488d431..58d7cf384 100644 --- a/src/backend/langflow/services/chat/service.py +++ b/src/backend/langflow/services/chat/service.py @@ -59,9 +59,7 @@ class ChatService(Service): """Send the last chat message to the client.""" client_id = self.chat_cache.current_client_id if client_id in self.active_connections: - chat_response = self.chat_history.get_history( - client_id, filter_messages=False - )[-1] + chat_response = self.chat_history.get_history(client_id, filter_messages=False)[-1] if chat_response.is_bot: # Process FileResponse if isinstance(chat_response, FileResponse): @@ -88,9 +86,7 @@ class ChatService(Service): data_type=self.last_cached_object_dict["type"], ) - self.chat_history.add_message( - self.chat_cache.current_client_id, chat_response - ) + self.chat_history.add_message(self.chat_cache.current_client_id, chat_response) async def connect(self, client_id: str, websocket: WebSocket): self.active_connections[client_id] = websocket @@ -121,9 +117,7 @@ class ChatService(Service): if "after sending" in str(exc): logger.error(f"Error closing connection: {exc}") - async def process_message( - self, client_id: str, payload: Dict, langchain_object: Any - ): + async def process_message(self, client_id: str, payload: Dict, langchain_object: Any): # Process the graph data and chat message chat_inputs = payload.pop("inputs", {}) chatkey = payload.pop("chatKey", None) @@ -211,15 +205,11 @@ class ChatService(Service): continue with self.chat_cache.set_client_id(client_id): - if langchain_object := self.cache_service.get(client_id).get( - "result" - ): + if langchain_object := self.cache_service.get(client_id).get("result"): await self.process_message(client_id, payload, langchain_object) else: - raise RuntimeError( - f"Could not find a build result for client_id {client_id}" - ) + raise RuntimeError(f"Could not find a build result for client_id {client_id}") except Exception as exc: # Handle any exceptions that might occur logger.exception(f"Error handling websocket: {exc}") diff --git a/src/backend/langflow/services/chat/utils.py b/src/backend/langflow/services/chat/utils.py index 85b86a801..604f4f5a5 100644 --- a/src/backend/langflow/services/chat/utils.py +++ b/src/backend/langflow/services/chat/utils.py @@ -15,9 +15,7 @@ async def process_graph( if langchain_object is None: # Raise user facing error - raise ValueError( - "There was an error loading the langchain_object. Please, check all the nodes and try again." - ) + raise ValueError("There was an error loading the langchain_object. Please, check all the nodes and try again.") # Generate result and thought try: diff --git a/src/backend/langflow/services/database/models/api_key/crud.py b/src/backend/langflow/services/database/models/api_key/crud.py index 0e0ae6137..806848218 100644 --- a/src/backend/langflow/services/database/models/api_key/crud.py +++ b/src/backend/langflow/services/database/models/api_key/crud.py @@ -18,9 +18,7 @@ def get_api_keys(session: Session, user_id: UUID) -> List[ApiKeyRead]: return [ApiKeyRead.from_orm(api_key) for api_key in api_keys] -def create_api_key( - session: Session, api_key_create: ApiKeyCreate, user_id: UUID -) -> UnmaskedApiKeyRead: +def create_api_key(session: Session, api_key_create: ApiKeyCreate, user_id: UUID) -> UnmaskedApiKeyRead: # Generate a random API key with 32 bytes of randomness generated_api_key = f"sk-{secrets.token_urlsafe(32)}" diff --git a/src/backend/langflow/services/database/models/user/crud.py b/src/backend/langflow/services/database/models/user/crud.py index 32b94982c..dd6cbdc7a 100644 --- a/src/backend/langflow/services/database/models/user/crud.py +++ b/src/backend/langflow/services/database/models/user/crud.py @@ -19,9 +19,7 @@ def get_user_by_id(db: Session, id: UUID) -> Union[User, None]: return db.query(User).filter(User.id == id).first() -def update_user( - user_db: Optional[User], user: UserUpdate, db: Session = Depends(get_session) -) -> User: +def update_user(user_db: Optional[User], user: UserUpdate, db: Session = Depends(get_session)) -> User: if not user_db: raise HTTPException(status_code=404, detail="User not found") @@ -37,9 +35,7 @@ def update_user( changed = True if not changed: - raise HTTPException( - status_code=status.HTTP_304_NOT_MODIFIED, detail="Nothing to update" - ) + raise HTTPException(status_code=status.HTTP_304_NOT_MODIFIED, detail="Nothing to update") user_db.updated_at = datetime.now(timezone.utc) flag_modified(user_db, "updated_at") diff --git a/src/backend/langflow/services/database/service.py b/src/backend/langflow/services/database/service.py index dfcda9a5c..2dfac285f 100644 --- a/src/backend/langflow/services/database/service.py +++ b/src/backend/langflow/services/database/service.py @@ -34,10 +34,7 @@ class DatabaseService(Service): def _create_engine(self) -> "Engine": """Create the engine for the database.""" settings_service = get_settings_service() - if ( - settings_service.settings.DATABASE_URL - and settings_service.settings.DATABASE_URL.startswith("sqlite") - ): + if settings_service.settings.DATABASE_URL and settings_service.settings.DATABASE_URL.startswith("sqlite"): connect_args = {"check_same_thread": False} else: connect_args = {} @@ -49,9 +46,7 @@ class DatabaseService(Service): def __exit__(self, exc_type, exc_value, traceback): if exc_type is not None: # If an exception has been raised - logger.error( - f"Session rollback because of exception: {exc_type.__name__} {exc_value}" - ) + logger.error(f"Session rollback because of exception: {exc_type.__name__} {exc_value}") self._session.rollback() else: self._session.commit() @@ -99,9 +94,7 @@ class DatabaseService(Service): expected_columns = list(model.__fields__.keys()) try: - available_columns = [ - col["name"] for col in inspector.get_columns(table) - ] + available_columns = [col["name"] for col in inspector.get_columns(table)] except sa.exc.NoSuchTableError: logger.error(f"Missing table: {table}") return False @@ -153,9 +146,7 @@ class DatabaseService(Service): try: command.check(alembic_cfg) except Exception as exc: - if isinstance(exc, util.exc.CommandError) or isinstance( - exc, util.exc.AutogenerateDiffsDetected - ): + if isinstance(exc, util.exc.CommandError) or isinstance(exc, util.exc.AutogenerateDiffsDetected): command.upgrade(alembic_cfg, "head") # We should check the schema health after running migrations @@ -174,10 +165,7 @@ class DatabaseService(Service): # We will check that all models are in the database # and that the database is up to date with all columns sql_models = [models.Flow, models.User, models.ApiKey] - return [ - TableResults(sql_model.__tablename__, self.check_table(sql_model)) - for sql_model in sql_models - ] + return [TableResults(sql_model.__tablename__, self.check_table(sql_model)) for sql_model in sql_models] def check_table(self, model): results = [] @@ -185,9 +173,7 @@ class DatabaseService(Service): table_name = model.__tablename__ expected_columns = list(model.__fields__.keys()) try: - available_columns = [ - col["name"] for col in inspector.get_columns(table_name) - ] + available_columns = [col["name"] for col in inspector.get_columns(table_name)] results.append(Result(name=table_name, type="table", success=True)) except sa.exc.NoSuchTableError: logger.error(f"Missing table: {table_name}") @@ -218,9 +204,7 @@ class DatabaseService(Service): try: table.create(self.engine, checkfirst=True) except OperationalError as oe: - logger.warning( - f"Table {table} already exists, skipping. Exception: {oe}" - ) + logger.warning(f"Table {table} already exists, skipping. Exception: {oe}") except Exception as exc: logger.error(f"Error creating table {table}: {exc}") raise RuntimeError(f"Error creating table {table}") from exc @@ -232,9 +216,7 @@ class DatabaseService(Service): if table not in table_names: logger.error("Something went wrong creating the database and tables.") logger.error("Please check your database settings.") - raise RuntimeError( - "Something went wrong creating the database and tables." - ) + raise RuntimeError("Something went wrong creating the database and tables.") logger.debug("Database and tables created successfully") diff --git a/src/backend/langflow/services/database/utils.py b/src/backend/langflow/services/database/utils.py index 610196a51..aae567e91 100644 --- a/src/backend/langflow/services/database/utils.py +++ b/src/backend/langflow/services/database/utils.py @@ -13,9 +13,7 @@ def initialize_database(): logger.debug("Initializing database") from langflow.services import service_manager, ServiceType - database_service: "DatabaseService" = service_manager.get( - ServiceType.DATABASE_SERVICE - ) + database_service: "DatabaseService" = service_manager.get(ServiceType.DATABASE_SERVICE) try: database_service.create_db_and_tables() except Exception as exc: @@ -41,9 +39,7 @@ def initialize_database(): # This means there's wrong revision in the DB # We need to delete the alembic_version table # and run the migrations again - logger.warning( - "Wrong revision in DB, deleting alembic_version table and running migrations again" - ) + logger.warning("Wrong revision in DB, deleting alembic_version table and running migrations again") with session_getter(database_service) as session: session.execute("DROP TABLE alembic_version") database_service.run_migrations() diff --git a/src/backend/langflow/services/manager.py b/src/backend/langflow/services/manager.py index 10fd6b699..fce0e9106 100644 --- a/src/backend/langflow/services/manager.py +++ b/src/backend/langflow/services/manager.py @@ -53,15 +53,10 @@ class ServiceManager: self._create_service(dependency) # Collect the dependent services - dependent_services = { - dep.value: self.services[dep] - for dep in self.dependencies.get(service_name, []) - } + dependent_services = {dep.value: self.services[dep] for dep in self.dependencies.get(service_name, [])} # Create the actual service - self.services[service_name] = self.factories[service_name].create( - **dependent_services - ) + self.services[service_name] = self.factories[service_name].create(**dependent_services) self.services[service_name].set_ready() def _validate_service_creation(self, service_name: ServiceType): @@ -69,9 +64,7 @@ class ServiceManager: Validate whether the service can be created. """ if service_name not in self.factories: - raise ValueError( - f"No factory registered for the service class '{service_name.name}'" - ) + raise ValueError(f"No factory registered for the service class '{service_name.name}'") def update(self, service_name: ServiceType): """ @@ -144,9 +137,7 @@ def initialize_session_service(): initialize_settings_service() - service_manager.register_factory( - cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE] - ) + service_manager.register_factory(cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE]) service_manager.register_factory( session_service_factory.SessionServiceFactory(), diff --git a/src/backend/langflow/services/plugins/langfuse.py b/src/backend/langflow/services/plugins/langfuse.py index d375caa9e..103cfe260 100644 --- a/src/backend/langflow/services/plugins/langfuse.py +++ b/src/backend/langflow/services/plugins/langfuse.py @@ -23,10 +23,7 @@ class LangfuseInstance: settings_manager = get_settings_service() - if ( - settings_manager.settings.LANGFUSE_PUBLIC_KEY - and settings_manager.settings.LANGFUSE_SECRET_KEY - ): + if settings_manager.settings.LANGFUSE_PUBLIC_KEY and settings_manager.settings.LANGFUSE_SECRET_KEY: logger.debug("Langfuse credentials found") cls._instance = Langfuse( public_key=settings_manager.settings.LANGFUSE_PUBLIC_KEY, diff --git a/src/backend/langflow/services/session/utils.py b/src/backend/langflow/services/session/utils.py index 374d85540..1d62508a3 100644 --- a/src/backend/langflow/services/session/utils.py +++ b/src/backend/langflow/services/session/utils.py @@ -3,6 +3,4 @@ import string def session_id_generator(size=6): - return "".join( - random.SystemRandom().choices(string.ascii_uppercase + string.digits, k=size) - ) + return "".join(random.SystemRandom().choices(string.ascii_uppercase + string.digits, k=size)) diff --git a/src/backend/langflow/services/settings/auth.py b/src/backend/langflow/services/settings/auth.py index 7ac7461a0..432822f4c 100644 --- a/src/backend/langflow/services/settings/auth.py +++ b/src/backend/langflow/services/settings/auth.py @@ -26,9 +26,7 @@ class AuthSettings(BaseSettings): REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 12 # API Key to execute /process endpoint - API_KEY_SECRET_KEY: Optional[ - str - ] = "b82818e0ad4ff76615c5721ee21004b07d84cd9b87ba4d9cb42374da134b841a" + API_KEY_SECRET_KEY: Optional[str] = "b82818e0ad4ff76615c5721ee21004b07d84cd9b87ba4d9cb42374da134b841a" API_KEY_ALGORITHM: str = "HS256" API_V1_STR: str = "/api/v1" diff --git a/src/backend/langflow/services/settings/base.py b/src/backend/langflow/services/settings/base.py index 13688e889..0db5da6b4 100644 --- a/src/backend/langflow/services/settings/base.py +++ b/src/backend/langflow/services/settings/base.py @@ -83,9 +83,7 @@ class Settings(BaseSettings): @validator("DATABASE_URL", pre=True) def set_database_url(cls, value, values): if not value: - logger.debug( - "No database_url provided, trying LANGFLOW_DATABASE_URL env variable" - ) + logger.debug("No database_url provided, trying LANGFLOW_DATABASE_URL env variable") if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"): value = langflow_database_url logger.debug("Using LANGFLOW_DATABASE_URL env variable.") @@ -95,9 +93,7 @@ class Settings(BaseSettings): # so we need to migrate to the new format # if there is a database in that location if not values["CONFIG_DIR"]: - raise ValueError( - "CONFIG_DIR not set, please set it or provide a DATABASE_URL" - ) + raise ValueError("CONFIG_DIR not set, please set it or provide a DATABASE_URL") new_path = f"{values['CONFIG_DIR']}/langflow.db" if Path("./langflow.db").exists(): @@ -121,22 +117,15 @@ class Settings(BaseSettings): if os.getenv("LANGFLOW_COMPONENTS_PATH"): logger.debug("Adding LANGFLOW_COMPONENTS_PATH to components_path") langflow_component_path = os.getenv("LANGFLOW_COMPONENTS_PATH") - if ( - Path(langflow_component_path).exists() - and langflow_component_path not in value - ): + if Path(langflow_component_path).exists() and langflow_component_path not in value: if isinstance(langflow_component_path, list): for path in langflow_component_path: if path not in value: value.append(path) - logger.debug( - f"Extending {langflow_component_path} to components_path" - ) + logger.debug(f"Extending {langflow_component_path} to components_path") elif langflow_component_path not in value: value.append(langflow_component_path) - logger.debug( - f"Appending {langflow_component_path} to components_path" - ) + logger.debug(f"Appending {langflow_component_path} to components_path") if not value: value = [BASE_COMPONENTS_PATH] diff --git a/src/backend/langflow/services/settings/factory.py b/src/backend/langflow/services/settings/factory.py index 4ba9f3f82..713f13f82 100644 --- a/src/backend/langflow/services/settings/factory.py +++ b/src/backend/langflow/services/settings/factory.py @@ -10,6 +10,4 @@ class SettingsServiceFactory(ServiceFactory): def create(self): # Here you would have logic to create and configure a SettingsService langflow_dir = Path(__file__).parent.parent.parent - return SettingsService.load_settings_from_yaml( - str(langflow_dir / "config.yaml") - ) + return SettingsService.load_settings_from_yaml(str(langflow_dir / "config.yaml")) diff --git a/src/backend/langflow/services/settings/service.py b/src/backend/langflow/services/settings/service.py index cdededcea..e9e535911 100644 --- a/src/backend/langflow/services/settings/service.py +++ b/src/backend/langflow/services/settings/service.py @@ -30,9 +30,7 @@ class SettingsService(Service): for key in settings_dict: if key not in Settings.__fields__.keys(): raise KeyError(f"Key {key} not found in settings") - logger.debug( - f"Loading {len(settings_dict[key])} {key} from {file_path}" - ) + logger.debug(f"Loading {len(settings_dict[key])} {key} from {file_path}") settings = Settings(**settings_dict) if not settings.CONFIG_DIR: diff --git a/src/backend/langflow/services/settings/utils.py b/src/backend/langflow/services/settings/utils.py index fae96ff28..1fd308e72 100644 --- a/src/backend/langflow/services/settings/utils.py +++ b/src/backend/langflow/services/settings/utils.py @@ -14,9 +14,7 @@ def set_secure_permissions(file_path): import win32security user, domain, _ = win32security.LookupAccountName("", win32api.GetUserName()) - sd = win32security.GetFileSecurity( - file_path, win32security.DACL_SECURITY_INFORMATION - ) + sd = win32security.GetFileSecurity(file_path, win32security.DACL_SECURITY_INFORMATION) dacl = win32security.ACL() # Set the new DACL for the file: read and write access for the owner, no access for everyone else @@ -26,9 +24,7 @@ def set_secure_permissions(file_path): user, ) sd.SetSecurityDescriptorDacl(1, dacl, 0) - win32security.SetFileSecurity( - file_path, win32security.DACL_SECURITY_INFORMATION, sd - ) + win32security.SetFileSecurity(file_path, win32security.DACL_SECURITY_INFORMATION, sd) else: print("Unsupported OS") diff --git a/src/backend/langflow/services/store/schema.py b/src/backend/langflow/services/store/schema.py index 841905260..9788a28ce 100644 --- a/src/backend/langflow/services/store/schema.py +++ b/src/backend/langflow/services/store/schema.py @@ -63,9 +63,7 @@ class ListComponentResponse(BaseModel): if all(["id" in tag and "name" in tag for tag in v]): return v else: - return [ - TagResponse(**tag.get("tags_id")) for tag in v if tag.get("tags_id") - ] + return [TagResponse(**tag.get("tags_id")) for tag in v if tag.get("tags_id")] class ListComponentResponseModel(BaseModel): diff --git a/src/backend/langflow/services/store/service.py b/src/backend/langflow/services/store/service.py index 9f7934d93..02075a804 100644 --- a/src/backend/langflow/services/store/service.py +++ b/src/backend/langflow/services/store/service.py @@ -21,9 +21,7 @@ if TYPE_CHECKING: from contextlib import contextmanager from contextvars import ContextVar -user_data_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar( - "user_data", default=None -) +user_data_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar("user_data", default=None) @contextmanager @@ -31,9 +29,7 @@ def user_data_context(store_service: "StoreService", api_key: Optional[str] = No # Fetch and set user data to the context variable if api_key: try: - user_data = store_service._get( - f"{store_service.base_url}/users/me", api_key, params={"fields": "id"} - ) + user_data = store_service._get(f"{store_service.base_url}/users/me", api_key, params={"fields": "id"}) user_data_var.set(user_data) except HTTPStatusError as exc: if exc.response.status_code == 403: @@ -77,9 +73,7 @@ class StoreService(Service): # will make a property return that data # Without making the request multiple times - def _get( - self, url: str, api_key: str, params: Dict[str, Any] = None - ) -> List[Dict[str, Any]]: + def _get(self, url: str, api_key: str, params: Dict[str, Any] = None) -> List[Dict[str, Any]]: """Utility method to perform GET requests.""" if api_key: headers = {"Authorization": f"Bearer {api_key}"} @@ -99,9 +93,7 @@ class StoreService(Service): # For now we are calling it just for testing try: headers = {"Authorization": f"Bearer {api_key}"} - response = httpx.post( - webhook_url, headers=headers, json={"component_id": str(component_id)} - ) + response = httpx.post(webhook_url, headers=headers, json={"component_id": str(component_id)}) response.raise_for_status() return response.json() except HTTPError as exc: @@ -146,9 +138,7 @@ class StoreService(Service): if tags: tags_filter = {"tags": {"_and": []}} for tag in tags: - tags_filter["tags"]["_and"].append( - {"_some": {"tags_id": {"name": {"_eq": tag}}}} - ) + tags_filter["tags"]["_and"].append({"_some": {"tags_id": {"name": {"_eq": tag}}}}) filter_conditions.append(tags_filter) if date_from: @@ -166,13 +156,7 @@ class StoreService(Service): params["fields"] = ",".join(fields) if filter_by_user: - params["deep"] = json.dumps( - { - "components": { - "_filter": {"user_created": {"token": {"_eq": api_key}}} - } - } - ) + params["deep"] = json.dumps({"components": {"_filter": {"user_created": {"token": {"_eq": api_key}}}}}) else: # params["filter"] = json.dumps({"status": {"_eq": "public"}}) filter_conditions.append({"status": {"_in": ["public", "Public"]}}) @@ -192,13 +176,7 @@ class StoreService(Service): params = {"aggregate": json.dumps({"count": "*"})} filter_conditions = [] if filter_conditions is None else filter_conditions if filter_by_user: - params["deep"] = json.dumps( - { - "components": { - "_filter": {"user_created": {"token": {"_eq": api_key}}} - } - } - ) + params["deep"] = json.dumps({"components": {"_filter": {"user_created": {"token": {"_eq": api_key}}}}}) else: filter_conditions.append({"status": {"_in": ["public", "Public"]}}) @@ -250,9 +228,7 @@ class StoreService(Service): if tags: tags_filter = {"tags": {"_and": []}} for tag in tags: - tags_filter["tags"]["_and"].append( - {"_some": {"tags_id": {"name": {"_eq": tag}}}} - ) + tags_filter["tags"]["_and"].append({"_some": {"tags_id": {"name": {"_eq": tag}}}}) filter_conditions.append(tags_filter) if is_component is not None: @@ -286,9 +262,7 @@ class StoreService(Service): # component.tags = [tags_id.tags_id for tags_id in component.tags] return results_objects, filter_conditions - def get_liked_by_user_components( - self, component_ids: List[UUID], api_key: str - ) -> List[UUID]: + def get_liked_by_user_components(self, component_ids: List[UUID], api_key: str) -> List[UUID]: # Get fields id # filter should be "id is in component_ids AND liked_by directus_users_id token is api_key" # return the ids @@ -310,9 +284,7 @@ class StoreService(Service): return [result["id"] for result in results] # Which of the components is parent of the user's components - def get_components_in_users_collection( - self, component_ids: List[UUID], api_key: str - ): + def get_components_in_users_collection(self, component_ids: List[UUID], api_key: str): user_data = user_data_var.get() if not user_data: raise ValueError("No user data") @@ -332,18 +304,14 @@ class StoreService(Service): def download(self, api_key: str, component_id: str) -> DownloadComponentResponse: url = f"{self.components_url}/{component_id}" - params = { - "fields": ",".join(["id", "name", "description", "data", "is_component"]) - } + params = {"fields": ",".join(["id", "name", "description", "data", "is_component"])} component = self._get(url, api_key, params) self.call_webhook(api_key, self.download_webhook_url, component_id) return DownloadComponentResponse(**component) - def upload( - self, api_key: str, component_data: StoreComponentCreate - ) -> ComponentResponse: + def upload(self, api_key: str, component_data: StoreComponentCreate) -> ComponentResponse: headers = {"Authorization": f"Bearer {api_key}"} component_dict = component_data.dict(exclude_unset=True) # Parent is a UUID, but the store expects a string @@ -353,9 +321,7 @@ class StoreService(Service): component_dict = process_tags_for_post(component_dict) try: - response = httpx.post( - self.components_url, headers=headers, json=component_dict - ) + response = httpx.post(self.components_url, headers=headers, json=component_dict) response.raise_for_status() component = response.json()["data"] return ComponentResponse(**component) diff --git a/src/backend/langflow/services/task/backends/celery.py b/src/backend/langflow/services/task/backends/celery.py index eae985f3a..f23374549 100644 --- a/src/backend/langflow/services/task/backends/celery.py +++ b/src/backend/langflow/services/task/backends/celery.py @@ -10,9 +10,7 @@ class CeleryBackend(TaskBackend): def __init__(self): self.celery_app = celery_app - def launch_task( - self, task_func: Callable[..., Any], *args: Any, **kwargs: Any - ) -> tuple[str, AsyncResult]: + def launch_task(self, task_func: Callable[..., Any], *args: Any, **kwargs: Any) -> tuple[str, AsyncResult]: # I need to type the delay method to make it easier from celery import Task # type: ignore diff --git a/src/backend/langflow/services/task/service.py b/src/backend/langflow/services/task/service.py index 807505c3a..6730b1d91 100644 --- a/src/backend/langflow/services/task/service.py +++ b/src/backend/langflow/services/task/service.py @@ -63,9 +63,7 @@ class TaskService(Service): result = task.get() return task.id, result - async def launch_task( - self, task_func: Callable[..., Any], *args: Any, **kwargs: Any - ) -> Any: + async def launch_task(self, task_func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: logger.debug(f"Launching task {task_func} with args {args} and kwargs {kwargs}") logger.debug(f"Using backend {self.backend}") task = self.backend.launch_task(task_func, *args, **kwargs) diff --git a/src/backend/langflow/services/utils.py b/src/backend/langflow/services/utils.py index c5976ef61..f0fd91f1f 100644 --- a/src/backend/langflow/services/utils.py +++ b/src/backend/langflow/services/utils.py @@ -71,16 +71,12 @@ def get_or_create_super_user(session: Session, username, password, is_default): ) return None else: - logger.debug( - "User with superuser credentials exists but is not a superuser." - ) + logger.debug("User with superuser credentials exists but is not a superuser.") return None if user: if verify_password(password, user.password): - raise ValueError( - "User with superuser credentials exists but is not a superuser." - ) + raise ValueError("User with superuser credentials exists but is not a superuser.") else: raise ValueError("Incorrect superuser credentials") @@ -109,21 +105,15 @@ def setup_superuser(settings_service, session: Session): username = settings_service.auth_settings.SUPERUSER password = settings_service.auth_settings.SUPERUSER_PASSWORD - is_default = (username == DEFAULT_SUPERUSER) and ( - password == DEFAULT_SUPERUSER_PASSWORD - ) + is_default = (username == DEFAULT_SUPERUSER) and (password == DEFAULT_SUPERUSER_PASSWORD) try: - user = get_or_create_super_user( - session=session, username=username, password=password, is_default=is_default - ) + user = get_or_create_super_user(session=session, username=username, password=password, is_default=is_default) if user is not None: logger.debug("Superuser created successfully.") except Exception as exc: logger.exception(exc) - raise RuntimeError( - "Could not create superuser. Please create a superuser manually." - ) from exc + raise RuntimeError("Could not create superuser. Please create a superuser manually.") from exc finally: settings_service.auth_settings.reset_credentials() @@ -137,9 +127,7 @@ def teardown_superuser(settings_service, session): if not settings_service.auth_settings.AUTO_LOGIN: try: - logger.debug( - "AUTO_LOGIN is set to False. Removing default superuser if exists." - ) + logger.debug("AUTO_LOGIN is set to False. Removing default superuser if exists.") username = DEFAULT_SUPERUSER from langflow.services.database.models.user.user import User @@ -187,9 +175,7 @@ def initialize_session_service(): initialize_settings_service() - service_manager.register_factory( - cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE] - ) + service_manager.register_factory(cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE]) service_manager.register_factory( session_service_factory.SessionServiceFactory(), @@ -206,17 +192,13 @@ def initialize_services(): service_manager.register_factory(factory, dependencies=dependencies) except Exception as exc: logger.exception(exc) - raise RuntimeError( - "Could not initialize services. Please check your settings." - ) from exc + raise RuntimeError("Could not initialize services. Please check your settings.") from exc # Test cache connection service_manager.get(ServiceType.CACHE_SERVICE) # Setup the superuser initialize_database() - setup_superuser( - service_manager.get(ServiceType.SETTINGS_SERVICE), next(get_session()) - ) + setup_superuser(service_manager.get(ServiceType.SETTINGS_SERVICE), next(get_session())) try: get_db_service().migrate_flows_if_auto_login() except Exception as exc: diff --git a/src/backend/langflow/template/frontend_node/base.py b/src/backend/langflow/template/frontend_node/base.py index 4f8b6ad7e..75af144a7 100644 --- a/src/backend/langflow/template/frontend_node/base.py +++ b/src/backend/langflow/template/frontend_node/base.py @@ -67,11 +67,7 @@ class FrontendNode(BaseModel): def process_base_classes(self) -> None: """Removes unwanted base classes from the list of base classes.""" - self.base_classes = [ - base_class - for base_class in self.base_classes - if base_class not in CLASSES_TO_REMOVE - ] + self.base_classes = [base_class for base_class in self.base_classes if base_class not in CLASSES_TO_REMOVE] def to_dict(self) -> dict: """Returns a dict representation of the frontend node.""" @@ -130,9 +126,7 @@ class FrontendNode(BaseModel): return _type @staticmethod - def handle_special_field( - field, key: str, _type: str, SPECIAL_FIELD_HANDLERS - ) -> str: + def handle_special_field(field, key: str, _type: str, SPECIAL_FIELD_HANDLERS) -> str: """Handles special field by using the respective handler if present.""" handler = SPECIAL_FIELD_HANDLERS.get(key) return handler(field) if handler else _type @@ -144,11 +138,7 @@ class FrontendNode(BaseModel): field.field_type = "file" field.suffixes = [".json", ".yaml", ".yml"] field.file_types = ["json", "yaml", "yml"] - elif ( - _type.startswith("Dict") - or _type.startswith("Mapping") - or _type.startswith("dict") - ): + elif _type.startswith("Dict") or _type.startswith("Mapping") or _type.startswith("dict"): field.field_type = "dict" return _type @@ -159,9 +149,7 @@ class FrontendNode(BaseModel): field.value = value["default"] @staticmethod - def handle_specific_field_values( - field: TemplateField, key: str, name: Optional[str] = None - ) -> None: + def handle_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None: """Handles specific field values for certain fields.""" if key == "headers": field.value = """{"Authorization": "Bearer "}""" @@ -169,9 +157,7 @@ class FrontendNode(BaseModel): FrontendNode._handle_api_key_specific_field_values(field, key, name) @staticmethod - def _handle_model_specific_field_values( - field: TemplateField, key: str, name: Optional[str] = None - ) -> None: + def _handle_model_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None: """Handles specific field values related to models.""" model_dict = { "OpenAI": constants.OPENAI_MODELS, @@ -184,9 +170,7 @@ class FrontendNode(BaseModel): field.is_list = True @staticmethod - def _handle_api_key_specific_field_values( - field: TemplateField, key: str, name: Optional[str] = None - ) -> None: + def _handle_api_key_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None: """Handles specific field values related to API keys.""" if "api_key" in key and "OpenAI" in str(name): field.display_name = "OpenAI API Key" @@ -225,10 +209,7 @@ class FrontendNode(BaseModel): @staticmethod def should_be_password(key: str, show: bool) -> bool: """Determines whether the field should be a password field.""" - return ( - any(text in key.lower() for text in {"password", "token", "api", "key"}) - and show - ) + return any(text in key.lower() for text in {"password", "token", "api", "key"}) and show @staticmethod def should_be_multiline(key: str) -> bool: diff --git a/src/backend/langflow/template/frontend_node/chains.py b/src/backend/langflow/template/frontend_node/chains.py index b678dec3b..dcd8cfcc7 100644 --- a/src/backend/langflow/template/frontend_node/chains.py +++ b/src/backend/langflow/template/frontend_node/chains.py @@ -133,7 +133,9 @@ class SeriesCharacterChainNode(FrontendNode): ), ], ) - description: str = "SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa + description: str = ( + "SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa + ) base_classes: list[str] = [ "LLMChain", "BaseCustomChain", diff --git a/src/backend/langflow/template/frontend_node/documentloaders.py b/src/backend/langflow/template/frontend_node/documentloaders.py index 0be2ebe98..20ec27009 100644 --- a/src/backend/langflow/template/frontend_node/documentloaders.py +++ b/src/backend/langflow/template/frontend_node/documentloaders.py @@ -3,9 +3,7 @@ from langflow.template.field.base import TemplateField from langflow.template.frontend_node.base import FrontendNode -def build_file_field( - suffixes: list, fileTypes: list, name: str = "file_path" -) -> TemplateField: +def build_file_field(suffixes: list, fileTypes: list, name: str = "file_path") -> TemplateField: """Build a template field for a document loader.""" return TemplateField( field_type="file", @@ -27,32 +25,22 @@ class DocumentLoaderFrontNode(FrontendNode): "AirbyteJSONLoader": build_file_field(suffixes=[".json"], fileTypes=["json"]), "CoNLLULoader": build_file_field(suffixes=[".csv"], fileTypes=["csv"]), "CSVLoader": build_file_field(suffixes=[".csv"], fileTypes=["csv"]), - "UnstructuredEmailLoader": build_file_field( - suffixes=[".eml"], fileTypes=["eml"] - ), + "UnstructuredEmailLoader": build_file_field(suffixes=[".eml"], fileTypes=["eml"]), "EverNoteLoader": build_file_field(suffixes=[".xml"], fileTypes=["xml"]), "FacebookChatLoader": build_file_field(suffixes=[".json"], fileTypes=["json"]), "BSHTMLLoader": build_file_field(suffixes=[".html"], fileTypes=["html"]), - "UnstructuredHTMLLoader": build_file_field( - suffixes=[".html"], fileTypes=["html"] - ), + "UnstructuredHTMLLoader": build_file_field(suffixes=[".html"], fileTypes=["html"]), "UnstructuredImageLoader": build_file_field( suffixes=[".jpg", ".jpeg", ".png", ".gif", ".bmp"], fileTypes=["jpg", "jpeg", "png", "gif", "bmp"], ), - "UnstructuredMarkdownLoader": build_file_field( - suffixes=[".md"], fileTypes=["md"] - ), + "UnstructuredMarkdownLoader": build_file_field(suffixes=[".md"], fileTypes=["md"]), "PyPDFLoader": build_file_field(suffixes=[".pdf"], fileTypes=["pdf"]), - "UnstructuredPowerPointLoader": build_file_field( - suffixes=[".pptx", ".ppt"], fileTypes=["pptx", "ppt"] - ), + "UnstructuredPowerPointLoader": build_file_field(suffixes=[".pptx", ".ppt"], fileTypes=["pptx", "ppt"]), "SRTLoader": build_file_field(suffixes=[".srt"], fileTypes=["srt"]), "TelegramChatLoader": build_file_field(suffixes=[".json"], fileTypes=["json"]), "TextLoader": build_file_field(suffixes=[".txt"], fileTypes=["txt"]), - "UnstructuredWordDocumentLoader": build_file_field( - suffixes=[".docx", ".doc"], fileTypes=["docx", "doc"] - ), + "UnstructuredWordDocumentLoader": build_file_field(suffixes=[".docx", ".doc"], fileTypes=["docx", "doc"]), } def add_extra_fields(self) -> None: diff --git a/src/backend/langflow/template/frontend_node/embeddings.py b/src/backend/langflow/template/frontend_node/embeddings.py index 665328e78..4c608fe54 100644 --- a/src/backend/langflow/template/frontend_node/embeddings.py +++ b/src/backend/langflow/template/frontend_node/embeddings.py @@ -70,9 +70,7 @@ class EmbeddingFrontendNode(FrontendNode): field.advanced = True split_name = field.name.split("_") title_name = " ".join([s.capitalize() for s in split_name]) - field.display_name = title_name.replace("Openai", "OpenAI").replace( - "Api", "API" - ) + field.display_name = title_name.replace("Openai", "OpenAI").replace("Api", "API") if "api_key" in field.name: field.password = True diff --git a/src/backend/langflow/template/frontend_node/formatter/field_formatters.py b/src/backend/langflow/template/frontend_node/formatter/field_formatters.py index a67387df7..c7dca1065 100644 --- a/src/backend/langflow/template/frontend_node/formatter/field_formatters.py +++ b/src/backend/langflow/template/frontend_node/formatter/field_formatters.py @@ -112,10 +112,7 @@ class PasswordFieldFormatter(FieldFormatter): def format(self, field: TemplateField, name: Optional[str] = None) -> None: key = field.name show = field.show - if ( - any(text in key.lower() for text in {"password", "token", "api", "key"}) - and show - ): + if any(text in key.lower() for text in {"password", "token", "api", "key"}) and show: field.password = True @@ -157,9 +154,5 @@ class DictCodeFileFormatter(FieldFormatter): field.field_type = "file" field.suffixes = [".json", ".yaml", ".yml"] field.file_types = ["json", "yaml", "yml"] - elif ( - _type.startswith("Dict") - or _type.startswith("Mapping") - or _type.startswith("dict") - ): + elif _type.startswith("Dict") or _type.startswith("Mapping") or _type.startswith("dict"): field.field_type = "dict" diff --git a/src/backend/langflow/template/frontend_node/llms.py b/src/backend/langflow/template/frontend_node/llms.py index b8e007a27..613b2c1d0 100644 --- a/src/backend/langflow/template/frontend_node/llms.py +++ b/src/backend/langflow/template/frontend_node/llms.py @@ -54,9 +54,9 @@ class LLMFrontendNode(FrontendNode): @staticmethod def format_openai_field(field: TemplateField): if "openai" in field.name.lower(): - field.display_name = ( - field.name.title().replace("Openai", "OpenAI").replace("_", " ") - ).replace("Api", "API") + field.display_name = (field.name.title().replace("Openai", "OpenAI").replace("_", " ")).replace( + "Api", "API" + ) if "key" not in field.name.lower() and "token" not in field.name.lower(): field.password = False @@ -109,10 +109,7 @@ class LLMFrontendNode(FrontendNode): if field.name in SHOW_FIELDS: field.show = True - if "api" in field.name and ( - "key" in field.name - or ("token" in field.name and "tokens" not in field.name) - ): + if "api" in field.name and ("key" in field.name or ("token" in field.name and "tokens" not in field.name)): field.password = True field.show = True # Required should be False to support diff --git a/src/backend/langflow/template/frontend_node/memories.py b/src/backend/langflow/template/frontend_node/memories.py index 019dc0fa8..bbf1c9a8d 100644 --- a/src/backend/langflow/template/frontend_node/memories.py +++ b/src/backend/langflow/template/frontend_node/memories.py @@ -76,9 +76,7 @@ class MemoryFrontendNode(FrontendNode): field.show = True field.advanced = False field.value = "" - field.info = ( - INPUT_KEY_INFO if field.name == "input_key" else OUTPUT_KEY_INFO - ) + field.info = INPUT_KEY_INFO if field.name == "input_key" else OUTPUT_KEY_INFO if field.name == "memory_key": field.value = "chat_history" diff --git a/src/backend/langflow/template/frontend_node/prompts.py b/src/backend/langflow/template/frontend_node/prompts.py index f0ebc35aa..dccd66301 100644 --- a/src/backend/langflow/template/frontend_node/prompts.py +++ b/src/backend/langflow/template/frontend_node/prompts.py @@ -36,10 +36,7 @@ class PromptFrontendNode(FrontendNode): field.field_type = "prompt" field.advanced = False - if ( - "Union" in field.field_type - and "BaseMessagePromptTemplate" in field.field_type - ): + if "Union" in field.field_type and "BaseMessagePromptTemplate" in field.field_type: field.field_type = "BaseMessagePromptTemplate" # All prompt fields should be password=False diff --git a/src/backend/langflow/utils/payload.py b/src/backend/langflow/utils/payload.py index cac23a0d6..02cca5c71 100644 --- a/src/backend/langflow/utils/payload.py +++ b/src/backend/langflow/utils/payload.py @@ -81,9 +81,7 @@ def build_json(root, graph) -> Dict: raise ValueError(f"No child with type {node_type} found") values = [build_json(child, graph) for child in children] value = ( - list(values) - if value["list"] - else next(iter(values), None) # type: ignore + list(values) if value["list"] else next(iter(values), None) # type: ignore ) final_dict[key] = value diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index c23c559e4..8eb7f8b25 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -15,12 +15,8 @@ def remove_ansi_escape_codes(text): return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text) -def build_template_from_function( - name: str, type_to_loader_dict: Dict, add_function: bool = False -): - classes = [ - item.__annotations__["return"].__name__ for item in type_to_loader_dict.values() - ] +def build_template_from_function(name: str, type_to_loader_dict: Dict, add_function: bool = False): + classes = [item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()] # Raise error if name is not in chains if name not in classes: @@ -41,9 +37,7 @@ def build_template_from_function( for name_, value_ in value.__repr_args__(): if name_ == "default_factory": try: - variables[class_field_items][ - "default" - ] = get_default_factory( + variables[class_field_items]["default"] = get_default_factory( module=_class.__base__.__module__, function=value_ ) except Exception: @@ -52,9 +46,7 @@ def build_template_from_function( variables[class_field_items][name_] = value_ variables[class_field_items]["placeholder"] = ( - docs.params[class_field_items] - if class_field_items in docs.params - else "" + docs.params[class_field_items] if class_field_items in docs.params else "" ) # Adding function to base classes to allow # the output to be a function @@ -69,9 +61,7 @@ def build_template_from_function( } -def build_template_from_class( - name: str, type_to_cls_dict: Dict, add_function: bool = False -): +def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False): classes = [item.__name__ for item in type_to_cls_dict.values()] # Raise error if name is not in chains @@ -95,9 +85,7 @@ def build_template_from_class( for name_, value_ in value.__repr_args__(): if name_ == "default_factory": try: - variables[class_field_items][ - "default" - ] = get_default_factory( + variables[class_field_items]["default"] = get_default_factory( module=_class.__base__.__module__, function=value_ ) except Exception: @@ -106,9 +94,7 @@ def build_template_from_class( variables[class_field_items][name_] = value_ variables[class_field_items]["placeholder"] = ( - docs.params[class_field_items] - if class_field_items in docs.params - else "" + docs.params[class_field_items] if class_field_items in docs.params else "" ) base_classes = get_base_classes(_class) # Adding function to base classes to allow @@ -140,9 +126,7 @@ def build_template_from_method( # Check if the method exists in this class if not hasattr(_class, method_name): - raise ValueError( - f"Method {method_name} not found in class {class_name}" - ) + raise ValueError(f"Method {method_name} not found in class {class_name}") # Get the method method = getattr(_class, method_name) @@ -161,12 +145,8 @@ def build_template_from_method( "_type": _type, **{ name: { - "default": param.default - if param.default != param.empty - else None, - "type": param.annotation - if param.annotation != param.empty - else None, + "default": param.default if param.default != param.empty else None, + "type": param.annotation if param.annotation != param.empty else None, "required": param.default == param.empty, } for name, param in params.items() @@ -253,9 +233,7 @@ def sync_to_async(func): return async_wrapper -def format_dict( - dictionary: Dict[str, Any], class_name: Optional[str] = None -) -> Dict[str, Any]: +def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) -> Dict[str, Any]: """ Formats a dictionary by removing certain keys and modifying the values of other keys. @@ -330,9 +308,7 @@ def check_list_type(_type: str, value: Dict[str, Any]) -> str: The modified type string. """ if any(list_type in _type for list_type in ["List", "Sequence", "Set"]): - _type = ( - _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1] - ) + _type = _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1] value["list"] = True else: value["list"] = False @@ -436,9 +412,7 @@ def set_headers_value(value: Dict[str, Any]) -> None: value["value"] = """{"Authorization": "Bearer "}""" -def add_options_to_field( - value: Dict[str, Any], class_name: Optional[str], key: str -) -> None: +def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: str) -> None: """ Adds options to the field based on the class name and key. """ diff --git a/src/backend/langflow/utils/validate.py b/src/backend/langflow/utils/validate.py index f8a9c1d1d..dc0244aca 100644 --- a/src/backend/langflow/utils/validate.py +++ b/src/backend/langflow/utils/validate.py @@ -41,9 +41,7 @@ def validate_code(code): # Evaluate the function definition for node in tree.body: if isinstance(node, ast.FunctionDef): - code_obj = compile( - ast.Module(body=[node], type_ignores=[]), "", "exec" - ) + code_obj = compile(ast.Module(body=[node], type_ignores=[]), "", "exec") try: exec(code_obj) except Exception as e: @@ -63,8 +61,7 @@ def eval_function(function_string: str): ( obj for name, obj in namespace.items() - if isinstance(obj, types.FunctionType) - and obj.__code__.co_filename == "" + if isinstance(obj, types.FunctionType) and obj.__code__.co_filename == "" ), None, ) @@ -88,23 +85,15 @@ def execute_function(code, function_name, *args, **kwargs): exec_globals, locals(), ) - exec_globals[alias.asname or alias.name] = importlib.import_module( - alias.name - ) + exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name) except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"Module {alias.name} not found. Please install it and try again." - ) from e + raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e function_code = next( - node - for node in module.body - if isinstance(node, ast.FunctionDef) and node.name == function_name + node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name ) function_code.parent = None - code_obj = compile( - ast.Module(body=[function_code], type_ignores=[]), "", "exec" - ) + code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "", "exec") try: exec(code_obj, exec_globals, locals()) except Exception as exc: @@ -131,23 +120,15 @@ def create_function(code, function_name): if isinstance(node, ast.Import): for alias in node.names: try: - exec_globals[alias.asname or alias.name] = importlib.import_module( - alias.name - ) + exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name) except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"Module {alias.name} not found. Please install it and try again." - ) from e + raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e function_code = next( - node - for node in module.body - if isinstance(node, ast.FunctionDef) and node.name == function_name + node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name ) function_code.parent = None - code_obj = compile( - ast.Module(body=[function_code], type_ignores=[]), "", "exec" - ) + code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "", "exec") with contextlib.suppress(Exception): exec(code_obj, exec_globals, locals()) exec_globals[function_name] = locals()[function_name] @@ -178,32 +159,20 @@ def create_class(code, class_name): if isinstance(node, ast.Import): for alias in node.names: try: - exec_globals[alias.asname or alias.name] = importlib.import_module( - alias.name - ) + exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name) except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"Module {alias.name} not found. Please install it and try again." - ) from e + raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e elif isinstance(node, ast.ImportFrom): try: imported_module = importlib.import_module(node.module) for alias in node.names: exec_globals[alias.name] = getattr(imported_module, alias.name) except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"Module {node.module} not found. Please install it and try again." - ) from e + raise ModuleNotFoundError(f"Module {node.module} not found. Please install it and try again.") from e - class_code = next( - node - for node in module.body - if isinstance(node, ast.ClassDef) and node.name == class_name - ) + class_code = next(node for node in module.body if isinstance(node, ast.ClassDef) and node.name == class_name) class_code.parent = None - code_obj = compile( - ast.Module(body=[class_code], type_ignores=[]), "", "exec" - ) + code_obj = compile(ast.Module(body=[class_code], type_ignores=[]), "", "exec") # This suppresses import errors # with contextlib.suppress(Exception): exec(code_obj, exec_globals, locals()) diff --git a/src/backend/langflow/worker.py b/src/backend/langflow/worker.py index 264e94c7f..e466792cf 100644 --- a/src/backend/langflow/worker.py +++ b/src/backend/langflow/worker.py @@ -30,9 +30,7 @@ def build_vertex(self, vertex: "Vertex") -> "Vertex": vertex.build() return vertex except SoftTimeLimitExceeded as e: - raise self.retry( - exc=SoftTimeLimitExceeded("Task took too long"), countdown=2 - ) from e + raise self.retry(exc=SoftTimeLimitExceeded("Task took too long"), countdown=2) from e @celery_app.task(acks_late=True) @@ -47,9 +45,7 @@ def process_graph_cached_task( if clear_cache: session_service.clear_session(session_id) if session_id is None: - session_id = session_service.generate_key( - session_id=session_id, data_graph=data_graph - ) + session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph) # Load the graph using SessionService graph, artifacts = session_service.load_session(session_id, data_graph) built_object = graph.build() diff --git a/tests/conftest.py b/tests/conftest.py index f58c007f0..857e26f06 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,31 +26,17 @@ if TYPE_CHECKING: def pytest_configure(): - pytest.BASIC_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "basic_example.json" - ) - pytest.COMPLEX_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "complex_example.json" - ) - pytest.OPENAPI_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "Openapi.json" - ) - pytest.GROUPED_CHAT_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "grouped_chat.json" - ) - pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "one_group_chat.json" - ) - pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = ( - Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json" - ) + pytest.BASIC_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "basic_example.json" + pytest.COMPLEX_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "complex_example.json" + pytest.OPENAPI_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "Openapi.json" + pytest.GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "grouped_chat.json" + pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "one_group_chat.json" + pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json" pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = ( Path(__file__).parent.absolute() / "data" / "BasicChatwithPromptandHistory.json" ) - pytest.VECTOR_STORE_PATH = ( - Path(__file__).parent.absolute() / "data" / "Vector_store.json" - ) + pytest.VECTOR_STORE_PATH = Path(__file__).parent.absolute() / "data" / "Vector_store.json" pytest.CODE_WITH_SYNTAX_ERROR = """ def get_text(): retun "Hello World" @@ -68,9 +54,7 @@ async def async_client() -> AsyncGenerator: @pytest.fixture(name="session") def session_fixture(): - engine = create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ) + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) SQLModel.metadata.create_all(engine) with Session(engine) as session: yield session @@ -106,9 +90,7 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env): monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false") # monkeypatch langflow.services.task.manager.USE_CELERY to True # monkeypatch.setattr(manager, "USE_CELERY", True) - monkeypatch.setattr( - celery_app, "celery_app", celery_app.make_celery("langflow", Config) - ) + monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config)) # def get_session_override(): # return session @@ -259,11 +241,7 @@ def active_user(client): is_superuser=False, ) # check if user exists - if ( - active_user := session.query(User) - .filter(User.username == user.username) - .first() - ): + if active_user := session.query(User).filter(User.username == user.username).first(): return active_user session.add(user) session.commit() @@ -286,9 +264,7 @@ def flow(client, json_flow: str, active_user): from langflow.services.database.models.flow.flow import FlowCreate loaded_json = json.loads(json_flow) - flow_data = FlowCreate( - name="test_flow", data=loaded_json.get("data"), user_id=active_user.id - ) + flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id) flow = Flow(**flow_data.dict()) with session_getter(get_db_service()) as session: session.add(flow) @@ -315,9 +291,7 @@ def added_vector_store(client, json_vector_store, logged_in_headers): vector_store = orjson.loads(json_vector_store) data = vector_store["data"] vector_store = FlowCreate(name="Vector Store", description="description", data=data) - response = client.post( - "api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers - ) + response = client.post("api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers) assert response.status_code == 201 assert response.json()["name"] == vector_store.name assert response.json()["data"] == vector_store.data diff --git a/tests/locust/locustfile.py b/tests/locust/locustfile.py index aca0d1de9..1fc91ee2c 100644 --- a/tests/locust/locustfile.py +++ b/tests/locust/locustfile.py @@ -66,9 +66,7 @@ class NameTest(FastHttpUser): result1, session_id = self.process(name, self.flow_id, payload1) payload2 = { - "inputs": { - "text": "What is my name? Please, answer like this: Your name is " - }, + "inputs": {"text": "What is my name? Please, answer like this: Your name is "}, "session_id": session_id, "sync": False, } @@ -88,9 +86,7 @@ class NameTest(FastHttpUser): logged_in_headers = {"Authorization": f"Bearer {a_token}"} print("Logged in") with open( - Path(__file__).parent.parent - / "data" - / "BasicChatwithPromptandHistory.json", + Path(__file__).parent.parent / "data" / "BasicChatwithPromptandHistory.json", "r", ) as f: json_flow = f.read() @@ -115,11 +111,7 @@ class NameTest(FastHttpUser): ) print(response.json()) user_id = next( - ( - user["id"] - for user in response.json()["users"] - if user["username"] == "superuser" - ), + (user["id"] for user in response.json()["users"] if user["username"] == "superuser"), None, ) # Create api key diff --git a/tests/test_api_key.py b/tests/test_api_key.py index 43b91fa43..7988793d4 100644 --- a/tests/test_api_key.py +++ b/tests/test_api_key.py @@ -6,9 +6,7 @@ from langflow.services.database.models.api_key import ApiKeyCreate def api_key(client, logged_in_headers, active_user): api_key = ApiKeyCreate(name="test-api-key") - response = client.post( - "api/v1/api_key", data=api_key.json(), headers=logged_in_headers - ) + response = client.post("api/v1/api_key", data=api_key.json(), headers=logged_in_headers) assert response.status_code == 200, response.text return response.json() @@ -28,9 +26,7 @@ def test_get_api_keys(client, logged_in_headers, api_key): def test_create_api_key(client, logged_in_headers): api_key_name = "test-api-key" - response = client.post( - "api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers - ) + response = client.post("api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers) assert response.status_code == 200 data = response.json() assert "name" in data and data["name"] == api_key_name diff --git a/tests/test_chains_template.py b/tests/test_chains_template.py index eb20a0571..771fb91cc 100644 --- a/tests/test_chains_template.py +++ b/tests/test_chains_template.py @@ -96,10 +96,7 @@ def test_conversation_chain(client: TestClient, logged_in_headers): assert template["_type"] == "ConversationChain" # Test the description object - assert ( - chain["description"] - == "Chain to have a conversation and load context from memory." - ) + assert chain["description"] == "Chain to have a conversation and load context from memory." def test_llm_chain(client: TestClient, logged_in_headers): @@ -293,10 +290,7 @@ def test_llm_math_chain(client: TestClient, logged_in_headers): assert template["_type"] == "LLMMathChain" # Test the description object - assert ( - chain["description"] - == "Chain that interprets a prompt and executes python code to do math." - ) + assert chain["description"] == "Chain that interprets a prompt and executes python code to do math." def test_series_character_chain(client: TestClient, logged_in_headers): @@ -402,10 +396,7 @@ def test_mid_journey_prompt_chain(client: TestClient, logged_in_headers): "info": "", } # Test the description object - assert ( - chain["description"] - == "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts." - ) + assert chain["description"] == "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts." def test_time_travel_guide_chain(client: TestClient, logged_in_headers): diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index 47c9cbfb2..9cc279d2a 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -113,9 +113,7 @@ def test_custom_component_init(): """ function_entrypoint_name = "build" - custom_component = CustomComponent( - code=code_default, function_entrypoint_name=function_entrypoint_name - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name) assert custom_component.code == code_default assert custom_component.function_entrypoint_name == function_entrypoint_name @@ -124,9 +122,7 @@ def test_custom_component_build_template_config(): """ Test the build_template_config property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") config = custom_component.build_template_config assert isinstance(config, dict) @@ -135,9 +131,7 @@ def test_custom_component_get_function(): """ Test the get_function property of the CustomComponent class. """ - custom_component = CustomComponent( - code="def build(): pass", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") my_function = custom_component.get_function assert isinstance(my_function, types.FunctionType) @@ -222,9 +216,7 @@ def test_custom_component_get_function_entrypoint_args(): Test the get_function_entrypoint_args property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") args = custom_component.get_function_entrypoint_args assert len(args) == 4 assert args[0]["name"] == "self" @@ -237,9 +229,7 @@ def test_custom_component_get_function_entrypoint_return_type(): Test the get_function_entrypoint_return_type property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") return_type = custom_component.get_function_entrypoint_return_type assert return_type == ["Document"] @@ -248,9 +238,7 @@ def test_custom_component_get_main_class_name(): """ Test the get_main_class_name property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") class_name = custom_component.get_main_class_name assert class_name == "YourComponent" @@ -260,9 +248,7 @@ def test_custom_component_get_function_valid(): Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name. """ - custom_component = CustomComponent( - code="def build(): pass", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") my_function = custom_component.get_function assert callable(my_function) @@ -297,9 +283,7 @@ def test_code_parser_parse_callable_details_no_args(): parser = CodeParser("") node = ast.FunctionDef( name="test", - args=ast.arguments( - args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] - ), + args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[], decorator_list=[], returns=None, @@ -345,9 +329,7 @@ def test_code_parser_parse_function_def_not_init(): parser = CodeParser("") stmt = ast.FunctionDef( name="test", - args=ast.arguments( - args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] - ), + args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[], decorator_list=[], returns=None, @@ -365,9 +347,7 @@ def test_code_parser_parse_function_def_init(): parser = CodeParser("") stmt = ast.FunctionDef( name="__init__", - args=ast.arguments( - args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] - ), + args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[], decorator_list=[], returns=None, @@ -402,9 +382,7 @@ def test_custom_component_get_code_tree_syntax_error(): Test the get_code_tree method of the CustomComponent class raises the CodeSyntaxError when given incorrect syntax. """ - custom_component = CustomComponent( - code="import os as", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="import os as", function_entrypoint_name="build") with pytest.raises(CodeSyntaxError): custom_component.get_code_tree(custom_component.code) @@ -458,9 +436,7 @@ def test_custom_component_build_not_implemented(): Test the build method of the CustomComponent class raises the NotImplementedError. """ - custom_component = CustomComponent( - code="def build(): pass", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") with pytest.raises(NotImplementedError): custom_component.build() @@ -494,9 +470,7 @@ def test_flow(db): } # Create flow - flow = FlowCreate( - id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data - ) + flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data) # Add to database db.add(flow) diff --git a/tests/test_custom_types.py b/tests/test_custom_types.py index b65f58d0a..ba54b7023 100644 --- a/tests/test_custom_types.py +++ b/tests/test_custom_types.py @@ -18,9 +18,7 @@ def test_python_function_tool(): with pytest.raises(SyntaxError): code = pytest.CODE_WITH_SYNTAX_ERROR func = get_function(code) - func = PythonFunctionTool( - name="Test", description="Testing", code=code, func=func - ) + func = PythonFunctionTool(name="Test", description="Testing", code=code, func=func) def test_python_function(): diff --git a/tests/test_database.py b/tests/test_database.py index 4b3465e97..a7a04e76c 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -27,9 +27,7 @@ def json_style(): ) -def test_create_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_create_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] flow = FlowCreate(name="Test Flow", description="description", data=data) @@ -39,9 +37,7 @@ def test_create_flow( assert response.json()["data"] == flow.data # flow is optional so we can create a flow without a flow flow = FlowCreate(name="Test Flow") - response = client.post( - "api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers - ) + response = client.post("api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers) assert response.status_code == 201 assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data @@ -82,9 +78,7 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he assert response.json()["data"] == flow.data -def test_update_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] @@ -97,9 +91,7 @@ def test_update_flow( description="updated description", data=data, ) - response = client.patch( - f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers - ) + response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers) assert response.status_code == 200 assert response.json()["name"] == updated_flow.name @@ -107,9 +99,7 @@ def test_update_flow( # assert response.json()["data"] == updated_flow.data -def test_delete_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_delete_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] flow = FlowCreate(name="Test Flow", description="description", data=data) @@ -120,9 +110,7 @@ def test_delete_flow( assert response.json()["message"] == "Flow deleted successfully" -def test_create_flows( - client: TestClient, session: Session, json_flow: str, logged_in_headers -): +def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] # Create test data @@ -133,9 +121,7 @@ def test_create_flows( ] ) # Make request to endpoint - response = client.post( - "api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers - ) + response = client.post("api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers) # Check response status code assert response.status_code == 201 # Check response data @@ -149,9 +135,7 @@ def test_create_flows( assert response_data[1]["data"] == data -def test_upload_file( - client: TestClient, session: Session, json_flow: str, logged_in_headers -): +def test_upload_file(client: TestClient, session: Session, json_flow: str, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] # Create test data @@ -218,9 +202,7 @@ def test_download_file( assert response_data[1]["data"] == data -def test_create_flow_with_invalid_data( - client: TestClient, active_user, logged_in_headers -): +def test_create_flow_with_invalid_data(client: TestClient, active_user, logged_in_headers): flow = {"name": "a" * 256, "data": "Invalid flow data"} response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers) assert response.status_code == 422 @@ -232,29 +214,19 @@ def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers assert response.status_code == 404 -def test_update_flow_idempotency( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_update_flow_idempotency(client: TestClient, json_flow: str, active_user, logged_in_headers): flow_data = orjson.loads(json_flow) data = flow_data["data"] flow_data = FlowCreate(name="Test Flow", description="description", data=data) - response = client.post( - "api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers - ) + response = client.post("api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers) flow_id = response.json()["id"] updated_flow = FlowCreate(name="Updated Flow", description="description", data=data) - response1 = client.put( - f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers - ) - response2 = client.put( - f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers - ) + response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers) + response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers) assert response1.json() == response2.json() -def test_update_nonexistent_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow_data = orjson.loads(json_flow) data = flow_data["data"] uuid = uuid4() @@ -263,9 +235,7 @@ def test_update_nonexistent_flow( description="description", data=data, ) - response = client.patch( - f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers - ) + response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers) assert response.status_code == 404 diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 48650afdc..aba0a1a78 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -31,10 +31,7 @@ def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1): href, headers=headers, ) - if ( - task_status_response.status_code == 200 - and task_status_response.json()["status"] == "SUCCESS" - ): + if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS": return task_status_response.json() time.sleep(sleep_time) return None # Return None if task did not complete in time @@ -130,11 +127,7 @@ def created_api_key(active_user): ) db_manager = get_db_service() with session_getter(db_manager) as session: - if ( - existing_api_key := session.query(ApiKey) - .filter(ApiKey.api_key == api_key.api_key) - .first() - ): + if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first(): return existing_api_key session.add(api_key) session.commit() @@ -193,9 +186,7 @@ def test_process_flow_invalid_id(client, monkeypatch, created_api_key): } invalid_id = uuid.uuid4() - response = client.post( - f"api/v1/process/{invalid_id}", headers=headers, json=post_data - ) + response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data) assert response.status_code == 404 assert f"Flow {invalid_id} not found" in response.json()["detail"] @@ -236,9 +227,7 @@ def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_k monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses) - monkeypatch.setattr( - endpoints, "process_graph_cached_task", mock_process_graph_cached_task - ) + monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task) api_key = created_api_key.api_key headers = {"x-api-key": api_key} @@ -510,9 +499,7 @@ def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_a @pytest.mark.async_test -def test_vector_store_in_process( - distributed_client, added_vector_store, created_api_key -): +def test_vector_store_in_process(distributed_client, added_vector_store, created_api_key): # Run the /api/v1/process/{flow_id} endpoint headers = {"x-api-key": created_api_key.api_key} post_data = {"inputs": {"input": "What is Langflow?"}} @@ -563,9 +550,7 @@ def test_async_task_processing(distributed_client, added_flow, created_api_key): # Test function without loop @pytest.mark.async_test -def test_async_task_processing_vector_store( - client, added_vector_store, created_api_key -): +def test_async_task_processing_vector_store(client, added_vector_store, created_api_key): headers = {"x-api-key": created_api_key.api_key} post_data = {"inputs": {"input": "How do I upload examples?"}} @@ -594,6 +579,4 @@ def test_async_task_processing_vector_store( # Validate that the task completed successfully and the result is as expected assert "result" in task_status_json, task_status_json assert "output" in task_status_json["result"], task_status_json["result"] - assert "Langflow" in task_status_json["result"]["output"], task_status_json[ - "result" - ] + assert "Langflow" in task_status_json["result"]["output"], task_status_json["result"] diff --git a/tests/test_frontend_nodes.py b/tests/test_frontend_nodes.py index 00fe9fcb1..a45339d4b 100644 --- a/tests/test_frontend_nodes.py +++ b/tests/test_frontend_nodes.py @@ -39,9 +39,7 @@ def test_template_field_defaults(sample_template_field: TemplateField): assert sample_template_field.name == "test_field" -def test_template_to_dict( - sample_template: Template, sample_template_field: TemplateField -): +def test_template_to_dict(sample_template: Template, sample_template_field: TemplateField): template_dict = sample_template.to_dict() assert template_dict["_type"] == "test_template" assert len(template_dict) == 2 # _type and test_field diff --git a/tests/test_graph.py b/tests/test_graph.py index 1a24e0e4b..1c63e62e8 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -47,13 +47,7 @@ def sample_nodes(): return [ { "id": "node1", - "data": { - "node": { - "template": { - "some_field": {"show": True, "advanced": False, "name": "Name1"} - } - } - }, + "data": {"node": {"template": {"some_field": {"show": True, "advanced": False, "name": "Name1"}}}}, }, { "id": "node2", @@ -71,11 +65,7 @@ def sample_nodes(): }, { "id": "node3", - "data": { - "node": { - "template": {"unrelated_field": {"show": True, "advanced": True}} - } - }, + "data": {"node": {"template": {"unrelated_field": {"show": True, "advanced": True}}}}, }, ] @@ -158,15 +148,9 @@ def test_get_node_neighbors_basic(basic_graph): # Root Node is an Agent, it requires an LLMChain and tools # We need to check if there is a Chain in the one of the neighbors' # data attribute in the type key - assert any( - "ConversationBufferMemory" in neighbor.data["type"] - for neighbor, val in neighbors.items() - if val - ) + assert any("ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() if val) - assert any( - "OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val - ) + assert any("OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val) def test_get_node(basic_graph): @@ -339,9 +323,7 @@ def test_find_last_node(grouped_chat_json_flow): def test_ungroup_node(grouped_chat_json_flow): grouped_chat_data = json.loads(grouped_chat_json_flow).get("data") - group_node = grouped_chat_data["nodes"][ - 2 - ] # Assuming the first node is a group node + group_node = grouped_chat_data["nodes"][2] # Assuming the first node is a group node base_flow = copy.deepcopy(grouped_chat_data) ungroup_node(group_node["data"], base_flow) # after ungroup_node is called, the base_flow and grouped_chat_data should be different @@ -393,14 +375,9 @@ def test_process_flow_one_group(one_grouped_chat_json_flow): assert "edges" in processed_flow # Now get the node that has ChatOpenAI in its id - chat_openai_node = next( - (node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None - ) + chat_openai_node = next((node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None) assert chat_openai_node is not None - assert ( - chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] - == "test" - ) + assert chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] == "test" def test_process_flow_vector_store_grouped(vector_store_grouped_json_flow): @@ -449,17 +426,11 @@ def test_update_template(sample_template, sample_nodes): assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False - assert ( - node1_updated["data"]["node"]["template"]["some_field"]["display_name"] - == "Name1" - ) + assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1" assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True - assert ( - node2_updated["data"]["node"]["template"]["other_field"]["display_name"] - == "DisplayName2" - ) + assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2" # Ensure node3 remains unchanged assert node3_updated == sample_nodes[2] @@ -490,9 +461,7 @@ def test_set_new_target_handle(): "data": { "node": { "flow": True, - "template": { - "field_1": {"proxy": {"field": "new_field", "id": "new_id"}} - }, + "template": {"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}}, } } } @@ -512,9 +481,7 @@ def test_update_source_handle(): "nodes": [{"id": "some_node"}, {"id": "last_node"}], "edges": [{"source": "some_node"}], } - updated_edge = update_source_handle( - new_edge, flow_data["nodes"], flow_data["edges"] - ) + updated_edge = update_source_handle(new_edge, flow_data["nodes"], flow_data["edges"]) assert updated_edge["source"] == "last_node" assert updated_edge["data"]["sourceHandle"]["id"] == "last_node" diff --git a/tests/test_login.py b/tests/test_login.py index 252d27f33..399c7b761 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -9,9 +9,7 @@ from langflow.services.auth.utils import get_password_hash def test_user(): return User( username="testuser", - password=get_password_hash( - "testpassword" - ), # Assuming password needs to be hashed + password=get_password_hash("testpassword"), # Assuming password needs to be hashed is_active=True, is_superuser=False, ) @@ -23,17 +21,13 @@ def test_login_successful(client, test_user): session.add(test_user) session.commit() - response = client.post( - "api/v1/login", data={"username": "testuser", "password": "testpassword"} - ) + response = client.post("api/v1/login", data={"username": "testuser", "password": "testpassword"}) assert response.status_code == 200 assert "access_token" in response.json() def test_login_unsuccessful_wrong_username(client): - response = client.post( - "api/v1/login", data={"username": "wrongusername", "password": "testpassword"} - ) + response = client.post("api/v1/login", data={"username": "wrongusername", "password": "testpassword"}) assert response.status_code == 401 assert response.json()["detail"] == "Incorrect username or password" @@ -43,8 +37,6 @@ def test_login_unsuccessful_wrong_password(client, test_user, session): session.add(test_user) session.commit() - response = client.post( - "api/v1/login", data={"username": "testuser", "password": "wrongpassword"} - ) + response = client.post("api/v1/login", data={"username": "testuser", "password": "wrongpassword"}) assert response.status_code == 401 assert response.json()["detail"] == "Incorrect username or password" diff --git a/tests/test_setup_superuser.py b/tests/test_setup_superuser.py index c4b80167d..1d3ed5d68 100644 --- a/tests/test_setup_superuser.py +++ b/tests/test_setup_superuser.py @@ -94,9 +94,7 @@ from langflow.services.utils import ( @patch("langflow.services.deps.get_settings_service") @patch("langflow.services.deps.get_session") -def test_teardown_superuser_default_superuser( - mock_get_session, mock_get_settings_service -): +def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service): mock_settings_service = MagicMock() mock_settings_service.auth_settings.AUTO_LOGIN = True mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER @@ -122,9 +120,7 @@ def test_teardown_superuser_default_superuser( @patch("langflow.services.deps.get_settings_service") @patch("langflow.services.deps.get_session") -def test_teardown_superuser_no_default_superuser( - mock_get_session, mock_get_settings_service -): +def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service): ADMIN_USER_NAME = "admin_user" mock_settings_service = MagicMock() mock_settings_service.auth_settings.AUTO_LOGIN = False diff --git a/tests/test_template.py b/tests/test_template.py index 81f2a6020..5ee2e71da 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -65,9 +65,7 @@ def test_build_template_from_function(): assert "base_classes" in result # Test with add_function=True - result_with_function = build_template_from_function( - "ExampleClass1", type_to_loader_dict, add_function=True - ) + result_with_function = build_template_from_function("ExampleClass1", type_to_loader_dict, add_function=True) assert result_with_function is not None assert "function" in result_with_function["base_classes"] diff --git a/tests/test_user.py b/tests/test_user.py index 32e408cdf..3b4e71527 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -85,15 +85,11 @@ def test_deactivated_user_cannot_access(client, deactivated_user, logged_in_head assert response.json()["detail"] == "The user doesn't have enough privileges" -def test_data_consistency_after_update( - client, active_user, logged_in_headers, super_user_headers -): +def test_data_consistency_after_update(client, active_user, logged_in_headers, super_user_headers): user_id = active_user.id update_data = UserUpdate(is_active=False) - response = client.patch( - f"/api/v1/users/{user_id}", json=update_data.dict(), headers=super_user_headers - ) + response = client.patch(f"/api/v1/users/{user_id}", json=update_data.dict(), headers=super_user_headers) assert response.status_code == 200, response.json() # Fetch the updated user from the database @@ -167,17 +163,13 @@ def test_patch_user(client, active_user, logged_in_headers): username="newname", ) - response = client.patch( - f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers - ) + response = client.patch(f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers) assert response.status_code == 200, response.json() update_data = UserUpdate( profile_image="new_image", ) - response = client.patch( - f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers - ) + response = client.patch(f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers) assert response.status_code == 200, response.json() @@ -205,9 +197,7 @@ def test_patch_user_wrong_id(client, active_user, logged_in_headers): username="newname", ) - response = client.patch( - f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers - ) + response = client.patch(f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers) assert response.status_code == 422, response.json() assert response.json() == { "detail": [ diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 5016eb704..c4c9ee322 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -31,9 +31,7 @@ def test_websocket_endpoint(client: TestClient, active_user, logged_in_headers): # Assuming your websocket_endpoint uses chat_service which caches data from stream_build access_token = logged_in_headers["Authorization"].split(" ")[1] with pytest.raises(WebSocketDisconnect): - with client.websocket_connect( - f"api/v1/chat/non_existing_client_id?token={access_token}" - ) as websocket: + with client.websocket_connect(f"api/v1/chat/non_existing_client_id?token={access_token}") as websocket: websocket.send_json({"type": "test"}) data = websocket.receive_json() assert "Please, build the flow before sending messages" in data["message"]