parent
e58b27062c
commit
942c8dca36
62 changed files with 228 additions and 374 deletions
|
|
@ -457,11 +457,10 @@ def migration(
|
|||
"""
|
||||
Run or test migrations.
|
||||
"""
|
||||
if fix:
|
||||
if not typer.confirm(
|
||||
"This will delete all data necessary to fix migrations. Are you sure you want to continue?"
|
||||
):
|
||||
raise typer.Abort
|
||||
if fix and not typer.confirm(
|
||||
"This will delete all data necessary to fix migrations. Are you sure you want to continue?"
|
||||
):
|
||||
raise typer.Abort
|
||||
|
||||
initialize_services(fix_migration=fix)
|
||||
db_service = get_db_service()
|
||||
|
|
|
|||
|
|
@ -90,10 +90,7 @@ async def logs(
|
|||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
detail="Timestamp is required when requesting logs after the timestamp",
|
||||
)
|
||||
if lines_before <= 0:
|
||||
content = log_buffer.get_last_n(10)
|
||||
else:
|
||||
content = log_buffer.get_last_n(lines_before)
|
||||
content = log_buffer.get_last_n(10) if lines_before <= 0 else log_buffer.get_last_n(lines_before)
|
||||
else:
|
||||
if lines_before > 0:
|
||||
content = log_buffer.get_before_timestamp(timestamp=timestamp, lines=lines_before)
|
||||
|
|
|
|||
|
|
@ -344,20 +344,17 @@ async def build_flow(
|
|||
raise ValueError(msg) from exc
|
||||
event_manager.on_end_vertex(data={"build_data": build_data})
|
||||
await client_consumed_queue.get()
|
||||
if vertex_build_response.valid:
|
||||
if vertex_build_response.next_vertices_ids:
|
||||
tasks = []
|
||||
for next_vertex_id in vertex_build_response.next_vertices_ids:
|
||||
task = asyncio.create_task(
|
||||
build_vertices(next_vertex_id, graph, client_consumed_queue, event_manager)
|
||||
)
|
||||
tasks.append(task)
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
except asyncio.CancelledError:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
return
|
||||
if vertex_build_response.valid and vertex_build_response.next_vertices_ids:
|
||||
tasks = []
|
||||
for next_vertex_id in vertex_build_response.next_vertices_ids:
|
||||
task = asyncio.create_task(build_vertices(next_vertex_id, graph, client_consumed_queue, event_manager))
|
||||
tasks.append(task)
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
except asyncio.CancelledError:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
return
|
||||
|
||||
async def event_generator(event_manager: EventManager, client_consumed_queue: asyncio.Queue) -> None:
|
||||
if not data:
|
||||
|
|
|
|||
|
|
@ -95,13 +95,12 @@ def validate_input_and_tweaks(input_request: SimplifiedAPIRequest):
|
|||
if has_input_value and input_value_is_chat:
|
||||
msg = "If you pass an input_value to the chat input, you cannot pass a tweak with the same name."
|
||||
raise InvalidChatInputException(msg)
|
||||
elif "Text Input" in key or "TextInput" in key:
|
||||
if isinstance(value, dict):
|
||||
has_input_value = value.get("input_value") is not None
|
||||
input_value_is_text = input_request.input_value is not None and input_request.input_type == "text"
|
||||
if has_input_value and input_value_is_text:
|
||||
msg = "If you pass an input_value to the text input, you cannot pass a tweak with the same name."
|
||||
raise InvalidChatInputException(msg)
|
||||
elif ("Text Input" in key or "TextInput" in key) and isinstance(value, dict):
|
||||
has_input_value = value.get("input_value") is not None
|
||||
input_value_is_text = input_request.input_value is not None and input_request.input_type == "text"
|
||||
if has_input_value and input_value_is_text:
|
||||
msg = "If you pass an input_value to the text input, you cannot pass a tweak with the same name."
|
||||
raise InvalidChatInputException(msg)
|
||||
|
||||
|
||||
async def simple_run_flow(
|
||||
|
|
|
|||
|
|
@ -310,10 +310,7 @@ async def upload_file(
|
|||
contents = await file.read()
|
||||
data = orjson.loads(contents)
|
||||
response_list = []
|
||||
if "flows" in data:
|
||||
flow_list = FlowListCreate(**data)
|
||||
else:
|
||||
flow_list = FlowListCreate(flows=[FlowCreate(**data)])
|
||||
flow_list = FlowListCreate(**data) if "flows" in data else FlowListCreate(flows=[FlowCreate(**data)])
|
||||
# Now we set the user_id for all flows
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = current_user.id
|
||||
|
|
|
|||
|
|
@ -53,10 +53,7 @@ class BaseCrewComponent(Component):
|
|||
self,
|
||||
) -> Callable:
|
||||
def task_callback(task_output: TaskOutput):
|
||||
if self._vertex:
|
||||
vertex_id = self._vertex.id
|
||||
else:
|
||||
vertex_id = self.display_name or self.__class__.__name__
|
||||
vertex_id = self._vertex.id if self._vertex else self.display_name or self.__class__.__name__
|
||||
self.log(task_output.model_dump(), name=f"Task (Agent: {task_output.agent}) - {vertex_id}")
|
||||
|
||||
return task_callback
|
||||
|
|
|
|||
|
|
@ -67,10 +67,7 @@ def build_data_from_result_data(result_data: ResultData, get_final_results_only:
|
|||
if isinstance(result_data.results, dict):
|
||||
for name, result in result_data.results.items():
|
||||
dataobj: Data | Message | None = None
|
||||
if isinstance(result, Message):
|
||||
dataobj = result
|
||||
else:
|
||||
dataobj = Data(data=result, text_key=name)
|
||||
dataobj = result if isinstance(result, Message) else Data(data=result, text_key=name)
|
||||
|
||||
data.append(dataobj)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -25,12 +25,16 @@ class ChatComponent(Component):
|
|||
msg = "Only one message can be stored at a time."
|
||||
raise ValueError(msg)
|
||||
stored_message = messages[0]
|
||||
if hasattr(self, "_event_manager") and self._event_manager and stored_message.id:
|
||||
if not isinstance(message.text, str):
|
||||
complete_message = self._stream_message(message, stored_message.id)
|
||||
message_table = update_message(message_id=stored_message.id, message={"text": complete_message})
|
||||
stored_message = Message(**message_table.model_dump())
|
||||
self.vertex._added_message = stored_message
|
||||
if (
|
||||
hasattr(self, "_event_manager")
|
||||
and self._event_manager
|
||||
and stored_message.id
|
||||
and not isinstance(message.text, str)
|
||||
):
|
||||
complete_message = self._stream_message(message, stored_message.id)
|
||||
message_table = update_message(message_id=stored_message.id, message={"text": complete_message})
|
||||
stored_message = Message(**message_table.model_dump())
|
||||
self.vertex._added_message = stored_message
|
||||
self.status = stored_message
|
||||
return stored_message
|
||||
|
||||
|
|
@ -77,8 +81,6 @@ class ChatComponent(Component):
|
|||
session_id: str | None = None,
|
||||
return_message: bool | None = False,
|
||||
) -> Message:
|
||||
message: Message | None = None
|
||||
|
||||
if isinstance(input_value, Data):
|
||||
# Update the data of the record
|
||||
message = Message.from_data(input_value)
|
||||
|
|
@ -86,10 +88,7 @@ class ChatComponent(Component):
|
|||
message = Message(
|
||||
text=input_value, sender=sender, sender_name=sender_name, files=files, session_id=session_id
|
||||
)
|
||||
if not return_message:
|
||||
message_text = message.text
|
||||
else:
|
||||
message_text = message # type: ignore
|
||||
message_text = message.text if not return_message else message
|
||||
|
||||
self.status = message_text
|
||||
if session_id and isinstance(message, Message) and isinstance(message.text, str):
|
||||
|
|
|
|||
|
|
@ -130,10 +130,7 @@ class SequentialTaskAgentComponent(Component):
|
|||
|
||||
# If there's a previous task, create a list of tasks
|
||||
if self.previous_task:
|
||||
if isinstance(self.previous_task, list):
|
||||
tasks = self.previous_task + [task]
|
||||
else:
|
||||
tasks = [self.previous_task, task]
|
||||
tasks = self.previous_task + [task] if isinstance(self.previous_task, list) else [self.previous_task, task]
|
||||
else:
|
||||
tasks = [task]
|
||||
|
||||
|
|
|
|||
|
|
@ -28,10 +28,7 @@ class SQLGeneratorComponent(LCChainComponent):
|
|||
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]
|
||||
|
||||
def invoke_chain(self) -> Message:
|
||||
if self.prompt:
|
||||
prompt_template = PromptTemplate.from_template(template=self.prompt)
|
||||
else:
|
||||
prompt_template = None
|
||||
prompt_template = PromptTemplate.from_template(template=self.prompt) if self.prompt else None
|
||||
|
||||
if self.top_k < 1:
|
||||
msg = "Top K must be greater than 0."
|
||||
|
|
|
|||
|
|
@ -96,10 +96,7 @@ class GmailLoaderComponent(Component):
|
|||
msg = "From email not found."
|
||||
raise ValueError(msg)
|
||||
|
||||
if "parts" in msg["payload"]:
|
||||
parts = msg["payload"]["parts"]
|
||||
else:
|
||||
parts = [msg["payload"]]
|
||||
parts = msg["payload"]["parts"] if "parts" in msg["payload"] else [msg["payload"]]
|
||||
|
||||
for part in parts:
|
||||
if part["mimeType"] == "text/plain":
|
||||
|
|
|
|||
|
|
@ -96,17 +96,13 @@ class GoogleDriveSearchComponent(Component):
|
|||
"""
|
||||
Generates the appropriate Google Drive URL for a file based on its MIME type.
|
||||
"""
|
||||
if mime_type == "application/vnd.google-apps.document":
|
||||
return f"https://docs.google.com/document/d/{file_id}/edit"
|
||||
if mime_type == "application/vnd.google-apps.spreadsheet":
|
||||
return f"https://docs.google.com/spreadsheets/d/{file_id}/edit"
|
||||
if mime_type == "application/vnd.google-apps.presentation":
|
||||
return f"https://docs.google.com/presentation/d/{file_id}/edit"
|
||||
if mime_type == "application/vnd.google-apps.drawing":
|
||||
return f"https://docs.google.com/drawings/d/{file_id}/edit"
|
||||
if mime_type == "application/pdf":
|
||||
return f"https://drive.google.com/file/d/{file_id}/view?usp=drivesdk"
|
||||
return f"https://drive.google.com/file/d/{file_id}/view?usp=drivesdk"
|
||||
return {
|
||||
"application/vnd.google-apps.document": f"https://docs.google.com/document/d/{file_id}/edit",
|
||||
"application/vnd.google-apps.spreadsheet": f"https://docs.google.com/spreadsheets/d/{file_id}/edit",
|
||||
"application/vnd.google-apps.presentation": f"https://docs.google.com/presentation/d/{file_id}/edit",
|
||||
"application/vnd.google-apps.drawing": f"https://docs.google.com/drawings/d/{file_id}/edit",
|
||||
"application/pdf": f"https://drive.google.com/file/d/{file_id}/view?usp=drivesdk",
|
||||
}.get(mime_type, f"https://drive.google.com/file/d/{file_id}/view?usp=drivesdk")
|
||||
|
||||
def search_files(self) -> dict:
|
||||
# Load the token information from the JSON string
|
||||
|
|
|
|||
|
|
@ -39,10 +39,7 @@ class TextEmbedderComponent(Component):
|
|||
embeddings = embedding_model.embed_documents([text_content])
|
||||
|
||||
# Assuming the embedding model returns a list of embeddings, we take the first one
|
||||
if embeddings:
|
||||
embedding_vector = embeddings[0]
|
||||
else:
|
||||
embedding_vector = []
|
||||
embedding_vector = embeddings[0] if embeddings else []
|
||||
|
||||
# Create a Data object to encapsulate the results
|
||||
result_data = Data(data={"text": text_content, "embeddings": embedding_vector})
|
||||
|
|
|
|||
|
|
@ -21,25 +21,24 @@ class AIMLEmbeddingsImpl(BaseModel, Embeddings):
|
|||
"Authorization": f"Bearer {self.api_key.get_secret_value()}",
|
||||
}
|
||||
|
||||
with httpx.Client() as client:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for i, text in enumerate(texts):
|
||||
futures.append((i, executor.submit(self._embed_text, client, headers, text)))
|
||||
with httpx.Client() as client, concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for i, text in enumerate(texts):
|
||||
futures.append((i, executor.submit(self._embed_text, client, headers, text)))
|
||||
|
||||
for index, future in futures:
|
||||
try:
|
||||
result_data = future.result()
|
||||
assert len(result_data["data"]) == 1, "Expected one embedding"
|
||||
embeddings[index] = result_data["data"][0]["embedding"]
|
||||
except (
|
||||
httpx.HTTPStatusError,
|
||||
httpx.RequestError,
|
||||
json.JSONDecodeError,
|
||||
KeyError,
|
||||
) as e:
|
||||
logger.error(f"Error occurred: {e}")
|
||||
raise
|
||||
for index, future in futures:
|
||||
try:
|
||||
result_data = future.result()
|
||||
assert len(result_data["data"]) == 1, "Expected one embedding"
|
||||
embeddings[index] = result_data["data"][0]["embedding"]
|
||||
except (
|
||||
httpx.HTTPStatusError,
|
||||
httpx.RequestError,
|
||||
json.JSONDecodeError,
|
||||
KeyError,
|
||||
) as e:
|
||||
logger.error(f"Error occurred: {e}")
|
||||
raise
|
||||
|
||||
return embeddings # type: ignore
|
||||
|
||||
|
|
|
|||
|
|
@ -61,15 +61,9 @@ class FirecrawlCrawlApi(CustomComponent):
|
|||
"Could not import firecrawl integration package. " "Please install it with `pip install firecrawl-py`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
if crawlerOptions:
|
||||
crawler_options_dict = crawlerOptions.__dict__["data"]["text"]
|
||||
else:
|
||||
crawler_options_dict = {}
|
||||
crawler_options_dict = crawlerOptions.__dict__["data"]["text"] if crawlerOptions else {}
|
||||
|
||||
if pageOptions:
|
||||
page_options_dict = pageOptions.__dict__["data"]["text"]
|
||||
else:
|
||||
page_options_dict = {}
|
||||
page_options_dict = pageOptions.__dict__["data"]["text"] if pageOptions else {}
|
||||
|
||||
if not idempotency_key:
|
||||
idempotency_key = str(uuid.uuid4())
|
||||
|
|
|
|||
|
|
@ -54,15 +54,9 @@ class FirecrawlScrapeApi(CustomComponent):
|
|||
"Could not import firecrawl integration package. " "Please install it with `pip install firecrawl-py`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
if extractorOptions:
|
||||
extractor_options_dict = extractorOptions.__dict__["data"]["text"]
|
||||
else:
|
||||
extractor_options_dict = {}
|
||||
extractor_options_dict = extractorOptions.__dict__["data"]["text"] if extractorOptions else {}
|
||||
|
||||
if pageOptions:
|
||||
page_options_dict = pageOptions.__dict__["data"]["text"]
|
||||
else:
|
||||
page_options_dict = {}
|
||||
page_options_dict = pageOptions.__dict__["data"]["text"] if pageOptions else {}
|
||||
|
||||
app = FirecrawlApp(api_key=api_key)
|
||||
results = app.scrape_url(
|
||||
|
|
|
|||
|
|
@ -79,10 +79,7 @@ class AIMLModelComponent(LCModelComponent):
|
|||
aiml_api_base = self.aiml_api_base or "https://api.aimlapi.com"
|
||||
seed = self.seed
|
||||
|
||||
if isinstance(aiml_api_key, SecretStr):
|
||||
openai_api_key = aiml_api_key.get_secret_value()
|
||||
else:
|
||||
openai_api_key = aiml_api_key
|
||||
openai_api_key = aiml_api_key.get_secret_value() if isinstance(aiml_api_key, SecretStr) else aiml_api_key
|
||||
|
||||
return ChatOpenAI(
|
||||
model=model_name,
|
||||
|
|
|
|||
|
|
@ -36,10 +36,7 @@ class CohereComponent(LCModelComponent):
|
|||
cohere_api_key = self.cohere_api_key
|
||||
temperature = self.temperature
|
||||
|
||||
if cohere_api_key:
|
||||
api_key = SecretStr(cohere_api_key)
|
||||
else:
|
||||
api_key = None
|
||||
api_key = SecretStr(cohere_api_key) if cohere_api_key else None
|
||||
|
||||
return ChatCohere(
|
||||
temperature=temperature or 0.75,
|
||||
|
|
|
|||
|
|
@ -78,10 +78,7 @@ class MistralAIModelComponent(LCModelComponent):
|
|||
random_seed = self.random_seed
|
||||
safe_mode = self.safe_mode
|
||||
|
||||
if mistral_api_key:
|
||||
api_key = SecretStr(mistral_api_key)
|
||||
else:
|
||||
api_key = None
|
||||
api_key = SecretStr(mistral_api_key) if mistral_api_key else None
|
||||
|
||||
return ChatMistralAI(
|
||||
max_tokens=max_tokens or None,
|
||||
|
|
|
|||
|
|
@ -102,10 +102,7 @@ class OpenAIModelComponent(LCModelComponent):
|
|||
json_mode = bool(output_schema_dict) or self.json_mode
|
||||
seed = self.seed
|
||||
|
||||
if openai_api_key:
|
||||
api_key = SecretStr(openai_api_key)
|
||||
else:
|
||||
api_key = None
|
||||
api_key = SecretStr(openai_api_key) if openai_api_key else None
|
||||
output = ChatOpenAI(
|
||||
max_tokens=max_tokens or None,
|
||||
model_kwargs=model_kwargs,
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class SubFlowComponent(Component):
|
|||
for vertex in inputs_vertex:
|
||||
new_vertex_inputs = []
|
||||
field_template = vertex.data["node"]["template"]
|
||||
for inp in field_template.keys():
|
||||
for inp in field_template:
|
||||
if inp not in ["code", "_type"]:
|
||||
field_template[inp]["display_name"] = (
|
||||
vertex.display_name + " - " + field_template[inp]["display_name"]
|
||||
|
|
@ -84,10 +84,10 @@ class SubFlowComponent(Component):
|
|||
|
||||
async def generate_results(self) -> list[Data]:
|
||||
tweaks: dict = {}
|
||||
for field in self._attributes.keys():
|
||||
for field in self._attributes:
|
||||
if field != "flow_name":
|
||||
[node, name] = field.split("|")
|
||||
if node not in tweaks.keys():
|
||||
if node not in tweaks:
|
||||
tweaks[node] = {}
|
||||
tweaks[node][name] = self._attributes[field]
|
||||
flow_name = self._attributes.get("flow_name")
|
||||
|
|
|
|||
|
|
@ -43,10 +43,7 @@ class CharacterTextSplitterComponent(LCTextSplitterComponent):
|
|||
return self.data_input
|
||||
|
||||
def build_text_splitter(self) -> TextSplitter:
|
||||
if self.separator:
|
||||
separator = unescape_string(self.separator)
|
||||
else:
|
||||
separator = "\n\n"
|
||||
separator = unescape_string(self.separator) if self.separator else "\n\n"
|
||||
return CharacterTextSplitter(
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
chunk_size=self.chunk_size,
|
||||
|
|
|
|||
|
|
@ -51,10 +51,7 @@ class NaturalLanguageTextSplitterComponent(LCTextSplitterComponent):
|
|||
return self.data_input
|
||||
|
||||
def build_text_splitter(self) -> TextSplitter:
|
||||
if self.separator:
|
||||
separator = unescape_string(self.separator)
|
||||
else:
|
||||
separator = "\n\n"
|
||||
separator = unescape_string(self.separator) if self.separator else "\n\n"
|
||||
return NLTKTextSplitter(
|
||||
language=self.language.lower() if self.language else "english",
|
||||
separator=separator,
|
||||
|
|
|
|||
|
|
@ -188,9 +188,7 @@ class PythonCodeStructuredTool(LCToolComponent):
|
|||
schema_annotation = Any
|
||||
schema_fields[field_name] = (
|
||||
schema_annotation,
|
||||
Field(
|
||||
default=func_arg["default"] if "default" in func_arg else Undefined, description=field_description
|
||||
),
|
||||
Field(default=func_arg.get("default", Undefined), description=field_description),
|
||||
)
|
||||
|
||||
if "temp_annotation_type" in _globals:
|
||||
|
|
|
|||
|
|
@ -174,10 +174,7 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
|
|||
else:
|
||||
documents.append(_input)
|
||||
|
||||
if self.enable_body_search:
|
||||
body_index_options = [("index_analyzer", "STANDARD")]
|
||||
else:
|
||||
body_index_options = None
|
||||
body_index_options = [("index_analyzer", "STANDARD")] if self.enable_body_search else None
|
||||
|
||||
if self.setup_mode == "Off":
|
||||
setup_mode = SetupMode.OFF
|
||||
|
|
|
|||
|
|
@ -161,10 +161,7 @@ class CassandraGraphVectorStoreComponent(LCVectorStoreComponent):
|
|||
else:
|
||||
documents.append(_input)
|
||||
|
||||
if self.setup_mode == "Off":
|
||||
setup_mode = SetupMode.OFF
|
||||
else:
|
||||
setup_mode = SetupMode.SYNC
|
||||
setup_mode = SetupMode.OFF if self.setup_mode == "Off" else SetupMode.SYNC
|
||||
|
||||
if documents:
|
||||
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
|
||||
|
|
|
|||
|
|
@ -125,10 +125,7 @@ class ChromaVectorStoreComponent(LCVectorStoreComponent):
|
|||
client = Client(settings=chroma_settings)
|
||||
|
||||
# Check persist_directory and expand it if it is a relative path
|
||||
if self.persist_directory is not None:
|
||||
persist_directory = self.resolve_path(self.persist_directory)
|
||||
else:
|
||||
persist_directory = None
|
||||
persist_directory = self.resolve_path(self.persist_directory) if self.persist_directory is not None else None
|
||||
|
||||
chroma = Chroma(
|
||||
persist_directory=persist_directory,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import ast
|
||||
import contextlib
|
||||
import inspect
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
|
@ -171,11 +172,9 @@ class CodeParser:
|
|||
return_type_str = ast.unparse(node.returns)
|
||||
eval_env = self.construct_eval_env(return_type_str, tuple(self.data["imports"]))
|
||||
|
||||
try:
|
||||
# Handle cases where the type is not found in the constructed environment
|
||||
with contextlib.suppress(NameError):
|
||||
return_type = eval(return_type_str, eval_env)
|
||||
except NameError:
|
||||
# Handle cases where the type is not found in the constructed environment
|
||||
pass
|
||||
|
||||
func = CallableCodeDetails(
|
||||
name=node.name,
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class BaseComponent:
|
|||
template_config[attribute] = func(value=value)
|
||||
|
||||
for key in template_config.copy():
|
||||
if key not in ATTR_FUNC_MAPPING.keys():
|
||||
if key not in ATTR_FUNC_MAPPING:
|
||||
template_config.pop(key, None)
|
||||
|
||||
return template_config
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ def add_extra_fields(frontend_node, field_config, function_args):
|
|||
field_required,
|
||||
config,
|
||||
)
|
||||
if "kwargs" in function_args_names and not all(key in function_args_names for key in field_config.keys()):
|
||||
if "kwargs" in function_args_names and not all(key in function_args_names for key in field_config):
|
||||
for field_name, field_config in _field_config.copy().items():
|
||||
if "name" not in field_config or field_name == "code":
|
||||
continue
|
||||
|
|
@ -503,35 +503,35 @@ def update_field_dict(
|
|||
call: bool = False,
|
||||
):
|
||||
"""Update the field dictionary by calling options() or value() if they are callable"""
|
||||
if ("real_time_refresh" in field_dict or "refresh_button" in field_dict) and any(
|
||||
(
|
||||
field_dict.get("real_time_refresh", False),
|
||||
field_dict.get("refresh_button", False),
|
||||
if (
|
||||
("real_time_refresh" in field_dict or "refresh_button" in field_dict)
|
||||
and any(
|
||||
(
|
||||
field_dict.get("real_time_refresh", False),
|
||||
field_dict.get("refresh_button", False),
|
||||
)
|
||||
)
|
||||
and call
|
||||
):
|
||||
if call:
|
||||
try:
|
||||
dd_build_config = dotdict(build_config)
|
||||
custom_component_instance.update_build_config(
|
||||
build_config=dd_build_config,
|
||||
field_value=update_field,
|
||||
field_name=update_field_value,
|
||||
)
|
||||
build_config = dd_build_config
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while running update_build_config: {str(exc)}")
|
||||
msg = f"Error while running update_build_config: {str(exc)}"
|
||||
raise UpdateBuildConfigError(msg) from exc
|
||||
try:
|
||||
dd_build_config = dotdict(build_config)
|
||||
custom_component_instance.update_build_config(
|
||||
build_config=dd_build_config,
|
||||
field_value=update_field,
|
||||
field_name=update_field_value,
|
||||
)
|
||||
build_config = dd_build_config
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while running update_build_config: {str(exc)}")
|
||||
msg = f"Error while running update_build_config: {str(exc)}"
|
||||
raise UpdateBuildConfigError(msg) from exc
|
||||
|
||||
return build_config
|
||||
|
||||
|
||||
def sanitize_field_config(field_config: dict | Input):
|
||||
# If any of the already existing keys are in field_config, remove them
|
||||
if isinstance(field_config, Input):
|
||||
field_dict = field_config.to_dict()
|
||||
else:
|
||||
field_dict = field_config
|
||||
field_dict = field_config.to_dict() if isinstance(field_config, Input) else field_config
|
||||
for key in [
|
||||
"name",
|
||||
"field_type",
|
||||
|
|
|
|||
|
|
@ -244,10 +244,10 @@ class CycleEdge(Edge):
|
|||
await self.honor(source, target)
|
||||
|
||||
# If the target vertex is a power component we log messages
|
||||
if target.vertex_type == "ChatOutput" and (
|
||||
isinstance(target.params.get(INPUT_FIELD_NAME), str)
|
||||
or isinstance(target.params.get(INPUT_FIELD_NAME), dict)
|
||||
if (
|
||||
target.vertex_type == "ChatOutput"
|
||||
and isinstance(target.params.get(INPUT_FIELD_NAME), str | dict)
|
||||
and target.params.get("message") == ""
|
||||
):
|
||||
if target.params.get("message") == "":
|
||||
return self.result
|
||||
return self.result
|
||||
return self.result
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import uuid
|
||||
|
|
@ -1061,10 +1062,7 @@ class Graph:
|
|||
same_length = len(vertex.edges) == len(other_vertex.edges)
|
||||
if not same_length:
|
||||
return False
|
||||
for edge in vertex.edges:
|
||||
if edge not in other_vertex.edges:
|
||||
return False
|
||||
return True
|
||||
return all(edge in other_vertex.edges for edge in vertex.edges)
|
||||
|
||||
def update(self, other: Graph) -> Graph:
|
||||
# Existing vertices in self graph
|
||||
|
|
@ -1080,10 +1078,8 @@ class Graph:
|
|||
|
||||
# Remove vertices that are not in the other graph
|
||||
for vertex_id in removed_vertex_ids:
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
self.remove_vertex(vertex_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# The order here matters because adding the vertex is required
|
||||
# if any of them have edges that point to any of the new vertices
|
||||
|
|
|
|||
|
|
@ -55,9 +55,7 @@ class RunnableVerticesManager:
|
|||
return False
|
||||
if vertex_id not in self.vertices_to_run:
|
||||
return False
|
||||
if not self.are_all_predecessors_fulfilled(vertex_id):
|
||||
return False
|
||||
return True
|
||||
return self.are_all_predecessors_fulfilled(vertex_id)
|
||||
|
||||
def are_all_predecessors_fulfilled(self, vertex_id: str) -> bool:
|
||||
return not any(self.run_predecessors.get(vertex_id, []))
|
||||
|
|
|
|||
|
|
@ -354,12 +354,7 @@ def has_cycle(vertex_ids: list[str], edges: list[tuple[str, str]]) -> bool:
|
|||
visited: set[str] = set()
|
||||
rec_stack: set[str] = set()
|
||||
|
||||
for vertex in vertex_ids:
|
||||
if vertex not in visited:
|
||||
if dfs(vertex, visited, rec_stack):
|
||||
return True
|
||||
|
||||
return False
|
||||
return any(vertex not in visited and dfs(vertex, visited, rec_stack) for vertex in vertex_ids)
|
||||
|
||||
|
||||
def find_cycle_edge(entry_point: str, edges: list[tuple[str, str]]) -> tuple[str, str]:
|
||||
|
|
|
|||
|
|
@ -103,11 +103,10 @@ def get_artifact_type(value, build_result) -> str:
|
|||
case Message():
|
||||
result = ArtifactType.MESSAGE
|
||||
|
||||
if result == ArtifactType.UNKNOWN:
|
||||
if isinstance(build_result, Generator):
|
||||
result = ArtifactType.STREAM
|
||||
elif isinstance(value, Message) and isinstance(value.text, Generator):
|
||||
result = ArtifactType.STREAM
|
||||
if result == ArtifactType.UNKNOWN and (
|
||||
isinstance(build_result, Generator) or isinstance(value, Message) and isinstance(value.text, Generator)
|
||||
):
|
||||
result = ArtifactType.STREAM
|
||||
|
||||
return result.value
|
||||
|
||||
|
|
|
|||
|
|
@ -287,7 +287,7 @@ class Vertex:
|
|||
if not param_dict or len(param_dict) != 1:
|
||||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
else:
|
||||
params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()}
|
||||
params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict}
|
||||
|
||||
else:
|
||||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
|
|
@ -415,8 +415,6 @@ class Vertex:
|
|||
elif val is not None and val != "":
|
||||
params[field_name] = val
|
||||
|
||||
elif val is not None and val != "":
|
||||
params[field_name] = val
|
||||
if field.get("load_from_db"):
|
||||
load_from_db_fields.append(field_name)
|
||||
|
||||
|
|
@ -534,10 +532,7 @@ class Vertex:
|
|||
# to the frontend
|
||||
self.set_artifacts()
|
||||
artifacts = self.artifacts_raw
|
||||
if isinstance(artifacts, dict):
|
||||
messages = self.extract_messages_from_artifacts(artifacts)
|
||||
else:
|
||||
messages = []
|
||||
messages = self.extract_messages_from_artifacts(artifacts) if isinstance(artifacts, dict) else []
|
||||
result_dict = ResultData(
|
||||
results=result_dict,
|
||||
artifacts=artifacts,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Generator, Iterator
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
|
@ -167,7 +168,7 @@ class ComponentVertex(Vertex):
|
|||
message_dict = artifact if isinstance(artifact, dict) else artifact.model_dump()
|
||||
if not message_dict.get("text"):
|
||||
continue
|
||||
try:
|
||||
with contextlib.suppress(KeyError):
|
||||
messages.append(
|
||||
ChatOutputResponse(
|
||||
message=message_dict["text"],
|
||||
|
|
@ -182,8 +183,6 @@ class ComponentVertex(Vertex):
|
|||
type=self.artifacts_type[key],
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
return messages
|
||||
|
||||
def _finalize_build(self):
|
||||
|
|
@ -440,11 +439,12 @@ class InterfaceVertex(ComponentVertex):
|
|||
for key, value in origin_vertex.results.items():
|
||||
if isinstance(value, AsyncIterator | Iterator):
|
||||
origin_vertex.results[key] = complete_message
|
||||
if self._custom_component:
|
||||
if hasattr(self._custom_component, "should_store_message") and hasattr(
|
||||
self._custom_component, "store_message"
|
||||
):
|
||||
self._custom_component.store_message(message)
|
||||
if (
|
||||
self._custom_component
|
||||
and hasattr(self._custom_component, "should_store_message")
|
||||
and hasattr(self._custom_component, "store_message")
|
||||
):
|
||||
self._custom_component.store_message(message)
|
||||
log_vertex_build(
|
||||
flow_id=self.graph.flow_id,
|
||||
vertex_id=self.id,
|
||||
|
|
|
|||
|
|
@ -90,17 +90,19 @@ def update_projects_components_with_latest_component_versions(project_data, all_
|
|||
)
|
||||
else:
|
||||
for attr in NODE_FORMAT_ATTRIBUTES:
|
||||
if attr in latest_node:
|
||||
if (
|
||||
attr in latest_node
|
||||
# Check if it needs to be updated
|
||||
if latest_node[attr] != node_data.get(attr):
|
||||
node_changes_log[node_data["display_name"]].append(
|
||||
{
|
||||
"attr": attr,
|
||||
"old_value": node_data.get(attr),
|
||||
"new_value": latest_node[attr],
|
||||
}
|
||||
)
|
||||
node_data[attr] = latest_node[attr]
|
||||
and latest_node[attr] != node_data.get(attr)
|
||||
):
|
||||
node_changes_log[node_data["display_name"]].append(
|
||||
{
|
||||
"attr": attr,
|
||||
"old_value": node_data.get(attr),
|
||||
"new_value": latest_node[attr],
|
||||
}
|
||||
)
|
||||
node_data[attr] = latest_node[attr]
|
||||
|
||||
for field_name, field_dict in latest_template.items():
|
||||
if field_name not in node_data["template"]:
|
||||
|
|
@ -109,17 +111,20 @@ def update_projects_components_with_latest_component_versions(project_data, all_
|
|||
# The idea here is to update some attributes of the field
|
||||
to_check_attributes = FIELD_FORMAT_ATTRIBUTES
|
||||
for attr in to_check_attributes:
|
||||
if attr in field_dict and attr in node_data["template"].get(field_name):
|
||||
if (
|
||||
attr in field_dict
|
||||
and attr in node_data["template"].get(field_name)
|
||||
# Check if it needs to be updated
|
||||
if field_dict[attr] != node_data["template"][field_name][attr]:
|
||||
node_changes_log[node_data["display_name"]].append(
|
||||
{
|
||||
"attr": f"{field_name}.{attr}",
|
||||
"old_value": node_data["template"][field_name][attr],
|
||||
"new_value": field_dict[attr],
|
||||
}
|
||||
)
|
||||
node_data["template"][field_name][attr] = field_dict[attr]
|
||||
and field_dict[attr] != node_data["template"][field_name][attr]
|
||||
):
|
||||
node_changes_log[node_data["display_name"]].append(
|
||||
{
|
||||
"attr": f"{field_name}.{attr}",
|
||||
"old_value": node_data["template"][field_name][attr],
|
||||
"new_value": field_dict[attr],
|
||||
}
|
||||
)
|
||||
node_data["template"][field_name][attr] = field_dict[attr]
|
||||
# Remove fields that are not in the latest template
|
||||
if node_data.get("display_name") != "Prompt":
|
||||
for field_name in list(node_data["template"].keys()):
|
||||
|
|
@ -274,16 +279,17 @@ def update_edges_with_latest_component_versions(project_data):
|
|||
source_handle["output_types"] = new_output_types
|
||||
|
||||
field_name = target_handle.get("fieldName")
|
||||
if field_name in target_node_data.get("template"):
|
||||
if target_handle["inputTypes"] != target_node_data.get("template").get(field_name).get("input_types"):
|
||||
edge_changes_log[target_node_data["display_name"]].append(
|
||||
{
|
||||
"attr": "inputTypes",
|
||||
"old_value": target_handle["inputTypes"],
|
||||
"new_value": target_node_data.get("template").get(field_name).get("input_types"),
|
||||
}
|
||||
)
|
||||
target_handle["inputTypes"] = target_node_data.get("template").get(field_name).get("input_types")
|
||||
if field_name in target_node_data.get("template") and target_handle["inputTypes"] != target_node_data.get(
|
||||
"template"
|
||||
).get(field_name).get("input_types"):
|
||||
edge_changes_log[target_node_data["display_name"]].append(
|
||||
{
|
||||
"attr": "inputTypes",
|
||||
"old_value": target_handle["inputTypes"],
|
||||
"new_value": target_node_data.get("template").get(field_name).get("input_types"),
|
||||
}
|
||||
)
|
||||
target_handle["inputTypes"] = target_node_data.get("template").get(field_name).get("input_types")
|
||||
escaped_source_handle = escape_json_dump(source_handle)
|
||||
escaped_target_handle = escape_json_dump(target_handle)
|
||||
try:
|
||||
|
|
@ -390,10 +396,7 @@ def get_project_data(project):
|
|||
updated_at_datetime = datetime.strptime(project_updated_at, "%Y-%m-%dT%H:%M:%S.%f")
|
||||
project_data = project.get("data")
|
||||
project_icon = project.get("icon")
|
||||
if project_icon and purely_emoji(project_icon):
|
||||
project_icon = demojize(project_icon)
|
||||
else:
|
||||
project_icon = ""
|
||||
project_icon = demojize(project_icon) if project_icon and purely_emoji(project_icon) else ""
|
||||
project_icon_bg_color = project.get("icon_bg_color")
|
||||
return (
|
||||
project_name,
|
||||
|
|
|
|||
|
|
@ -131,12 +131,7 @@ class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin, MetadataTr
|
|||
ValueError: If the value is not of a valid type or if the input is missing a required key.
|
||||
"""
|
||||
is_list = _info.data["is_list"]
|
||||
value = None
|
||||
if is_list:
|
||||
value = [cls._validate_value(vv, _info) for vv in v]
|
||||
else:
|
||||
value = cls._validate_value(v, _info)
|
||||
return value
|
||||
return [cls._validate_value(vv, _info) for vv in v] if is_list else cls._validate_value(v, _info)
|
||||
|
||||
|
||||
class MessageInput(StrInput, InputTraceMixin):
|
||||
|
|
|
|||
|
|
@ -89,12 +89,11 @@ def convert_kwargs(params):
|
|||
# Loop through items to avoid repeated lookups
|
||||
items_to_remove = []
|
||||
for key, value in params.items():
|
||||
if "kwargs" in key or "config" in key:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
params[key] = orjson.loads(value)
|
||||
except orjson.JSONDecodeError:
|
||||
items_to_remove.append(key)
|
||||
if ("kwargs" in key or "config" in key) and isinstance(value, str):
|
||||
try:
|
||||
params[key] = orjson.loads(value)
|
||||
except orjson.JSONDecodeError:
|
||||
items_to_remove.append(key)
|
||||
|
||||
# Remove invalid keys outside the loop to avoid modifying dict during iteration
|
||||
for key in items_to_remove:
|
||||
|
|
|
|||
|
|
@ -69,11 +69,9 @@ def extract_input_variables_from_prompt(prompt: str) -> list[str]:
|
|||
if not match:
|
||||
break
|
||||
|
||||
# Extract the variable name from either the single or double brace match
|
||||
if match.group(1): # Match found in double braces
|
||||
variable_name = "{{" + match.group(1) + "}}" # Re-add single braces for JSON strings
|
||||
else: # Match found in single braces
|
||||
variable_name = match.group(2)
|
||||
# Extract the variable name from either the single or double brace match.
|
||||
# If match found in double braces, re-add single braces for JSON strings.
|
||||
variable_name = "{{" + match.group(1) + "}}" if match.group(1) else match.group(2)
|
||||
if variable_name is not None:
|
||||
# This means there is a match
|
||||
# but there is nothing inside the braces
|
||||
|
|
|
|||
|
|
@ -77,22 +77,18 @@ class SizedLogBuffer:
|
|||
try:
|
||||
with self._wlock:
|
||||
as_list = list(self.buffer)
|
||||
i = 0
|
||||
max_index = -1
|
||||
for ts, msg in as_list:
|
||||
for i, (ts, msg) in enumerate(as_list):
|
||||
if ts >= timestamp:
|
||||
max_index = i
|
||||
break
|
||||
i += 1
|
||||
if max_index == -1:
|
||||
return self.get_last_n(lines)
|
||||
rc = {}
|
||||
i = 0
|
||||
start_from = max(max_index - lines, 0)
|
||||
for ts, msg in as_list:
|
||||
for i, (ts, msg) in enumerate(as_list):
|
||||
if start_from <= i < max_index:
|
||||
rc[ts] = msg
|
||||
i += 1
|
||||
return rc
|
||||
finally:
|
||||
self._rsemaphore.release()
|
||||
|
|
|
|||
|
|
@ -46,10 +46,7 @@ def get_messages(
|
|||
if flow_id:
|
||||
stmt = stmt.where(MessageTable.flow_id == flow_id)
|
||||
if order_by:
|
||||
if order == "DESC":
|
||||
col = getattr(MessageTable, order_by).desc()
|
||||
else:
|
||||
col = getattr(MessageTable, order_by).asc()
|
||||
col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc()
|
||||
stmt = stmt.order_by(col)
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
|
|
|||
|
|
@ -29,10 +29,7 @@ async def run_graph_internal(
|
|||
) -> tuple[list[RunOutputs], str]:
|
||||
"""Run the graph and generate the result"""
|
||||
inputs = inputs or []
|
||||
if session_id is None:
|
||||
session_id_str = flow_id
|
||||
else:
|
||||
session_id_str = session_id
|
||||
session_id_str = flow_id if session_id is None else session_id
|
||||
components = []
|
||||
inputs_list = []
|
||||
types = []
|
||||
|
|
@ -168,11 +165,7 @@ def process_tweaks(
|
|||
:return: The modified graph_data dictionary.
|
||||
:raises ValueError: If the input is not in the expected format.
|
||||
"""
|
||||
tweaks_dict = {}
|
||||
if not isinstance(tweaks, dict):
|
||||
tweaks_dict = cast(dict[str, Any], tweaks.model_dump())
|
||||
else:
|
||||
tweaks_dict = tweaks
|
||||
tweaks_dict = cast(dict[str, Any], tweaks.model_dump()) if not isinstance(tweaks, dict) else tweaks
|
||||
if "stream" not in tweaks_dict:
|
||||
tweaks_dict |= {"stream": stream}
|
||||
nodes = validate_input(graph_data, cast(dict[str, str | dict[str, Any]], tweaks_dict))
|
||||
|
|
@ -182,9 +175,7 @@ def process_tweaks(
|
|||
all_nodes_tweaks = {}
|
||||
for key, value in tweaks_dict.items():
|
||||
if isinstance(value, dict):
|
||||
if node := nodes_map.get(key):
|
||||
apply_tweaks(node, value)
|
||||
elif node := nodes_display_name_map.get(key):
|
||||
if (node := nodes_map.get(key)) or (node := nodes_display_name_map.get(key)):
|
||||
apply_tweaks(node, value)
|
||||
else:
|
||||
all_nodes_tweaks[key] = value
|
||||
|
|
|
|||
|
|
@ -41,11 +41,13 @@ def get_artifact_type(value, build_result=None) -> str:
|
|||
case list():
|
||||
result = ArtifactType.ARRAY
|
||||
|
||||
if result == ArtifactType.UNKNOWN:
|
||||
if build_result and isinstance(build_result, Generator):
|
||||
result = ArtifactType.STREAM
|
||||
elif isinstance(value, Message) and isinstance(value.text, Generator):
|
||||
result = ArtifactType.STREAM
|
||||
if result == ArtifactType.UNKNOWN and (
|
||||
build_result
|
||||
and isinstance(build_result, Generator)
|
||||
or isinstance(value, Message)
|
||||
and isinstance(value.text, Generator)
|
||||
):
|
||||
result = ArtifactType.STREAM
|
||||
|
||||
return result.value
|
||||
|
||||
|
|
@ -56,9 +58,7 @@ def post_process_raw(raw, artifact_type: str):
|
|||
elif artifact_type == ArtifactType.ARRAY.value:
|
||||
_raw = []
|
||||
for item in raw:
|
||||
if hasattr(item, "dict"):
|
||||
_raw.append(recursive_serialize_or_str(item))
|
||||
elif hasattr(item, "model_dump"):
|
||||
if hasattr(item, "dict") or hasattr(item, "model_dump"):
|
||||
_raw.append(recursive_serialize_or_str(item))
|
||||
else:
|
||||
_raw.append(str(item))
|
||||
|
|
|
|||
|
|
@ -103,10 +103,7 @@ class Message(Data):
|
|||
# they are: "text", "sender"
|
||||
if self.text is None or not self.sender:
|
||||
logger.warning("Missing required keys ('text', 'sender') in Message, defaulting to HumanMessage.")
|
||||
if not isinstance(self.text, str):
|
||||
text = ""
|
||||
else:
|
||||
text = self.text
|
||||
text = "" if not isinstance(self.text, str) else self.text
|
||||
|
||||
if self.sender == MESSAGE_SENDER_USER or not self.sender:
|
||||
if self.files:
|
||||
|
|
@ -160,9 +157,7 @@ class Message(Data):
|
|||
|
||||
@field_serializer("text", mode="plain")
|
||||
def serialize_text(self, value):
|
||||
if isinstance(value, AsyncIterator):
|
||||
return ""
|
||||
if isinstance(value, Iterator):
|
||||
if isinstance(value, AsyncIterator | Iterator):
|
||||
return ""
|
||||
return value
|
||||
|
||||
|
|
|
|||
|
|
@ -56,12 +56,13 @@ def get_type(payload):
|
|||
case str():
|
||||
result = LogType.TEXT
|
||||
|
||||
if result == LogType.UNKNOWN:
|
||||
if payload and isinstance(payload, Generator):
|
||||
result = LogType.STREAM
|
||||
|
||||
elif isinstance(payload, Message) and isinstance(payload.text, Generator):
|
||||
result = LogType.STREAM
|
||||
if result == LogType.UNKNOWN and (
|
||||
payload
|
||||
and isinstance(payload, Generator)
|
||||
or isinstance(payload, Message)
|
||||
and isinstance(payload.text, Generator)
|
||||
):
|
||||
result = LogType.STREAM
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -72,11 +72,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]): # type: ignore
|
|||
# Move the key to the end to make it recently used
|
||||
self._cache.move_to_end(key)
|
||||
# Check if the value is pickled
|
||||
if isinstance(item["value"], bytes):
|
||||
value = pickle.loads(item["value"])
|
||||
else:
|
||||
value = item["value"]
|
||||
return value
|
||||
return pickle.loads(item["value"]) if isinstance(item["value"], bytes) else item["value"]
|
||||
self.delete(key)
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -126,10 +126,7 @@ def save_uploaded_file(file: UploadFile, folder_name):
|
|||
cache_path = Path(CACHE_DIR)
|
||||
folder_path = cache_path / folder_name
|
||||
filename = file.filename
|
||||
if isinstance(filename, str) or isinstance(filename, Path):
|
||||
file_extension = Path(filename).suffix
|
||||
else:
|
||||
file_extension = ""
|
||||
file_extension = Path(filename).suffix if isinstance(filename, str | Path) else ""
|
||||
file_object = file.file
|
||||
|
||||
# Create the folder if it doesn't exist
|
||||
|
|
|
|||
|
|
@ -93,10 +93,7 @@ class CacheService(Subject, Service):
|
|||
"image": "png",
|
||||
"pandas": "csv",
|
||||
}
|
||||
if obj_type in object_extensions:
|
||||
_extension = object_extensions[obj_type]
|
||||
else:
|
||||
_extension = type(obj).__name__.lower()
|
||||
_extension = object_extensions[obj_type] if obj_type in object_extensions else type(obj).__name__.lower()
|
||||
self.current_cache[name] = {
|
||||
"obj": obj,
|
||||
"type": obj_type,
|
||||
|
|
|
|||
|
|
@ -117,10 +117,10 @@ class FlowBase(SQLModel):
|
|||
raise ValueError(msg)
|
||||
|
||||
# data must contain nodes and edges
|
||||
if "nodes" not in v.keys():
|
||||
if "nodes" not in v:
|
||||
msg = "Flow must have nodes"
|
||||
raise ValueError(msg)
|
||||
if "edges" not in v.keys():
|
||||
if "edges" not in v:
|
||||
msg = "Flow must have edges"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
|
|
|||
|
|
@ -46,12 +46,9 @@ class MessageBase(SQLModel):
|
|||
timestamp = message.timestamp
|
||||
if not flow_id and message.flow_id:
|
||||
flow_id = message.flow_id
|
||||
if not isinstance(message.text, str):
|
||||
# If the text is not a string, it means it could be
|
||||
# async iterator so we simply add it as an empty string
|
||||
message_text = ""
|
||||
else:
|
||||
message_text = message.text
|
||||
# If the text is not a string, it means it could be
|
||||
# async iterator so we simply add it as an empty string
|
||||
message_text = "" if not isinstance(message.text, str) else message.text
|
||||
return cls(
|
||||
sender=message.sender,
|
||||
sender_name=message.sender_name,
|
||||
|
|
|
|||
|
|
@ -31,10 +31,7 @@ def is_list_of_any(field: FieldInfo) -> bool:
|
|||
if field.annotation is None:
|
||||
return False
|
||||
try:
|
||||
if hasattr(field.annotation, "__args__"):
|
||||
union_args = field.annotation.__args__
|
||||
else:
|
||||
union_args = []
|
||||
union_args = field.annotation.__args__ if hasattr(field.annotation, "__args__") else []
|
||||
|
||||
return field.annotation.__origin__ is list or any(
|
||||
arg.__origin__ is list for arg in union_args if hasattr(arg, "__origin__")
|
||||
|
|
@ -267,10 +264,7 @@ class Settings(BaseSettings):
|
|||
final_path = new_path
|
||||
|
||||
if final_path is None:
|
||||
if is_pre_release:
|
||||
final_path = new_pre_path
|
||||
else:
|
||||
final_path = new_path
|
||||
final_path = new_pre_path if is_pre_release else new_path
|
||||
|
||||
value = f"sqlite:///{final_path}"
|
||||
|
||||
|
|
@ -370,7 +364,7 @@ def load_settings_from_yaml(file_path: str) -> Settings:
|
|||
settings_dict = {k.upper(): v for k, v in settings_dict.items()}
|
||||
|
||||
for key in settings_dict:
|
||||
if key not in Settings.model_fields.keys():
|
||||
if key not in Settings.model_fields:
|
||||
msg = f"Key {key} not found in settings"
|
||||
raise KeyError(msg)
|
||||
logger.debug(f"Loading {len(settings_dict[key])} {key} from {file_path}")
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class SettingsService(Service):
|
|||
settings_dict = {k.upper(): v for k, v in settings_dict.items()}
|
||||
|
||||
for key in settings_dict:
|
||||
if key not in Settings.model_fields.keys():
|
||||
if key not in Settings.model_fields:
|
||||
msg = f"Key {key} not found in settings"
|
||||
raise KeyError(msg)
|
||||
logger.debug(f"Loading {len(settings_dict[key])} {key} from {file_path}")
|
||||
|
|
|
|||
|
|
@ -126,10 +126,7 @@ class StoreService(Service):
|
|||
self, url: str, api_key: str | None = None, params: dict[str, Any] | None = None
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
"""Utility method to perform GET requests."""
|
||||
if api_key:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
else:
|
||||
headers = {}
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(url, headers=headers, params=params, timeout=self.timeout)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -10,6 +12,7 @@ from langflow.services.tracing.schema import Log
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langfuse.client import StatefulSpanClient
|
||||
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
|
|
@ -23,7 +26,7 @@ class LangFuseTracer(BaseTracer):
|
|||
self.trace_type = trace_type
|
||||
self.trace_id = trace_id
|
||||
self.flow_id = trace_name.split(" - ")[-1]
|
||||
self.last_span = None
|
||||
self.last_span: StatefulSpanClient | None = None
|
||||
self.spans: dict = {}
|
||||
self._ready: bool = self.setup_langfuse()
|
||||
|
||||
|
|
@ -68,7 +71,7 @@ class LangFuseTracer(BaseTracer):
|
|||
trace_type: str,
|
||||
inputs: dict[str, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
vertex: Optional["Vertex"] = None,
|
||||
vertex: Vertex | None = None,
|
||||
):
|
||||
start_time = datetime.utcnow()
|
||||
if not self._ready:
|
||||
|
|
@ -86,10 +89,7 @@ class LangFuseTracer(BaseTracer):
|
|||
"start_time": start_time,
|
||||
}
|
||||
|
||||
if self.last_span:
|
||||
span = self.last_span.span(**content_span)
|
||||
else:
|
||||
span = self.trace.span(**content_span)
|
||||
span = self.last_span.span(**content_span) if self.last_span else self.trace.span(**content_span)
|
||||
|
||||
self.last_span = span
|
||||
self.spans[trace_id] = span
|
||||
|
|
@ -127,7 +127,7 @@ class LangFuseTracer(BaseTracer):
|
|||
|
||||
self._client.flush()
|
||||
|
||||
def get_langchain_callback(self) -> Optional["BaseCallbackHandler"]:
|
||||
def get_langchain_callback(self) -> BaseCallbackHandler | None:
|
||||
if not self._ready:
|
||||
return None
|
||||
return None # self._callback
|
||||
|
|
|
|||
|
|
@ -243,7 +243,7 @@ class TracingService(Service):
|
|||
|
||||
def _cleanup_inputs(self, inputs: dict[str, Any]):
|
||||
inputs = inputs.copy()
|
||||
for key in inputs.keys():
|
||||
for key in inputs:
|
||||
if "api_key" in key:
|
||||
inputs[key] = "*****" # avoid logging api_keys for security reasons
|
||||
return inputs
|
||||
|
|
|
|||
|
|
@ -20,10 +20,7 @@ def convert_to_langchain_type(value):
|
|||
else:
|
||||
value = value.to_lc_document()
|
||||
elif isinstance(value, Data):
|
||||
if "text" in value.data:
|
||||
value = value.to_lc_document()
|
||||
else:
|
||||
value = value.data
|
||||
value = value.to_lc_document() if "text" in value.data else value.data
|
||||
return value
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class KubernetesSecretService(VariableService, Service):
|
|||
return []
|
||||
|
||||
names = []
|
||||
for key in variables.keys():
|
||||
for key in variables:
|
||||
if key.startswith(CREDENTIAL_TYPE + "_"):
|
||||
names.append(key[len(CREDENTIAL_TYPE) + 1 :])
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -97,9 +97,8 @@ class Input(BaseModel):
|
|||
def serialize_model(self, handler):
|
||||
result = handler(self)
|
||||
# If the field is str, we add the Text input type
|
||||
if self.field_type in ["str", "Text"]:
|
||||
if "input_types" not in result:
|
||||
result["input_types"] = ["Text"]
|
||||
if self.field_type in ["str", "Text"] and "input_types" not in result:
|
||||
result["input_types"] = ["Text"]
|
||||
if self.field_type == Text:
|
||||
result["type"] = "str"
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -56,9 +56,7 @@ def build_template_from_function(name: str, type_to_loader_dict: dict, add_funct
|
|||
elif name_ not in ["name"]:
|
||||
variables[class_field_items][name_] = value_
|
||||
|
||||
variables[class_field_items]["placeholder"] = (
|
||||
docs.params[class_field_items] if class_field_items in docs.params else ""
|
||||
)
|
||||
variables[class_field_items]["placeholder"] = docs.params.get(class_field_items, "")
|
||||
# Adding function to base classes to allow
|
||||
# the output to be a function
|
||||
base_classes = get_base_classes(_class)
|
||||
|
|
|
|||
|
|
@ -169,6 +169,7 @@ select = [
|
|||
"Q",
|
||||
"RET",
|
||||
"RSE",
|
||||
"SIM",
|
||||
"SLOT",
|
||||
"T10",
|
||||
"TID",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue