ruff format

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-11-13 18:32:34 -03:00
commit e4d8687a09
115 changed files with 494 additions and 1505 deletions

View file

@ -98,12 +98,8 @@ def update_settings(
@app.command()
def run(
host: str = typer.Option(
"127.0.0.1", help="Host to bind the server to.", envvar="LANGFLOW_HOST"
),
workers: int = typer.Option(
1, help="Number of worker processes.", envvar="LANGFLOW_WORKERS"
),
host: str = typer.Option("127.0.0.1", help="Host to bind the server to.", envvar="LANGFLOW_HOST"),
workers: int = typer.Option(1, help="Number of worker processes.", envvar="LANGFLOW_WORKERS"),
timeout: int = typer.Option(300, help="Worker timeout in seconds."),
port: int = typer.Option(7860, help="Port to listen on.", envvar="LANGFLOW_PORT"),
components_path: Optional[Path] = typer.Option(
@ -111,19 +107,11 @@ def run(
help="Path to the directory containing custom components.",
envvar="LANGFLOW_COMPONENTS_PATH",
),
config: str = typer.Option(
Path(__file__).parent / "config.yaml", help="Path to the configuration file."
),
config: str = typer.Option(Path(__file__).parent / "config.yaml", help="Path to the configuration file."),
# .env file param
env_file: Path = typer.Option(
None, help="Path to the .env file containing environment variables."
),
log_level: str = typer.Option(
"critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"
),
log_file: Path = typer.Option(
"logs/langflow.log", help="Path to the log file.", envvar="LANGFLOW_LOG_FILE"
),
env_file: Path = typer.Option(None, help="Path to the .env file containing environment variables."),
log_level: str = typer.Option("critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"),
log_file: Path = typer.Option("logs/langflow.log", help="Path to the log file.", envvar="LANGFLOW_LOG_FILE"),
cache: Optional[str] = typer.Option(
envvar="LANGFLOW_LANGCHAIN_CACHE",
help="Type of cache to use. (InMemoryCache, SQLiteCache)",
@ -202,9 +190,7 @@ def run(
def run_on_mac_or_linux(host, port, log_level, options, app, open_browser=True):
webapp_process = Process(
target=run_langflow, args=(host, port, log_level, options, app)
)
webapp_process = Process(target=run_langflow, args=(host, port, log_level, options, app))
webapp_process.start()
status_code = 0
while status_code != 200:
@ -280,9 +266,7 @@ def print_banner(host, port):
)
# Create a panel with the title and the info text, and a border around it
panel = Panel(
f"{title}\n{info_text}", box=box.ROUNDED, border_style="blue", expand=False
)
panel = Panel(f"{title}\n{info_text}", box=box.ROUNDED, border_style="blue", expand=False)
# Print the banner with a separator line before and after
rprint(panel)
@ -314,12 +298,8 @@ def run_langflow(host, port, log_level, options, app):
@app.command()
def superuser(
username: str = typer.Option(..., prompt=True, help="Username for the superuser."),
password: str = typer.Option(
..., prompt=True, hide_input=True, help="Password for the superuser."
),
log_level: str = typer.Option(
"critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"
),
password: str = typer.Option(..., prompt=True, hide_input=True, help="Password for the superuser."),
log_level: str = typer.Option("critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"),
):
"""
Create a superuser.

View file

@ -2,9 +2,7 @@ API_WORDS = ["api", "key", "token"]
def has_api_terms(word: str):
return "api" in word and (
"key" in word or ("token" in word and "tokens" not in word)
)
return "api" in word and ("key" in word or ("token" in word and "tokens" not in word))
def remove_api_keys(flow: dict):
@ -14,11 +12,7 @@ def remove_api_keys(flow: dict):
node_data = node.get("data").get("node")
template = node_data.get("template")
for value in template.values():
if (
isinstance(value, dict)
and has_api_terms(value["name"])
and value.get("password")
):
if isinstance(value, dict) and has_api_terms(value["name"]) and value.get("password"):
value["value"] = None
return flow
@ -39,9 +33,7 @@ def build_input_keys_response(langchain_object, artifacts):
input_keys_response["input_keys"][key] = value
# If the object has memory, that memory will have a memory_variables attribute
# memory variables should be removed from the input keys
if hasattr(langchain_object, "memory") and hasattr(
langchain_object.memory, "memory_variables"
):
if hasattr(langchain_object, "memory") and hasattr(langchain_object.memory, "memory_variables"):
# Remove memory variables from input keys
input_keys_response["input_keys"] = {
key: value
@ -51,9 +43,7 @@ def build_input_keys_response(langchain_object, artifacts):
# Add memory variables to memory_keys
input_keys_response["memory_keys"] = langchain_object.memory.memory_variables
if hasattr(langchain_object, "prompt") and hasattr(
langchain_object.prompt, "template"
):
if hasattr(langchain_object, "prompt") and hasattr(langchain_object.prompt, "template"):
input_keys_response["template"] = langchain_object.prompt.template
return input_keys_response

View file

@ -79,9 +79,7 @@ def save_store_api_key(
try:
api_key = api_key_request.api_key
# Encrypt the API key
encrypted = auth_utils.encrypt_api_key(
api_key, settings_service=settings_service
)
encrypted = auth_utils.encrypt_api_key(api_key, settings_service=settings_service)
current_user.store_api_key = encrypted
db.commit()
return {"detail": "API Key saved"}

View file

@ -79,9 +79,7 @@ def validate_prompt(template: str):
# Check if there are invalid characters in the input_variables
input_variables = check_input_variables(input_variables)
if any(var in INVALID_NAMES for var in input_variables):
raise ValueError(
f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. "
)
raise ValueError(f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. ")
try:
PromptTemplate(template=template, input_variables=input_variables)
@ -132,9 +130,7 @@ def check_input_variables(input_variables: list):
return input_variables
def build_error_message(
input_variables, invalid_chars, wrong_variables, fixed_variables, empty_variables
):
def build_error_message(input_variables, invalid_chars, wrong_variables, fixed_variables, empty_variables):
input_variables_str = ", ".join([f"'{var}'" for var in input_variables])
error_string = f"Invalid input variables: {input_variables_str}. "

View file

@ -28,9 +28,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
await self.websocket.send_json(resp.dict())
async def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
"""Run when tool starts running."""
resp = ChatResponse(
message="",

View file

@ -37,13 +37,9 @@ async def chat(
await websocket.accept()
user = await get_current_user(token, db)
if not user:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
if not user.is_active:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
if client_id in chat_service.cache_service:
await chat_service.handle_websocket(client_id, websocket)
@ -59,9 +55,7 @@ async def chat(
logger.error(f"Error in chat websocket: {exc}")
messsage = exc.detail if isinstance(exc, HTTPException) else str(exc)
if "Could not validate credentials" in str(exc):
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
else:
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=messsage)
@ -103,15 +97,10 @@ async def init_build(
@router.get("/build/{flow_id}/status", response_model=BuiltResponse)
async def build_status(
flow_id: str, cache_service: "BaseCacheService" = Depends(get_cache_service)
):
async def build_status(flow_id: str, cache_service: "BaseCacheService" = Depends(get_cache_service)):
"""Check the flow_id is in the cache_service."""
try:
built = (
flow_id in cache_service
and cache_service[flow_id]["status"] == BuildStatus.SUCCESS
)
built = flow_id in cache_service and cache_service[flow_id]["status"] == BuildStatus.SUCCESS
return BuiltResponse(
built=built,
@ -173,9 +162,7 @@ async def stream_build(
params = vertex._built_object_repr()
valid = True
logger.debug(f"Building node {str(vertex.vertex_type)}")
logger.debug(
f"Output: {params[:100]}{'...' if len(params) > 100 else ''}"
)
logger.debug(f"Output: {params[:100]}{'...' if len(params) > 100 else ''}")
if vertex.artifacts:
# The artifacts will be prompt variables
# passed to build_input_keys_response
@ -187,9 +174,7 @@ async def stream_build(
valid = False
update_build_status(cache_service, flow_id, BuildStatus.FAILURE)
vertex_id = (
vertex.parent_node_id if vertex.parent_is_top_level else vertex.id
)
vertex_id = vertex.parent_node_id if vertex.parent_is_top_level else vertex.id
if vertex_id in graph.top_level_nodes:
response = {
"valid": valid,
@ -203,9 +188,7 @@ async def stream_build(
langchain_object = graph.build()
# Now we need to check the input_keys to send them to the client
if hasattr(langchain_object, "input_keys"):
input_keys_response = build_input_keys_response(
langchain_object, artifacts
)
input_keys_response = build_input_keys_response(langchain_object, artifacts)
else:
input_keys_response = {
"input_keys": None,

View file

@ -92,12 +92,7 @@ async def process(
)
# Get the flow that matches the flow_id and belongs to the user
flow = (
session.query(Flow)
.filter(Flow.id == flow_id)
.filter(Flow.user_id == api_key_user.id)
.first()
)
flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
@ -111,9 +106,7 @@ async def process(
logger.error(f"Error processing tweaks: {exc}")
if sync:
task_id, result = await task_service.launch_and_await_task(
process_graph_cached_task
if task_service.use_celery
else process_graph_cached,
process_graph_cached_task if task_service.use_celery else process_graph_cached,
graph_data,
inputs,
clear_cache,
@ -133,13 +126,9 @@ async def process(
)
if session_id is None:
# Generate a session ID
session_id = get_session_service().generate_key(
session_id=session_id, data_graph=graph_data
)
session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data)
task_id, task = await task_service.launch_task(
process_graph_cached_task
if task_service.use_celery
else process_graph_cached,
process_graph_cached_task if task_service.use_celery else process_graph_cached,
graph_data,
inputs,
clear_cache,
@ -162,18 +151,12 @@ async def process(
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
# This means the Flow ID is not a valid UUID which means it can't find the flow
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
except ValueError as exc:
if f"Flow {flow_id} not found" in str(exc):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
except Exception as e:
# Log stack trace
logger.exception(e)

View file

@ -64,12 +64,7 @@ def read_flow(
current_user: User = Depends(get_current_active_user),
):
"""Read a flow."""
if user_flow := (
session.query(Flow)
.filter(Flow.id == flow_id)
.filter(Flow.user_id == current_user.id)
.first()
):
if user_flow := (session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == current_user.id).first()):
return user_flow
else:
raise HTTPException(status_code=404, detail="Flow not found")

View file

@ -44,9 +44,7 @@ async def login_to_get_access_token(
@router.get("/auto_login")
async def auto_login(
db: Session = Depends(get_session), settings_service=Depends(get_settings_service)
):
async def auto_login(db: Session = Depends(get_session), settings_service=Depends(get_settings_service)):
if settings_service.auth_settings.AUTO_LOGIN:
return create_user_longterm_token(db)
@ -60,9 +58,7 @@ async def auto_login(
@router.post("/refresh")
async def refresh_token(
token: str, current_user: Session = Depends(get_current_active_user)
):
async def refresh_token(token: str, current_user: Session = Depends(get_current_active_user)):
if token:
return create_refresh_token(token)
else:

View file

@ -149,9 +149,7 @@ class StreamData(BaseModel):
data: dict
def __str__(self) -> str:
return (
f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n"
)
return f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n"
class CustomComponentCode(BaseModel):

View file

@ -28,9 +28,7 @@ def get_user_store_api_key(
settings_service=Depends(get_settings_service),
):
if not user.store_api_key:
raise HTTPException(
status_code=400, detail="You must have a store API key set."
)
raise HTTPException(status_code=400, detail="You must have a store API key set.")
decrypted = auth_utils.decrypt_api_key(user.store_api_key, settings_service)
return decrypted
@ -107,9 +105,7 @@ def get_components(
)
except HTTPStatusError as exc:
if exc.response.status_code == 403:
raise ValueError(
"You are not authorized to access this public resource"
)
raise ValueError("You are not authorized to access this public resource")
try:
if result:
if len(result) >= limit:
@ -124,25 +120,19 @@ def get_components(
comp_count = 0
except HTTPStatusError as exc:
if exc.response.status_code == 403:
raise ValueError(
"You are not authorized to access this public resource"
)
raise ValueError("You are not authorized to access this public resource")
if store_api_Key and result:
# Now, from the result, we need to get the components
# the user likes and set the liked_by_user to True
try:
updated_result = update_components_with_user_data(
result, store_service, store_api_Key
)
updated_result = update_components_with_user_data(result, store_service, store_api_Key)
authorized = True
result = updated_result
except Exception:
# If we get an error here, it means the user is not authorized
authorized = False
return ListComponentResponseModel(
results=result, authorized=authorized, count=comp_count
)
return ListComponentResponseModel(results=result, authorized=authorized, count=comp_count)
except Exception as exc:
if isinstance(exc, HTTPStatusError):
if exc.response.status_code == 403:
@ -237,9 +227,7 @@ def like_component(
):
try:
result = store_service.like_component(store_api_Key, component_id)
likes_count = store_service.get_component_likes_count(
store_api_Key, component_id
)
likes_count = store_service.get_component_likes_count(store_api_Key, component_id)
return UsersLikesResponse(likes_count=likes_count, liked_by_user=result)
except Exception as exc:

View file

@ -46,9 +46,7 @@ def add_user(
session.refresh(new_user)
except IntegrityError as e:
session.rollback()
raise HTTPException(
status_code=400, detail="This username is unavailable."
) from e
raise HTTPException(status_code=400, detail="This username is unavailable.") from e
return new_user
@ -96,14 +94,10 @@ def patch_user(
Update an existing user's data.
"""
if not user.is_superuser and user.id != user_id:
raise HTTPException(
status_code=403, detail="You don't have the permission to update this user"
)
raise HTTPException(status_code=403, detail="You don't have the permission to update this user")
if user_update.password:
if not user.is_superuser:
raise HTTPException(
status_code=400, detail="You can't change your password here"
)
raise HTTPException(status_code=400, detail="You can't change your password here")
user_update.password = get_password_hash(user_update.password)
if user_db := get_user_by_id(session, user_id):
@ -123,16 +117,12 @@ def reset_password(
Reset a user's password.
"""
if user_id != user.id:
raise HTTPException(
status_code=400, detail="You can't change another user's password"
)
raise HTTPException(status_code=400, detail="You can't change another user's password")
if not user:
raise HTTPException(status_code=404, detail="User not found")
if verify_password(user_update.password, user.password):
raise HTTPException(
status_code=400, detail="You can't use your current password"
)
raise HTTPException(status_code=400, detail="You can't use your current password")
new_password = get_password_hash(user_update.password)
user.password = new_password
session.commit()
@ -151,13 +141,9 @@ def delete_user(
Delete a user from the database.
"""
if current_user.id == user_id:
raise HTTPException(
status_code=400, detail="You can't delete your own user account"
)
raise HTTPException(status_code=400, detail="You can't delete your own user account")
elif not current_user.is_superuser:
raise HTTPException(
status_code=403, detail="You don't have the permission to delete this user"
)
raise HTTPException(status_code=403, detail="You don't have the permission to delete this user")
user_db = session.query(User).filter(User.id == user_id).first()
if not user_db:

View file

@ -41,9 +41,7 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
add_new_variables_to_template(input_variables, prompt_request)
remove_old_variables_from_template(
old_custom_fields, input_variables, prompt_request
)
remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request)
update_input_variables_field(input_variables, prompt_request)
@ -58,19 +56,12 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
def get_old_custom_fields(prompt_request):
try:
if (
len(prompt_request.frontend_node.custom_fields) == 1
and prompt_request.name == ""
):
if len(prompt_request.frontend_node.custom_fields) == 1 and prompt_request.name == "":
# If there is only one custom field and the name is empty string
# then we are dealing with the first prompt request after the node was created
prompt_request.name = list(
prompt_request.frontend_node.custom_fields.keys()
)[0]
prompt_request.name = list(prompt_request.frontend_node.custom_fields.keys())[0]
old_custom_fields = prompt_request.frontend_node.custom_fields[
prompt_request.name
].copy()
old_custom_fields = prompt_request.frontend_node.custom_fields[prompt_request.name].copy()
except KeyError:
old_custom_fields = []
prompt_request.frontend_node.custom_fields[prompt_request.name] = []
@ -92,40 +83,26 @@ def add_new_variables_to_template(input_variables, prompt_request):
)
if variable in prompt_request.frontend_node.template:
# Set the new field with the old value
template_field.value = prompt_request.frontend_node.template[variable][
"value"
]
template_field.value = prompt_request.frontend_node.template[variable]["value"]
prompt_request.frontend_node.template[variable] = template_field.to_dict()
# Check if variable is not already in the list before appending
if (
variable
not in prompt_request.frontend_node.custom_fields[prompt_request.name]
):
prompt_request.frontend_node.custom_fields[prompt_request.name].append(
variable
)
if variable not in prompt_request.frontend_node.custom_fields[prompt_request.name]:
prompt_request.frontend_node.custom_fields[prompt_request.name].append(variable)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def remove_old_variables_from_template(
old_custom_fields, input_variables, prompt_request
):
def remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request):
for variable in old_custom_fields:
if variable not in input_variables:
try:
# Remove the variable from custom_fields associated with the given name
if (
variable
in prompt_request.frontend_node.custom_fields[prompt_request.name]
):
prompt_request.frontend_node.custom_fields[
prompt_request.name
].remove(variable)
if variable in prompt_request.frontend_node.custom_fields[prompt_request.name]:
prompt_request.frontend_node.custom_fields[prompt_request.name].remove(variable)
# Remove the variable from the template
prompt_request.frontend_node.template.pop(variable, None)
@ -137,6 +114,4 @@ def remove_old_variables_from_template(
def update_input_variables_field(input_variables, prompt_request):
if "input_variables" in prompt_request.frontend_node.template:
prompt_request.frontend_node.template["input_variables"][
"value"
] = input_variables
prompt_request.frontend_node.template["input_variables"]["value"] = input_variables

View file

@ -71,7 +71,9 @@ class ConversationalAgent(CustomComponent):
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
)
agent = OpenAIFunctionsAgent(
llm=llm, tools=tools, prompt=prompt # type: ignore
llm=llm,
tools=tools,
prompt=prompt, # type: ignore
)
return AgentExecutor(
agent=agent,

View file

@ -18,9 +18,7 @@ class PromptRunner(CustomComponent):
"code": {"show": False},
}
def build(
self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}
) -> Document:
def build(self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}) -> Document:
chain = prompt | llm
# The input is an empty dict because the prompt is already filled
result = chain.invoke(input=inputs)

