Fix import errors and type annotations

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-28 20:15:23 -03:00
commit f86f6e6281
24 changed files with 138 additions and 245 deletions

View file

@ -222,7 +222,7 @@ def build_and_cache_graph(
graph: Optional[Graph] = None,
):
"""Build and cache the graph."""
flow: Flow = session.get(Flow, flow_id)
flow: Optional[Flow] = session.get(Flow, flow_id)
if not flow or not flow.data:
raise ValueError("Invalid flow ID")
other_graph = Graph.from_payload(flow.data, flow_id)
@ -236,10 +236,12 @@ def build_and_cache_graph(
def format_syntax_error_message(exc: SyntaxError) -> str:
"""Format a SyntaxError message for returning to the frontend."""
if exc.text is None:
return f"Syntax error in code. Error on line {exc.lineno}"
return f"Syntax error in code. Error on line {exc.lineno}: {exc.text.strip()}"
def get_causing_exception(exc: Exception) -> Exception:
def get_causing_exception(exc: BaseException) -> BaseException:
"""Get the causing exception from an exception."""
if hasattr(exc, "__cause__") and exc.__cause__:
return get_causing_exception(exc.__cause__)

View file

@ -4,117 +4,16 @@ from uuid import UUID
from langchain.schema import AgentAction, AgentFinish
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from loguru import logger
from langflow.api.v1.schemas import ChatResponse, PromptResponse
from langflow.services.deps import get_chat_service
from langflow.utils.util import remove_ansi_escape_codes
from loguru import logger
if TYPE_CHECKING:
from langflow.services.socket.service import SocketIOService
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""
def __init__(self, client_id: str = None):
self.chat_service = get_chat_service()
self.client_id = client_id
self.websocket = self.chat_service.active_connections[self.client_id]
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
await self.websocket.send_json(resp.model_dump())
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
"""Run when tool starts running."""
resp = ChatResponse(
message="",
type="stream",
intermediate_steps=f"Tool input: {input_str}",
)
await self.websocket.send_json(resp.model_dump())
async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
observation_prefix = kwargs.get("observation_prefix", "Tool output: ")
split_output = output.split()
first_word = split_output[0]
rest_of_output = split_output[1:]
# Create a formatted message.
intermediate_steps = f"{observation_prefix}{first_word}"
# Create a ChatResponse instance.
resp = ChatResponse(
message="",
type="stream",
intermediate_steps=intermediate_steps,
)
rest_of_resps = [
ChatResponse(
message="",
type="stream",
intermediate_steps=f"{word}",
)
for word in rest_of_output
]
resps = [resp] + rest_of_resps
# Try to send the response, handle potential errors.
try:
# This is to emulate the stream of tokens
for resp in resps:
await self.websocket.send_json(resp.model_dump())
except Exception as exc:
logger.error(f"Error sending response: {exc}")
async def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
async def on_text(self, text: str, **kwargs: Any) -> Any:
"""Run on arbitrary text."""
# This runs when first sending the prompt
# to the LLM, adding it will send the final prompt
# to the frontend
if "Prompt after formatting" in text:
text = text.replace("Prompt after formatting:\n", "")
text = remove_ansi_escape_codes(text)
resp = PromptResponse(
prompt=text,
)
await self.websocket.send_json(resp.model_dump())
self.chat_service.chat_history.add_message(self.client_id, resp)
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
log = f"Thought: {action.log}"
# if there are line breaks, split them and send them
# as separate messages
if "\n" in log:
logs = log.split("\n")
for log in logs:
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
await self.websocket.send_json(resp.model_dump())
else:
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
await self.websocket.send_json(resp.model_dump())
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
resp = ChatResponse(
message="",
type="stream",
intermediate_steps=finish.log,
)
await self.websocket.send_json(resp.model_dump())
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""
@ -130,7 +29,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
async def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
"""Run when tool starts running."""
resp = ChatResponse(
message="",
@ -168,7 +69,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
try:
# This is to emulate the stream of tokens
for resp in resps:
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
await self.socketio_service.emit_token(
to=self.sid, data=resp.model_dump()
)
except Exception as exc:
logger.error(f"Error sending response: {exc}")
@ -194,7 +97,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
resp = PromptResponse(
prompt=text,
)
await self.socketio_service.emit_message(to=self.sid, data=resp.model_dump())
await self.socketio_service.emit_message(
to=self.sid, data=resp.model_dump()
)
self.chat_service.chat_history.add_message(self.client_id, resp)
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
@ -205,7 +110,9 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
logs = log.split("\n")
for log in logs:
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
await self.socketio_service.emit_token(
to=self.sid, data=resp.model_dump()
)
else:
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
@ -232,5 +139,7 @@ class StreamingLLMCallbackHandler(BaseCallbackHandler):
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
loop = asyncio.get_event_loop()
coroutine = self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
coroutine = self.socketio_service.emit_token(
to=self.sid, data=resp.model_dump()
)
asyncio.run_coroutine_threadsafe(coroutine, loop)

View file

@ -9,6 +9,7 @@ from sqlmodel import Session, select
from langflow.api.utils import update_frontend_node_with_template_values
from langflow.api.v1.schemas import (
CustomComponentCode,
InputValueRequest,
ProcessResponse,
RunResponse,
TaskStatusResponse,
@ -54,7 +55,7 @@ def get_all(
async def run_flow_with_caching(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
inputs: Optional[Union[List[dict], dict]] = None,
inputs: Optional[InputValueRequest] = None,
tweaks: Optional[dict] = None,
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
@ -62,6 +63,11 @@ async def run_flow_with_caching(
session_service: SessionService = Depends(get_session_service),
):
try:
if inputs is not None:
input_values_dict: dict[str, Union[str, list[str]]] = inputs.model_dump()
else:
input_values_dict = {}
if session_id:
session_data = await session_service.load_session(
session_id, flow_id=flow_id
@ -74,7 +80,7 @@ async def run_flow_with_caching(
graph=graph,
flow_id=flow_id,
session_id=session_id,
inputs=inputs,
inputs=input_values_dict,
artifacts=artifacts,
session_service=session_service,
stream=stream,
@ -99,7 +105,7 @@ async def run_flow_with_caching(
graph=graph_data,
flow_id=flow_id,
session_id=session_id,
inputs=inputs,
inputs=input_values_dict,
artifacts={},
session_service=session_service,
stream=stream,

View file

@ -59,4 +59,4 @@ class RetrievalQAComponent(CustomComponent):
final_result = "\n".join([str(result_str), references_str])
self.status = final_result
return final_result
return final_result # OK

View file

@ -102,7 +102,7 @@ class GatherRecordsComponent(CustomComponent):
silent_errors: bool,
max_concurrency: int,
use_multithreading: bool,
) -> List[Record]:
) -> List[Optional[Record]]:
if use_multithreading:
records = self.parallel_load_records(
file_paths, silent_errors, max_concurrency

View file

@ -79,6 +79,7 @@ class ChatComponent(CustomComponent):
session_id: Optional[str] = None,
return_record: Optional[bool] = False,
) -> Union[Text, Record]:
input_value_record: Optional[Record] = None
if return_record:
if isinstance(input_value, Record):
# Update the data of the record
@ -86,7 +87,7 @@ class ChatComponent(CustomComponent):
input_value.data["sender_name"] = sender_name
input_value.data["session_id"] = session_id
else:
input_value = Record(
input_value_record = Record(
text=input_value,
data={
"sender": sender,
@ -96,7 +97,11 @@ class ChatComponent(CustomComponent):
)
if not input_value:
input_value = ""
self.status = input_value
if return_record and input_value_record:
result = input_value_record
else:
result = input_value
self.status = result
if session_id:
self.store_message(input_value, session_id, sender, sender_name)
return input_value
self.store_message(result, session_id, sender, sender_name)
return result

View file

@ -150,6 +150,7 @@ class ChatLiteLLMComponent(CustomComponent):
LLM = ChatLiteLLM(
model=model,
client=None,
streaming=streaming,
temperature=temperature,
model_kwargs=model_kwargs if model_kwargs is not None else {},

View file

@ -36,7 +36,7 @@ class RunnableExecComponent(CustomComponent):
runnable: Runnable,
output_key: str = "output",
) -> Text:
result = runnable.invoke({input_key: inputs})
result = runnable.invoke({input_key: input_value})
result = result.get(output_key)
self.status = result
return result

View file

@ -52,7 +52,9 @@ class Graph:
self._vertices = self._graph_data["nodes"]
self._edges = self._graph_data["edges"]
self.inactive_vertices = set()
self.inactive_vertices: set = set()
self.edges: List[ContractEdge] = []
self.vertices: List[Vertex] = []
self._build_graph()
self.build_graph_maps()
self.define_vertices_lists()
@ -100,7 +102,7 @@ class Graph:
async def run(
self, inputs: Dict[str, Union[str, list[str]]], stream: bool
) -> List["ResultData"]:
) -> List[Optional["ResultData"]]:
"""Runs the graph with the given inputs."""
# inputs is {"message": "Hello, world!"}
@ -108,7 +110,7 @@ class Graph:
# of the vertices that are inputs
# if the value is a list, we need to run multiple times
outputs = []
inputs_values = inputs.get(INPUT_FIELD_NAME)
inputs_values = inputs.get(INPUT_FIELD_NAME, "")
if not isinstance(inputs_values, list):
inputs_values = [inputs_values]
for input_value in inputs_values:
@ -245,7 +247,7 @@ class Graph:
return False
return True
def update(self, other: "Graph") -> None:
def update(self, other: "Graph") -> "Graph":
# Existing vertices in self graph
existing_vertex_ids = set(vertex.id for vertex in self.vertices)
# Vertex IDs in the other graph
@ -274,7 +276,7 @@ class Graph:
if not self_vertex.pinned:
self_vertex._built = False
self_vertex.result = None
self_vertex.artifacts = None
self_vertex.artifacts = {}
self_vertex.set_top_level(self.top_level_vertices)
self.reset_all_edges_of_vertex(self_vertex)
@ -623,7 +625,7 @@ class Graph:
queue = deque(
vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0
)
layers = []
layers: List[List[str]] = []
current_layer = 0
while queue:

View file

@ -1,9 +1,11 @@
from enum import Enum
from typing import Any, Optional
from langflow.graph.utils import serialize_field
from pydantic import BaseModel, Field, field_serializer
from langflow.graph.utils import serialize_field
from langflow.utils.schemas import ContainsEnumMeta
class ResultData(BaseModel):
results: Optional[Any] = Field(default_factory=dict)
@ -18,7 +20,7 @@ class ResultData(BaseModel):
return serialize_field(value)
class InterfaceComponentTypes(str, Enum):
class InterfaceComponentTypes(str, Enum, metaclass=ContainsEnumMeta):
# ChatInput and ChatOutput are the only ones that are
# power components
ChatInput = "ChatInput"
@ -26,6 +28,14 @@ class InterfaceComponentTypes(str, Enum):
TextInput = "TextInput"
TextOutput = "TextOutput"
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
else:
return True
INPUT_COMPONENTS = [
InterfaceComponentTypes.ChatInput,

View file

@ -77,7 +77,7 @@ class Vertex:
self.should_run = True
self.result: Optional[ResultData] = None
try:
self.is_interface_component = InterfaceComponentTypes(self.vertex_type)
self.is_interface_component = self.vertex_type in InterfaceComponentTypes
except ValueError:
self.is_interface_component = False
@ -107,29 +107,6 @@ class Vertex:
def add_build_time(self, time):
self.build_times.append(time)
# Build a result dict for each edge
# like so: {edge.target.id: {edge.target_param: self._built_object}}
async def get_result_dict(self, force: bool = False) -> Dict[str, Dict[str, Any]]:
"""
Returns a dictionary with the result of the build process.
"""
edge_results = {}
for edge in self.edges:
target = self.graph.get_vertex(edge.target_id)
if edge.is_fulfilled and isinstance(
await edge.get_result(
source=self,
target=target,
),
str,
):
if edge.target_id not in edge_results:
edge_results[edge.target_id] = {}
edge_results[edge.target_id][edge.target_param] = await edge.get_result(
source=self, target=target
)
return edge_results
def set_result(self, result: ResultData) -> None:
self.result = result
@ -626,7 +603,7 @@ class Vertex:
return self.get_requester_result(requester)
self._reset()
if self.is_input:
if self.is_input and inputs is not None:
self.update_raw_params(inputs)
# Run steps

View file

@ -1,4 +1,5 @@
import warnings
from typing import Callable
import emoji
@ -30,7 +31,7 @@ def getattr_return_bool(value):
return value
ATTR_FUNC_MAPPING = {
ATTR_FUNC_MAPPING: dict[str, Callable] = {
"display_name": getattr_return_str,
"description": getattr_return_str,
"beta": getattr_return_bool,

View file

@ -35,6 +35,7 @@ from langflow.utils import validate
if TYPE_CHECKING:
from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
@ -292,8 +293,9 @@ class CustomComponent(Component):
def get_function(self):
return validate.create_function(self.code, self.function_entrypoint_name)
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
from langflow.processing.process import build_sorted_vertices, process_tweaks
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph":
from langflow.graph.graph.base import Graph
from langflow.processing.process import process_tweaks
db_service = get_db_service()
with session_getter(db_service) as session:
@ -302,7 +304,15 @@ class CustomComponent(Component):
raise ValueError(f"Flow {flow_id} not found")
if tweaks:
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
return await build_sorted_vertices(graph_data, self.user_id)
graph = Graph(**graph_data)
return graph
async def run_flow(
self, input_value: str, flow_id: str, tweaks: Optional[dict] = None
) -> Any:
graph = await self.load_flow(flow_id, tweaks)
input_value_dict = {"input_value": input_value}
return await graph.run(input_value_dict)
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
if not self._user_id:

View file

@ -1,14 +1,11 @@
from typing import Dict, Optional, Tuple, Union
from uuid import UUID
from typing import Dict, Tuple
from loguru import logger
from langflow.graph import Graph
async def build_sorted_vertices(
data_graph, flow_id: Optional[Union[str, UUID]] = None
) -> Tuple[Graph, Dict]:
async def build_sorted_vertices(data_graph, flow_id: str) -> Tuple[Graph, Dict]:
"""
Build langchain object from data_graph.
"""

View file

@ -4,7 +4,7 @@ from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackHandler
from loguru import logger
from langflow.api.v1.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler
from langflow.api.v1.callback import StreamingLLMCallbackHandler
from langflow.processing.process import fix_memory_inputs, format_actions
from langflow.services.deps import get_plugins_service
@ -15,10 +15,7 @@ if TYPE_CHECKING:
def setup_callbacks(sync, trace_id, **kwargs):
"""Setup callbacks for langchain object"""
callbacks = []
if sync:
callbacks.append(StreamingLLMCallbackHandler(**kwargs))
else:
callbacks.append(AsyncStreamingLLMCallbackHandler(**kwargs))
callbacks.append(StreamingLLMCallbackHandler(**kwargs))
plugin_service = get_plugins_service()
plugin_callbacks = plugin_service.get_callbacks(_id=trace_id)
@ -42,7 +39,9 @@ def get_langfuse_callback(trace_id):
return None
def flush_langfuse_callback_if_present(callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]]):
def flush_langfuse_callback_if_present(
callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]]
):
"""
If langfuse callback is present, run callback.langfuse.flush()
"""
@ -83,9 +82,15 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa
# if langfuse callback is present, run callback.langfuse.flush()
flush_langfuse_callback_if_present(callbacks)
intermediate_steps = output.get("intermediate_steps", []) if isinstance(output, dict) else []
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = output.get(langchain_object.output_keys[0]) if isinstance(output, dict) else output
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
try:
thought = format_actions(intermediate_steps) if intermediate_steps else ""
except Exception as exc:

View file

@ -13,12 +13,7 @@ from pydantic import BaseModel
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.run import (
build_sorted_vertices,
get_memory_key,
update_memory_keys,
)
from langflow.services.deps import get_session_service
from langflow.interface.run import get_memory_key, update_memory_keys
from langflow.services.session.service import SessionService
@ -203,62 +198,12 @@ class Result(BaseModel):
session_id: str
async def process_graph_cached(
data_graph: Dict[str, Any],
inputs: Optional[Union[dict, List[dict]]] = None,
clear_cache=False,
session_id=None,
) -> Result:
session_service = get_session_service()
if clear_cache:
session_service.clear_session(session_id)
if session_id is None:
session_id = session_service.generate_key(
session_id=session_id, data_graph=data_graph
)
# Load the graph using SessionService
session = await session_service.load_session(
session_id, data_graph, flow_id=flow_id
)
graph, artifacts = session if session else (None, None)
if not graph:
raise ValueError("Graph not found in the session")
result = await build_graph_and_generate_result(
graph=graph,
session_id=session_id,
inputs=inputs,
artifacts=artifacts,
session_service=session_service,
)
return result
async def build_graph_and_generate_result(
graph: "Graph",
session_id: str,
inputs: Optional[Union[dict, List[dict]]] = None,
artifacts: Optional[Dict[str, Any]] = None,
session_service: Optional[SessionService] = None,
):
"""Build the graph and generate the result"""
built_object = await graph.build()
processed_inputs = process_inputs(inputs, artifacts or {})
result = await generate_result(built_object, processed_inputs)
# langchain_object is now updated with the new memory
# we need to update the cache with the updated langchain_object
if session_id and session_service:
session_service.update_session(session_id, (graph, artifacts))
return Result(result=result, session_id=session_id)
async def run_graph(
graph: Union["Graph", dict],
flow_id: str,
stream: bool,
session_id: Optional[str] = None,
inputs: Optional[Union[dict, List[dict]]] = None,
inputs: Optional[dict[str, Union[List[str], str]]] = None,
artifacts: Optional[Dict[str, Any]] = None,
session_service: Optional[SessionService] = None,
):

View file

@ -1,11 +1,10 @@
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Type, Union
from typing import TYPE_CHECKING, Optional, Union
import duckdb
from loguru import logger
from platformdirs import user_cache_dir
from pydantic import BaseModel
from langflow.services.base import Service
from langflow.services.monitor.schema import (
@ -56,7 +55,7 @@ class MonitorService(Service):
):
# Make sure the model passed matches the table
model: Type[BaseModel] = self.table_map.get(table_name)
model = self.table_map.get(table_name)
if model is None:
raise ValueError(f"Unknown table name: {table_name}")

View file

@ -23,7 +23,10 @@ def get_table_schema_as_dict(conn: duckdb.DuckDBPyConnection, table_name: str) -
def model_to_sql_column_definitions(model: Type[BaseModel]) -> dict:
columns = {}
for field_name, field_type in model.model_fields.items():
if hasattr(field_type.annotation, "__args__"):
if (
hasattr(field_type.annotation, "__args__")
and field_type.annotation is not None
):
field_args = field_type.annotation.__args__
else:
field_args = []
@ -82,7 +85,7 @@ def drop_and_create_table_if_schema_mismatch(
def add_row_to_table(
conn: duckdb.DuckDBPyConnection,
table_name: str,
model: Type[BaseModel],
model: Type,
monitor_data: Union[Dict[str, Any], BaseModel],
):
# Validate the data with the Pydantic model

View file

@ -14,9 +14,7 @@ class SessionService(Service):
def __init__(self, cache_service):
self.cache_service: "BaseCacheService" = cache_service
async def load_session(
self, key, data_graph: Optional[dict] = None, flow_id: Optional[str] = None
):
async def load_session(self, key, flow_id: str, data_graph: Optional[dict] = None):
# Check if the data is cached
if key in self.cache_service:
return self.cache_service.get(key)

View file

@ -30,7 +30,7 @@ class SettingsService(Service):
settings_dict = {k.upper(): v for k, v in settings_dict.items()}
for key in settings_dict:
if key not in Settings.model_fields().keys():
if key not in Settings.model_fields.keys():
raise KeyError(f"Key {key} not found in settings")
logger.debug(
f"Loading {len(settings_dict[key])} {key} from {file_path}"

View file

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any
import socketio
import socketio # type: ignore
from loguru import logger
from langflow.services.base import Service

View file

@ -1,7 +1,7 @@
import time
from typing import Callable
import socketio
import socketio # type: ignore
from sqlmodel import select
from langflow.api.utils import format_elapsed_time

View file

@ -1,5 +1,5 @@
import boto3
from botocore.exceptions import ClientError, NoCredentialsError
import boto3 # type: ignore
from botocore.exceptions import ClientError, NoCredentialsError # type: ignore
from loguru import logger
from .service import StorageService
@ -25,7 +25,9 @@ class S3StorageService(StorageService):
:raises Exception: If an error occurs during file saving.
"""
try:
self.s3_client.put_object(Bucket=self.bucket, Key=f"{folder}/{file_name}", Body=data)
self.s3_client.put_object(
Bucket=self.bucket, Key=f"{folder}/{file_name}", Body=data
)
logger.info(f"File {file_name} saved successfully in folder {folder}.")
except NoCredentialsError:
logger.error("Credentials not available for AWS S3.")
@ -44,8 +46,12 @@ class S3StorageService(StorageService):
:raises Exception: If an error occurs during file retrieval.
"""
try:
response = self.s3_client.get_object(Bucket=self.bucket, Key=f"{folder}/{file_name}")
logger.info(f"File {file_name} retrieved successfully from folder {folder}.")
response = self.s3_client.get_object(
Bucket=self.bucket, Key=f"{folder}/{file_name}"
)
logger.info(
f"File {file_name} retrieved successfully from folder {folder}."
)
return response["Body"].read()
except ClientError as e:
logger.error(f"Error retrieving file {file_name} from folder {folder}: {e}")
@ -61,7 +67,11 @@ class S3StorageService(StorageService):
"""
try:
response = self.s3_client.list_objects_v2(Bucket=self.bucket, Prefix=folder)
files = [item["Key"] for item in response.get("Contents", []) if "/" not in item["Key"][len(folder) :]]
files = [
item["Key"]
for item in response.get("Contents", [])
if "/" not in item["Key"][len(folder) :]
]
logger.info(f"{len(files)} files listed in folder {folder}.")
return files
except ClientError as e:
@ -77,7 +87,9 @@ class S3StorageService(StorageService):
:raises Exception: If an error occurs during file deletion.
"""
try:
self.s3_client.delete_object(Bucket=self.bucket, Key=f"{folder}/{file_name}")
self.s3_client.delete_object(
Bucket=self.bucket, Key=f"{folder}/{file_name}"
)
logger.info(f"File {file_name} deleted successfully from folder {folder}.")
except ClientError as e:
logger.error(f"Error deleting file {file_name} from folder {folder}: {e}")

View file

@ -1,3 +1,4 @@
import enum
from typing import Dict, List, Optional, Union
from langchain_core.messages import BaseMessage
@ -40,3 +41,13 @@ class ChatOutputResponse(BaseModel):
message = self.message.replace("\n\n", "\n")
self.message = message.replace("\n", "\n\n")
return self
class ContainsEnumMeta(enum.EnumMeta):
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
else:
return True