Format py

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-06 14:54:14 -03:00
commit 2450161fc1
68 changed files with 408 additions and 972 deletions

View file

@ -86,9 +86,7 @@ def validate_prompt(template: str):
# Check if there are invalid characters in the input_variables
input_variables = check_input_variables(input_variables)
if any(var in INVALID_NAMES for var in input_variables):
raise ValueError(
f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. "
)
raise ValueError(f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. ")
try:
PromptTemplate(template=template, input_variables=input_variables)
@ -123,9 +121,7 @@ def fix_variable(var, invalid_chars, wrong_variables):
# Handle variables starting with a number
if var[0].isdigit():
invalid_chars.append(var[0])
new_var, invalid_chars, wrong_variables = fix_variable(
var[1:], invalid_chars, wrong_variables
)
new_var, invalid_chars, wrong_variables = fix_variable(var[1:], invalid_chars, wrong_variables)
# Temporarily replace {{ and }} to avoid treating them as invalid
new_var = new_var.replace("{{", "ᴛᴇᴍᴘᴏᴘᴇɴ").replace("}}", "ᴛᴇᴍᴘʟsᴇ")
@ -152,9 +148,7 @@ def check_variable(var, invalid_chars, wrong_variables, empty_variables):
return wrong_variables, empty_variables
def check_for_errors(
input_variables, fixed_variables, wrong_variables, empty_variables
):
def check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables):
if any(var for var in input_variables if var not in fixed_variables):
error_message = (
f"Error: Input variables contain invalid characters or formats. \n"
@ -179,17 +173,11 @@ def check_input_variables(input_variables):
if is_json_like(var):
continue
new_var, wrong_variables, empty_variables = fix_variable(
var, invalid_chars, wrong_variables
)
wrong_variables, empty_variables = check_variable(
var, INVALID_CHARACTERS, wrong_variables, empty_variables
)
new_var, wrong_variables, empty_variables = fix_variable(var, invalid_chars, wrong_variables)
wrong_variables, empty_variables = check_variable(var, INVALID_CHARACTERS, wrong_variables, empty_variables)
fixed_variables.append(new_var)
variables_to_check.append(var)
check_for_errors(
variables_to_check, fixed_variables, wrong_variables, empty_variables
)
check_for_errors(variables_to_check, fixed_variables, wrong_variables, empty_variables)
return fixed_variables

View file

@ -33,9 +33,7 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
async def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
"""Run when tool starts running."""
resp = ChatResponse(
message="",
@ -73,9 +71,7 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
try:
# This is to emulate the stream of tokens
for resp in resps:
await self.socketio_service.emit_token(
to=self.sid, data=resp.model_dump()
)
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
except Exception as exc:
logger.error(f"Error sending response: {exc}")
@ -101,9 +97,7 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
resp = PromptResponse(
prompt=text,
)
await self.socketio_service.emit_message(
to=self.sid, data=resp.model_dump()
)
await self.socketio_service.emit_message(to=self.sid, data=resp.model_dump())
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
log = f"Thought: {action.log}"
@ -113,9 +107,7 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
logs = log.split("\n")
for log in logs:
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
await self.socketio_service.emit_token(
to=self.sid, data=resp.model_dump()
)
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
else:
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())

View file

@ -101,12 +101,8 @@ async def build_vertex(
cache = chat_service.get_cache(flow_id)
if not cache:
# If there's no cache
logger.warning(
f"No cache found for {flow_id}. Building graph starting at {vertex_id}"
)
graph = build_and_cache_graph(
flow_id=flow_id, session=next(get_session()), chat_service=chat_service
)
logger.warning(f"No cache found for {flow_id}. Building graph starting at {vertex_id}")
graph = build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service)
else:
graph = cache.get("result")
result_data_response = ResultDataResponse(results={})
@ -126,9 +122,7 @@ async def build_vertex(
else:
raise ValueError(f"No result found for vertex {vertex_id}")
next_vertices_ids = vertex.successors_ids
next_vertices_ids = [
v for v in next_vertices_ids if graph.should_run_vertex(v)
]
next_vertices_ids = [v for v in next_vertices_ids if graph.should_run_vertex(v)]
result_data_response = ResultDataResponse(**result_dict.model_dump())
@ -211,9 +205,7 @@ async def build_vertex_stream(
else:
graph = cache.get("result")
else:
session_data = await session_service.load_session(
session_id, flow_id=flow_id
)
session_data = await session_service.load_session(session_id, flow_id=flow_id)
graph, artifacts = session_data if session_data else (None, None)
if not graph:
raise ValueError(f"No graph found for {flow_id}.")

View file

@ -52,9 +52,7 @@ def get_all(
raise HTTPException(status_code=500, detail=str(exc)) from exc
@router.post(
"/run/{flow_id}", response_model=RunResponse, response_model_exclude_none=True
)
@router.post("/run/{flow_id}", response_model=RunResponse, response_model_exclude_none=True)
async def run_flow_with_caching(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
@ -113,9 +111,7 @@ async def run_flow_with_caching(
outputs = []
if session_id:
session_data = await session_service.load_session(
session_id, flow_id=flow_id
)
session_data = await session_service.load_session(session_id, flow_id=flow_id)
graph, artifacts = session_data if session_data else (None, None)
task_result: Any = None
if not graph:
@ -134,11 +130,7 @@ async def run_flow_with_caching(
else:
# Get the flow that matches the flow_id and belongs to the user
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
flow = session.exec(
select(Flow)
.where(Flow.id == flow_id)
.where(Flow.user_id == api_key_user.id)
).first()
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
@ -162,18 +154,12 @@ async def run_flow_with_caching(
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
# This means the Flow ID is not a valid UUID which means it can't find the flow
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
except ValueError as exc:
if f"Flow {flow_id} not found" in str(exc):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
@router.post(
@ -202,8 +188,7 @@ async def process(
"""
# Raise a depreciation warning
logger.warning(
"The /process endpoint is deprecated and will be removed in a future version. "
"Please use /run instead."
"The /process endpoint is deprecated and will be removed in a future version. " "Please use /run instead."
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -275,16 +260,12 @@ async def custom_component(
built_frontend_node = build_custom_component_template(component, user_id=user.id)
built_frontend_node = update_frontend_node_with_template_values(
built_frontend_node, raw_code.frontend_node
)
built_frontend_node = update_frontend_node_with_template_values(built_frontend_node, raw_code.frontend_node)
return built_frontend_node
@router.post("/custom_component/reload", status_code=HTTPStatus.OK)
async def reload_custom_component(
path: str, user: User = Depends(get_current_active_user)
):
async def reload_custom_component(path: str, user: User = Depends(get_current_active_user)):
from langflow.interface.custom.utils import build_custom_component_template
try:

View file

@ -57,11 +57,7 @@ def read_flows(
try:
auth_settings = settings_service.auth_settings
if auth_settings.AUTO_LOGIN:
flows = session.exec(
select(Flow).where(
Flow.user_id == None | Flow.user_id == current_user.id
)
).all() # noqa
flows = session.exec(select(Flow).where(Flow.user_id == None | Flow.user_id == current_user.id)).all() # noqa
else:
flows = current_user.flows
@ -71,7 +67,8 @@ def read_flows(
try:
example_flows = session.exec(
select(Flow).where(
Flow.user_id == None, Flow.folder == STARTER_FOLDER_NAME # noqa
Flow.user_id == None,
Flow.folder == STARTER_FOLDER_NAME, # noqa
)
).all() # noqa
for example_flow in example_flows:
@ -98,7 +95,9 @@ def read_flow(
if auth_settings.AUTO_LOGIN:
# If auto login is enable user_id can be current_user.id or None
# so write an OR
stmt = stmt.where((Flow.user_id == current_user.id) | (Flow.user_id == None))
stmt = stmt.where(
(Flow.user_id == current_user.id) | (Flow.user_id == None) # noqa
) # noqa
if user_flow := session.exec(stmt).first():
return user_flow
else:

View file

@ -158,9 +158,7 @@ class StreamData(BaseModel):
data: dict
def __str__(self) -> str:
return (
f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n"
)
return f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n"
class CustomComponentCode(BaseModel):

View file

@ -41,9 +41,7 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
add_new_variables_to_template(input_variables, prompt_request)
remove_old_variables_from_template(
old_custom_fields, input_variables, prompt_request
)
remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request)
update_input_variables_field(input_variables, prompt_request)
@ -58,19 +56,12 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
def get_old_custom_fields(prompt_request):
try:
if (
len(prompt_request.frontend_node.custom_fields) == 1
and prompt_request.name == ""
):
if len(prompt_request.frontend_node.custom_fields) == 1 and prompt_request.name == "":
# If there is only one custom field and the name is empty string
# then we are dealing with the first prompt request after the node was created
prompt_request.name = list(
prompt_request.frontend_node.custom_fields.keys()
)[0]
prompt_request.name = list(prompt_request.frontend_node.custom_fields.keys())[0]
old_custom_fields = prompt_request.frontend_node.custom_fields[
prompt_request.name
]
old_custom_fields = prompt_request.frontend_node.custom_fields[prompt_request.name]
if old_custom_fields is None:
old_custom_fields = []
@ -87,40 +78,26 @@ def add_new_variables_to_template(input_variables, prompt_request):
template_field = DefaultPromptField(name=variable, display_name=variable)
if variable in prompt_request.frontend_node.template:
# Set the new field with the old value
template_field.value = prompt_request.frontend_node.template[variable][
"value"
]
template_field.value = prompt_request.frontend_node.template[variable]["value"]
prompt_request.frontend_node.template[variable] = template_field.to_dict()
# Check if variable is not already in the list before appending
if (
variable
not in prompt_request.frontend_node.custom_fields[prompt_request.name]
):
prompt_request.frontend_node.custom_fields[prompt_request.name].append(
variable
)
if variable not in prompt_request.frontend_node.custom_fields[prompt_request.name]:
prompt_request.frontend_node.custom_fields[prompt_request.name].append(variable)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def remove_old_variables_from_template(
old_custom_fields, input_variables, prompt_request
):
def remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request):
for variable in old_custom_fields:
if variable not in input_variables:
try:
# Remove the variable from custom_fields associated with the given name
if (
variable
in prompt_request.frontend_node.custom_fields[prompt_request.name]
):
prompt_request.frontend_node.custom_fields[
prompt_request.name
].remove(variable)
if variable in prompt_request.frontend_node.custom_fields[prompt_request.name]:
prompt_request.frontend_node.custom_fields[prompt_request.name].remove(variable)
# Remove the variable from the template
prompt_request.frontend_node.template.pop(variable, None)
@ -132,6 +109,4 @@ def remove_old_variables_from_template(
def update_input_variables_field(input_variables, prompt_request):
if "input_variables" in prompt_request.frontend_node.template:
prompt_request.frontend_node.template["input_variables"][
"value"
] = input_variables
prompt_request.frontend_node.template["input_variables"]["value"] = input_variables

View file

@ -35,9 +35,7 @@ def retrieve_file_paths(
glob = "**/*" if recursive else "*"
paths = walk_level(path_obj, depth) if depth else path_obj.glob(glob)
file_paths = [
Text(p) for p in paths if p.is_file() and match_types(p) and is_not_hidden(p)
]
file_paths = [Text(p) for p in paths if p.is_file() and match_types(p) and is_not_hidden(p)]
return file_paths
@ -70,16 +68,12 @@ def get_elements(
if use_multithreading:
records = parallel_load_records(file_paths, silent_errors, max_concurrency)
else:
records = [
parse_file_to_record(file_path, silent_errors) for file_path in file_paths
]
records = [parse_file_to_record(file_path, silent_errors) for file_path in file_paths]
records = list(filter(None, records))
return records
def parallel_load_records(
file_paths: List[str], silent_errors: bool, max_concurrency: int
) -> List[Optional[Record]]:
def parallel_load_records(file_paths: List[str], silent_errors: bool, max_concurrency: int) -> List[Optional[Record]]:
with futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
loaded_files = executor.map(
lambda file_path: parse_file_to_record(file_path, silent_errors),

View file

@ -45,9 +45,7 @@ class ChatComponent(CustomComponent):
return []
if not session_id or not sender or not sender_name:
raise ValueError(
"All of session_id, sender, and sender_name must be provided."
)
raise ValueError("All of session_id, sender, and sender_name must be provided.")
if isinstance(message, Record):
record = message
record.data.update(

View file

@ -27,12 +27,12 @@ class APIRequest(CustomComponent):
"headers": {
"display_name": "Headers",
"info": "The headers to send with the request.",
"input_types": ["dict"]
"input_types": ["dict"],
},
"body": {
"display_name": "Body",
"info": "The body to send with the request (for POST, PATCH, PUT).",
"input_types": ["dict"]
"input_types": ["dict"],
},
"timeout": {
"display_name": "Timeout",
@ -58,9 +58,7 @@ class APIRequest(CustomComponent):
data = body if body else None
data = json.dumps(data)
try:
response = await client.request(
method, url, headers=headers, content=data, timeout=timeout
)
response = await client.request(method, url, headers=headers, content=data, timeout=timeout)
try:
response_json = response.json()
result = orjson_dumps(response_json, indent_2=False)
@ -96,16 +94,9 @@ class APIRequest(CustomComponent):
if headers is None:
headers = {}
urls = url if isinstance(url, list) else [url]
bodies = (
body
if isinstance(body, list)
else [body] if body else [None] * len(urls)
)
bodies = body if isinstance(body, list) else [body] if body else [None] * len(urls)
async with httpx.AsyncClient() as client:
results = await asyncio.gather(
*[
self.make_request(client, method, u, headers, rec, timeout)
for u, rec in zip(urls, bodies)
]
*[self.make_request(client, method, u, headers, rec, timeout) for u, rec in zip(urls, bodies)]
)
return results

View file

@ -57,20 +57,13 @@ class DirectoryComponent(CustomComponent):
if types is None:
types = []
resolved_path = self.resolve_path(path)
file_paths = retrieve_file_paths(
resolved_path, types, load_hidden, recursive, depth
)
file_paths = retrieve_file_paths(resolved_path, types, load_hidden, recursive, depth)
loaded_records = []
if use_multithreading:
loaded_records = parallel_load_records(
file_paths, silent_errors, max_concurrency
)
loaded_records = parallel_load_records(file_paths, silent_errors, max_concurrency)
else:
loaded_records = [
parse_file_to_record(file_path, silent_errors)
for file_path in file_paths
]
loaded_records = [parse_file_to_record(file_path, silent_errors) for file_path in file_paths]
loaded_records = list(filter(None, loaded_records))
self.status = loaded_records
return loaded_records

View file

@ -11,9 +11,7 @@ class FileLoaderComponent(CustomComponent):
beta = True
def build_config(self):
loader_options = ["Automatic"] + [
loader_info["name"] for loader_info in LOADERS_INFO
]
loader_options = ["Automatic"] + [loader_info["name"] for loader_info in LOADERS_INFO]
file_types = []
suffixes = []
@ -105,9 +103,7 @@ class FileLoaderComponent(CustomComponent):
if isinstance(selected_loader_info, dict):
loader_import: str = selected_loader_info["import"]
else:
raise ValueError(
f"Loader info for {loader} is not a dict\nLoader info:\n{selected_loader_info}"
)
raise ValueError(f"Loader info for {loader} is not a dict\nLoader info:\n{selected_loader_info}")
module_name, class_name = loader_import.rsplit(".", 1)
try:
@ -115,9 +111,7 @@ class FileLoaderComponent(CustomComponent):
loader_module = __import__(module_name, fromlist=[class_name])
loader_instance = getattr(loader_module, class_name)
except ImportError as e:
raise ValueError(
f"Loader {loader} could not be imported\nLoader info:\n{selected_loader_info}"
) from e
raise ValueError(f"Loader {loader} could not be imported\nLoader info:\n{selected_loader_info}") from e
result = loader_instance(file_path=file_path)
docs = result.load()

View file

@ -19,7 +19,6 @@ class URLComponent(CustomComponent):
self,
urls: list[str],
) -> Record:
loader = WebBaseLoader(web_paths=urls)
docs = loader.load()
records = self.to_records(docs)

View file

@ -18,9 +18,7 @@ class NotifyComponent(CustomComponent):
},
}
def build(
self, name: str, record: Optional[Record] = None, append: bool = False
) -> Record:
def build(self, name: str, record: Optional[Record] = None, append: bool = False) -> Record:
if record and not isinstance(record, Record):
if isinstance(record, str):
record = Record(text=record)

View file

@ -39,10 +39,7 @@ class RunFlowComponent(CustomComponent):
records.append(record)
return records
async def build(
self, input_value: Text, flow_name: str, tweaks: NestedDict
) -> Record:
async def build(self, input_value: Text, flow_name: str, tweaks: NestedDict) -> Record:
results: List[Optional[ResultData]] = await self.run_flow(
input_value=input_value, flow_name=flow_name, tweaks=tweaks
)

View file

@ -11,7 +11,6 @@ class ExtractKeyFromRecordComponent(CustomComponent):
}
def build(self, record: Record, key: str, silent_error: bool = True) -> dict:
data = getattr(record, key)
self.status = data
return data

View file

@ -9,9 +9,7 @@ class UUIDGeneratorComponent(CustomComponent):
display_name = "Unique ID Generator"
description = "Generates a unique ID."
def update_build_config(
self, build_config: dict, field_name: Text, field_value: Any
):
def update_build_config(self, build_config: dict, field_name: Text, field_value: Any):
if field_name == "unique_id":
build_config[field_name]["value"] = str(uuid.uuid4())
return build_config

View file

@ -0,0 +1,47 @@
from typing import Any
from langflow import CustomComponent
from langflow.schema import Record
from langflow.template.field.base import TemplateField
class RecordComponent(CustomComponent):
display_name = "Record Numbers"
description = "A component to create a record from key-value pairs."
field_order = ["n_keys"]
def update_build_config(self, build_config: dict, field_name: str, field_value: Any):
if field_value is None:
return
elif int(field_value) == 0:
keep = ["n_keys", "code"]
for key in build_config.copy():
if key in keep:
continue
del build_config[key]
build_config[field_name]["value"] = int(field_value)
# Add new fields depending on the field value
for i in range(int(field_value)):
field = TemplateField(
name=f"Key and Value {i}",
field_type="dict",
display_name="",
info="The key for the record.",
input_types=["Text"],
)
build_config[field.name] = field.to_dict()
def build_config(self):
return {
"n_keys": {
"display_name": "Number of Fields",
"refresh": True,
"info": "The number of keys to create in the record.",
},
}
def build(self, n_keys: int, **kwargs) -> Record:
data = {k: v for d in kwargs.values() for k, v in d.items()}
record = Record(data=data)
return record

View file

@ -0,0 +1,51 @@
from typing import Any, List
from langflow import CustomComponent
from langflow.schema import Record
from langflow.template.field.base import TemplateField
class RecordComponent2(CustomComponent):
display_name = "Record Text"
description = "A component to create a record from key-value pairs."
field_order = ["keys"]
def update_build_config(self, build_config: dict, field_name: str, field_value: Any):
if field_value is None:
field_value = []
if field_name is None:
return build_config
elif len(field_value) == 0:
keep = ["keys", "code"]
for key in build_config.copy():
if key in keep:
continue
del build_config[key]
build_config[field_name]["value"] = field_value
# Add new fields depending on the field value
for val in field_value:
if not isinstance(val, str) or val == "":
continue
field = TemplateField(
name=val,
field_type="str",
display_name="",
info="The key for the record.",
)
build_config[field.name] = field.to_dict()
def build_config(self):
return {
"keys": {
"display_name": "Keys",
"refresh": True,
"info": "The number of keys to create in the record.",
"input_types": [],
},
}
def build(self, keys: List[str], **kwargs) -> Record:
record = Record(data=kwargs)
self.status = record
return record

View file

@ -9,9 +9,7 @@ from langflow.schema.schema import Record
class LanguageRecursiveTextSplitterComponent(CustomComponent):
display_name: str = "Language Recursive Text Splitter"
description: str = "Split text into chunks of a specified length based on language."
documentation: str = (
"https://docs.langflow.org/components/text-splitters#languagerecursivetextsplitter"
)
documentation: str = "https://docs.langflow.org/components/text-splitters#languagerecursivetextsplitter"
def build_config(self):
options = [x.value for x in Language]

View file

@ -11,9 +11,7 @@ from langflow.utils.util import build_loader_repr_from_documents
class RecursiveCharacterTextSplitterComponent(CustomComponent):
display_name: str = "Recursive Character Text Splitter"
description: str = "Split text into chunks of a specified length."
documentation: str = (
"https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter"
)
documentation: str = "https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter"
def build_config(self):
return {

View file

@ -85,8 +85,7 @@ class ChromaComponent(CustomComponent):
if chroma_server_host is not None:
chroma_settings = chromadb.config.Settings(
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
or None,
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
chroma_server_host=chroma_server_host,
chroma_server_port=chroma_server_port or None,
chroma_server_grpc_port=chroma_server_grpc_port or None,
@ -107,9 +106,7 @@ class ChromaComponent(CustomComponent):
documents.append(_input)
if documents is not None and embedding is not None:
if len(documents) == 0:
raise ValueError(
"If documents are provided, there must be at least one document."
)
raise ValueError("If documents are provided, there must be at least one document.")
chroma = Chroma.from_documents(
documents=documents, # type: ignore
persist_directory=index_directory,

View file

@ -92,8 +92,7 @@ class ChromaSearchComponent(LCVectorStoreComponent):
if chroma_server_host is not None:
chroma_settings = chromadb.config.Settings(
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
or None,
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
chroma_server_host=chroma_server_host,
chroma_server_port=chroma_server_port or None,
chroma_server_grpc_port=chroma_server_grpc_port or None,

View file

@ -33,9 +33,7 @@ class FAISSSearchComponent(LCVectorStoreComponent):
if not folder_path:
raise ValueError("Folder path is required to save the FAISS index.")
path = self.resolve_path(folder_path)
vector_store = FAISS.load_local(
folder_path=Text(path), embeddings=embedding, index_name=index_name
)
vector_store = FAISS.load_local(folder_path=Text(path), embeddings=embedding, index_name=index_name)
if not vector_store:
raise ValueError("Failed to load the FAISS index.")

View file

@ -9,9 +9,7 @@ from langflow.schema.schema import Record
class MongoDBAtlasComponent(CustomComponent):
display_name = "MongoDB Atlas"
description = (
"Construct a `MongoDB Atlas Vector Search` vector store from raw documents."
)
description = "Construct a `MongoDB Atlas Vector Search` vector store from raw documents."
icon = "MongoDB"
def build_config(self):
@ -39,9 +37,7 @@ class MongoDBAtlasComponent(CustomComponent):
try:
from pymongo import MongoClient
except ImportError:
raise ImportError(
"Please install pymongo to use MongoDB Atlas Vector Store"
)
raise ImportError("Please install pymongo to use MongoDB Atlas Vector Store")
try:
mongo_client: MongoClient = MongoClient(mongodb_atlas_cluster_uri)
collection = mongo_client[db_name][collection_name]

View file

@ -67,9 +67,7 @@ class RedisComponent(CustomComponent):
documents.append(_input)
if not documents:
if schema is None:
raise ValueError(
"If no documents are provided, a schema must be provided."
)
raise ValueError("If no documents are provided, a schema must be provided.")
redis_vs = Redis.from_existing_index(
embedding=embedding,
index_name=redis_index_name,

View file

@ -33,7 +33,6 @@ class RedisSearchComponent(RedisComponent, LCVectorStoreComponent):
"input_value": {"display_name": "Input"},
"index_name": {"display_name": "Index Name", "value": "your_index"},
"code": {"show": False, "display_name": "Code"},
"embedding": {"display_name": "Embedding"},
"schema": {"display_name": "Schema", "file_types": [".yaml"]},
"redis_server_url": {

View file

@ -35,9 +35,7 @@ class SupabaseComponent(CustomComponent):
supabase_url: str = "",
table_name: str = "",
) -> Union[VectorStore, SupabaseVectorStore, BaseRetriever]:
supabase: Client = create_client(
supabase_url, supabase_key=supabase_service_key
)
supabase: Client = create_client(supabase_url, supabase_key=supabase_service_key)
documents = []
for _input in inputs:
if isinstance(_input, Record):

View file

@ -38,9 +38,7 @@ class SupabaseSearchComponent(LCVectorStoreComponent):
supabase_url: str = "",
table_name: str = "",
) -> List[Record]:
supabase: Client = create_client(
supabase_url, supabase_key=supabase_service_key
)
supabase: Client = create_client(supabase_url, supabase_key=supabase_service_key)
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embedding,

View file

@ -15,9 +15,7 @@ from langflow.schema.schema import Record
class VectaraComponent(CustomComponent):
display_name: str = "Vectara"
description: str = "Implementation of Vector Store using Vectara"
documentation = (
"https://python.langchain.com/docs/integrations/vectorstores/vectara"
)
documentation = "https://python.langchain.com/docs/integrations/vectorstores/vectara"
beta = True
icon = "Vectara"
field_config = {

View file

@ -11,9 +11,7 @@ from langflow.schema import Record
class VectaraSearchComponent(VectaraComponent, LCVectorStoreComponent):
display_name: str = "Vectara Search"
description: str = "Search a Vectara Vector Store for similar documents."
documentation = (
"https://python.langchain.com/docs/integrations/vectorstores/vectara"
)
documentation = "https://python.langchain.com/docs/integrations/vectorstores/vectara"
beta = True
icon = "Vectara"

View file

@ -12,9 +12,7 @@ from langflow.schema.schema import Record
class WeaviateVectorStoreComponent(CustomComponent):
display_name: str = "Weaviate"
description: str = "Implementation of Vector Store using Weaviate"
documentation = (
"https://python.langchain.com/docs/integrations/vectorstores/weaviate"
)
documentation = "https://python.langchain.com/docs/integrations/vectorstores/weaviate"
beta = True
field_config = {
"url": {"display_name": "Weaviate URL", "value": "http://localhost:8080"},

View file

@ -11,9 +11,7 @@ from langflow.schema import Record
class WeaviateSearchVectorStore(WeaviateVectorStoreComponent, LCVectorStoreComponent):
display_name: str = "Weaviate Search"
description: str = "Search a Weaviate Vector Store for similar documents."
documentation = (
"https://python.langchain.com/docs/integrations/vectorstores/weaviate"
)
documentation = "https://python.langchain.com/docs/integrations/vectorstores/weaviate"
beta = True
icon = "Weaviate"

View file

@ -37,14 +37,8 @@ class LCVectorStoreComponent(CustomComponent):
"""
docs: List[Document] = []
if (
input_value
and isinstance(input_value, str)
and hasattr(vector_store, "search")
):
docs = vector_store.search(
query=input_value, search_type=search_type.lower()
)
if input_value and isinstance(input_value, str) and hasattr(vector_store, "search"):
docs = vector_store.search(query=input_value, search_type=search_type.lower())
else:
raise ValueError("Invalid inputs provided.")
return docs_to_records(docs)

View file

@ -16,9 +16,7 @@ class PGVectorComponent(CustomComponent):
display_name: str = "PGVector"
description: str = "Implementation of Vector Store using PostgreSQL"
documentation = (
"https://python.langchain.com/docs/integrations/vectorstores/pgvector"
)
documentation = "https://python.langchain.com/docs/integrations/vectorstores/pgvector"
def build_config(self):
"""

View file

@ -15,9 +15,7 @@ class PGVectorSearchComponent(PGVectorComponent, LCVectorStoreComponent):
display_name: str = "PGVector Search"
description: str = "Search a PGVector Store for similar documents."
documentation = (
"https://python.langchain.com/docs/integrations/vectorstores/pgvector"
)
documentation = "https://python.langchain.com/docs/integrations/vectorstores/pgvector"
def build_config(self):
"""

View file

@ -12,9 +12,7 @@ if TYPE_CHECKING:
class SourceHandle(BaseModel):
baseClasses: List[str] = Field(
..., description="List of base classes for the source handle."
)
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
@ -22,9 +20,7 @@ class SourceHandle(BaseModel):
class TargetHandle(BaseModel):
fieldName: str = Field(..., description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
inputTypes: Optional[List[str]] = Field(
None, description="List of input types for the target handle."
)
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
type: str = Field(..., description="Type of the target handle.")
@ -53,24 +49,16 @@ class Edge:
def validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = (
self.target_handle.type in self.source_handle.baseClasses
)
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
else:
self.valid_handles = (
any(
baseClass in self.target_handle.inputTypes
for baseClass in self.source_handle.baseClasses
)
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
or self.target_handle.type in self.source_handle.baseClasses
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has invalid handles"
)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
def __setstate__(self, state):
self.source_id = state["source_id"]
@ -87,11 +75,7 @@ class Edge:
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(
output in target_req
for output in self.source_types
for target_req in self.target_reqs
)
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
# Get what type of input the target node is expecting
self.matched_type = next(
@ -102,10 +86,7 @@ class Edge:
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has no matched type"
)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
def __repr__(self) -> str:
return (
@ -118,10 +99,7 @@ class Edge:
def __eq__(self, __o: object) -> bool:
# Create a better way to compare edges
return (
self._source_handle == __o._source_handle
and self._target_handle == __o._target_handle
)
return self._source_handle == __o._source_handle and self._target_handle == __o._target_handle
class ContractEdge(Edge):
@ -178,9 +156,7 @@ class ContractEdge(Edge):
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
def log_transaction(
edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None
):
def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None):
try:
monitor_service = get_monitor_service()
clean_params = build_clean_params(target)

View file

@ -76,9 +76,7 @@ class Graph:
"""Returns the state of the graph."""
return self.state_manager.get_state(name, run_id=self._run_id)
def update_state(
self, name: str, record: Union[str, Record], caller: Optional[str] = None
) -> None:
def update_state(self, name: str, record: Union[str, Record], caller: Optional[str] = None) -> None:
"""Updates the state of the graph."""
if caller:
# If there is a caller which is a vertex_id, I want to activate
@ -110,12 +108,9 @@ class Graph:
def reset_activated_vertices(self):
self.activated_vertices = []
def append_state(
self, name: str, record: Union[str, Record], caller: Optional[str] = None
) -> None:
def append_state(self, name: str, record: Union[str, Record], caller: Optional[str] = None) -> None:
"""Appends the state of the graph."""
if caller:
self.activate_state_vertices(name, caller)
self.state_manager.append_state(name, record, run_id=self._run_id)
@ -161,10 +156,7 @@ class Graph:
"""Runs the graph with the given inputs."""
for vertex_id in self._is_input_vertices:
vertex = self.get_vertex(vertex_id)
if input_components and (
vertex_id not in input_components
or vertex.display_name not in input_components
):
if input_components and (vertex_id not in input_components or vertex.display_name not in input_components):
continue
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
@ -187,11 +179,7 @@ class Graph:
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
if (
not vertex.result
and not stream
and hasattr(vertex, "consume_async_generator")
):
if not vertex.result and not stream and hasattr(vertex, "consume_async_generator"):
await vertex.consume_async_generator()
if vertex.display_name in outputs or vertex.id in outputs:
vertex_outputs.append(vertex.result)
@ -269,9 +257,7 @@ class Graph:
def build_parent_child_map(self):
parent_child_map = defaultdict(list)
for vertex in self.vertices:
parent_child_map[vertex.id] = [
child.id for child in self.get_successors(vertex)
]
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
return parent_child_map
def increment_run_count(self):
@ -456,11 +442,7 @@ class Graph:
"""Updates the edges of a vertex."""
# Vertex has edges, so we need to update the edges
for edge in vertex.edges:
if (
edge not in self.edges
and edge.source_id in self.vertex_map
and edge.target_id in self.vertex_map
):
if edge not in self.edges and edge.source_id in self.vertex_map and edge.target_id in self.vertex_map:
self.edges.append(edge)
def _build_graph(self) -> None:
@ -485,11 +467,7 @@ class Graph:
return
self.vertices.remove(vertex)
self.vertex_map.pop(vertex_id)
self.edges = [
edge
for edge in self.edges
if edge.source_id != vertex_id and edge.target_id != vertex_id
]
self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id]
def _build_vertex_params(self) -> None:
"""Identifies and handles the LLM vertex within the graph."""
@ -510,9 +488,7 @@ class Graph:
return
for vertex in self.vertices:
if not self._validate_vertex(vertex):
raise ValueError(
f"{vertex.display_name} is not connected to any other components"
)
raise ValueError(f"{vertex.display_name} is not connected to any other components")
def _validate_vertex(self, vertex: Vertex) -> bool:
"""Validates a vertex."""
@ -574,9 +550,7 @@ class Graph:
name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}",
)
tasks.append(task)
vertex_task_run_count[vertex_id] = (
vertex_task_run_count.get(vertex_id, 0) + 1
)
vertex_task_run_count[vertex_id] = vertex_task_run_count.get(vertex_id, 0) + 1
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
await self._execute_tasks(tasks)
logger.debug("Graph processing complete")
@ -618,9 +592,7 @@ class Graph:
def dfs(vertex):
if state[vertex] == 1:
# We have a cycle
raise ValueError(
"Graph contains a cycle, cannot perform topological sort"
)
raise ValueError("Graph contains a cycle, cannot perform topological sort")
if state[vertex] == 0:
state[vertex] = 1
for edge in vertex.edges:
@ -644,10 +616,7 @@ class Graph:
def get_predecessors(self, vertex):
"""Returns the predecessors of a vertex."""
return [
self.get_vertex(source_id)
for source_id in self.predecessor_map.get(vertex.id, [])
]
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
def get_all_successors(self, vertex, recursive=True, flat=True):
# Recursively get the successors of the current vertex
@ -688,10 +657,7 @@ class Graph:
def get_successors(self, vertex):
"""Returns the successors of a vertex."""
return [
self.get_vertex(target_id)
for target_id in self.successor_map.get(vertex.id, [])
]
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
"""Returns the neighbors of a vertex."""
@ -737,9 +703,7 @@ class Graph:
edges_added.add((source.id, target.id))
return edges
def _get_vertex_class(
self, node_type: str, node_base_type: str, node_id: str
) -> Type[Vertex]:
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
@ -772,18 +736,14 @@ class Graph:
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(
vertex_type, vertex_base_type, vertex_data["id"]
)
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)
return vertices
def get_children_by_vertex_type(
self, vertex: Vertex, vertex_type: str
) -> List[Vertex]:
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
vertex_types = [vertex.data["type"]]
@ -795,9 +755,7 @@ class Graph:
def __repr__(self):
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join(
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
)
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]:
@ -865,8 +823,7 @@ class Graph:
vertex.id
for vertex in vertices
# if filter_graphs then only vertex.is_input will be considered
if self.in_degree_map[vertex.id] == 0
and (not filter_graphs or vertex.is_input)
if self.in_degree_map[vertex.id] == 0 and (not filter_graphs or vertex.is_input)
)
layers: List[List[str]] = []
visited = set(queue)
@ -940,9 +897,7 @@ class Graph:
return refined_layers
def sort_chat_inputs_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
chat_inputs_first = []
for layer in vertices_layers:
for vertex_id in layer:
@ -983,9 +938,7 @@ class Graph:
first_layer = vertices_layers[0]
# save the only the rest
self.vertices_layers = vertices_layers[1:]
self.vertices_to_run = {
vertex_id for vertex_id in chain.from_iterable(vertices_layers)
}
self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)}
# Return just the first layer
return first_layer
@ -996,15 +949,11 @@ class Graph:
self.vertices_to_run.remove(vertex_id)
return should_run
def sort_interface_components_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
def contains_interface_component(vertex):
return any(
component.value in vertex for component in InterfaceComponentTypes
)
return any(component.value in vertex for component in InterfaceComponentTypes)
# Sort each inner list so that vertices containing ChatInput or ChatOutput come first
sorted_vertices = [
@ -1016,22 +965,16 @@ class Graph:
]
return sorted_vertices
def sort_by_avg_build_time(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
def sort_by_avg_build_time(self, vertices_layers: List[List[str]]) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
def sort_layer_by_avg_build_time(vertices_ids: List[str]) -> List[str]:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
if len(vertices_ids) == 1:
return vertices_ids
vertices_ids.sort(
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
)
vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time)
return vertices_ids
sorted_vertices = [
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
]
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
return sorted_vertices

View file

@ -47,10 +47,7 @@ class VertexTypesDict(LazyLoadDictBase):
**{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()},
**{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()},
**{t: types.OutputParserVertex for t in output_parser_creator.to_list()},
**{
t: types.CustomComponentVertex
for t in custom_component_creator.to_list()
},
**{t: types.CustomComponentVertex for t in custom_component_creator.to_list()},
**{t: types.RetrieverVertex for t in retriever_creator.to_list()},
**{t: types.ChatVertex for t in CHAT_COMPONENTS},
**{t: types.RoutingVertex for t in ROUTING_COMPONENTS},

View file

@ -59,13 +59,8 @@ class Vertex:
self.updated_raw_params = False
self.id: str = data["id"]
self.is_state = False
self.is_input = any(
input_component_name in self.id for input_component_name in INPUT_COMPONENTS
)
self.is_output = any(
output_component_name in self.id
for output_component_name in OUTPUT_COMPONENTS
)
self.is_input = any(input_component_name in self.id for input_component_name in INPUT_COMPONENTS)
self.is_output = any(output_component_name in self.id for output_component_name in OUTPUT_COMPONENTS)
self.has_session_id = None
self._custom_component = None
self.has_external_input = False
@ -106,17 +101,11 @@ class Vertex:
def set_state(self, state: str):
self.state = VertexStates[state]
if (
self.state == VertexStates.INACTIVE
and self.graph.in_degree_map[self.id] < 2
):
if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] < 2:
# If the vertex is inactive and has only one in degree
# it means that it is not a merge point in the graph
self.graph.inactivated_vertices.add(self.id)
elif (
self.state == VertexStates.ACTIVE
and self.id in self.graph.inactivated_vertices
):
elif self.state == VertexStates.ACTIVE and self.id in self.graph.inactivated_vertices:
self.graph.inactivated_vertices.remove(self.id)
@property
@ -133,9 +122,7 @@ class Vertex:
# If the Vertex.type is a power component
# then we need to return the built object
# instead of the result dict
if self.is_interface_component and not isinstance(
self._built_object, UnbuiltObject
):
if self.is_interface_component and not isinstance(self._built_object, UnbuiltObject):
result = self._built_object
# if it is not a dict or a string and hasattr model_dump then
# return the model_dump
@ -147,11 +134,7 @@ class Vertex:
if isinstance(self._built_result, UnbuiltResult):
return {}
return (
self._built_result
if isinstance(self._built_result, dict)
else {"result": self._built_result}
)
return self._built_result if isinstance(self._built_result, dict) else {"result": self._built_result}
def set_artifacts(self) -> None:
pass
@ -221,31 +204,19 @@ class Vertex:
self.selected_output_type = self.data["node"].get("selected_output_type")
self.is_input = self.data["node"].get("is_input") or self.is_input
self.is_output = self.data["node"].get("is_output") or self.is_output
template_dicts = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
self.has_session_id = "session_id" in template_dicts
self.required_inputs = [
template_dicts[key]["type"]
for key, value in template_dicts.items()
if value["required"]
template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"]
]
self.optional_inputs = [
template_dicts[key]["type"]
for key, value in template_dicts.items()
if not value["required"]
template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"]
]
# Add the template_dicts[key]["input_types"] to the optional_inputs
self.optional_inputs.extend(
[
input_type
for value in template_dicts.values()
for input_type in value.get("input_types", [])
]
[input_type for value in template_dicts.values() for input_type in value.get("input_types", [])]
)
template_dict = self.data["node"]["template"]
@ -292,11 +263,7 @@ class Vertex:
self.updated_raw_params = False
return
template_dict = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
params = {}
for edge in self.edges:
@ -317,10 +284,7 @@ class Vertex:
# we don't know the key of the dict but we need to set the value
# to the vertex that is the source of the edge
param_dict = template_dict[param_key]["value"]
params[param_key] = {
key: self.graph.get_vertex(edge.source_id)
for key in param_dict.keys()
}
params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()}
else:
params[param_key] = self.graph.get_vertex(edge.source_id)
@ -356,11 +320,7 @@ class Vertex:
# list of dicts, so we need to convert it to a dict
# before passing it to the build method
if isinstance(val, list):
params[key] = {
k: v
for item in value.get("value", [])
for k, v in item.items()
}
params[key] = {k: v for item in value.get("value", []) for k, v in item.items()}
elif isinstance(val, dict):
params[key] = val
elif value.get("type") == "int" and val is not None:
@ -485,9 +445,7 @@ class Vertex:
if isinstance(self._built_object, str):
self._built_result = self._built_object
result = await generate_result(
self._built_object, inputs, self.has_external_output, session_id
)
result = await generate_result(self._built_object, inputs, self.has_external_output, session_id)
self._built_result = result
async def _build_each_node_in_params_dict(self, user_id=None):
@ -507,9 +465,7 @@ class Vertex:
elif key not in self.params or self.updated_raw_params:
self.params[key] = value
async def _build_dict_of_nodes_and_update_params(
self, key, nodes: Dict[str, "Vertex"], user_id=None
):
async def _build_dict_of_nodes_and_update_params(self, key, nodes: Dict[str, "Vertex"], user_id=None):
"""
Iterates over a dictionary of nodes, builds each and updates the params dictionary.
"""
@ -529,9 +485,7 @@ class Vertex:
"""
return all(self._is_node(node) for node in value)
async def get_result(
self, requester: Optional["Vertex"] = None, user_id=None, timeout=None
) -> Any:
async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any:
# PLEASE REVIEW THIS IF STATEMENT
# Check if the Vertex was built already
if self._built:
@ -565,9 +519,7 @@ class Vertex:
self._extend_params_list_with_result(key, result)
self.params[key] = result
async def _build_list_of_nodes_and_update_params(
self, key, nodes: List["Vertex"], user_id=None
):
async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
"""
Iterates over a list of nodes, builds each and updates the params dictionary.
"""
@ -634,9 +586,7 @@ class Vertex:
except Exception as exc:
logger.exception(exc)
raise ValueError(
f"Error building node {self.display_name}: {str(exc)}"
) from exc
raise ValueError(f"Error building node {self.display_name}: {str(exc)}") from exc
def _update_built_object_and_artifacts(self, result):
"""
@ -664,9 +614,7 @@ class Vertex:
logger.warning(message)
elif isinstance(self._built_object, (Iterator, AsyncIterator)):
if self.display_name in ["Text Output"]:
raise ValueError(
f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead."
)
raise ValueError(f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead.")
def _reset(self, params_update: Optional[Dict[str, Any]] = None):
self._built = False
@ -728,24 +676,16 @@ class Vertex:
return self._built_object
# Get the requester edge
requester_edge = next(
(edge for edge in self.edges if edge.target_id == requester.id), None
)
requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None)
# Return the result of the requester edge
return (
None
if requester_edge is None
else await requester_edge.get_result(source=self, target=requester)
)
return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester)
def add_edge(self, edge: "ContractEdge") -> None:
if edge not in self.edges:
self.edges.append(edge)
def __repr__(self) -> str:
return (
f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})"
)
return f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})"
def __eq__(self, __o: object) -> bool:
try:
@ -766,8 +706,4 @@ class Vertex:
def _built_object_repr(self):
# Add a message with an emoji, stars for sucess,
return (
"Built sucessfully ✨"
if self._built_object is not None
else "Failed to build 😵‍💫"
)
return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵‍💫"

View file

@ -1,7 +1,6 @@
import ast
import json
from typing import (AsyncIterator, Callable, Dict, Iterator, List, Optional,
Union)
from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
import yaml
from langchain_core.messages import AIMessage
@ -124,11 +123,9 @@ class DocumentLoaderVertex(Vertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(
len(record.text)
for record in self._built_object
if hasattr(record, "text")
) / len(self._built_object)
avg_length = sum(len(record.text) for record in self._built_object if hasattr(record, "text")) / len(
self._built_object
)
return f"""{self.display_name}({len(self._built_object)} records)
\nAvg. Record Length (characters): {int(avg_length)}
Records: {self._built_object[:3]}..."""
@ -201,9 +198,7 @@ class TextSplitterVertex(Vertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
self._built_object
)
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
\nDocuments: {self._built_object[:3]}..."""
@ -250,27 +245,18 @@ class PromptVertex(Vertex):
user_id = kwargs.get("user_id", None)
tools = kwargs.get("tools", [])
if not self._built or force:
if (
"input_variables" not in self.params
or self.params["input_variables"] is None
):
if "input_variables" not in self.params or self.params["input_variables"] is None:
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = (
[tool_node.build(user_id=user_id) for tool_node in tools]
if tools is not None
else []
)
tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
tools = flatten_list(tools)
self.params["tools"] = tools
prompt_params = [
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]
@ -280,20 +266,14 @@ class PromptVertex(Vertex):
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(
set(self.params["input_variables"])
)
self.params["input_variables"] = list(set(self.params["input_variables"]))
elif isinstance(self.params, dict):
self.params.pop("input_variables", None)
await self._build(user_id=user_id)
def _built_object_repr(self):
if (
not self.artifacts
or self._built_object is None
or not hasattr(self._built_object, "format")
):
if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"):
return super()._built_object_repr()
elif isinstance(self._built_object, UnbuiltObject):
return super()._built_object_repr()
@ -305,9 +285,7 @@ class PromptVertex(Vertex):
# so the prompt format doesn't break
artifacts.pop("handle_keys", None)
try:
if not hasattr(self._built_object, "template") and hasattr(
self._built_object, "prompt"
):
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
template = self._built_object.prompt.template
else:
template = self._built_object.template
@ -315,11 +293,7 @@ class PromptVertex(Vertex):
if value:
replace_key = "{" + key + "}"
template = template.replace(replace_key, value)
return (
template
if isinstance(template, str)
else f"{self.vertex_type}({template})"
)
return template if isinstance(template, str) else f"{self.vertex_type}({template})"
except KeyError:
return str(self._built_object)

View file

@ -30,8 +30,5 @@ def records_to_text(template: str, records: list[Record]) -> list[str]:
records = [records]
# Check if there are any format strings in the template
formated_records = [
template.format(text=record.text, data=record.data, **record.data)
for record in records
]
formated_records = [template.format(text=record.text, data=record.data, **record.data) for record in records]
return "\n".join(formated_records)

View file

@ -89,11 +89,6 @@ def create_new_project(
)
db_flow = Flow.model_validate(new_project, from_attributes=True)
session.add(db_flow)
flows = session.exec(
select(Flow).where(
Flow.name == project_name,
)
).all()
def get_all_flows_similar_to_project(session, project_name):
@ -117,7 +112,6 @@ def delete_start_projects(session):
def create_or_update_starter_projects():
with session_scope() as session:
starter_projects = load_starter_projects()
delete_start_projects(session)
@ -132,9 +126,7 @@ def create_or_update_starter_projects():
project_icon_bg_color,
) = get_project_data(project)
if project_name and project_data:
for existing_project in get_all_flows_similar_to_project(
session, project_name
):
for existing_project in get_all_flows_similar_to_project(session, project_name):
session.delete(existing_project)
create_new_project(

View file

@ -95,9 +95,7 @@ class CodeParser:
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
if alias.asname:
self.data["imports"].append(
(node.module, f"{alias.name} as {alias.asname}")
)
self.data["imports"].append((node.module, f"{alias.name} as {alias.asname}"))
else:
self.data["imports"].append((node.module, alias.name))
@ -146,9 +144,7 @@ class CodeParser:
return_type = None
if node.returns:
return_type_str = ast.unparse(node.returns)
eval_env = self.construct_eval_env(
return_type_str, tuple(self.data["imports"])
)
eval_env = self.construct_eval_env(return_type_str, tuple(self.data["imports"]))
try:
return_type = eval(return_type_str, eval_env)
@ -190,22 +186,14 @@ class CodeParser:
num_defaults = len(node.args.defaults)
num_missing_defaults = num_args - num_defaults
missing_defaults = [None] * num_missing_defaults
default_values = [
ast.unparse(default).strip("'") if default else None
for default in node.args.defaults
]
default_values = [ast.unparse(default).strip("'") if default else None for default in node.args.defaults]
# Now check all default values to see if there
# are any "None" values in the middle
default_values = [
None if value == "None" else value for value in default_values
]
default_values = [None if value == "None" else value for value in default_values]
defaults = missing_defaults + default_values
args = [
self.parse_arg(arg, default)
for arg, default in zip(node.args.args, defaults)
]
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.args, defaults)]
return args
def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
@ -223,17 +211,11 @@ class CodeParser:
"""
Parses the keyword-only arguments of a function or method node.
"""
kw_defaults = [None] * (
len(node.args.kwonlyargs) - len(node.args.kw_defaults)
) + [
ast.unparse(default) if default else None
for default in node.args.kw_defaults
kw_defaults = [None] * (len(node.args.kwonlyargs) - len(node.args.kw_defaults)) + [
ast.unparse(default) if default else None for default in node.args.kw_defaults
]
args = [
self.parse_arg(arg, default)
for arg, default in zip(node.args.kwonlyargs, kw_defaults)
]
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.kwonlyargs, kw_defaults)]
return args
def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
@ -337,9 +319,7 @@ class CodeParser:
Extracts global variables from the code.
"""
global_var = {
"targets": [
t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets
],
"targets": [t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets],
"value": ast.unparse(node.value),
}
self.data["global_vars"].append(global_var)

View file

@ -21,9 +21,7 @@ class ComponentFunctionEntrypointNameNullError(HTTPException):
class Component:
ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided."
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = (
"The name of the entrypoint function must be provided."
)
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = "The name of the entrypoint function must be provided."
code: Optional[str] = None
_function_entrypoint_name: str = "build"

View file

@ -77,17 +77,13 @@ class CustomComponent(Component):
def update_state(self, name: str, value: Any):
try:
self.vertex.graph.update_state(
name=name, record=value, caller=self.vertex.id
)
self.vertex.graph.update_state(name=name, record=value, caller=self.vertex.id)
except Exception as e:
raise ValueError(f"Error updating state: {e}")
def append_state(self, name: str, value: Any):
try:
self.vertex.graph.append_state(
name=name, record=value, caller=self.vertex.id
)
self.vertex.graph.append_state(name=name, record=value, caller=self.vertex.id)
except Exception as e:
raise ValueError(f"Error appending state: {e}")
@ -138,9 +134,7 @@ class CustomComponent(Component):
def build_config(self):
return self.field_config
def update_build_config(
self, build_config: dict, field_name: str, field_value: Any
):
def update_build_config(self, build_config: dict, field_name: str, field_value: Any):
build_config[field_name] = field_value
return build_config
@ -148,9 +142,7 @@ class CustomComponent(Component):
def tree(self):
return self.get_code_tree(self.code or "")
def to_records(
self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False
) -> List[Record]:
def to_records(self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False) -> List[Record]:
"""
Converts input data into a list of Record objects.
@ -199,9 +191,7 @@ class CustomComponent(Component):
return records
def create_references_from_records(
self, records: List[Record], include_data: bool = False
) -> str:
def create_references_from_records(self, records: List[Record], include_data: bool = False) -> str:
"""
Create references from a list of records.
@ -240,20 +230,14 @@ class CustomComponent(Component):
if not self.code:
return {}
component_classes = [
cls
for cls in self.tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
if not component_classes:
return {}
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
]
return build_methods[0] if build_methods else {}
@ -310,9 +294,7 @@ class CustomComponent(Component):
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.get_credential(
user_id=self._user_id or "", name=name, session=session
)
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return get_credential
@ -322,9 +304,7 @@ class CustomComponent(Component):
credential_service = get_credential_service()
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.list_credentials(
user_id=self._user_id, session=session
)
return credential_service.list_credentials(user_id=self._user_id, session=session)
def index(self, value: int = 0):
"""Returns a function that returns the value at the given index in the iterable."""
@ -363,11 +343,7 @@ class CustomComponent(Component):
if not self._flows_records:
self.list_flows()
if not flow_id and self._flows_records:
flow_ids = [
flow.data["id"]
for flow in self._flows_records
if flow.data["name"] == flow_name
]
flow_ids = [flow.data["id"] for flow in self._flows_records if flow.data["name"] == flow_name]
if not flow_ids:
raise ValueError(f"Flow {flow_name} not found")
elif len(flow_ids) > 1:
@ -389,9 +365,7 @@ class CustomComponent(Component):
db_service = get_db_service()
with get_session(db_service) as session:
flows = session.exec(
select(Flow)
.where(Flow.user_id == self._user_id)
.where(Flow.is_component == False)
select(Flow).where(Flow.user_id == self._user_id).where(Flow.is_component == False) # noqa
).all()
flows_records = [flow.to_record() for flow in flows]

View file

@ -80,13 +80,9 @@ class DirectoryReader:
except Exception as e:
logger.error(f"Error while loading component: {e}")
continue
items.append(
{"name": menu["name"], "path": menu["path"], "components": components}
)
items.append({"name": menu["name"], "path": menu["path"], "components": components})
filtered = [menu for menu in items if menu["components"]]
logger.debug(
f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}'
)
logger.debug(f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}')
return {"menu": filtered}
def validate_code(self, file_content):
@ -119,9 +115,7 @@ class DirectoryReader:
Walk through the directory path and return a list of all .py files.
"""
if not (safe_path := self.get_safe_path()):
raise CustomComponentPathValueError(
f"The path needs to start with '{self.base_path}'."
)
raise CustomComponentPathValueError(f"The path needs to start with '{self.base_path}'.")
file_list = []
safe_path_obj = Path(safe_path)
@ -131,11 +125,7 @@ class DirectoryReader:
# any folders below [folder] will be ignored
# basically the parent folder of the file should be a
# folder in the safe_path
if (
file_path.is_file()
and file_path.parent.parent == safe_path_obj
and not file_path.name.startswith("__")
):
if file_path.is_file() and file_path.parent.parent == safe_path_obj and not file_path.name.startswith("__"):
file_list.append(str(file_path))
return file_list
@ -173,9 +163,7 @@ class DirectoryReader:
for node in ast.walk(module):
if isinstance(node, ast.FunctionDef):
for arg in node.args.args:
if self._is_type_hint_in_arg_annotation(
arg.annotation, type_hint_name
):
if self._is_type_hint_in_arg_annotation(arg.annotation, type_hint_name):
return True
except SyntaxError:
# Returns False if the code is not valid Python
@ -193,16 +181,14 @@ class DirectoryReader:
and annotation.value.id == type_hint_name
)
def is_type_hint_used_but_not_imported(
self, type_hint_name: str, code: str
) -> bool:
def is_type_hint_used_but_not_imported(self, type_hint_name: str, code: str) -> bool:
"""
Check if a type hint is used but not imported in the given code.
"""
try:
return self._is_type_hint_used_in_args(
return self._is_type_hint_used_in_args(type_hint_name, code) and not self._is_type_hint_imported(
type_hint_name, code
) and not self._is_type_hint_imported(type_hint_name, code)
)
except SyntaxError:
# Returns True if there's something wrong with the code
# TODO : Find a better way to handle this
@ -223,9 +209,9 @@ class DirectoryReader:
return False, "Syntax error"
elif not self.validate_build(file_content):
return False, "Missing build function"
elif self._is_type_hint_used_in_args(
elif self._is_type_hint_used_in_args("Optional", file_content) and not self._is_type_hint_imported(
"Optional", file_content
) and not self._is_type_hint_imported("Optional", file_content):
):
return (
False,
"Type hint 'Optional' is used but not imported in the code.",
@ -241,18 +227,14 @@ class DirectoryReader:
from the .py files in the directory.
"""
response = {"menu": []}
logger.debug(
"-------------------- Building component menu list --------------------"
)
logger.debug("-------------------- Building component menu list --------------------")
for file_path in file_paths:
menu_name = os.path.basename(os.path.dirname(file_path))
filename = os.path.basename(file_path)
validation_result, result_content = self.process_file(file_path)
if not validation_result:
logger.error(
f"Error while processing file {file_path}: {result_content}"
)
logger.error(f"Error while processing file {file_path}: {result_content}")
menu_result = self.find_menu(response, menu_name) or {
"name": menu_name,
@ -265,9 +247,7 @@ class DirectoryReader:
# first check if it's already CamelCase
if "_" in component_name:
component_name_camelcase = " ".join(
word.title() for word in component_name.split("_")
)
component_name_camelcase = " ".join(word.title() for word in component_name.split("_"))
else:
component_name_camelcase = component_name
@ -275,9 +255,7 @@ class DirectoryReader:
try:
output_types = self.get_output_types_from_code(result_content)
except Exception as exc:
logger.exception(
f"Error while getting output types from code: {str(exc)}"
)
logger.exception(f"Error while getting output types from code: {str(exc)}")
output_types = [component_name_camelcase]
else:
output_types = [component_name_camelcase]
@ -293,9 +271,7 @@ class DirectoryReader:
if menu_result not in response["menu"]:
response["menu"].append(menu_result)
logger.debug(
"-------------------- Component menu list built --------------------"
)
logger.debug("-------------------- Component menu list built --------------------")
return response
@staticmethod

View file

@ -32,18 +32,14 @@ class UpdateBuildConfigError(Exception):
pass
def add_output_types(
frontend_node: CustomComponentFrontendNode, return_types: List[str]
):
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
"""Add output types to the frontend node"""
for return_type in return_types:
if return_type is None:
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid return type. Please check your code and try again."
),
"error": ("Invalid return type. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
)
@ -75,18 +71,14 @@ def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: List
frontend_node.field_order = field_order
def add_base_classes(
frontend_node: CustomComponentFrontendNode, return_types: List[str]
):
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
"""Add base classes to the frontend node"""
for return_type_instance in return_types:
if return_type_instance is None:
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid return type. Please check your code and try again."
),
"error": ("Invalid return type. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
)
@ -163,14 +155,10 @@ def add_new_custom_field(
# If options is a list, then it's a dropdown
# If options is None, then it's a list of strings
is_list = isinstance(field_config.get("options"), list)
field_config["is_list"] = (
is_list or field_config.get("is_list", False) or field_contains_list
)
field_config["is_list"] = is_list or field_config.get("is_list", False) or field_contains_list
if "name" in field_config:
warnings.warn(
"The 'name' key in field_config is used to build the object and can't be changed."
)
warnings.warn("The 'name' key in field_config is used to build the object and can't be changed.")
required = field_config.pop("required", field_required)
placeholder = field_config.pop("placeholder", "")
@ -209,9 +197,7 @@ def add_extra_fields(frontend_node, field_config, function_args):
]:
continue
field_name, field_type, field_value, field_required = get_field_properties(
extra_field
)
field_name, field_type, field_value, field_required = get_field_properties(extra_field)
config = _field_config.pop(field_name, {})
frontend_node = add_new_custom_field(
frontend_node,
@ -221,17 +207,13 @@ def add_extra_fields(frontend_node, field_config, function_args):
field_required,
config,
)
if "kwargs" in function_args_names and not all(
key in function_args_names for key in field_config.keys()
):
if "kwargs" in function_args_names and not all(key in function_args_names for key in field_config.keys()):
for field_name, field_config in _field_config.copy().items():
if "name" not in field_config or field_name == "code":
continue
config = _field_config.get(field_name, {})
config = config.model_dump() if isinstance(config, BaseModel) else config
field_name, field_type, field_value, field_required = get_field_properties(
extra_field=config
)
field_name, field_type, field_value, field_required = get_field_properties(extra_field=config)
frontend_node = add_new_custom_field(
frontend_node,
field_name,
@ -269,9 +251,7 @@ def run_build_config(
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid type convertion. Please check your code and try again."
),
"error": ("Invalid type convertion. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
) from exc
@ -383,16 +363,10 @@ def build_custom_component_template(
add_extra_fields(frontend_node, field_config, entrypoint_args)
frontend_node = add_code_field(
frontend_node, custom_component.code, field_config.get("code", {})
)
frontend_node = add_code_field(frontend_node, custom_component.code, field_config.get("code", {}))
add_base_classes(
frontend_node, custom_component.get_function_entrypoint_return_type
)
add_output_types(
frontend_node, custom_component.get_function_entrypoint_return_type
)
add_base_classes(frontend_node, custom_component.get_function_entrypoint_return_type)
add_output_types(frontend_node, custom_component.get_function_entrypoint_return_type)
reorder_fields(frontend_node, custom_instance._get_field_order())
@ -403,9 +377,7 @@ def build_custom_component_template(
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid type convertion. Please check your code and try again."
),
"error": ("Invalid type convertion. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
) from exc
@ -431,9 +403,7 @@ def build_custom_components(settings_service):
if not settings_service.settings.COMPONENTS_PATH:
return {}
logger.info(
f"Building custom components from {settings_service.settings.COMPONENTS_PATH}"
)
logger.info(f"Building custom components from {settings_service.settings.COMPONENTS_PATH}")
custom_components_from_file = {}
processed_paths = set()
for path in settings_service.settings.COMPONENTS_PATH:
@ -444,9 +414,7 @@ def build_custom_components(settings_service):
custom_component_dict = build_custom_component_list_from_path(path_str)
if custom_component_dict:
category = next(iter(custom_component_dict))
logger.info(
f"Loading {len(custom_component_dict[category])} component(s) from category {category}"
)
logger.info(f"Loading {len(custom_component_dict[category])} component(s) from category {category}")
custom_components_from_file = merge_nested_dicts_with_renaming(
custom_components_from_file, custom_component_dict
)
@ -467,14 +435,10 @@ def update_field_dict(
if "refresh" in field_dict:
if call:
try:
custom_component_instance.update_build_config(
build_config, update_field, update_field_value
)
custom_component_instance.update_build_config(build_config, update_field, update_field_value)
except Exception as exc:
logger.error(f"Error while running update_build_config: {str(exc)}")
raise UpdateBuildConfigError(
f"Error while running update_build_config: {str(exc)}"
) from exc
raise UpdateBuildConfigError(f"Error while running update_build_config: {str(exc)}") from exc
field_dict["refresh"] = True
# Let's check if "range_spec" is a RangeSpec object

View file

@ -144,13 +144,9 @@ async def instantiate_based_on_type(
return class_object(**params)
async def instantiate_custom_component(
node_type, class_object, params, user_id, vertex
):
async def instantiate_custom_component(node_type, class_object, params, user_id, vertex):
params_copy = params.copy()
class_object: Type["CustomComponent"] = eval_custom_component_code(
params_copy.pop("code")
)
class_object: Type["CustomComponent"] = eval_custom_component_code(params_copy.pop("code"))
custom_component: "CustomComponent" = class_object(
user_id=user_id,
parameters=params_copy,
@ -226,9 +222,7 @@ def instantiate_memory(node_type, class_object, params):
# I want to catch a specific attribute error that happens
# when the object does not have a cursor attribute
except Exception as exc:
if "object has no attribute 'cursor'" in str(
exc
) or 'object has no field "conn"' in str(exc):
if "object has no attribute 'cursor'" in str(exc) or 'object has no field "conn"' in str(exc):
raise AttributeError(
(
"Failed to build connection to database."
@ -271,9 +265,7 @@ def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params:
if class_method := getattr(class_object, method, None):
agent = class_method(**params)
tools = params.get("tools", [])
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, handle_parsing_errors=True
)
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, handle_parsing_errors=True)
return load_agent_executor(class_object, params)
@ -329,11 +321,7 @@ def instantiate_embedding(node_type, class_object, params: Dict):
try:
return class_object(**params)
except ValidationError:
params = {
key: value
for key, value in params.items()
if key in class_object.model_fields
}
params = {key: value for key, value in params.items() if key in class_object.model_fields}
return class_object(**params)
@ -345,9 +333,7 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
if "texts" in params:
params["documents"] = params.pop("texts")
if "documents" in params:
params["documents"] = [
doc for doc in params["documents"] if isinstance(doc, Document)
]
params["documents"] = [doc for doc in params["documents"] if isinstance(doc, Document)]
if initializer := vecstore_initializer.get(class_object.__name__):
vecstore = initializer(class_object, params)
else:
@ -362,9 +348,7 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
return vecstore
def instantiate_documentloader(
node_type: str, class_object: Type[BaseLoader], params: Dict
):
def instantiate_documentloader(node_type: str, class_object: Type[BaseLoader], params: Dict):
if "file_filter" in params:
# file_filter will be a string but we need a function
# that will be used to filter the files using file_filter
@ -373,17 +357,13 @@ def instantiate_documentloader(
# in x and if it is, we will return True
file_filter = params.pop("file_filter")
extensions = file_filter.split(",")
params["file_filter"] = lambda x: any(
extension.strip() in x for extension in extensions
)
params["file_filter"] = lambda x: any(extension.strip() in x for extension in extensions)
metadata = params.pop("metadata", None)
if metadata and isinstance(metadata, str):
try:
metadata = orjson.loads(metadata)
except json.JSONDecodeError as exc:
raise ValueError(
"The metadata you provided is not a valid JSON string."
) from exc
raise ValueError("The metadata you provided is not a valid JSON string.") from exc
if node_type == "WebBaseLoader":
if web_path := params.pop("web_path", None):
@ -416,16 +396,12 @@ def instantiate_textsplitter(
"Try changing the chunk_size of the Text Splitter."
) from exc
if (
"separator_type" in params and params["separator_type"] == "Text"
) or "separator_type" not in params:
if ("separator_type" in params and params["separator_type"] == "Text") or "separator_type" not in params:
params.pop("separator_type", None)
# separators might come in as an escaped string like \\n
# so we need to convert it to a string
if "separators" in params:
params["separators"] = (
params["separators"].encode().decode("unicode-escape")
)
params["separators"] = params["separators"].encode().decode("unicode-escape")
text_splitter = class_object(**params)
else:
from langchain.text_splitter import Language
@ -452,8 +428,7 @@ def replace_zero_shot_prompt_with_prompt_template(nodes):
tools = [
tool
for tool in nodes
if tool["type"] != "chatOutputNode"
and "Tool" in tool["data"]["node"]["base_classes"]
if tool["type"] != "chatOutputNode" and "Tool" in tool["data"]["node"]["base_classes"]
]
node["data"] = build_prompt_template(prompt=node["data"], tools=tools)
break
@ -467,9 +442,7 @@ def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs)
# agent has hidden args for memory. might need to be support
# memory = params["memory"]
# if allowed_tools is not a list or set, make it a list
if not isinstance(allowed_tools, (list, set)) and isinstance(
allowed_tools, BaseTool
):
if not isinstance(allowed_tools, (list, set)) and isinstance(allowed_tools, BaseTool):
allowed_tools = [allowed_tools]
tool_names = [tool.name for tool in allowed_tools]
# Agent class requires an output_parser but Agent classes
@ -497,10 +470,7 @@ def build_prompt_template(prompt, tools):
format_instructions = prompt["node"]["template"]["format_instructions"]["value"]
tool_strings = "\n".join(
[
f"{tool['data']['node']['name']}: {tool['data']['node']['description']}"
for tool in tools
]
[f"{tool['data']['node']['name']}: {tool['data']['node']['description']}" for tool in tools]
)
tool_names = ", ".join([tool["data"]["node"]["name"] for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)

View file

@ -66,6 +66,4 @@ def get_all_types_dict(settings_service):
"""Get all types dictionary combining native and custom components."""
native_components = build_langchain_types_dict()
custom_components_from_file = build_custom_components(settings_service)
return merge_nested_dicts_with_renaming(
native_components, custom_components_from_file
)
return merge_nested_dicts_with_renaming(native_components, custom_components_from_file)

View file

@ -43,9 +43,7 @@ def try_setting_streaming_options(langchain_object):
llm = None
if hasattr(langchain_object, "llm"):
llm = langchain_object.llm
elif hasattr(langchain_object, "llm_chain") and hasattr(
langchain_object.llm_chain, "llm"
):
elif hasattr(langchain_object, "llm_chain") and hasattr(langchain_object.llm_chain, "llm"):
llm = langchain_object.llm_chain.llm
if isinstance(llm, BaseLanguageModel):
@ -71,9 +69,7 @@ def extract_input_variables_from_prompt(prompt: str) -> list[str]:
# Extract the variable name from either the single or double brace match
if match.group(1): # Match found in double braces
variable_name = (
"{{" + match.group(1) + "}}"
) # Re-add single braces for JSON strings
variable_name = "{{" + match.group(1) + "}}" # Re-add single braces for JSON strings
else: # Match found in single braces
variable_name = match.group(2)
if variable_name is not None:
@ -109,9 +105,7 @@ def set_langchain_cache(settings):
if cache_type := os.getenv("LANGFLOW_LANGCHAIN_CACHE"):
try:
cache_class = import_class(
f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}"
)
cache_class = import_class(f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}")
logger.debug(f"Setting up LLM caching with {cache_class.__name__}")
set_llm_cache(cache_class())

View file

@ -20,9 +20,7 @@ from langflow.utils.logger import configure
def get_lifespan(fix_migration=False, socketio_server=None):
@asynccontextmanager
async def lifespan(app: FastAPI):
initialize_services(
fix_migration=fix_migration, socketio_server=socketio_server
)
initialize_services(fix_migration=fix_migration, socketio_server=socketio_server)
setup_llm_caching()
LangfuseInstance.update()
create_or_update_starter_projects()
@ -36,9 +34,7 @@ def create_app():
"""Create the FastAPI app and include the router."""
configure()
socketio_server = socketio.AsyncServer(
async_mode="asgi", cors_allowed_origins="*", logger=True
)
socketio_server = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*", logger=True)
lifespan = get_lifespan(socketio_server=socketio_server)
app = FastAPI(lifespan=lifespan)
origins = ["*"]
@ -105,9 +101,7 @@ def get_static_files_dir():
return frontend_path / "frontend"
def setup_app(
static_files_dir: Optional[Path] = None, backend_only: bool = False
) -> FastAPI:
def setup_app(static_files_dir: Optional[Path] = None, backend_only: bool = False) -> FastAPI:
"""Setup the FastAPI app."""
# get the directory of the current file
if not static_files_dir:

View file

@ -126,9 +126,7 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"):
result = await runnable.ainvoke(inputs)
else:
raise ValueError(
f"Runnable {runnable} does not support inputs of type {type(inputs)}"
)
raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}")
# Check if the result is a list of AIMessages
if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result):
result = [r.content for r in result]
@ -137,9 +135,7 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
return result
async def process_inputs_dict(
built_object: Union[Chain, VectorStore, Runnable], inputs: dict
):
async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict):
if isinstance(built_object, Chain):
if inputs is None:
raise ValueError("Inputs must be provided for a Chain")
@ -174,9 +170,7 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]):
return await process_runnable(built_object, inputs)
async def generate_result(
built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]
):
async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
if isinstance(inputs, dict):
result = await process_inputs_dict(built_object, inputs)
elif isinstance(inputs, List) and isinstance(built_object, Runnable):
@ -215,9 +209,7 @@ async def run_graph(
else:
graph_data = graph._graph_data
if not session_id and session_service is not None:
session_id = session_service.generate_key(
session_id=flow_id, data_graph=graph_data
)
session_id = session_service.generate_key(session_id=flow_id, data_graph=graph_data)
if inputs is None:
inputs = {}
@ -232,18 +224,14 @@ async def run_graph(
return outputs, session_id
def validate_input(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
raise ValueError("graph_data and tweaks should be dictionaries")
nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes")
if not isinstance(nodes, list):
raise ValueError(
"graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key"
)
raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key")
return nodes
@ -252,9 +240,7 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
template_data = node.get("data", {}).get("node", {}).get("template")
if not isinstance(template_data, dict):
logger.warning(
f"Template data for node {node.get('id')} should be a dictionary"
)
logger.warning(f"Template data for node {node.get('id')} should be a dictionary")
return
for tweak_name, tweak_value in node_tweaks.items():
@ -269,9 +255,7 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None:
vertex.params[tweak_name] = tweak_value
def process_tweaks(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
) -> Dict[str, Any]:
def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
"""
This function is used to tweak the graph data using the node id and the tweaks dict.
@ -307,8 +291,6 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]):
if node_tweaks := tweaks.get(node_id):
apply_tweaks_on_vertex(vertex, node_tweaks)
else:
logger.warning(
"Each node should be a Vertex with an 'id' attribute of type str"
)
logger.warning("Each node should be a Vertex with an 'id' attribute of type str")
return graph

View file

@ -63,9 +63,7 @@ class Record(BaseModel):
return self.data.get(key, self._default_value)
except KeyError:
# Fallback to default behavior to raise AttributeError for undefined attributes
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{key}'"
)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
def __setattr__(self, key, value):
"""

View file

@ -22,9 +22,7 @@ async def process_graph(
if build_result is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
raise ValueError("There was an error loading the langchain_object. Please, check all the nodes and try again.")
# Generate result and thought
try:
@ -50,7 +48,5 @@ async def process_graph(
raise e
async def run_build_result(
build_result: Any, chat_inputs: ChatMessage, client_id: str, session_id: str
):
async def run_build_result(build_result: Any, chat_inputs: ChatMessage, client_id: str, session_id: str):
return build_result(inputs=chat_inputs.message)

View file

@ -22,9 +22,7 @@ class FlowBase(SQLModel):
icon_bg_color: Optional[str] = Field(default=None, nullable=True)
data: Optional[Dict] = Field(default=None, nullable=True)
is_component: Optional[bool] = Field(default=False, nullable=True)
updated_at: Optional[datetime] = Field(
default_factory=datetime.utcnow, nullable=True
)
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, nullable=True)
folder: Optional[str] = Field(default=None, nullable=True)
@field_validator("icon_bg_color")

View file

@ -36,10 +36,7 @@ class DatabaseService(Service):
def _create_engine(self) -> "Engine":
"""Create the engine for the database."""
settings_service = get_settings_service()
if (
settings_service.settings.DATABASE_URL
and settings_service.settings.DATABASE_URL.startswith("sqlite")
):
if settings_service.settings.DATABASE_URL and settings_service.settings.DATABASE_URL.startswith("sqlite"):
connect_args = {"check_same_thread": False}
else:
connect_args = {}
@ -51,9 +48,7 @@ class DatabaseService(Service):
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None: # If an exception has been raised
logger.error(
f"Session rollback because of exception: {exc_type.__name__} {exc_value}"
)
logger.error(f"Session rollback because of exception: {exc_type.__name__} {exc_value}")
self._session.rollback()
else:
self._session.commit()
@ -70,9 +65,7 @@ class DatabaseService(Service):
settings_service = get_settings_service()
if settings_service.auth_settings.AUTO_LOGIN:
with Session(self.engine) as session:
flows = session.exec(
select(models.Flow).where(models.Flow.user_id is None)
).all()
flows = session.exec(select(models.Flow).where(models.Flow.user_id is None)).all()
if flows:
logger.debug("Migrating flows to default superuser")
username = settings_service.auth_settings.SUPERUSER
@ -102,9 +95,7 @@ class DatabaseService(Service):
expected_columns = list(model.model_fields.keys())
try:
available_columns = [
col["name"] for col in inspector.get_columns(table)
]
available_columns = [col["name"] for col in inspector.get_columns(table)]
except sa.exc.NoSuchTableError:
logger.error(f"Missing table: {table}")
return False
@ -161,9 +152,7 @@ class DatabaseService(Service):
try:
command.check(alembic_cfg)
except Exception as exc:
if isinstance(
exc, (util.exc.CommandError, util.exc.AutogenerateDiffsDetected)
):
if isinstance(exc, (util.exc.CommandError, util.exc.AutogenerateDiffsDetected)):
command.upgrade(alembic_cfg, "head")
time.sleep(3)
@ -199,10 +188,7 @@ class DatabaseService(Service):
# We will check that all models are in the database
# and that the database is up to date with all columns
sql_models = [models.Flow, models.User, models.ApiKey]
return [
TableResults(sql_model.__tablename__, self.check_table(sql_model))
for sql_model in sql_models
]
return [TableResults(sql_model.__tablename__, self.check_table(sql_model)) for sql_model in sql_models]
def check_table(self, model):
results = []
@ -211,9 +197,7 @@ class DatabaseService(Service):
expected_columns = list(model.__fields__.keys())
available_columns = []
try:
available_columns = [
col["name"] for col in inspector.get_columns(table_name)
]
available_columns = [col["name"] for col in inspector.get_columns(table_name)]
results.append(Result(name=table_name, type="table", success=True))
except sa.exc.NoSuchTableError:
logger.error(f"Missing table: {table_name}")
@ -244,9 +228,7 @@ class DatabaseService(Service):
try:
table.create(self.engine, checkfirst=True)
except OperationalError as oe:
logger.warning(
f"Table {table} already exists, skipping. Exception: {oe}"
)
logger.warning(f"Table {table} already exists, skipping. Exception: {oe}")
except Exception as exc:
logger.error(f"Error creating table {table}: {exc}")
raise RuntimeError(f"Error creating table {table}") from exc
@ -258,9 +240,7 @@ class DatabaseService(Service):
if table not in table_names:
logger.error("Something went wrong creating the database and tables.")
logger.error("Please check your database settings.")
raise RuntimeError(
"Something went wrong creating the database and tables."
)
raise RuntimeError("Something went wrong creating the database and tables.")
logger.debug("Database and tables created successfully")

View file

@ -10,9 +10,7 @@ if TYPE_CHECKING:
class TransactionModel(BaseModel):
id: Optional[int] = Field(default=None, alias="id")
timestamp: Optional[datetime] = Field(
default_factory=datetime.now, alias="timestamp"
)
timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp")
source: str
target: str
target_args: dict
@ -53,12 +51,8 @@ class MessageModel(BaseModel):
@classmethod
def from_record(cls, record: "Record"):
# first check if the record has all the required fields
if not record.data or (
"sender" not in record.data and "sender_name" not in record.data
):
raise ValueError(
"The record does not have the required fields 'sender' and 'sender_name' in the data."
)
if not record.data or ("sender" not in record.data and "sender_name" not in record.data):
raise ValueError("The record does not have the required fields 'sender' and 'sender_name' in the data.")
return cls(
sender=record.data["sender"],
sender_name=record.data["sender_name"],

View file

@ -58,12 +58,10 @@ class Settings(BaseSettings):
STORE: Optional[bool] = True
STORE_URL: Optional[str] = "https://api.langflow.store"
DOWNLOAD_WEBHOOK_URL: Optional[str] = (
"https://api.langflow.store/flows/trigger/ec611a61-8460-4438-b187-a4f65e5559d4"
)
LIKE_WEBHOOK_URL: Optional[str] = (
"https://api.langflow.store/flows/trigger/64275852-ec00-45c1-984e-3bff814732da"
)
DOWNLOAD_WEBHOOK_URL: Optional[
str
] = "https://api.langflow.store/flows/trigger/ec611a61-8460-4438-b187-a4f65e5559d4"
LIKE_WEBHOOK_URL: Optional[str] = "https://api.langflow.store/flows/trigger/64275852-ec00-45c1-984e-3bff814732da"
STORAGE_TYPE: str = "local"
@ -95,9 +93,7 @@ class Settings(BaseSettings):
@validator("DATABASE_URL", pre=True)
def set_database_url(cls, value, values):
if not value:
logger.debug(
"No database_url provided, trying LANGFLOW_DATABASE_URL env variable"
)
logger.debug("No database_url provided, trying LANGFLOW_DATABASE_URL env variable")
if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"):
value = langflow_database_url
logger.debug("Using LANGFLOW_DATABASE_URL env variable.")
@ -107,9 +103,7 @@ class Settings(BaseSettings):
# so we need to migrate to the new format
# if there is a database in that location
if not values["CONFIG_DIR"]:
raise ValueError(
"CONFIG_DIR not set, please set it or provide a DATABASE_URL"
)
raise ValueError("CONFIG_DIR not set, please set it or provide a DATABASE_URL")
new_path = f"{values['CONFIG_DIR']}/langflow.db"
if Path("./langflow.db").exists():
@ -133,22 +127,15 @@ class Settings(BaseSettings):
if os.getenv("LANGFLOW_COMPONENTS_PATH"):
logger.debug("Adding LANGFLOW_COMPONENTS_PATH to components_path")
langflow_component_path = os.getenv("LANGFLOW_COMPONENTS_PATH")
if (
Path(langflow_component_path).exists()
and langflow_component_path not in value
):
if Path(langflow_component_path).exists() and langflow_component_path not in value:
if isinstance(langflow_component_path, list):
for path in langflow_component_path:
if path not in value:
value.append(path)
logger.debug(
f"Extending {langflow_component_path} to components_path"
)
logger.debug(f"Extending {langflow_component_path} to components_path")
elif langflow_component_path not in value:
value.append(langflow_component_path)
logger.debug(
f"Appending {langflow_component_path} to components_path"
)
logger.debug(f"Appending {langflow_component_path} to components_path")
if not value:
value = [BASE_COMPONENTS_PATH]
@ -160,9 +147,7 @@ class Settings(BaseSettings):
logger.debug(f"Components path: {value}")
return value
model_config = SettingsConfigDict(
validate_assignment=True, extra="ignore", env_prefix="LANGFLOW_"
)
model_config = SettingsConfigDict(validate_assignment=True, extra="ignore", env_prefix="LANGFLOW_")
# @model_validator()
# @classmethod

View file

@ -96,9 +96,7 @@ async def build_vertex(
)
# Emit the vertex build response
response = VertexBuildResponse(
valid=valid, params=params, id=vertex.id, data=result_dict
)
response = VertexBuildResponse(valid=valid, params=params, id=vertex.id, data=result_dict)
await sio.emit("vertex_build", data=response.model_dump(), to=sid)
except Exception as exc:

View file

@ -74,9 +74,7 @@ class TaskService(Service):
result = await result
return task.id, result
async def launch_task(
self, task_func: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
async def launch_task(self, task_func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
logger.debug(f"Launching task {task_func} with args {args} and kwargs {kwargs}")
logger.debug(f"Using backend {self.backend}")
task = self.backend.launch_task(task_func, *args, **kwargs)

View file

@ -92,16 +92,12 @@ def get_or_create_super_user(session: Session, username, password, is_default):
)
return None
else:
logger.debug(
"User with superuser credentials exists but is not a superuser."
)
logger.debug("User with superuser credentials exists but is not a superuser.")
return None
if user:
if verify_password(password, user.password):
raise ValueError(
"User with superuser credentials exists but is not a superuser."
)
raise ValueError("User with superuser credentials exists but is not a superuser.")
else:
raise ValueError("Incorrect superuser credentials")
@ -130,21 +126,15 @@ def setup_superuser(settings_service, session: Session):
username = settings_service.auth_settings.SUPERUSER
password = settings_service.auth_settings.SUPERUSER_PASSWORD
is_default = (username == DEFAULT_SUPERUSER) and (
password == DEFAULT_SUPERUSER_PASSWORD
)
is_default = (username == DEFAULT_SUPERUSER) and (password == DEFAULT_SUPERUSER_PASSWORD)
try:
user = get_or_create_super_user(
session=session, username=username, password=password, is_default=is_default
)
user = get_or_create_super_user(session=session, username=username, password=password, is_default=is_default)
if user is not None:
logger.debug("Superuser created successfully.")
except Exception as exc:
logger.exception(exc)
raise RuntimeError(
"Could not create superuser. Please create a superuser manually."
) from exc
raise RuntimeError("Could not create superuser. Please create a superuser manually.") from exc
finally:
settings_service.auth_settings.reset_credentials()
@ -158,9 +148,7 @@ def teardown_superuser(settings_service, session):
if not settings_service.auth_settings.AUTO_LOGIN:
try:
logger.debug(
"AUTO_LOGIN is set to False. Removing default superuser if exists."
)
logger.debug("AUTO_LOGIN is set to False. Removing default superuser if exists.")
username = DEFAULT_SUPERUSER
from langflow.services.database.models.user.model import User
@ -210,9 +198,7 @@ def initialize_session_service():
initialize_settings_service()
service_manager.register_factory(
cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE]
)
service_manager.register_factory(cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE])
service_manager.register_factory(
session_service_factory.SessionServiceFactory(),
@ -229,9 +215,7 @@ def initialize_services(fix_migration: bool = False, socketio_server=None):
service_manager.register_factory(factory, dependencies=dependencies)
except Exception as exc:
logger.exception(exc)
raise RuntimeError(
"Could not initialize services. Please check your settings."
) from exc
raise RuntimeError("Could not initialize services. Please check your settings.") from exc
# Test cache connection
service_manager.get(ServiceType.CACHE_SERVICE)
@ -241,9 +225,7 @@ def initialize_services(fix_migration: bool = False, socketio_server=None):
except Exception as exc:
logger.error(exc)
raise exc
setup_superuser(
service_manager.get(ServiceType.SETTINGS_SERVICE), next(get_session())
)
setup_superuser(service_manager.get(ServiceType.SETTINGS_SERVICE), next(get_session()))
try:
get_db_service().migrate_flows_if_auto_login()
except Exception as exc:

View file

@ -68,9 +68,7 @@ class TemplateField(BaseModel):
refresh: Optional[bool] = None
"""Specifies if the field should be refreshed. Defaults to False."""
range_spec: Optional[RangeSpec] = Field(
default=None, serialization_alias="rangeSpec"
)
range_spec: Optional[RangeSpec] = Field(default=None, serialization_alias="rangeSpec")
"""Range specification for the field. Defaults to None."""
title_case: bool = False
@ -119,10 +117,6 @@ class TemplateField(BaseModel):
if not isinstance(value, list):
raise ValueError("file_types must be a list")
return [
(
f".{file_type}"
if isinstance(file_type, str) and not file_type.startswith(".")
else file_type
)
(f".{file_type}" if isinstance(file_type, str) and not file_type.startswith(".") else file_type)
for file_type in value
]

View file

@ -174,9 +174,7 @@ class FrontendNode(BaseModel):
return _type
@staticmethod
def handle_special_field(
field, key: str, _type: str, SPECIAL_FIELD_HANDLERS
) -> str:
def handle_special_field(field, key: str, _type: str, SPECIAL_FIELD_HANDLERS) -> str:
"""Handles special field by using the respective handler if present."""
handler = SPECIAL_FIELD_HANDLERS.get(key)
return handler(field) if handler else _type
@ -187,11 +185,7 @@ class FrontendNode(BaseModel):
if "dict" in _type.lower() and field.name == "dict_":
field.field_type = "file"
field.file_types = [".json", ".yaml", ".yml"]
elif (
_type.startswith("Dict")
or _type.startswith("Mapping")
or _type.startswith("dict")
):
elif _type.startswith("Dict") or _type.startswith("Mapping") or _type.startswith("dict"):
field.field_type = "dict"
return _type
@ -202,9 +196,7 @@ class FrontendNode(BaseModel):
field.value = value["default"]
@staticmethod
def handle_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
def handle_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values for certain fields."""
if key == "headers":
field.value = """{"Authorization": "Bearer <token>"}"""
@ -212,9 +204,7 @@ class FrontendNode(BaseModel):
FrontendNode._handle_api_key_specific_field_values(field, key, name)
@staticmethod
def _handle_model_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
def _handle_model_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values related to models."""
model_dict = {
"OpenAI": constants.OPENAI_MODELS,
@ -227,9 +217,7 @@ class FrontendNode(BaseModel):
field.is_list = True
@staticmethod
def _handle_api_key_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
def _handle_api_key_specific_field_values(field: TemplateField, key: str, name: Optional[str] = None) -> None:
"""Handles specific field values related to API keys."""
if "api_key" in key and "OpenAI" in str(name):
field.display_name = "OpenAI API Key"
@ -269,10 +257,7 @@ class FrontendNode(BaseModel):
@staticmethod
def should_be_password(key: str, show: bool) -> bool:
"""Determines whether the field should be a password field."""
return (
any(text in key.lower() for text in {"password", "token", "api", "key"})
and show
)
return any(text in key.lower() for text in {"password", "token", "api", "key"}) and show
@staticmethod
def should_be_multiline(key: str) -> bool:

View file

@ -80,9 +80,7 @@ class MemoryFrontendNode(FrontendNode):
field.show = True
field.advanced = False
field.value = ""
field.info = (
INPUT_KEY_INFO if field.name == "input_key" else OUTPUT_KEY_INFO
)
field.info = INPUT_KEY_INFO if field.name == "input_key" else OUTPUT_KEY_INFO
if field.name == "memory_key":
field.value = "chat_history"

View file

@ -45,9 +45,7 @@ class Template(BaseModel):
"""Returns the field with the given name."""
field = next((field for field in self.fields if field.name == field_name), None)
if field is None:
raise ValueError(
f"Field {field_name} not found in template {self.type_name}"
)
raise ValueError(f"Field {field_name} not found in template {self.type_name}")
return field
def update_field(self, field_name: str, field: TemplateField) -> None:

View file

@ -15,12 +15,8 @@ def remove_ansi_escape_codes(text):
return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text)
def build_template_from_function(
name: str, type_to_loader_dict: Dict, add_function: bool = False
):
classes = [
item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()
]
def build_template_from_function(name: str, type_to_loader_dict: Dict, add_function: bool = False):
classes = [item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()]
# Raise error if name is not in chains
if name not in classes:
@ -41,10 +37,8 @@ def build_template_from_function(
for name_, value_ in value.__repr_args__():
if name_ == "default_factory":
try:
variables[class_field_items]["default"] = (
get_default_factory(
module=_class.__base__.__module__, function=value_
)
variables[class_field_items]["default"] = get_default_factory(
module=_class.__base__.__module__, function=value_
)
except Exception:
variables[class_field_items]["default"] = None
@ -52,9 +46,7 @@ def build_template_from_function(
variables[class_field_items][name_] = value_
variables[class_field_items]["placeholder"] = (
docs.params[class_field_items]
if class_field_items in docs.params
else ""
docs.params[class_field_items] if class_field_items in docs.params else ""
)
# Adding function to base classes to allow
# the output to be a function
@ -69,9 +61,7 @@ def build_template_from_function(
}
def build_template_from_class(
name: str, type_to_cls_dict: Dict, add_function: bool = False
):
def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False):
classes = [item.__name__ for item in type_to_cls_dict.values()]
# Raise error if name is not in chains
@ -95,11 +85,9 @@ def build_template_from_class(
for name_, value_ in value.__repr_args__():
if name_ == "default_factory":
try:
variables[class_field_items]["default"] = (
get_default_factory(
module=_class.__base__.__module__,
function=value_,
)
variables[class_field_items]["default"] = get_default_factory(
module=_class.__base__.__module__,
function=value_,
)
except Exception:
variables[class_field_items]["default"] = None
@ -107,9 +95,7 @@ def build_template_from_class(
variables[class_field_items][name_] = value_
variables[class_field_items]["placeholder"] = (
docs.params[class_field_items]
if class_field_items in docs.params
else ""
docs.params[class_field_items] if class_field_items in docs.params else ""
)
base_classes = get_base_classes(_class)
# Adding function to base classes to allow
@ -141,9 +127,7 @@ def build_template_from_method(
# Check if the method exists in this class
if not hasattr(_class, method_name):
raise ValueError(
f"Method {method_name} not found in class {class_name}"
)
raise ValueError(f"Method {method_name} not found in class {class_name}")
# Get the method
method = getattr(_class, method_name)
@ -162,14 +146,8 @@ def build_template_from_method(
"_type": _type,
**{
name: {
"default": (
param.default if param.default != param.empty else None
),
"type": (
param.annotation
if param.annotation != param.empty
else None
),
"default": (param.default if param.default != param.empty else None),
"type": (param.annotation if param.annotation != param.empty else None),
"required": param.default == param.empty,
}
for name, param in params.items()
@ -256,9 +234,7 @@ def sync_to_async(func):
return async_wrapper
def format_dict(
dictionary: Dict[str, Any], class_name: Optional[str] = None
) -> Dict[str, Any]:
def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) -> Dict[str, Any]:
"""
Formats a dictionary by removing certain keys and modifying the
values of other keys.
@ -344,9 +320,7 @@ def check_list_type(_type: str, value: Dict[str, Any]) -> str:
The modified type string.
"""
if any(list_type in _type for list_type in ["List", "Sequence", "Set"]):
_type = (
_type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
)
_type = _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
value["list"] = True
else:
value["list"] = False
@ -449,9 +423,7 @@ def set_headers_value(value: Dict[str, Any]) -> None:
value["value"] = """{"Authorization": "Bearer <token>"}"""
def add_options_to_field(
value: Dict[str, Any], class_name: Optional[str], key: str
) -> None:
def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: str) -> None:
"""
Adds options to the field based on the class name and key.
"""

View file

@ -43,9 +43,7 @@ def validate_code(code):
# Evaluate the function definition
for node in tree.body:
if isinstance(node, ast.FunctionDef):
code_obj = compile(
ast.Module(body=[node], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[node], type_ignores=[]), "<string>", "exec")
try:
exec(code_obj)
except Exception as e:
@ -89,23 +87,15 @@ def execute_function(code, function_name, *args, **kwargs):
exec_globals,
locals(),
)
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
function_code = next(
node
for node in module.body
if isinstance(node, ast.FunctionDef) and node.name == function_name
node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name
)
function_code.parent = None
code_obj = compile(
ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec")
try:
exec(code_obj, exec_globals, locals())
except Exception as exc:
@ -132,23 +122,15 @@ def create_function(code, function_name):
if isinstance(node, ast.Import):
for alias in node.names:
try:
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
function_code = next(
node
for node in module.body
if isinstance(node, ast.FunctionDef) and node.name == function_name
node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == function_name
)
function_code.parent = None
code_obj = compile(
ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[function_code], type_ignores=[]), "<string>", "exec")
with contextlib.suppress(Exception):
exec(code_obj, exec_globals, locals())
exec_globals[function_name] = locals()[function_name]
@ -210,13 +192,9 @@ def prepare_global_scope(code, module):
if isinstance(node, ast.Import):
for alias in node.names:
try:
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
exec_globals[alias.asname or alias.name] = importlib.import_module(alias.name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
raise ModuleNotFoundError(f"Module {alias.name} not found. Please install it and try again.") from e
elif isinstance(node, ast.ImportFrom) and node.module is not None:
try:
imported_module = importlib.import_module(node.module)
@ -237,11 +215,7 @@ def extract_class_code(module, class_name):
:param class_name: Name of the class to extract
:return: AST node of the specified class
"""
class_code = next(
node
for node in module.body
if isinstance(node, ast.ClassDef) and node.name == class_name
)
class_code = next(node for node in module.body if isinstance(node, ast.ClassDef) and node.name == class_name)
class_code.parent = None
return class_code
@ -254,9 +228,7 @@ def compile_class_code(class_code):
:param class_code: AST node of the class
:return: Compiled code object of the class
"""
code_obj = compile(
ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec"
)
code_obj = compile(ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec")
return code_obj
@ -300,9 +272,7 @@ def get_default_imports(code_string):
langflow_imports = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
necessary_imports = find_names_in_code(code_string, langflow_imports)
langflow_module = importlib.import_module("langflow.field_typing")
default_imports.update(
{name: getattr(langflow_module, name) for name in necessary_imports}
)
default_imports.update({name: getattr(langflow_module, name) for name in necessary_imports})
return default_imports