View file

@ -18,9 +18,7 @@ class MetalRetrieverComponent(CustomComponent):
"code": {"show": False},
}
def build(
self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None
) -> BaseRetriever:
def build(self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None) -> BaseRetriever:
try:
metal = Metal(api_key=api_key, client_id=client_id, index_id=index_id)
except Exception as e:

View file

@ -10,9 +10,7 @@ from langchain.agents.agent_toolkits.base import BaseToolkit
class MetaphorToolkit(CustomComponent):
display_name: str = "Metaphor"
description: str = "Metaphor Toolkit"
documentation = (
"https://python.langchain.com/docs/integrations/tools/metaphor_search"
)
documentation = "https://python.langchain.com/docs/integrations/tools/metaphor_search"
beta = True
# api key should be password = True
field_config = {
@ -33,9 +31,7 @@ class MetaphorToolkit(CustomComponent):
@tool
def search(query: str):
"""Call search engine with a query."""
return client.search(
query, use_autoprompt=use_autoprompt, num_results=search_num_results
)
return client.search(query, use_autoprompt=use_autoprompt, num_results=search_num_results)
@tool
def get_contents(ids: List[str]):

View file

@ -30,9 +30,7 @@ class GetRequest(CustomComponent):
},
}
def get_document(
self, session: requests.Session, url: str, headers: Optional[dict], timeout: int
) -> Document:
def get_document(self, session: requests.Session, url: str, headers: Optional[dict], timeout: int) -> Document:
try:
response = session.get(url, headers=headers, timeout=int(timeout))
try:

View file

@ -21,9 +21,7 @@ class JSONDocumentBuilder(CustomComponent):
description: str = "Build a Document containing a JSON object using a key and another Document page content."
output_types: list[str] = ["Document"]
beta = True
documentation: str = (
"https://docs.langflow.org/components/utilities#json-document-builder"
)
documentation: str = "https://docs.langflow.org/components/utilities#json-document-builder"
field_config = {
"key": {"display_name": "Key"},
@ -38,18 +36,11 @@ class JSONDocumentBuilder(CustomComponent):
documents = None
if isinstance(document, list):
documents = [
Document(
page_content=orjson_dumps({key: doc.page_content}, indent_2=False)
)
for doc in document
Document(page_content=orjson_dumps({key: doc.page_content}, indent_2=False)) for doc in document
]
elif isinstance(document, Document):
documents = Document(
page_content=orjson_dumps({key: document.page_content}, indent_2=False)
)
documents = Document(page_content=orjson_dumps({key: document.page_content}, indent_2=False))
else:
raise TypeError(
f"Expected Document or list of Documents, got {type(document)}"
)
raise TypeError(f"Expected Document or list of Documents, got {type(document)}")
self.repr_value = documents
return documents

View file

@ -65,16 +65,12 @@ class PostRequest(CustomComponent):
if not isinstance(document, list) and isinstance(document, Document):
documents: list[Document] = [document]
elif isinstance(document, list) and all(
isinstance(doc, Document) for doc in document
):
elif isinstance(document, list) and all(isinstance(doc, Document) for doc in document):
documents = document
else:
raise ValueError("document must be a Document or a list of Documents")
with requests.Session() as session:
documents = [
self.post_document(session, doc, url, headers) for doc in documents
]
documents = [self.post_document(session, doc, url, headers) for doc in documents]
self.repr_value = documents
return documents

View file

@ -39,9 +39,7 @@ class UpdateRequest(CustomComponent):
) -> Document:
try:
if method == "PATCH":
response = session.patch(
url, headers=headers, data=document.page_content
)
response = session.patch(url, headers=headers, data=document.page_content)
elif method == "PUT":
response = session.put(url, headers=headers, data=document.page_content)
else:
@ -78,17 +76,12 @@ class UpdateRequest(CustomComponent):
if not isinstance(document, list) and isinstance(document, Document):
documents: list[Document] = [document]
elif isinstance(document, list) and all(
isinstance(doc, Document) for doc in document
):
elif isinstance(document, list) and all(isinstance(doc, Document) for doc in document):
documents = document
else:
raise ValueError("document must be a Document or a list of Documents")
with requests.Session() as session:
documents = [
self.update_document(session, doc, url, headers, method)
for doc in documents
]
documents = [self.update_document(session, doc, url, headers, method) for doc in documents]
self.repr_value = documents
return documents

View file

@ -86,8 +86,7 @@ class ChromaComponent(CustomComponent):
if chroma_server_host is not None:
chroma_settings = chromadb.config.Settings(
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
or None,
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
chroma_server_host=chroma_server_host,
chroma_server_port=chroma_server_port or None,
chroma_server_grpc_port=chroma_server_grpc_port or None,
@ -104,6 +103,4 @@ class ChromaComponent(CustomComponent):
client_settings=chroma_settings,
)
return Chroma(
persist_directory=persist_directory, client_settings=chroma_settings
)
return Chroma(persist_directory=persist_directory, client_settings=chroma_settings)

View file

@ -10,9 +10,7 @@ from langchain.schema import BaseRetriever
class VectaraComponent(CustomComponent):
display_name: str = "Vectara"
description: str = "Implementation of Vector Store using Vectara"
documentation = (
"https://python.langchain.com/docs/integrations/vectorstores/vectara"
)
documentation = "https://python.langchain.com/docs/integrations/vectorstores/vectara"
beta = True
# api key should be password = True
field_config = {

View file

@ -8,9 +8,7 @@ if TYPE_CHECKING:
class SourceHandle(BaseModel):
baseClasses: List[str] = Field(
..., description="List of base classes for the source handle."
)
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
@ -18,9 +16,7 @@ class SourceHandle(BaseModel):
class TargetHandle(BaseModel):
fieldName: str = Field(..., description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
inputTypes: Optional[List[str]] = Field(
None, description="List of input types for the target handle."
)
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
type: str = Field(..., description="Type of the target handle.")
@ -49,23 +45,17 @@ class Edge:
def validate_handles(self) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = (
self.target_handle.type in self.source_handle.baseClasses
)
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
else:
self.valid_handles = (
any(
baseClass in self.target_handle.inputTypes
for baseClass in self.source_handle.baseClasses
)
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
or self.target_handle.type in self.source_handle.baseClasses
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
raise ValueError(
f"Edge between {self.source.vertex_type} and {self.target.vertex_type} "
f"has invalid handles"
f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " f"has invalid handles"
)
def __setstate__(self, state):
@ -87,11 +77,7 @@ class Edge:
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(
output in target_req
for output in self.source_types
for target_req in self.target_reqs
)
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
# Get what type of input the target node is expecting
self.matched_type = next(
@ -103,8 +89,7 @@ class Edge:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
raise ValueError(
f"Edge between {self.source.vertex_type} and {self.target.vertex_type} "
f"has no matched type"
f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " f"has no matched type"
)
def __repr__(self) -> str:
@ -117,8 +102,4 @@ class Edge:
return hash(self.__repr__())
def __eq__(self, __value: object) -> bool:
return (
self.__repr__() == __value.__repr__()
if isinstance(__value, Edge)
else False
)
return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False

View file

@ -104,9 +104,7 @@ class Graph:
return
for node in self.nodes:
if not self._validate_node(node):
raise ValueError(
f"{node.vertex_type} is not connected to any other components"
)
raise ValueError(f"{node.vertex_type} is not connected to any other components")
def _validate_node(self, node: Vertex) -> bool:
"""Validates a node."""
@ -119,9 +117,7 @@ class Graph:
def get_nodes_with_target(self, node: Vertex) -> List[Vertex]:
"""Returns the nodes connected to a node."""
connected_nodes: List[Vertex] = [
edge.source for edge in self.edges if edge.target == node
]
connected_nodes: List[Vertex] = [edge.source for edge in self.edges if edge.target == node]
return connected_nodes
def build(self) -> Chain:
@ -149,9 +145,7 @@ class Graph:
def dfs(node):
if state[node] == 1:
# We have a cycle
raise ValueError(
"Graph contains a cycle, cannot perform topological sort"
)
raise ValueError("Graph contains a cycle, cannot perform topological sort")
if state[node] == 0:
state[node] = 1
for edge in node.edges:
@ -245,7 +239,5 @@ class Graph:
def __repr__(self):
node_ids = [node.id for node in self.nodes]
edges_repr = "\n".join(
[f"{edge.source.id} --> {edge.target.id}" for edge in self.edges]
)
edges_repr = "\n".join([f"{edge.source.id} --> {edge.target.id}" for edge in self.edges])
return f"Graph:\nNodes: {node_ids}\nConnections:\n{edges_repr}"

View file

@ -47,10 +47,7 @@ class VertexTypesDict(LazyLoadDictBase):
**{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()},
**{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()},
**{t: types.OutputParserVertex for t in output_parser_creator.to_list()},
**{
t: types.CustomComponentVertex
for t in custom_component_creator.to_list()
},
**{t: types.CustomComponentVertex for t in custom_component_creator.to_list()},
**{t: types.RetrieverVertex for t in retriever_creator.to_list()},
}

View file

@ -28,23 +28,14 @@ def ungroup_node(group_node_data, base_flow):
g_edges = flow["data"]["edges"]
# Redirect edges to the correct proxy node
updated_edges = get_updated_edges(
base_flow, g_nodes, g_edges, group_node_data["id"]
)
updated_edges = get_updated_edges(base_flow, g_nodes, g_edges, group_node_data["id"])
# Update template values
update_template(template, g_nodes)
nodes = [
n for n in base_flow["nodes"] if n["id"] != group_node_data["id"]
] + g_nodes
nodes = [n for n in base_flow["nodes"] if n["id"] != group_node_data["id"]] + g_nodes
edges = (
[
e
for e in base_flow["edges"]
if e["target"] != group_node_data["id"]
and e["source"] != group_node_data["id"]
]
[e for e in base_flow["edges"] if e["target"] != group_node_data["id"] and e["source"] != group_node_data["id"]]
+ g_edges
+ updated_edges
)
@ -66,11 +57,7 @@ def process_flow(flow_object):
if node_id in processed_nodes:
return
if (
node.get("data")
and node["data"].get("node")
and node["data"]["node"].get("flow")
):
if node.get("data") and node["data"].get("node") and node["data"]["node"].get("flow"):
process_flow(node["data"]["node"]["flow"]["data"])
new_nodes = ungroup_node(node["data"], cloned_flow)
# Add new nodes to the queue for future processing
@ -108,26 +95,16 @@ def update_template(template, g_nodes):
if node_index != -1:
display_name = None
show = g_nodes[node_index]["data"]["node"]["template"][field]["show"]
advanced = g_nodes[node_index]["data"]["node"]["template"][field][
"advanced"
]
advanced = g_nodes[node_index]["data"]["node"]["template"][field]["advanced"]
if "display_name" in g_nodes[node_index]["data"]["node"]["template"][field]:
display_name = g_nodes[node_index]["data"]["node"]["template"][field][
"display_name"
]
display_name = g_nodes[node_index]["data"]["node"]["template"][field]["display_name"]
else:
display_name = g_nodes[node_index]["data"]["node"]["template"][field][
"name"
]
display_name = g_nodes[node_index]["data"]["node"]["template"][field]["name"]
g_nodes[node_index]["data"]["node"]["template"][field] = value
g_nodes[node_index]["data"]["node"]["template"][field]["show"] = show
g_nodes[node_index]["data"]["node"]["template"][field][
"advanced"
] = advanced
g_nodes[node_index]["data"]["node"]["template"][field][
"display_name"
] = display_name
g_nodes[node_index]["data"]["node"]["template"][field]["advanced"] = advanced
g_nodes[node_index]["data"]["node"]["template"][field]["display_name"] = display_name
def update_target_handle(new_edge, g_nodes, group_node_id):

View file

@ -51,9 +51,7 @@ class Vertex:
self.params.pop(target_param, None)
continue
if target_param in self.params and not is_basic_type(
self.params[target_param]
):
if target_param in self.params and not is_basic_type(self.params[target_param]):
# edge.source.params = {}
edge.source._build_params()
edge.source._built_object = UnbuiltObject()
@ -99,29 +97,17 @@ class Vertex:
def _parse_data(self) -> None:
self.data = self._data["data"]
self.output = self.data["node"]["base_classes"]
template_dicts = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
self.required_inputs = [
template_dicts[key]["type"]
for key, value in template_dicts.items()
if value["required"]
template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"]
]
self.optional_inputs = [
template_dicts[key]["type"]
for key, value in template_dicts.items()
if not value["required"]
template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"]
]
# Add the template_dicts[key]["input_types"] to the optional_inputs
self.optional_inputs.extend(
[
input_type
for value in template_dicts.values()
for input_type in value.get("input_types", [])
]
[input_type for value in template_dicts.values() for input_type in value.get("input_types", [])]
)
template_dict = self.data["node"]["template"]
@ -160,11 +146,7 @@ class Vertex:
# and use that as the value for the param
# If the type is "str", then we need to get the value of the "value" key
# and use that as the value for the param
template_dict = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
params = self.params.copy() if self.params else {}
for edge in self.edges:
@ -209,11 +191,7 @@ class Vertex:
# before passing it to the build method
_value = value.get("value")
if isinstance(_value, list):
params[key] = {
k: v
for item in value.get("value", [])
for k, v in item.items()
}
params[key] = {k: v for item in value.get("value", []) for k, v in item.items()}
elif isinstance(_value, dict):
params[key] = _value
elif value.get("type") == "int" and value.get("value") is not None:
@ -304,9 +282,7 @@ class Vertex:
self._extend_params_list_with_result(key, result)
self.params[key] = result
def _build_list_of_nodes_and_update_params(
self, key, nodes: List["Vertex"], user_id=None
):
def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
"""
Iterates over a list of nodes, builds each and updates the params dictionary.
"""
@ -358,9 +334,7 @@ class Vertex:
self._update_built_object_and_artifacts(result)
except Exception as exc:
logger.exception(exc)
raise ValueError(
f"Error building node {self.vertex_type}(ID:{self.id}): {str(exc)}"
) from exc
raise ValueError(f"Error building node {self.vertex_type}(ID:{self.id}): {str(exc)}") from exc
def _update_built_object_and_artifacts(self, result):
"""
@ -408,8 +382,4 @@ class Vertex:
def _built_object_repr(self):
# Add a message with an emoji, stars for sucess,
return (
"Built sucessfully ✨"
if self._built_object is not None
else "Failed to build 😵‍💫"
)
return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵‍💫"

View file

@ -107,11 +107,9 @@ class DocumentLoaderVertex(Vertex):
# show how many documents are in the list?
if self._built_object:
avg_length = sum(
len(doc.page_content)
for doc in self._built_object
if hasattr(doc, "page_content")
) / len(self._built_object)
avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len(
self._built_object
)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
Documents: {self._built_object[:3]}..."""
@ -184,9 +182,7 @@ class TextSplitterVertex(Vertex):
# show how many documents are in the list?
if self._built_object:
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
self._built_object
)
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
\nDocuments: {self._built_object[:3]}..."""
@ -232,27 +228,18 @@ class PromptVertex(Vertex):
**kwargs,
) -> Any:
if not self._built or force:
if (
"input_variables" not in self.params
or self.params["input_variables"] is None
):
if "input_variables" not in self.params or self.params["input_variables"] is None:
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = (
[tool_node.build(user_id=user_id) for tool_node in tools]
if tools is not None
else []
)
tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
tools = flatten_list(tools)
self.params["tools"] = tools
prompt_params = [
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]
@ -262,9 +249,7 @@ class PromptVertex(Vertex):
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(
set(self.params["input_variables"])
)
self.params["input_variables"] = list(set(self.params["input_variables"]))
elif isinstance(self.params, dict):
self.params.pop("input_variables", None)
@ -272,11 +257,7 @@ class PromptVertex(Vertex):
return self._built_object
def _built_object_repr(self):
if (
not self.artifacts
or self._built_object is None
or not hasattr(self._built_object, "format")
):
if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"):
return super()._built_object_repr()
# We'll build the prompt with the artifacts
# to show the user what the prompt looks like
@ -286,9 +267,7 @@ class PromptVertex(Vertex):
# so the prompt format doesn't break
artifacts.pop("handle_keys", None)
try:
if not hasattr(self._built_object, "template") and hasattr(
self._built_object, "prompt"
):
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
template = self._built_object.prompt.template
else:
template = self._built_object.template
@ -296,11 +275,7 @@ class PromptVertex(Vertex):
if value:
replace_key = "{" + key + "}"
template = template.replace(replace_key, value)
return (
template
if isinstance(template, str)
else f"{self.vertex_type}({template})"
)
return template if isinstance(template, str) else f"{self.vertex_type}({template})"
except KeyError:
return str(self._built_object)

View file

@ -42,9 +42,7 @@ class AgentCreator(LangChainTypeCreator):
add_function=True,
method_name=self.from_method_nodes[name],
)
return build_template_from_class(
name, self.type_to_loader_dict, add_function=True
)
return build_template_from_class(name, self.type_to_loader_dict, add_function=True)
except ValueError as exc:
raise ValueError("Agent not found") from exc
except AttributeError as exc:
@ -56,15 +54,8 @@ class AgentCreator(LangChainTypeCreator):
names = []
settings_service = get_settings_service()
for _, agent in self.type_to_loader_dict.items():
agent_name = (
agent.function_name()
if hasattr(agent, "function_name")
else agent.__name__
)
if (
agent_name in settings_service.settings.AGENTS
or settings_service.settings.DEV
):
agent_name = agent.function_name() if hasattr(agent, "function_name") else agent.__name__
if agent_name in settings_service.settings.AGENTS or settings_service.settings.DEV:
names.append(agent_name)
return names

View file

@ -66,7 +66,8 @@ class JsonAgent(CustomAgentExecutor):
prompt=prompt,
)
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names # type: ignore
llm_chain=llm_chain,
allowed_tools=tool_names, # type: ignore
)
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
@ -90,11 +91,7 @@ class CSVAgent(CustomAgentExecutor):
@classmethod
def from_toolkit_and_llm(
cls,
path: str,
llm: BaseLanguageModel,
pandas_kwargs: Optional[dict] = None,
**kwargs: Any
cls, path: str, llm: BaseLanguageModel, pandas_kwargs: Optional[dict] = None, **kwargs: Any
):
import pandas as pd # type: ignore
@ -115,7 +112,9 @@ class CSVAgent(CustomAgentExecutor):
)
tool_names = {tool.name for tool in tools}
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
llm_chain=llm_chain,
allowed_tools=tool_names,
**kwargs, # type: ignore
)
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
@ -139,9 +138,7 @@ class VectorStoreAgent(CustomAgentExecutor):
super().__init__(*args, **kwargs)
@classmethod
def from_toolkit_and_llm(
cls, llm: BaseLanguageModel, vectorstoreinfo: VectorStoreInfo, **kwargs: Any
):
def from_toolkit_and_llm(cls, llm: BaseLanguageModel, vectorstoreinfo: VectorStoreInfo, **kwargs: Any):
"""Construct a vectorstore agent from an LLM and tools."""
toolkit = VectorStoreToolkit(vectorstore_info=vectorstoreinfo, llm=llm)
@ -154,11 +151,11 @@ class VectorStoreAgent(CustomAgentExecutor):
)
tool_names = {tool.name for tool in tools}
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
)
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True, handle_parsing_errors=True
llm_chain=llm_chain,
allowed_tools=tool_names,
**kwargs, # type: ignore
)
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
@ -179,9 +176,7 @@ class SQLAgent(CustomAgentExecutor):
super().__init__(*args, **kwargs)
@classmethod
def from_toolkit_and_llm(
cls, llm: BaseLanguageModel, database_uri: str, **kwargs: Any
):
def from_toolkit_and_llm(cls, llm: BaseLanguageModel, database_uri: str, **kwargs: Any):
"""Construct an SQL agent from an LLM and tools."""
db = SQLDatabase.from_uri(database_uri)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
@ -199,9 +194,7 @@ class SQLAgent(CustomAgentExecutor):
llmchain = LLMChain(
llm=llm,
prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "dialect"]
),
prompt=PromptTemplate(template=QUERY_CHECKER, input_variables=["query", "dialect"]),
)
tools = [
@ -224,7 +217,9 @@ class SQLAgent(CustomAgentExecutor):
)
tool_names = {tool.name for tool in tools} # type: ignore
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
llm_chain=llm_chain,
allowed_tools=tool_names,
**kwargs, # type: ignore
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
@ -255,10 +250,7 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
@classmethod
def from_toolkit_and_llm(
cls,
llm: BaseLanguageModel,
vectorstoreroutertoolkit: VectorStoreRouterToolkit,
**kwargs: Any
cls, llm: BaseLanguageModel, vectorstoreroutertoolkit: VectorStoreRouterToolkit, **kwargs: Any
):
"""Construct a vector store router agent from an LLM and tools."""
@ -274,11 +266,11 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
)
tool_names = {tool.name for tool in tools}
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
)
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True, handle_parsing_errors=True
llm_chain=llm_chain,
allowed_tools=tool_names,
**kwargs, # type: ignore
)
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)

View file

@ -30,13 +30,8 @@ class LangChainTypeCreator(BaseModel, ABC):
settings_service = get_settings_service()
if self.name_docs_dict is None:
try:
type_settings = getattr(
settings_service.settings, self.type_name.upper()
)
self.name_docs_dict = {
name: value_dict["documentation"]
for name, value_dict in type_settings.items()
}
type_settings = getattr(settings_service.settings, self.type_name.upper())
self.name_docs_dict = {name: value_dict["documentation"] for name, value_dict in type_settings.items()}
except AttributeError as exc:
logger.error(f"Error getting settings for {self.type_name}: {exc}")

View file

@ -33,8 +33,7 @@ class ChainCreator(LangChainTypeCreator):
if self.type_dict is None:
settings_service = get_settings_service()
self.type_dict: dict[str, Any] = {
chain_name: import_class(f"langchain.chains.{chain_name}")
for chain_name in chains.__all__
chain_name: import_class(f"langchain.chains.{chain_name}") for chain_name in chains.__all__
}
from langflow.interface.chains.custom import CUSTOM_CHAINS
@ -45,8 +44,7 @@ class ChainCreator(LangChainTypeCreator):
self.type_dict = {
name: chain
for name, chain in self.type_dict.items()
if name in settings_service.settings.CHAINS
or settings_service.settings.DEV
if name in settings_service.settings.CHAINS or settings_service.settings.DEV
}
return self.type_dict
@ -61,9 +59,7 @@ class ChainCreator(LangChainTypeCreator):
method_name=self.from_method_nodes[name],
add_function=True,
)
return build_template_from_class(
name, self.type_to_loader_dict, add_function=True
)
return build_template_from_class(name, self.type_to_loader_dict, add_function=True)
except ValueError as exc:
raise ValueError(f"Chain {name} not found: {exc}") from exc
except AttributeError as exc:
@ -73,11 +69,7 @@ class ChainCreator(LangChainTypeCreator):
def to_list(self) -> List[str]:
names = []
for _, chain in self.type_to_loader_dict.items():
chain_name = (
chain.function_name()
if hasattr(chain, "function_name")
else chain.__name__
)
chain_name = chain.function_name() if hasattr(chain, "function_name") else chain.__name__
names.append(chain_name)
return names

View file

@ -41,9 +41,7 @@ class BaseCustomConversationChain(ConversationChain):
values["template"] = values["template"].format(**format_dict)
values["template"] = values["template"]
values["input_variables"] = extract_input_variables_from_prompt(
values["template"]
)
values["input_variables"] = extract_input_variables_from_prompt(values["template"])
values["prompt"].template = values["template"]
values["prompt"].input_variables = values["input_variables"]
return values
@ -54,9 +52,7 @@ class SeriesCharacterChain(BaseCustomConversationChain):
character: str
series: str
template: Optional[
str
] = """I want you to act like {character} from {series}.
template: Optional[str] = """I want you to act like {character} from {series}.
I want you to respond and answer like {character}. do not write any explanations. only answer like {character}.
You must know all of the knowledge of {character}.
Current conversation:
@ -71,9 +67,7 @@ Human: {input}
class MidJourneyPromptChain(BaseCustomConversationChain):
"""MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."""
template: Optional[
str
] = """I want you to act as a prompt generator for Midjourney's artificial intelligence program.
template: Optional[str] = """I want you to act as a prompt generator for Midjourney's artificial intelligence program.
Your job is to provide detailed and creative descriptions that will inspire unique and interesting images from the AI.
Keep in mind that the AI is capable of understanding a wide range of language and can interpret abstract concepts, so feel free to be as imaginative and descriptive as possible.
For example, you could describe a scene from a futuristic city, or a surreal landscape filled with strange creatures.
@ -87,9 +81,7 @@ class MidJourneyPromptChain(BaseCustomConversationChain):
class TimeTravelGuideChain(BaseCustomConversationChain):
template: Optional[
str
] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.
template: Optional[str] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.
Current conversation:
{history}
Human: {input}

View file

@ -127,22 +127,14 @@ class CodeParser:
num_defaults = len(node.args.defaults)
num_missing_defaults = num_args - num_defaults
missing_defaults = [None] * num_missing_defaults
default_values = [
ast.unparse(default).strip("'") if default else None
for default in node.args.defaults
]
default_values = [ast.unparse(default).strip("'") if default else None for default in node.args.defaults]
# Now check all default values to see if there
# are any "None" values in the middle
default_values = [
None if value == "None" else value for value in default_values
]
default_values = [None if value == "None" else value for value in default_values]
defaults = missing_defaults + default_values
args = [
self.parse_arg(arg, default)
for arg, default in zip(node.args.args, defaults)
]
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.args, defaults)]
return args
def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
@ -160,17 +152,11 @@ class CodeParser:
"""
Parses the keyword-only arguments of a function or method node.
"""
kw_defaults = [None] * (
len(node.args.kwonlyargs) - len(node.args.kw_defaults)
) + [
ast.unparse(default) if default else None
for default in node.args.kw_defaults
kw_defaults = [None] * (len(node.args.kwonlyargs) - len(node.args.kw_defaults)) + [
ast.unparse(default) if default else None for default in node.args.kw_defaults
]
args = [
self.parse_arg(arg, default)
for arg, default in zip(node.args.kwonlyargs, kw_defaults)
]
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.kwonlyargs, kw_defaults)]
return args
def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
@ -254,9 +240,7 @@ class CodeParser:
Extracts global variables from the code.
"""
global_var = {
"targets": [
t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets
],
"targets": [t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets],
"value": ast.unparse(node.value),
}
self.data["global_vars"].append(global_var)

View file

@ -17,9 +17,7 @@ class ComponentFunctionEntrypointNameNullError(HTTPException):
class Component(BaseModel):
ERROR_CODE_NULL = "Python code must be provided."
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL = (
"The name of the entrypoint function must be provided."
)
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL = "The name of the entrypoint function must be provided."
code: Optional[str]
function_entrypoint_name = "build"

View file

@ -53,9 +53,9 @@ class CustomComponent(Component, extra=Extra.allow):
reader = DirectoryReader("", False)
for type_hint in TYPE_HINT_LIST:
if reader._is_type_hint_used_in_args(
if reader._is_type_hint_used_in_args(type_hint, code) and not reader._is_type_hint_imported(
type_hint, code
) and not reader._is_type_hint_imported(type_hint, code):
):
error_detail = {
"error": "Type hint Error",
"traceback": f"Type hint '{type_hint}' is used but not imported in the code.",
@ -74,20 +74,14 @@ class CustomComponent(Component, extra=Extra.allow):
return ""
tree = self.get_code_tree(self.code)
component_classes = [
cls
for cls in tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
component_classes = [cls for cls in tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
if not component_classes:
return ""
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
]
if not build_methods:
@ -103,8 +97,7 @@ class CustomComponent(Component, extra=Extra.allow):
detail={
"error": "Type hint Error",
"traceback": (
"Prompt type is not supported in the build method."
" Try using PromptTemplate instead."
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
),
},
)
@ -119,20 +112,14 @@ class CustomComponent(Component, extra=Extra.allow):
return []
tree = self.get_code_tree(self.code)
component_classes = [
cls
for cls in tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
component_classes = [cls for cls in tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
if not component_classes:
return []
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
]
if not build_methods:
@ -230,11 +217,7 @@ class CustomComponent(Component, extra=Extra.allow):
if flow_id:
flow = session.query(Flow).get(flow_id)
elif flow_name:
flow = (
session.query(Flow)
.filter(Flow.name == flow_name)
.filter(Flow.user_id == self.user_id)
).first()
flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first()
else:
raise ValueError("Either flow_name or flow_id must be provided")

View file

@ -76,9 +76,7 @@ class DirectoryReader:
for menu in data["menu"]
]
filtered = [menu for menu in items if menu["components"]]
logger.debug(
f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}'
)
logger.debug(f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}')
return {"menu": filtered}
def validate_code(self, file_content):
@ -111,9 +109,7 @@ class DirectoryReader:
Walk through the directory path and return a list of all .py files.
"""
if not (safe_path := self.get_safe_path()):
raise CustomComponentPathValueError(
f"The path needs to start with '{self.base_path}'."
)
raise CustomComponentPathValueError(f"The path needs to start with '{self.base_path}'.")
file_list = []
for root, _, files in os.walk(safe_path):
@ -158,9 +154,7 @@ class DirectoryReader:
for node in ast.walk(module):
if isinstance(node, ast.FunctionDef):
for arg in node.args.args:
if self._is_type_hint_in_arg_annotation(
arg.annotation, type_hint_name
):
if self._is_type_hint_in_arg_annotation(arg.annotation, type_hint_name):
return True
except SyntaxError:
# Returns False if the code is not valid Python
@ -178,16 +172,14 @@ class DirectoryReader:
and annotation.value.id == type_hint_name
)
def is_type_hint_used_but_not_imported(
self, type_hint_name: str, code: str
) -> bool:
def is_type_hint_used_but_not_imported(self, type_hint_name: str, code: str) -> bool:
"""
Check if a type hint is used but not imported in the given code.
"""
try:
return self._is_type_hint_used_in_args(
return self._is_type_hint_used_in_args(type_hint_name, code) and not self._is_type_hint_imported(
type_hint_name, code
) and not self._is_type_hint_imported(type_hint_name, code)
)
except SyntaxError:
# Returns True if there's something wrong with the code
# TODO : Find a better way to handle this
@ -208,9 +200,9 @@ class DirectoryReader:
return False, "Syntax error"
elif not self.validate_build(file_content):
return False, "Missing build function"
elif self._is_type_hint_used_in_args(
elif self._is_type_hint_used_in_args("Optional", file_content) and not self._is_type_hint_imported(
"Optional", file_content
) and not self._is_type_hint_imported("Optional", file_content):
):
return (
False,
"Type hint 'Optional' is used but not imported in the code.",
@ -226,9 +218,7 @@ class DirectoryReader:
from the .py files in the directory.
"""
response = {"menu": []}
logger.debug(
"-------------------- Building component menu list --------------------"
)
logger.debug("-------------------- Building component menu list --------------------")
for file_path in file_paths:
menu_name = os.path.basename(os.path.dirname(file_path))
@ -248,9 +238,7 @@ class DirectoryReader:
# first check if it's already CamelCase
if "_" in component_name:
component_name_camelcase = " ".join(
word.title() for word in component_name.split("_")
)
component_name_camelcase = " ".join(word.title() for word in component_name.split("_"))
else:
component_name_camelcase = component_name
@ -266,7 +254,5 @@ class DirectoryReader:
logger.debug(f"Component info: {component_info}")
if menu_result not in response["menu"]:
response["menu"].append(menu_result)
logger.debug(
"-------------------- Component menu list built --------------------"
)
logger.debug("-------------------- Component menu list built --------------------")
return response

View file

@ -46,34 +46,26 @@ toolkit_type_to_cls_dict: dict[str, Any] = {
# Memories
memory_type_to_cls_dict: dict[str, Any] = {
memory_name: import_class(f"langchain.memory.{memory_name}")
for memory_name in memory.__all__
memory_name: import_class(f"langchain.memory.{memory_name}") for memory_name in memory.__all__
}
# Wrappers
wrapper_type_to_cls_dict: dict[str, Any] = {
wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]
}
wrapper_type_to_cls_dict: dict[str, Any] = {wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]}
# Embeddings
embedding_type_to_cls_dict: dict[str, Any] = {
embedding_name: import_class(f"langchain.embeddings.{embedding_name}")
for embedding_name in embeddings.__all__
embedding_name: import_class(f"langchain.embeddings.{embedding_name}") for embedding_name in embeddings.__all__
}
# Document Loaders
documentloaders_type_to_cls_dict: dict[str, Any] = {
documentloader_name: import_class(
f"langchain.document_loaders.{documentloader_name}"
)
documentloader_name: import_class(f"langchain.document_loaders.{documentloader_name}")
for documentloader_name in document_loaders.__all__
}
# Text Splitters
textsplitter_type_to_cls_dict: dict[str, Any] = dict(
inspect.getmembers(text_splitter, inspect.isclass)
)
textsplitter_type_to_cls_dict: dict[str, Any] = dict(inspect.getmembers(text_splitter, inspect.isclass))
# merge CUSTOM_AGENTS and CUSTOM_CHAINS
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} # type: ignore

View file

@ -35,8 +35,7 @@ class DocumentLoaderCreator(LangChainTypeCreator):
return [
documentloader.__name__
for documentloader in self.type_to_loader_dict.values()
if documentloader.__name__ in settings_service.settings.DOCUMENTLOADERS
or settings_service.settings.DEV
if documentloader.__name__ in settings_service.settings.DOCUMENTLOADERS or settings_service.settings.DEV
]

View file

@ -37,8 +37,7 @@ class EmbeddingCreator(LangChainTypeCreator):
return [
embedding.__name__
for embedding in self.type_to_loader_dict.values()
if embedding.__name__ in settings_service.settings.EMBEDDINGS
or settings_service.settings.DEV
if embedding.__name__ in settings_service.settings.EMBEDDINGS or settings_service.settings.DEV
]

View file

@ -104,10 +104,7 @@ def import_prompt(prompt: str) -> Type[PromptTemplate]:
def import_wrapper(wrapper: str) -> Any:
"""Import wrapper from wrapper name"""
if (
isinstance(wrapper_creator.type_dict, dict)
and wrapper in wrapper_creator.type_dict
):
if isinstance(wrapper_creator.type_dict, dict) and wrapper in wrapper_creator.type_dict:
return wrapper_creator.type_dict.get(wrapper)

View file

@ -2,8 +2,6 @@ def initialize_vertexai(class_object, params):
if credentials_path := params.get("credentials"):
from google.oauth2 import service_account # type: ignore
credentials_object = service_account.Credentials.from_service_account_file(
filename=credentials_path
)
credentials_object = service_account.Credentials.from_service_account_file(filename=credentials_path)
params["credentials"] = credentials_object
return class_object(**params)

View file

@ -44,15 +44,10 @@ def build_vertex_in_params(params: Dict) -> Dict:
from langflow.graph.vertex.base import Vertex
# If any of the values in params is a Vertex, we will build it
return {
key: value.build() if isinstance(value, Vertex) else value
for key, value in params.items()
}
return {key: value.build() if isinstance(value, Vertex) else value for key, value in params.items()}
def instantiate_class(
node_type: str, base_type: str, params: Dict, user_id=None
) -> Any:
def instantiate_class(node_type: str, base_type: str, params: Dict, user_id=None) -> Any:
"""Instantiate class from module type and key, and params"""
params = convert_params_to_sets(params)
params = convert_kwargs(params)
@ -64,9 +59,7 @@ def instantiate_class(
return custom_node(**params)
logger.debug(f"Instantiating {node_type} of type {base_type}")
class_object = import_by_type(_type=base_type, name=node_type)
return instantiate_based_on_type(
class_object, base_type, node_type, params, user_id=user_id
)
return instantiate_based_on_type(class_object, base_type, node_type, params, user_id=user_id)
def convert_params_to_sets(params):
@ -194,9 +187,7 @@ def instantiate_memory(node_type, class_object, params):
# I want to catch a specific attribute error that happens
# when the object does not have a cursor attribute
except Exception as exc:
if "object has no attribute 'cursor'" in str(
exc
) or 'object has no field "conn"' in str(exc):
if "object has no attribute 'cursor'" in str(exc) or 'object has no field "conn"' in str(exc):
raise AttributeError(
(
"Failed to build connection to database."
@ -235,9 +226,7 @@ def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params:
if class_method := getattr(class_object, method, None):
agent = class_method(**params)
tools = params.get("tools", [])
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, handle_parsing_errors=True
)
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, handle_parsing_errors=True)
return load_agent_executor(class_object, params)
@ -290,11 +279,7 @@ def instantiate_embedding(node_type, class_object, params: Dict):
try:
return class_object(**params)
except ValidationError:
params = {
key: value
for key, value in params.items()
if key in class_object.__fields__
}
params = {key: value for key, value in params.items() if key in class_object.__fields__}
return class_object(**params)
@ -304,9 +289,7 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
if "texts" in params:
params["documents"] = params.pop("texts")
if "documents" in params:
params["documents"] = [
doc for doc in params["documents"] if isinstance(doc, Document)
]
params["documents"] = [doc for doc in params["documents"] if isinstance(doc, Document)]
if initializer := vecstore_initializer.get(class_object.__name__):
vecstore = initializer(class_object, params)
else:
@ -321,9 +304,7 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
return vecstore
def instantiate_documentloader(
node_type: str, class_object: Type[BaseLoader], params: Dict
):
def instantiate_documentloader(node_type: str, class_object: Type[BaseLoader], params: Dict):
if "file_filter" in params:
# file_filter will be a string but we need a function
# that will be used to filter the files using file_filter
@ -332,17 +313,13 @@ def instantiate_documentloader(
# in x and if it is, we will return True
file_filter = params.pop("file_filter")
extensions = file_filter.split(",")
params["file_filter"] = lambda x: any(
extension.strip() in x for extension in extensions
)
params["file_filter"] = lambda x: any(extension.strip() in x for extension in extensions)
metadata = params.pop("metadata", None)
if metadata and isinstance(metadata, str):
try:
metadata = orjson.loads(metadata)
except json.JSONDecodeError as exc:
raise ValueError(
"The metadata you provided is not a valid JSON string."
) from exc
raise ValueError("The metadata you provided is not a valid JSON string.") from exc
if node_type == "WebBaseLoader":
if web_path := params.pop("web_path", None):
@ -375,16 +352,12 @@ def instantiate_textsplitter(
"Try changing the chunk_size of the Text Splitter."
) from exc
if (
"separator_type" in params and params["separator_type"] == "Text"
) or "separator_type" not in params:
if ("separator_type" in params and params["separator_type"] == "Text") or "separator_type" not in params:
params.pop("separator_type", None)
# separators might come in as an escaped string like \\n
# so we need to convert it to a string
if "separators" in params:
params["separators"] = (
params["separators"].encode().decode("unicode-escape")
)
params["separators"] = params["separators"].encode().decode("unicode-escape")
text_splitter = class_object(**params)
else:
from langchain.text_splitter import Language
@ -411,8 +384,7 @@ def replace_zero_shot_prompt_with_prompt_template(nodes):
tools = [
tool
for tool in nodes
if tool["type"] != "chatOutputNode"
and "Tool" in tool["data"]["node"]["base_classes"]
if tool["type"] != "chatOutputNode" and "Tool" in tool["data"]["node"]["base_classes"]
]
node["data"] = build_prompt_template(prompt=node["data"], tools=tools)
break
@ -426,9 +398,7 @@ def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs)
# agent has hidden args for memory. might need to be support
# memory = params["memory"]
# if allowed_tools is not a list or set, make it a list
if not isinstance(allowed_tools, (list, set)) and isinstance(
allowed_tools, BaseTool
):
if not isinstance(allowed_tools, (list, set)) and isinstance(allowed_tools, BaseTool):
allowed_tools = [allowed_tools]
tool_names = [tool.name for tool in allowed_tools]
# Agent class requires an output_parser but Agent classes
@ -456,10 +426,7 @@ def build_prompt_template(prompt, tools):
format_instructions = prompt["node"]["template"]["format_instructions"]["value"]
tool_strings = "\n".join(
[
f"{tool['data']['node']['name']}: {tool['data']['node']['description']}"
for tool in tools
]
[f"{tool['data']['node']['name']}: {tool['data']['node']['description']}" for tool in tools]
)
tool_names = ", ".join([tool["data"]["node"]["name"] for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)

View file

@ -30,9 +30,7 @@ def check_tools_in_params(params: Dict):
def instantiate_from_template(class_object, params: Dict):
from_template_params = {
"template": params.pop("prompt", params.pop("template", ""))
}
from_template_params = {"template": params.pop("prompt", params.pop("template", ""))}
if not from_template_params.get("template"):
raise ValueError("Prompt template is required")
return class_object.from_template(**from_template_params)
@ -48,9 +46,7 @@ def handle_format_kwargs(prompt, params: Dict):
def handle_partial_variables(prompt, format_kwargs: Dict):
partial_variables = format_kwargs.copy()
partial_variables = {
key: value for key, value in partial_variables.items() if value
}
partial_variables = {key: value for key, value in partial_variables.items() if value}
# Remove handle_keys otherwise LangChain raises an error
partial_variables.pop("handle_keys", None)
if partial_variables and hasattr(prompt, "partial"):
@ -62,9 +58,7 @@ def handle_variable(params: Dict, input_variable: str, format_kwargs: Dict):
variable = params[input_variable]
if isinstance(variable, str):
format_kwargs[input_variable] = variable
elif isinstance(variable, BaseOutputParser) and hasattr(
variable, "get_format_instructions"
):
elif isinstance(variable, BaseOutputParser) and hasattr(variable, "get_format_instructions"):
format_kwargs[input_variable] = variable.get_format_instructions()
elif is_instance_of_list_or_document(variable):
format_kwargs = format_document(variable, input_variable, format_kwargs)
@ -107,8 +101,7 @@ def try_to_load_json(content):
def needs_handle_keys(variable):
return is_instance_of_list_or_document(variable) or (
isinstance(variable, BaseOutputParser)
and hasattr(variable, "get_format_instructions")
isinstance(variable, BaseOutputParser) and hasattr(variable, "get_format_instructions")
)

View file

@ -17,9 +17,7 @@ import orjson
def docs_in_params(params: dict) -> bool:
"""Check if params has documents OR texts and one of them is not an empty list,
If any of them is not an empty list, return True, else return False"""
return ("documents" in params and params["documents"]) or (
"texts" in params and params["texts"]
)
return ("documents" in params and params["documents"]) or ("texts" in params and params["texts"])
def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dict):
@ -31,9 +29,7 @@ def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dic
from pymongo import MongoClient
import certifi
client: MongoClient = MongoClient(
MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where()
)
client: MongoClient = MongoClient(MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where())
db_name = params.pop("db_name", None)
collection_name = params.pop("collection_name", None)
if not db_name or not collection_name:
@ -141,9 +137,7 @@ def initialize_pinecone(class_object: Type[Pinecone], params: dict):
pinecone_env = os.getenv("PINECONE_ENV")
if pinecone_api_key is None or pinecone_env is None:
raise ValueError(
"Pinecone API key and environment must be provided in the params"
)
raise ValueError("Pinecone API key and environment must be provided in the params")
# initialize pinecone
pinecone.init(
@ -177,19 +171,13 @@ def initialize_chroma(class_object: Type[Chroma], params: dict):
import chromadb # type: ignore
settings_params = {
key: params[key]
for key, value_ in params.items()
if key.startswith("chroma_server_") and value_
key: params[key] for key, value_ in params.items() if key.startswith("chroma_server_") and value_
}
chroma_settings = chromadb.config.Settings(**settings_params)
params["client_settings"] = chroma_settings
else:
# remove all chroma_server_ keys from params
params = {
key: value
for key, value in params.items()
if not key.startswith("chroma_server_")
}
params = {key: value for key, value in params.items() if not key.startswith("chroma_server_")}
persist = params.pop("persist", False)
if not docs_in_params(params):

View file

@ -38,8 +38,7 @@ class LLMCreator(LangChainTypeCreator):
return [
llm.__name__
for llm in self.type_to_loader_dict.values()
if llm.__name__ in settings_service.settings.LLMS
or settings_service.settings.DEV
if llm.__name__ in settings_service.settings.LLMS or settings_service.settings.DEV
]

View file

@ -53,8 +53,7 @@ class MemoryCreator(LangChainTypeCreator):
return [
memory.__name__
for memory in self.type_to_loader_dict.values()
if memory.__name__ in settings_service.settings.MEMORIES
or settings_service.settings.DEV
if memory.__name__ in settings_service.settings.MEMORIES or settings_service.settings.DEV
]

View file

@ -26,17 +26,14 @@ class OutputParserCreator(LangChainTypeCreator):
if self.type_dict is None:
settings_service = get_settings_service()
self.type_dict = {
output_parser_name: import_class(
f"langchain.output_parsers.{output_parser_name}"
)
output_parser_name: import_class(f"langchain.output_parsers.{output_parser_name}")
# if output_parser_name is not lower case it is a class
for output_parser_name in output_parsers.__all__
}
self.type_dict = {
name: output_parser
for name, output_parser in self.type_dict.items()
if name in settings_service.settings.OUTPUT_PARSERS
or settings_service.settings.DEV
if name in settings_service.settings.OUTPUT_PARSERS or settings_service.settings.DEV
}
return self.type_dict

View file

@ -36,8 +36,7 @@ class PromptCreator(LangChainTypeCreator):
self.type_dict = {
name: prompt
for name, prompt in self.type_dict.items()
if name in settings_service.settings.PROMPTS
or settings_service.settings.DEV
if name in settings_service.settings.PROMPTS or settings_service.settings.DEV
}
return self.type_dict

View file

@ -42,17 +42,13 @@ class BaseCustomPrompt(PromptTemplate):
values["template"] = values["template"].format(**format_dict)
values["template"] = values["template"]
values["input_variables"] = extract_input_variables_from_prompt(
values["template"]
)
values["input_variables"] = extract_input_variables_from_prompt(values["template"])
return values
class SeriesCharacterPrompt(BaseCustomPrompt):
# Add a very descriptive description for the prompt generator
description: Optional[
str
] = "A prompt that asks the AI to act like a character from a series."
description: Optional[str] = "A prompt that asks the AI to act like a character from a series."
character: str
series: str
template: str = """I want you to act like {character} from {series}.
@ -68,6 +64,4 @@ Human: {input}
input_variables: List[str] = ["character", "series"]
CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = {
"SeriesCharacterPrompt": SeriesCharacterPrompt
}
CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = {"SeriesCharacterPrompt": SeriesCharacterPrompt}

View file

@ -39,9 +39,7 @@ class RetrieverCreator(LangChainTypeCreator):
method_name=self.from_method_nodes[name],
)
else:
return build_template_from_class(
name, type_to_cls_dict=self.type_to_loader_dict
)
return build_template_from_class(name, type_to_cls_dict=self.type_to_loader_dict)
except ValueError as exc:
raise ValueError(f"Retriever {name} not found") from exc
except AttributeError as exc:
@ -53,8 +51,7 @@ class RetrieverCreator(LangChainTypeCreator):
return [
retriever
for retriever in self.type_to_loader_dict.keys()
if retriever in settings_service.settings.RETRIEVERS
or settings_service.settings.DEV
if retriever in settings_service.settings.RETRIEVERS or settings_service.settings.DEV
]

View file

@ -35,8 +35,7 @@ class TextSplitterCreator(LangChainTypeCreator):
return [
textsplitter.__name__
for textsplitter in self.type_to_loader_dict.values()
if textsplitter.__name__ in settings_service.settings.TEXTSPLITTERS
or settings_service.settings.DEV
if textsplitter.__name__ in settings_service.settings.TEXTSPLITTERS or settings_service.settings.DEV
]

View file

@ -32,13 +32,10 @@ class ToolkitCreator(LangChainTypeCreator):
if self.type_dict is None:
settings_service = get_settings_service()
self.type_dict = {
toolkit_name: import_class(
f"langchain.agents.agent_toolkits.{toolkit_name}"
)
toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}")
# if toolkit_name is not lower case it is a class
for toolkit_name in agent_toolkits.__all__
if not toolkit_name.islower()
and toolkit_name in settings_service.settings.TOOLKITS
if not toolkit_name.islower() and toolkit_name in settings_service.settings.TOOLKITS
}
return self.type_dict
@ -61,9 +58,7 @@ class ToolkitCreator(LangChainTypeCreator):
def get_create_function(self, name: str) -> Callable:
if loader_name := self.create_functions.get(name):
return import_module(
f"from langchain.agents.agent_toolkits import {loader_name[0]}"
)
return import_module(f"from langchain.agents.agent_toolkits import {loader_name[0]}")
else:
raise ValueError("Toolkit not found")

View file

@ -31,9 +31,7 @@ TOOL_INPUTS = {
placeholder="",
value="",
),
"llm": TemplateField(
field_type="BaseLanguageModel", required=True, is_list=False, show=True
),
"llm": TemplateField(field_type="BaseLanguageModel", required=True, is_list=False, show=True),
"func": TemplateField(
field_type="function",
required=True,
@ -76,10 +74,7 @@ class ToolCreator(LangChainTypeCreator):
tool_name = tool_params.get("name") or tool
if (
tool_name in settings_service.settings.TOOLS
or settings_service.settings.DEV
):
if tool_name in settings_service.settings.TOOLS or settings_service.settings.DEV:
if tool_name == "JsonSpec":
tool_params["path"] = tool_params.pop("dict_") # type: ignore
all_tools[tool_name] = {

View file

@ -21,16 +21,12 @@ def get_func_tool_params(func, **kwargs) -> Union[Dict, None]:
for keyword in tool.keywords:
if keyword.arg == "name":
try:
tool_params["name"] = ast.literal_eval(
keyword.value
)
tool_params["name"] = ast.literal_eval(keyword.value)
except ValueError:
break
elif keyword.arg == "description":
try:
tool_params["description"] = ast.literal_eval(
keyword.value
)
tool_params["description"] = ast.literal_eval(keyword.value)
except ValueError:
continue
@ -43,9 +39,7 @@ def get_func_tool_params(func, **kwargs) -> Union[Dict, None]:
else:
# get the class object from the return statement
try:
class_obj = eval(
compile(ast.Expression(tool), "<string>", "eval")
)
class_obj = eval(compile(ast.Expression(tool), "<string>", "eval"))
except Exception:
return None

View file

@ -114,14 +114,10 @@ def add_new_custom_field(
# If options is a list, then it's a dropdown
# If options is None, then it's a list of strings
is_list = isinstance(field_config.get("options"), list)
field_config["is_list"] = (
is_list or field_config.get("is_list", False) or field_contains_list
)
field_config["is_list"] = is_list or field_config.get("is_list", False) or field_contains_list
if "name" in field_config:
warnings.warn(
"The 'name' key in field_config is used to build the object and can't be changed."
)
warnings.warn("The 'name' key in field_config is used to build the object and can't be changed.")
field_config.pop("name", None)
required = field_config.pop("required", field_required)
@ -185,9 +181,7 @@ def extract_type_from_optional(field_type):
def build_frontend_node(custom_component: CustomComponent):
"""Build a frontend node for a custom component"""
try:
return (
CustomComponentFrontendNode().to_dict().get(type(custom_component).__name__)
)
return CustomComponentFrontendNode().to_dict().get(type(custom_component).__name__)
except Exception as exc:
logger.error(f"Error while building base frontend node: {exc}")
@ -236,9 +230,7 @@ def add_extra_fields(frontend_node, field_config, function_args):
if "name" not in extra_field or extra_field["name"] == "self":
continue
field_name, field_type, field_value, field_required = get_field_properties(
extra_field
)
field_name, field_type, field_value, field_required = get_field_properties(extra_field)
config = field_config.get(field_name, {})
frontend_node = add_new_custom_field(
frontend_node,
@ -273,8 +265,7 @@ def add_base_classes(frontend_node, return_types: List[str]):
status_code=400,
detail={
"error": (
"Invalid return type should be one of: "
f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}"
"Invalid return type should be one of: " f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}"
),
"traceback": traceback.format_exc(),
},
@ -296,8 +287,7 @@ def add_output_types(frontend_node, return_types: List[str]):
status_code=400,
detail={
"error": (
"Invalid return type should be one of: "
f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}"
"Invalid return type should be one of: " f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}"
),
"traceback": traceback.format_exc(),
},
@ -325,16 +315,10 @@ def build_langchain_template_custom_component(custom_component: CustomComponent)
add_extra_fields(frontend_node, field_config, entrypoint_args)
logger.debug("Added extra fields")
frontend_node = add_code_field(
frontend_node, custom_component.code, field_config.get("code", {})
)
frontend_node = add_code_field(frontend_node, custom_component.code, field_config.get("code", {}))
logger.debug("Added code field")
add_base_classes(
frontend_node, custom_component.get_function_entrypoint_return_type
)
add_output_types(
frontend_node, custom_component.get_function_entrypoint_return_type
)
add_base_classes(frontend_node, custom_component.get_function_entrypoint_return_type)
add_output_types(frontend_node, custom_component.get_function_entrypoint_return_type)
logger.debug("Added base classes")
return frontend_node
except Exception as exc:
@ -343,9 +327,7 @@ def build_langchain_template_custom_component(custom_component: CustomComponent)
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid type convertion. Please check your code and try again."
),
"error": ("Invalid type convertion. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
) from exc
@ -377,9 +359,7 @@ def build_valid_menu(valid_components):
valid_menu[menu_name] = {}
for component in menu_item["components"]:
logger.debug(
f"Building component: {component.get('name'), component.get('output_types')}"
)
logger.debug(f"Building component: {component.get('name'), component.get('output_types')}")
try:
component_name = component["name"]
component_code = component["code"]
@ -388,9 +368,7 @@ def build_valid_menu(valid_components):
component_extractor = CustomComponent(code=component_code)
component_extractor.is_check_valid()
component_template = build_langchain_template_custom_component(
component_extractor
)
component_template = build_langchain_template_custom_component(component_extractor)
component_template["output_types"] = component_output_types
if len(component_output_types) == 1:
component_name = component_output_types[0]
@ -398,9 +376,7 @@ def build_valid_menu(valid_components):
file_name = component.get("file").split(".")[0]
if "_" in file_name:
# turn .py file into camelcase
component_name = "".join(
[word.capitalize() for word in file_name.split("_")]
)
component_name = "".join([word.capitalize() for word in file_name.split("_")])
else:
component_name = file_name
@ -409,9 +385,7 @@ def build_valid_menu(valid_components):
except Exception as exc:
logger.error(f"Error loading Component: {component['output_types']}")
logger.exception(
f"Error while building custom component {component_output_types}: {exc}"
)
logger.exception(f"Error while building custom component {component_output_types}: {exc}")
return valid_menu
@ -449,20 +423,14 @@ def build_invalid_menu(invalid_components):
logger.debug(f"Added {component_name} to invalid menu to {menu_name}")
except Exception as exc:
logger.exception(
f"Error while creating custom component [{component_name}]: {str(exc)}"
)
logger.exception(f"Error while creating custom component [{component_name}]: {str(exc)}")
return invalid_menu
def merge_nested_dicts_with_renaming(dict1, dict2):
for key, value in dict2.items():
if (
key in dict1
and isinstance(value, dict)
and isinstance(dict1.get(key), dict)
):
if key in dict1 and isinstance(value, dict) and isinstance(dict1.get(key), dict):
for sub_key, sub_value in value.items():
if sub_key in dict1[key]:
new_key = get_new_key(dict1[key], sub_key)
@ -479,9 +447,7 @@ def build_langchain_custom_component_list_from_path(path: str):
file_list = load_files_from_path(path)
reader = DirectoryReader(path, False)
valid_components, invalid_components = build_and_validate_all_files(
reader, file_list
)
valid_components, invalid_components = build_and_validate_all_files(reader, file_list)
valid_menu = build_valid_menu(valid_components)
invalid_menu = build_invalid_menu(invalid_components)
@ -495,18 +461,14 @@ def get_all_types_dict(settings_service):
# need to merge all the keys into one dict
custom_components_from_file: dict[str, Any] = {}
if settings_service.settings.COMPONENTS_PATH:
logger.info(
f"Building custom components from {settings_service.settings.COMPONENTS_PATH}"
)
logger.info(f"Building custom components from {settings_service.settings.COMPONENTS_PATH}")
custom_component_dicts = []
processed_paths = []
for path in settings_service.settings.COMPONENTS_PATH:
if str(path) in processed_paths:
continue
custom_component_dict = build_langchain_custom_component_list_from_path(
str(path)
)
custom_component_dict = build_langchain_custom_component_list_from_path(str(path))
custom_component_dicts.append(custom_component_dict)
processed_paths.append(str(path))
@ -516,16 +478,12 @@ def get_all_types_dict(settings_service):
if not custom_component_dict:
continue
category = list(custom_component_dict.keys())[0]
logger.info(
f"Loading {len(custom_component_dict[category])} component(s) from category {category}"
)
logger.info(f"Loading {len(custom_component_dict[category])} component(s) from category {category}")
custom_components_from_file = merge_nested_dicts_with_renaming(
custom_components_from_file, custom_component_dict
)
return merge_nested_dicts_with_renaming(
native_components, custom_components_from_file
)
return merge_nested_dicts_with_renaming(native_components, custom_components_from_file)
def merge_nested_dicts(dict1, dict2):

View file

@ -29,16 +29,14 @@ class UtilityCreator(LangChainTypeCreator):
if self.type_dict is None:
settings_service = get_settings_service()
self.type_dict = {
utility_name: import_class(f"langchain.utilities.{utility_name}")
for utility_name in utilities.__all__
utility_name: import_class(f"langchain.utilities.{utility_name}") for utility_name in utilities.__all__
}
self.type_dict["SQLDatabase"] = utilities.SQLDatabase
# Filter according to settings.utilities
self.type_dict = {
name: utility
for name, utility in self.type_dict.items()
if name in settings_service.settings.UTILITIES
or settings_service.settings.DEV
if name in settings_service.settings.UTILITIES or settings_service.settings.DEV
}
return self.type_dict

View file

@ -43,9 +43,7 @@ def try_setting_streaming_options(langchain_object):
llm = None
if hasattr(langchain_object, "llm"):
llm = langchain_object.llm
elif hasattr(langchain_object, "llm_chain") and hasattr(
langchain_object.llm_chain, "llm"
):
elif hasattr(langchain_object, "llm_chain") and hasattr(langchain_object.llm_chain, "llm"):
llm = langchain_object.llm_chain.llm
if isinstance(llm, BaseLanguageModel):
@ -79,9 +77,7 @@ def set_langchain_cache(settings):
if cache_type := os.getenv("LANGFLOW_LANGCHAIN_CACHE"):
try:
cache_class = import_class(
f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}"
)
cache_class = import_class(f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}")
logger.debug(f"Setting up LLM caching with {cache_class.__name__}")
langchain.llm_cache = cache_class()

View file

@ -22,9 +22,7 @@ class VectorstoreCreator(LangChainTypeCreator):
def type_to_loader_dict(self) -> Dict:
if self.type_dict is None:
self.type_dict: dict[str, Any] = {
vectorstore_name: import_class(
f"langchain.vectorstores.{vectorstore_name}"
)
vectorstore_name: import_class(f"langchain.vectorstores.{vectorstore_name}")
for vectorstore_name in vectorstores.__all__
}
return self.type_dict
@ -48,8 +46,7 @@ class VectorstoreCreator(LangChainTypeCreator):
return [
vectorstore
for vectorstore in self.type_to_loader_dict.keys()
if vectorstore in settings_service.settings.VECTORSTORES
or settings_service.settings.DEV
if vectorstore in settings_service.settings.VECTORSTORES or settings_service.settings.DEV
]

View file

@ -16,8 +16,7 @@ class WrapperCreator(LangChainTypeCreator):
def type_to_loader_dict(self) -> Dict:
if self.type_dict is None:
self.type_dict = {
wrapper.__name__: wrapper
for wrapper in [requests.TextRequestsWrapper, sql_database.SQLDatabase]
wrapper.__name__: wrapper for wrapper in [requests.TextRequestsWrapper, sql_database.SQLDatabase]
}
return self.type_dict

View file

@ -84,9 +84,7 @@ def get_static_files_dir():
return frontend_path / "frontend"
def setup_app(
static_files_dir: Optional[Path] = None, backend_only: bool = False
) -> FastAPI:
def setup_app(static_files_dir: Optional[Path] = None, backend_only: bool = False) -> FastAPI:
"""Setup the FastAPI app."""
# get the directory of the current file
if not static_files_dir:

View file

@ -34,9 +34,7 @@ def get_langfuse_callback(trace_id):
if langfuse := LangfuseInstance.get():
logger.debug("Langfuse credentials found")
try:
trace = langfuse.trace(
CreateTrace(name="langflow-" + trace_id, id=trace_id)
)
trace = langfuse.trace(CreateTrace(name="langflow-" + trace_id, id=trace_id))
return trace.getNewHandler()
except Exception as exc:
logger.error(f"Error initializing langfuse callback: {exc}")
@ -44,9 +42,7 @@ def get_langfuse_callback(trace_id):
return None
def flush_langfuse_callback_if_present(
callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]]
):
def flush_langfuse_callback_if_present(callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]]):
"""
If langfuse callback is present, run callback.langfuse.flush()
"""
@ -88,15 +84,9 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa
# if langfuse callback is present, run callback.langfuse.flush()
flush_langfuse_callback_if_present(callbacks)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
intermediate_steps = output.get("intermediate_steps", []) if isinstance(output, dict) else []
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
result = output.get(langchain_object.output_keys[0]) if isinstance(output, dict) else output
try:
thought = format_actions(intermediate_steps) if intermediate_steps else ""
except Exception as exc:

View file

@ -112,9 +112,7 @@ def load_langchain_object(
logger.debug("Loaded LangChain object")
if langchain_object is None:
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
raise ValueError("There was an error loading the langchain_object. Please, check all the nodes and try again.")
return langchain_object, artifacts, session_id
@ -164,9 +162,7 @@ async def process_graph_cached(
if clear_cache:
session_service.clear_session(session_id)
if session_id is None:
session_id = session_service.generate_key(
session_id=session_id, data_graph=data_graph
)
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
# Load the graph using SessionService
graph, artifacts = session_service.load_session(session_id, data_graph)
built_object = graph.build()
@ -179,9 +175,7 @@ async def process_graph_cached(
return Result(result=result, session_id=session_id)
def load_flow_from_json(
flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True
):
def load_flow_from_json(flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True):
"""
Load flow from a JSON file or a JSON object.
@ -198,9 +192,7 @@ def load_flow_from_json(
elif isinstance(flow, dict):
flow_graph = flow
else:
raise TypeError(
"Input must be either a file path (str) or a JSON object (dict)"
)
raise TypeError("Input must be either a file path (str) or a JSON object (dict)")
graph_data = flow_graph["data"]
if tweaks is not None:
@ -226,18 +218,14 @@ def load_flow_from_json(
return graph
def validate_input(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
raise ValueError("graph_data and tweaks should be dictionaries")
nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes")
if not isinstance(nodes, list):
raise ValueError(
"graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key"
)
raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key")
return nodes
@ -246,9 +234,7 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
template_data = node.get("data", {}).get("node", {}).get("template")
if not isinstance(template_data, dict):
logger.warning(
f"Template data for node {node.get('id')} should be a dictionary"
)
logger.warning(f"Template data for node {node.get('id')} should be a dictionary")
return
for tweak_name, tweak_value in node_tweaks.items():
@ -257,9 +243,7 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
template_data[tweak_name][key] = tweak_value
def process_tweaks(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
) -> Dict[str, Any]:
def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
"""
This function is used to tweak the graph data using the node id and the tweaks dict.
@ -280,8 +264,6 @@ def process_tweaks(
if node_tweaks := tweaks.get(node_id):
apply_tweaks(node, node_tweaks)
else:
logger.warning(
"Each node should be a dictionary with an 'id' key of type str"
)
logger.warning("Each node should be a dictionary with an 'id' key of type str")
return graph_data

View file

@ -10,11 +10,7 @@ class LangflowApplication(BaseApplication):
super().__init__()
def load_config(self):
config = {
key: value
for key, value in self.options.items()
if key in self.cfg.settings and value is not None
}
config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None}
for key, value in config.items():
self.cfg.set(key.lower(), value)

View file

@ -20,12 +20,8 @@ oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login")
API_KEY_NAME = "x-api-key"
api_key_query = APIKeyQuery(
name=API_KEY_NAME, scheme_name="API key query", auto_error=False
)
api_key_header = APIKeyHeader(
name=API_KEY_NAME, scheme_name="API key header", auto_error=False
)
api_key_query = APIKeyQuery(name=API_KEY_NAME, scheme_name="API key query", auto_error=False)
api_key_header = APIKeyHeader(name=API_KEY_NAME, scheme_name="API key header", auto_error=False)
# Source: https://github.com/mrtolkien/fastapi_simple_security/blob/master/fastapi_simple_security/security_api_key.py
@ -118,23 +114,17 @@ def get_current_active_user(current_user: Annotated[User, Depends(get_current_us
return current_user
def get_current_active_superuser(
current_user: Annotated[User, Depends(get_current_user)]
) -> User:
def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User:
if not current_user.is_active:
raise HTTPException(status_code=401, detail="Inactive user")
if not current_user.is_superuser:
raise HTTPException(
status_code=400, detail="The user doesn't have enough privileges"
)
raise HTTPException(status_code=400, detail="The user doesn't have enough privileges")
return current_user
def verify_password(plain_password, hashed_password):
settings_service = get_settings_service()
return settings_service.auth_settings.pwd_context.verify(
plain_password, hashed_password
)
return settings_service.auth_settings.pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
@ -223,22 +213,16 @@ def get_user_id_from_token(token: str) -> UUID:
return UUID(int=0)
def create_user_tokens(
user_id: UUID, db: Session = Depends(get_session), update_last_login: bool = False
) -> dict:
def create_user_tokens(user_id: UUID, db: Session = Depends(get_session), update_last_login: bool = False) -> dict:
settings_service = get_settings_service()
access_token_expires = timedelta(
minutes=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
access_token_expires = timedelta(minutes=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_token(
data={"sub": str(user_id)},
expires_delta=access_token_expires,
)
refresh_token_expires = timedelta(
minutes=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES
)
refresh_token_expires = timedelta(minutes=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES)
refresh_token = create_token(
data={"sub": str(user_id), "type": "rf"},
expires_delta=refresh_token_expires,
@ -268,9 +252,7 @@ def create_refresh_token(refresh_token: str, db: Session = Depends(get_session))
token_type: str = payload.get("type") # type: ignore
if user_id is None or token_type is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
return create_user_tokens(user_id, db)
@ -281,9 +263,7 @@ def create_refresh_token(refresh_token: str, db: Session = Depends(get_session))
) from e
def authenticate_user(
username: str, password: str, db: Session = Depends(get_session)
) -> Optional[User]:
def authenticate_user(username: str, password: str, db: Session = Depends(get_session)) -> Optional[User]:
user = get_user_by_username(db, username)
if not user:
@ -318,9 +298,7 @@ def encrypt_api_key(api_key: str, settings_service=Depends(get_settings_service)
return encrypted_key
def decrypt_api_key(
encrypted_api_key: str, settings_service=Depends(get_settings_service)
):
def decrypt_api_key(encrypted_api_key: str, settings_service=Depends(get_settings_service)):
fernet = get_fernet(settings_service)
# Two-way decryption
if isinstance(encrypted_api_key, str):

View file

@ -26,9 +26,7 @@ class CacheServiceFactory(ServiceFactory):
if redis_cache.is_connected():
logger.debug("Redis cache is connected")
return redis_cache
logger.warning(
"Redis cache is not connected, falling back to in-memory cache"
)
logger.warning("Redis cache is not connected, falling back to in-memory cache")
return InMemoryCache()
elif settings_service.settings.CACHE_TYPE == "memory":

View file

@ -68,10 +68,7 @@ class InMemoryCache(BaseCacheService, Service):
Retrieve an item from the cache without acquiring the lock.
"""
if item := self._cache.get(key):
if (
self.expiration_time is None
or time.time() - item["time"] < self.expiration_time
):
if self.expiration_time is None or time.time() - item["time"] < self.expiration_time:
# Move the key to the end to make it recently used
self._cache.move_to_end(key)
# Check if the value is pickled
@ -118,11 +115,7 @@ class InMemoryCache(BaseCacheService, Service):
"""
with self._lock:
existing_value = self._get_without_lock(key)
if (
existing_value is not None
and isinstance(existing_value, dict)
and isinstance(value, dict)
):
if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict):
existing_value.update(value)
value = existing_value
@ -276,9 +269,7 @@ class RedisCache(BaseCacheService, Service):
if not result:
raise ValueError("RedisCache could not set the value.")
except TypeError as exc:
raise TypeError(
"RedisCache only accepts values that can be pickled. "
) from exc
raise TypeError("RedisCache only accepts values that can be pickled. ") from exc
def upsert(self, key, value):
"""
@ -290,11 +281,7 @@ class RedisCache(BaseCacheService, Service):
value: The value to insert or update.
"""
existing_value = self.get(key)
if (
existing_value is not None
and isinstance(existing_value, dict)
and isinstance(value, dict)
):
if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict):
existing_value.update(value)
value = existing_value

View file

@ -83,9 +83,7 @@ def clear_old_cache_files(max_cache_size: int = 3):
cache_files = list(cache_dir.glob("*.dill"))
if len(cache_files) > max_cache_size:
cache_files_sorted_by_mtime = sorted(
cache_files, key=lambda x: x.stat().st_mtime, reverse=True
)
cache_files_sorted_by_mtime = sorted(cache_files, key=lambda x: x.stat().st_mtime, reverse=True)
for cache_file in cache_files_sorted_by_mtime[max_cache_size:]:
with contextlib.suppress(OSError):

View file

@ -59,9 +59,7 @@ class ChatService(Service):
"""Send the last chat message to the client."""
client_id = self.chat_cache.current_client_id
if client_id in self.active_connections:
chat_response = self.chat_history.get_history(
client_id, filter_messages=False
)[-1]
chat_response = self.chat_history.get_history(client_id, filter_messages=False)[-1]
if chat_response.is_bot:
# Process FileResponse
if isinstance(chat_response, FileResponse):
@ -88,9 +86,7 @@ class ChatService(Service):
data_type=self.last_cached_object_dict["type"],
)
self.chat_history.add_message(
self.chat_cache.current_client_id, chat_response
)
self.chat_history.add_message(self.chat_cache.current_client_id, chat_response)
async def connect(self, client_id: str, websocket: WebSocket):
self.active_connections[client_id] = websocket
@ -121,9 +117,7 @@ class ChatService(Service):
if "after sending" in str(exc):
logger.error(f"Error closing connection: {exc}")
async def process_message(
self, client_id: str, payload: Dict, langchain_object: Any
):
async def process_message(self, client_id: str, payload: Dict, langchain_object: Any):
# Process the graph data and chat message
chat_inputs = payload.pop("inputs", {})
chatkey = payload.pop("chatKey", None)
@ -211,15 +205,11 @@ class ChatService(Service):
continue
with self.chat_cache.set_client_id(client_id):
if langchain_object := self.cache_service.get(client_id).get(
"result"
):
if langchain_object := self.cache_service.get(client_id).get("result"):
await self.process_message(client_id, payload, langchain_object)
else:
raise RuntimeError(
f"Could not find a build result for client_id {client_id}"
)
raise RuntimeError(f"Could not find a build result for client_id {client_id}")
except Exception as exc:
# Handle any exceptions that might occur
logger.exception(f"Error handling websocket: {exc}")

View file

@ -15,9 +15,7 @@ async def process_graph(
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
raise ValueError("There was an error loading the langchain_object. Please, check all the nodes and try again.")
# Generate result and thought
try:

View file

@ -18,9 +18,7 @@ def get_api_keys(session: Session, user_id: UUID) -> List[ApiKeyRead]:
return [ApiKeyRead.from_orm(api_key) for api_key in api_keys]
def create_api_key(
session: Session, api_key_create: ApiKeyCreate, user_id: UUID
) -> UnmaskedApiKeyRead:
def create_api_key(session: Session, api_key_create: ApiKeyCreate, user_id: UUID) -> UnmaskedApiKeyRead:
# Generate a random API key with 32 bytes of randomness
generated_api_key = f"sk-{secrets.token_urlsafe(32)}"

View file

@ -19,9 +19,7 @@ def get_user_by_id(db: Session, id: UUID) -> Union[User, None]:
return db.query(User).filter(User.id == id).first()
def update_user(
user_db: Optional[User], user: UserUpdate, db: Session = Depends(get_session)
) -> User:
def update_user(user_db: Optional[User], user: UserUpdate, db: Session = Depends(get_session)) -> User:
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
@ -37,9 +35,7 @@ def update_user(
changed = True
if not changed:
raise HTTPException(
status_code=status.HTTP_304_NOT_MODIFIED, detail="Nothing to update"
)
raise HTTPException(status_code=status.HTTP_304_NOT_MODIFIED, detail="Nothing to update")
user_db.updated_at = datetime.now(timezone.utc)
flag_modified(user_db, "updated_at")

View file

@ -34,10 +34,7 @@ class DatabaseService(Service):
def _create_engine(self) -> "Engine":
"""Create the engine for the database."""
settings_service = get_settings_service()
if (
settings_service.settings.DATABASE_URL
and settings_service.settings.DATABASE_URL.startswith("sqlite")
):
if settings_service.settings.DATABASE_URL and settings_service.settings.DATABASE_URL.startswith("sqlite"):
connect_args = {"check_same_thread": False}
else:
connect_args = {}
@ -49,9 +46,7 @@ class DatabaseService(Service):
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None: # If an exception has been raised
logger.error(
f"Session rollback because of exception: {exc_type.__name__} {exc_value}"
)
logger.error(f"Session rollback because of exception: {exc_type.__name__} {exc_value}")
self._session.rollback()
else:
self._session.commit()
@ -99,9 +94,7 @@ class DatabaseService(Service):
expected_columns = list(model.__fields__.keys())
try:
available_columns = [
col["name"] for col in inspector.get_columns(table)
]
available_columns = [col["name"] for col in inspector.get_columns(table)]
except sa.exc.NoSuchTableError:
logger.error(f"Missing table: {table}")
return False
@ -153,9 +146,7 @@ class DatabaseService(Service):
try:
command.check(alembic_cfg)
except Exception as exc:
if isinstance(exc, util.exc.CommandError) or isinstance(
exc, util.exc.AutogenerateDiffsDetected
):
if isinstance(exc, util.exc.CommandError) or isinstance(exc, util.exc.AutogenerateDiffsDetected):
command.upgrade(alembic_cfg, "head")
# We should check the schema health after running migrations
@ -174,10 +165,7 @@ class DatabaseService(Service):
# We will check that all models are in the database
# and that the database is up to date with all columns
sql_models = [models.Flow, models.User, models.ApiKey]
return [
TableResults(sql_model.__tablename__, self.check_table(sql_model))
for sql_model in sql_models
]
return [TableResults(sql_model.__tablename__, self.check_table(sql_model)) for sql_model in sql_models]
def check_table(self, model):
results = []
@ -185,9 +173,7 @@ class DatabaseService(Service):
table_name = model.__tablename__
expected_columns = list(model.__fields__.keys())
try:
available_columns = [
col["name"] for col in inspector.get_columns(table_name)
]
available_columns = [col["name"] for col in inspector.get_columns(table_name)]
results.append(Result(name=table_name, type="table", success=True))
except sa.exc.NoSuchTableError:
logger.error(f"Missing table: {table_name}")
@ -218,9 +204,7 @@ class DatabaseService(Service):
try:
table.create(self.engine, checkfirst=True)
except OperationalError as oe:
logger.warning(
f"Table {table} already exists, skipping. Exception: {oe}"
)
logger.warning(f"Table {table} already exists, skipping. Exception: {oe}")
except Exception as exc:
logger.error(f"Error creating table {table}: {exc}")
raise RuntimeError(f"Error creating table {table}") from exc
@ -232,9 +216,7 @@ class DatabaseService(Service):
if table not in table_names:
logger.error("Something went wrong creating the database and tables.")
logger.error("Please check your database settings.")
raise RuntimeError(
"Something went wrong creating the database and tables."
)
raise RuntimeError("Something went wrong creating the database and tables.")
logger.debug("Database and tables created successfully")

View file

@ -13,9 +13,7 @@ def initialize_database():
logger.debug("Initializing database")
from langflow.services import service_manager, ServiceType
database_service: "DatabaseService" = service_manager.get(
ServiceType.DATABASE_SERVICE
)
database_service: "DatabaseService" = service_manager.get(ServiceType.DATABASE_SERVICE)
try:
database_service.create_db_and_tables()
except Exception as exc:
@ -41,9 +39,7 @@ def initialize_database():
# This means there's wrong revision in the DB
# We need to delete the alembic_version table
# and run the migrations again
logger.warning(
"Wrong revision in DB, deleting alembic_version table and running migrations again"
)
logger.warning("Wrong revision in DB, deleting alembic_version table and running migrations again")
with session_getter(database_service) as session:
session.execute("DROP TABLE alembic_version")
database_service.run_migrations()

View file

@ -53,15 +53,10 @@ class ServiceManager:
self._create_service(dependency)
# Collect the dependent services
dependent_services = {
dep.value: self.services[dep]
for dep in self.dependencies.get(service_name, [])
}
dependent_services = {dep.value: self.services[dep] for dep in self.dependencies.get(service_name, [])}
# Create the actual service
self.services[service_name] = self.factories[service_name].create(
**dependent_services
)
self.services[service_name] = self.factories[service_name].create(**dependent_services)
self.services[service_name].set_ready()
def _validate_service_creation(self, service_name: ServiceType):
@ -69,9 +64,7 @@ class ServiceManager:
Validate whether the service can be created.
"""
if service_name not in self.factories:
raise ValueError(
f"No factory registered for the service class '{service_name.name}'"
)
raise ValueError(f"No factory registered for the service class '{service_name.name}'")
def update(self, service_name: ServiceType):
"""
@ -144,9 +137,7 @@ def initialize_session_service():
initialize_settings_service()
service_manager.register_factory(
cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE]
)
service_manager.register_factory(cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE])
service_manager.register_factory(
session_service_factory.SessionServiceFactory(),

View file

@ -23,10 +23,7 @@ class LangfuseInstance:
settings_manager = get_settings_service()
if (
settings_manager.settings.LANGFUSE_PUBLIC_KEY
and settings_manager.settings.LANGFUSE_SECRET_KEY
):
if settings_manager.settings.LANGFUSE_PUBLIC_KEY and settings_manager.settings.LANGFUSE_SECRET_KEY:
logger.debug("Langfuse credentials found")
cls._instance = Langfuse(
public_key=settings_manager.settings.LANGFUSE_PUBLIC_KEY,

View file

@ -3,6 +3,4 @@ import string
def session_id_generator(size=6):
return "".join(
random.SystemRandom().choices(string.ascii_uppercase + string.digits, k=size)
)
return "".join(random.SystemRandom().choices(string.ascii_uppercase + string.digits, k=size))

View file

@ -26,9 +26,7 @@ class AuthSettings(BaseSettings):
REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 12
# API Key to execute /process endpoint
API_KEY_SECRET_KEY: Optional[
str
] = "b82818e0ad4ff76615c5721ee21004b07d84cd9b87ba4d9cb42374da134b841a"
API_KEY_SECRET_KEY: Optional[str] = "b82818e0ad4ff76615c5721ee21004b07d84cd9b87ba4d9cb42374da134b841a"
API_KEY_ALGORITHM: str = "HS256"
API_V1_STR: str = "/api/v1"

View file

@ -83,9 +83,7 @@ class Settings(BaseSettings):
@validator("DATABASE_URL", pre=True)
def set_database_url(cls, value, values):
if not value:
logger.debug(
"No database_url provided, trying LANGFLOW_DATABASE_URL env variable"
)
logger.debug("No database_url provided, trying LANGFLOW_DATABASE_URL env variable")
if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"):
value = langflow_database_url
logger.debug("Using LANGFLOW_DATABASE_URL env variable.")
@ -95,9 +93,7 @@ class Settings(BaseSettings):
# so we need to migrate to the new format
# if there is a database in that location
if not values["CONFIG_DIR"]:
raise ValueError(
"CONFIG_DIR not set, please set it or provide a DATABASE_URL"
)
raise ValueError("CONFIG_DIR not set, please set it or provide a DATABASE_URL")
new_path = f"{values['CONFIG_DIR']}/langflow.db"
if Path("./langflow.db").exists():
@ -121,22 +117,15 @@ class Settings(BaseSettings):
if os.getenv("LANGFLOW_COMPONENTS_PATH"):
logger.debug("Adding LANGFLOW_COMPONENTS_PATH to components_path")
langflow_component_path = os.getenv("LANGFLOW_COMPONENTS_PATH")
if (
Path(langflow_component_path).exists()
and langflow_component_path not in value
):
if Path(langflow_component_path).exists() and langflow_component_path not in value:
if isinstance(langflow_component_path, list):
for path in langflow_component_path:
if path not in value:
value.append(path)
logger.debug(
f"Extending {langflow_component_path} to components_path"
)
logger.debug(f"Extending {langflow_component_path} to components_path")
elif langflow_component_path not in value:
value.append(langflow_component_path)
logger.debug(
f"Appending {langflow_component_path} to components_path"
)
logger.debug(f"Appending {langflow_component_path} to components_path")
if not value:
value = [BASE_COMPONENTS_PATH]

View file

@ -10,6 +10,4 @@ class SettingsServiceFactory(ServiceFactory):
def create(self):
# Here you would have logic to create and configure a SettingsService
langflow_dir = Path(__file__).parent.parent.parent
return SettingsService.load_settings_from_yaml(
str(langflow_dir / "config.yaml")
)
return SettingsService.load_settings_from_yaml(str(langflow_dir / "config.yaml"))

View file

@ -30,9 +30,7 @@ class SettingsService(Service):
for key in settings_dict:
if key not in Settings.__fields__.keys():
raise KeyError(f"Key {key} not found in settings")
logger.debug(
f"Loading {len(settings_dict[key])} {key} from {file_path}"
)
logger.debug(f"Loading {len(settings_dict[key])} {key} from {file_path}")
settings = Settings(**settings_dict)
if not settings.CONFIG_DIR:

View file

@ -14,9 +14,7 @@ def set_secure_permissions(file_path):
import win32security
user, domain, _ = win32security.LookupAccountName("", win32api.GetUserName())
sd = win32security.GetFileSecurity(
file_path, win32security.DACL_SECURITY_INFORMATION
)
sd = win32security.GetFileSecurity(file_path, win32security.DACL_SECURITY_INFORMATION)
dacl = win32security.ACL()
# Set the new DACL for the file: read and write access for the owner, no access for everyone else
@ -26,9 +24,7 @@ def set_secure_permissions(file_path):
user,
)
sd.SetSecurityDescriptorDacl(1, dacl, 0)
win32security.SetFileSecurity(
file_path, win32security.DACL_SECURITY_INFORMATION, sd
)
win32security.SetFileSecurity(file_path, win32security.DACL_SECURITY_INFORMATION, sd)
else:
print("Unsupported OS")

View file

@ -63,9 +63,7 @@ class ListComponentResponse(BaseModel):
if all(["id" in tag and "name" in tag for tag in v]):
return v
else:
return [
TagResponse(**tag.get("tags_id")) for tag in v if tag.get("tags_id")
]
return [TagResponse(**tag.get("tags_id")) for tag in v if tag.get("tags_id")]
class ListComponentResponseModel(BaseModel):

View file

@ -21,9 +21,7 @@ if TYPE_CHECKING:
from contextlib import contextmanager
from contextvars import ContextVar
user_data_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
"user_data", default=None
)
user_data_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar("user_data", default=None)
@contextmanager
@ -31,9 +29,7 @@ def user_data_context(store_service: "StoreService", api_key: Optional[str] = No
# Fetch and set user data to the context variable
if api_key:
try:
user_data = store_service._get(
f"{store_service.base_url}/users/me", api_key, params={"fields": "id"}
)
user_data = store_service._get(f"{store_service.base_url}/users/me", api_key, params={"fields": "id"})
user_data_var.set(user_data)
except HTTPStatusError as exc:
if exc.response.status_code == 403:
@ -77,9 +73,7 @@ class StoreService(Service):
# will make a property return that data
# Without making the request multiple times
def _get(
self, url: str, api_key: str, params: Dict[str, Any] = None
) -> List[Dict[str, Any]]:
def _get(self, url: str, api_key: str, params: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""Utility method to perform GET requests."""
if api_key:
headers = {"Authorization": f"Bearer {api_key}"}
@ -99,9 +93,7 @@ class StoreService(Service):
# For now we are calling it just for testing
try:
headers = {"Authorization": f"Bearer {api_key}"}
response = httpx.post(
webhook_url, headers=headers, json={"component_id": str(component_id)}
)
response = httpx.post(webhook_url, headers=headers, json={"component_id": str(component_id)})
response.raise_for_status()
return response.json()
except HTTPError as exc:
@ -146,9 +138,7 @@ class StoreService(Service):
if tags:
tags_filter = {"tags": {"_and": []}}
for tag in tags:
tags_filter["tags"]["_and"].append(
{"_some": {"tags_id": {"name": {"_eq": tag}}}}
)
tags_filter["tags"]["_and"].append({"_some": {"tags_id": {"name": {"_eq": tag}}}})
filter_conditions.append(tags_filter)
if date_from:
@ -166,13 +156,7 @@ class StoreService(Service):
params["fields"] = ",".join(fields)
if filter_by_user:
params["deep"] = json.dumps(
{
"components": {
"_filter": {"user_created": {"token": {"_eq": api_key}}}
}
}
)
params["deep"] = json.dumps({"components": {"_filter": {"user_created": {"token": {"_eq": api_key}}}}})
else:
# params["filter"] = json.dumps({"status": {"_eq": "public"}})
filter_conditions.append({"status": {"_in": ["public", "Public"]}})
@ -192,13 +176,7 @@ class StoreService(Service):
params = {"aggregate": json.dumps({"count": "*"})}
filter_conditions = [] if filter_conditions is None else filter_conditions
if filter_by_user:
params["deep"] = json.dumps(
{
"components": {
"_filter": {"user_created": {"token": {"_eq": api_key}}}
}
}
)
params["deep"] = json.dumps({"components": {"_filter": {"user_created": {"token": {"_eq": api_key}}}}})
else:
filter_conditions.append({"status": {"_in": ["public", "Public"]}})
@ -250,9 +228,7 @@ class StoreService(Service):
if tags:
tags_filter = {"tags": {"_and": []}}
for tag in tags:
tags_filter["tags"]["_and"].append(
{"_some": {"tags_id": {"name": {"_eq": tag}}}}
)
tags_filter["tags"]["_and"].append({"_some": {"tags_id": {"name": {"_eq": tag}}}})
filter_conditions.append(tags_filter)
if is_component is not None:
@ -286,9 +262,7 @@ class StoreService(Service):
# component.tags = [tags_id.tags_id for tags_id in component.tags]
return results_objects, filter_conditions
def get_liked_by_user_components(
self, component_ids: List[UUID], api_key: str
) -> List[UUID]:
def get_liked_by_user_components(self, component_ids: List[UUID], api_key: str) -> List[UUID]:
# Get fields id
# filter should be "id is in component_ids AND liked_by directus_users_id token is api_key"
# return the ids
@ -310,9 +284,7 @@ class StoreService(Service):
return [result["id"] for result in results]
# Which of the components is parent of the user's components
def get_components_in_users_collection(
self, component_ids: List[UUID], api_key: str
):
def get_components_in_users_collection(self, component_ids: List[UUID], api_key: str):
user_data = user_data_var.get()
if not user_data:
raise ValueError("No user data")
@ -332,18 +304,14 @@ class StoreService(Service):
def download(self, api_key: str, component_id: str) -> DownloadComponentResponse:
url = f"{self.components_url}/{component_id}"
params = {
"fields": ",".join(["id", "name", "description", "data", "is_component"])
}
params = {"fields": ",".join(["id", "name", "description", "data", "is_component"])}
component = self._get(url, api_key, params)
self.call_webhook(api_key, self.download_webhook_url, component_id)
return DownloadComponentResponse(**component)
def upload(
self, api_key: str, component_data: StoreComponentCreate
) -> ComponentResponse:
def upload(self, api_key: str, component_data: StoreComponentCreate) -> ComponentResponse:
headers = {"Authorization": f"Bearer {api_key}"}
component_dict = component_data.dict(exclude_unset=True)
# Parent is a UUID, but the store expects a string
@ -353,9 +321,7 @@ class StoreService(Service):
component_dict = process_tags_for_post(component_dict)
try:
response = httpx.post(
self.components_url, headers=headers, json=component_dict
)
response = httpx.post(self.components_url, headers=headers, json=component_dict)
response.raise_for_status()
component = response.json()["data"]
return ComponentResponse(**component)

View file

@ -10,9 +10,7 @@ class CeleryBackend(TaskBackend):
def __init__(self):
self.celery_app = celery_app
def launch_task(
self, task_func: Callable[..., Any], *args: Any, **kwargs: Any
) -> tuple[str, AsyncResult]:
def launch_task(self, task_func: Callable[..., Any], *args: Any, **kwargs: Any) -> tuple[str, AsyncResult]:
# I need to type the delay method to make it easier
from celery import Task # type: ignore

View file

@ -63,9 +63,7 @@ class TaskService(Service):
result = task.get()
return task.id, result
async def launch_task(
self, task_func: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
async def launch_task(self, task_func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
logger.debug(f"Launching task {task_func} with args {args} and kwargs {kwargs}")
logger.debug(f"Using backend {self.backend}")
task = self.backend.launch_task(task_func, *args, **kwargs)

View file

@ -71,16 +71,12 @@ def get_or_create_super_user(session: Session, username, password, is_default):
)
return None
else:
logger.debug(
"User with superuser credentials exists but is not a superuser."
)
logger.debug("User with superuser credentials exists but is not a superuser.")
return None
if user:
if verify_password(password, user.password):
raise ValueError(
"User with superuser credentials exists but is not a superuser."
)
raise ValueError("User with superuser credentials exists but is not a superuser.")
else:
raise ValueError("Incorrect superuser credentials")
@ -109,21 +105,15 @@ def setup_superuser(settings_service, session: Session):
username = settings_service.auth_settings.SUPERUSER
password = settings_service.auth_settings.SUPERUSER_PASSWORD
is_default = (username == DEFAULT_SUPERUSER) and (
password == DEFAULT_SUPERUSER_PASSWORD
)
is_default = (username == DEFAULT_SUPERUSER) and (password == DEFAULT_SUPERUSER_PASSWORD)
try:
user = get_or_create_super_user(
session=session, username=username, password=password, is_default=is_default
)
user = get_or_create_super_user(session=session, username=username, password=password, is_default=is_default)
if user is not None:
logger.debug("Superuser created successfully.")
except Exception as exc:
logger.exception(exc)
raise RuntimeError(
"Could not create superuser. Please create a superuser manually."
) from exc
raise RuntimeError("Could not create superuser. Please create a superuser manually.") from exc
finally:
settings_service.auth_settings.reset_credentials()
@ -137,9 +127,7 @@ def teardown_superuser(settings_service, session):
if not settings_service.auth_settings.AUTO_LOGIN:
try:
logger.debug(
"AUTO_LOGIN is set to False. Removing default superuser if exists."
)
logger.debug("AUTO_LOGIN is set to False. Removing default superuser if exists.")
username = DEFAULT_SUPERUSER
from langflow.services.database.models.user.user import User
@ -187,9 +175,7 @@ def initialize_session_service():
initialize_settings_service()
service_manager.register_factory(
cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE]
)
service_manager.register_factory(cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE])
service_manager.register_factory(
session_service_factory.SessionServiceFactory(),
@ -206,17 +192,13 @@ def initialize_services():
service_manager.register_factory(factory, dependencies=dependencies)
except Exception as exc:
logger.exception(exc)
raise RuntimeError(
"Could not initialize services. Please check your settings."
) from exc
raise RuntimeError("Could not initialize services. Please check your settings.") from exc
# Test cache connection
service_manager.get(ServiceType.CACHE_SERVICE)
# Setup the superuser
initialize_database()
setup_superuser(
service_manager.get(ServiceType.SETTINGS_SERVICE), next(get_session())
)
setup_superuser(service_manager.get(ServiceType.SETTINGS_SERVICE), next(get_session()))
try:
get_db_service().migrate_flows_if_auto_login()
except Exception as exc:

View file

@ -67,11 +67,7 @@ class FrontendNode(BaseModel):
def process_base_classes(self) -> None:
"""Removes unwanted base classes from the list of base classes."""
self.base_classes = [
base_class
for base_class in self.base_classes
if base_class not in CLASSES_TO_REMOVE
]
self.base_classes = [base_class for base_class in self.base_classes if base_class not in CLASSES_TO_REMOVE]
def to_dict(self) -> dict:
"""Returns a dict representation of the frontend node."""
@ -130,9 +126,7 @@ class FrontendNode(BaseModel):
return _type
@staticmethod
def handle_special_field(
field, key: str, _type: str, SPECIAL_FIELD_HANDLERS
) -> str:
def handle_special_field(field, key: str, _type: str, SPECIAL_FIELD_HANDLERS) -> str:
"""Handles special field by using the respective handler if present."""
handler = SPECIAL_FIELD_HANDLERS.get(key)
return handler(field) if handler else _type
@ -144,11 +138,7 @@ class FrontendNode(BaseModel):
field.field_type = "file"
field.suffixes = [".json", ".yaml", ".yml"]
field.file_types = ["json", "yaml", "yml"]
elif (
_type.startswith("Dict")
or _type.startswith("Mapping")
or _type.startswith("dict")
):
elif _type.startswith("Dict") or _type.startswith("Mapping") or _type.startswith("dict"):
field.field_type = "dict"
return _type
@ -159,9 +149,7 @@ class FrontendNode(BaseModel):
field.value = value["default"]
@staticmethod
def handle_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
def handle_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values for certain fields."""
if key == "headers":
field.value = """{"Authorization": "Bearer <token>"}"""
@ -169,9 +157,7 @@ class FrontendNode(BaseModel):
FrontendNode._handle_api_key_specific_field_values(field, key, name)
@staticmethod
def _handle_model_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
def _handle_model_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values related to models."""
model_dict = {
"OpenAI": constants.OPENAI_MODELS,
@ -184,9 +170,7 @@ class FrontendNode(BaseModel):
field.is_list = True
@staticmethod
def _handle_api_key_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
def _handle_api_key_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values related to API keys."""
if "api_key" in key and "OpenAI" in str(name):
field.display_name = "OpenAI API Key"
@ -225,10 +209,7 @@ class FrontendNode(BaseModel):
@staticmethod
def should_be_password(key: str, show: bool) -> bool:
"""Determines whether the field should be a password field."""
return (
any(text in key.lower() for text in {"password", "token", "api", "key"})
and show
)
return any(text in key.lower() for text in {"password", "token", "api", "key"}) and show
@staticmethod
def should_be_multiline(key: str) -> bool:

View file

@ -133,7 +133,9 @@ class SeriesCharacterChainNode(FrontendNode):
),
],
)
description: str = "SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa
description: str = (
"SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa
)
base_classes: list[str] = [
"LLMChain",
"BaseCustomChain",

View file

@ -3,9 +3,7 @@ from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
def build_file_field(
suffixes: list, fileTypes: list, name: str = "file_path"
) -> TemplateField:
def build_file_field(suffixes: list, fileTypes: list, name: str = "file_path") -> TemplateField:
"""Build a template field for a document loader."""
return TemplateField(
field_type="file",
@ -27,32 +25,22 @@ class DocumentLoaderFrontNode(FrontendNode):
"AirbyteJSONLoader": build_file_field(suffixes=[".json"], fileTypes=["json"]),
"CoNLLULoader": build_file_field(suffixes=[".csv"], fileTypes=["csv"]),
"CSVLoader": build_file_field(suffixes=[".csv"], fileTypes=["csv"]),
"UnstructuredEmailLoader": build_file_field(
suffixes=[".eml"], fileTypes=["eml"]
),
"UnstructuredEmailLoader": build_file_field(suffixes=[".eml"], fileTypes=["eml"]),
"EverNoteLoader": build_file_field(suffixes=[".xml"], fileTypes=["xml"]),
"FacebookChatLoader": build_file_field(suffixes=[".json"], fileTypes=["json"]),
"BSHTMLLoader": build_file_field(suffixes=[".html"], fileTypes=["html"]),
"UnstructuredHTMLLoader": build_file_field(
suffixes=[".html"], fileTypes=["html"]
),
"UnstructuredHTMLLoader": build_file_field(suffixes=[".html"], fileTypes=["html"]),
"UnstructuredImageLoader": build_file_field(
suffixes=[".jpg", ".jpeg", ".png", ".gif", ".bmp"],
fileTypes=["jpg", "jpeg", "png", "gif", "bmp"],
),
"UnstructuredMarkdownLoader": build_file_field(
suffixes=[".md"], fileTypes=["md"]
),
"UnstructuredMarkdownLoader": build_file_field(suffixes=[".md"], fileTypes=["md"]),
"PyPDFLoader": build_file_field(suffixes=[".pdf"], fileTypes=["pdf"]),
"UnstructuredPowerPointLoader": build_file_field(
suffixes=[".pptx", ".ppt"], fileTypes=["pptx", "ppt"]
),
"UnstructuredPowerPointLoader": build_file_field(suffixes=[".pptx", ".ppt"], fileTypes=["pptx", "ppt"]),
"SRTLoader": build_file_field(suffixes=[".srt"], fileTypes=["srt"]),
"TelegramChatLoader": build_file_field(suffixes=[".json"], fileTypes=["json"]),
"TextLoader": build_file_field(suffixes=[".txt"], fileTypes=["txt"]),
"UnstructuredWordDocumentLoader": build_file_field(
suffixes=[".docx", ".doc"], fileTypes=["docx", "doc"]
),
"UnstructuredWordDocumentLoader": build_file_field(suffixes=[".docx", ".doc"], fileTypes=["docx", "doc"]),
}
def add_extra_fields(self) -> None:

View file

@ -70,9 +70,7 @@ class EmbeddingFrontendNode(FrontendNode):
field.advanced = True
split_name = field.name.split("_")
title_name = " ".join([s.capitalize() for s in split_name])
field.display_name = title_name.replace("Openai", "OpenAI").replace(
"Api", "API"
)
field.display_name = title_name.replace("Openai", "OpenAI").replace("Api", "API")
if "api_key" in field.name:
field.password = True

View file

@ -112,10 +112,7 @@ class PasswordFieldFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
key = field.name
show = field.show
if (
any(text in key.lower() for text in {"password", "token", "api", "key"})
and show
):
if any(text in key.lower() for text in {"password", "token", "api", "key"}) and show:
field.password = True
@ -157,9 +154,5 @@ class DictCodeFileFormatter(FieldFormatter):
field.field_type = "file"
field.suffixes = [".json", ".yaml", ".yml"]
field.file_types = ["json", "yaml", "yml"]
elif (
_type.startswith("Dict")
or _type.startswith("Mapping")
or _type.startswith("dict")
):
elif _type.startswith("Dict") or _type.startswith("Mapping") or _type.startswith("dict"):
field.field_type = "dict"

View file

@ -54,9 +54,9 @@ class LLMFrontendNode(FrontendNode):
@staticmethod
def format_openai_field(field: TemplateField):
if "openai" in field.name.lower():
field.display_name = (
field.name.title().replace("Openai", "OpenAI").replace("_", " ")
).replace("Api", "API")
field.display_name = (field.name.title().replace("Openai", "OpenAI").replace("_", " ")).replace(
"Api", "API"
)
if "key" not in field.name.lower() and "token" not in field.name.lower():
field.password = False
@ -109,10 +109,7 @@ class LLMFrontendNode(FrontendNode):
if field.name in SHOW_FIELDS:
field.show = True
if "api" in field.name and (
"key" in field.name
or ("token" in field.name and "tokens" not in field.name)
):
if "api" in field.name and ("key" in field.name or ("token" in field.name and "tokens" not in field.name)):
field.password = True
field.show = True
# Required should be False to support

View file

@ -76,9 +76,7 @@ class MemoryFrontendNode(FrontendNode):
field.show = True
field.advanced = False
field.value = ""
field.info = (
INPUT_KEY_INFO if field.name == "input_key" else OUTPUT_KEY_INFO
)
field.info = INPUT_KEY_INFO if field.name == "input_key" else OUTPUT_KEY_INFO
if field.name == "memory_key":
field.value = "chat_history"

View file

@ -36,10 +36,7 @@ class PromptFrontendNode(FrontendNode):
field.field_type = "prompt"
field.advanced = False
if (
"Union" in field.field_type
and "BaseMessagePromptTemplate" in field.field_type
):
if "Union" in field.field_type and "BaseMessagePromptTemplate" in field.field_type:
field.field_type = "BaseMessagePromptTemplate"
# All prompt fields should be password=False

View file

@ -81,9 +81,7 @@ def build_json(root, graph) -> Dict:
raise ValueError(f"No child with type {node_type} found")
values = [build_json(child, graph) for child in children]
value = (
list(values)
if value["list"]
else next(iter(values), None) # type: ignore
list(values) if value["list"] else next(iter(values), None) # type: ignore
)
final_dict[key] = value

View file

@ -15,12 +15,8 @@ def remove_ansi_escape_codes(text):
return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text)
def build_template_from_function(
name: str, type_to_loader_dict: Dict, add_function: bool = False
):
classes = [
item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()
]
def build_template_from_function(name: str, type_to_loader_dict: Dict, add_function: bool = False):
classes = [item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()]
# Raise error if name is not in chains
if name not in classes:
@ -41,9 +37,7 @@ def build_template_from_function(
for name_, value_ in value.__repr_args__():
if name_ == "default_factory":
try:
variables[class_field_items][
"default"
] = get_default_factory(
variables[class_field_items]["default"] = get_default_factory(
module=_class.__base__.__module__, function=value_
)
except Exception:
@ -52,9 +46,7 @@ def build_template_from_function(
variables[class_field_items][name_] = value_
variables[class_field_items]["placeholder"] = (
docs.params[class_field_items]
if class_field_items in docs.params
else ""
docs.params[class_field_items] if class_field_items in docs.params else ""
)
# Adding function to base classes to allow
# the output to be a function
@ -69,9 +61,7 @@ def build_template_from_function(
}
def build_template_from_class(
name: str, type_to_cls_dict: Dict, add_function: bool = False
):
def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False):
classes = [item.__name__ for item in type_to_cls_dict.values()]
# Raise error if name is not in chains
@ -95,9 +85,7 @@ def build_template_from_class(
for name_, value_ in value.__repr_args__():
if name_ == "default_factory":
try:
variables[class_field_items][
"default"
] = get_default_factory(
variables[class_field_items]["default"] = get_default_factory(
module=_class.__base__.__module__, function=value_
)
except Exception:
@ -106,9 +94,7 @@ def build_template_from_class(
variables[class_field_items][name_] = value_
variables[class_field_items]["placeholder"] = (
docs.params[class_field_items]
if class_field_items in docs.params
else ""
docs.params[class_field_items] if class_field_items in docs.params else ""
)
base_classes = get_base_classes(_class)
# Adding function to base classes to allow
@ -140,9 +126,7 @@ def build_template_from_method(
# Check if the method exists in this class
if not hasattr(_class, method_name):
raise ValueError(
f"Method {method_name} not found in class {class_name}"
)
raise ValueError(f"Method {method_name} not found in class {class_name}")
# Get the method
method = getattr(_class, method_name)
@ -161,12 +145,8 @@ def build_template_from_method(
"_type": _type,
**{
name: {
"default": param.default
if param.default != param.empty
else None,
"type": param.annotation
if param.annotation != param.empty
else None,
"default": param.default if param.default != param.empty else None,
"type": param.annotation if param.annotation != param.empty else None,
"required": param.default == param.empty,
}
for name, param in params.items()
@ -253,9 +233,7 @@ def sync_to_async(func):
return async_wrapper
def format_dict(
dictionary: Dict[str, Any], class_name: Optional[str] = None
) -> Dict[str, Any]:
def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) -> Dict[str, Any]:
"""
Formats a dictionary by removing certain keys and modifying the
values of other keys.
@ -330,9 +308,7 @@ def check_list_type(_type: str, value: Dict[str, Any]) -> str:
The modified type string.
"""
if any(list_type in _type for list_type in ["List", "Sequence", "Set"]):
_type = (
_type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
)
_type = _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
value["list"] = True
else:
value["list"] = False
@ -436,9 +412,7 @@ def set_headers_value(value: Dict[str, Any]) -> None:
value["value"] = """{"Authorization": "Bearer <token>"}"""
def add_options_to_field(
value: Dict[str, Any], class_name: Optional[str], key: str
) -> None:
def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: str) -> None:
"""
Adds options to the field based on the class name and key.
"""

View file

@ -41,9 +41,7 @@ def validate_code(code):
# Evaluate the function definition
for node in tree.body:
if isinstance(node, ast.FunctionDef):
code_obj = compile(
ast.Module(body=[node], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[node], type_ignores=[]), "<string>", "exec")
try:
exec(code_obj)
except Exception as e:
@ -63,8 +61,7 @@ def eval_function(function_string: str):
(
obj
for name, obj in namespace.items()
if isinstance(obj, types.FunctionType)
and obj.__code__.co_filename == "<string>"
if isinstance(obj, types.FunctionType) and obj.__code__.co_filename == "<string>"
),
None,
)
@ -88,23 +85,15 @@ def execute_function(code, function_name, *args, **kwargs):
exec_globals,
locals(),
)
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
function_code = next(
node
for node in module.body
if isinstance(node, ast.FunctionDef) and node.name == function_name
node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name
)
function_code.parent = None
code_obj = compile(
ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec")
try:
exec(code_obj, exec_globals, locals())
except Exception as exc:
@ -131,23 +120,15 @@ def create_function(code, function_name):
if isinstance(node, ast.Import):
for alias in node.names:
try:
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
function_code = next(
node
for node in module.body
if isinstance(node, ast.FunctionDef) and node.name == function_name
node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name
)
function_code.parent = None
code_obj = compile(
ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec")
with contextlib.suppress(Exception):
exec(code_obj, exec_globals, locals())
exec_globals[function_name] = locals()[function_name]
@ -178,32 +159,20 @@ def create_class(code, class_name):
if isinstance(node, ast.Import):
for alias in node.names:
try:
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
elif isinstance(node, ast.ImportFrom):
try:
imported_module = importlib.import_module(node.module)
for alias in node.names:
exec_globals[alias.name] = getattr(imported_module, alias.name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {node.module} not found. Please install it and try again."
) from e
raise ModuleNotFoundError(f"Module {node.module} not found. Please install it and try again.") from e
class_code = next(
node
for node in module.body
if isinstance(node, ast.ClassDef) and node.name == class_name
)
class_code = next(node for node in module.body if isinstance(node, ast.ClassDef) and node.name == class_name)
class_code.parent = None
code_obj = compile(
ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec")
# This suppresses import errors
# with contextlib.suppress(Exception):
exec(code_obj, exec_globals, locals())

View file

@ -30,9 +30,7 @@ def build_vertex(self, vertex: "Vertex") -> "Vertex":
vertex.build()
return vertex
except SoftTimeLimitExceeded as e:
raise self.retry(
exc=SoftTimeLimitExceeded("Task took too long"), countdown=2
) from e
raise self.retry(exc=SoftTimeLimitExceeded("Task took too long"), countdown=2) from e
@celery_app.task(acks_late=True)
@ -47,9 +45,7 @@ def process_graph_cached_task(
if clear_cache:
session_service.clear_session(session_id)
if session_id is None:
session_id = session_service.generate_key(
session_id=session_id, data_graph=data_graph
)
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
# Load the graph using SessionService
graph, artifacts = session_service.load_session(session_id, data_graph)
built_object = graph.build()

Some files were not shown because too many files have changed in this diff Show more