From d1bc88715e78e75228f7a7d43cebfa630f384f64 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sat, 20 Apr 2024 10:44:15 -0300 Subject: [PATCH] Refactor file handling in ChatInput and ChatOutput components --- src/backend/base/langflow/api/v1/chat.py | 3 +- src/backend/base/langflow/base/data/utils.py | 7 +++ src/backend/base/langflow/base/io/chat.py | 16 ++++++ .../langflow/components/inputs/ChatInput.py | 2 + .../langflow/components/outputs/ChatOutput.py | 2 + src/backend/base/langflow/graph/graph/base.py | 3 +- .../base/langflow/graph/vertex/base.py | 15 ++++-- .../base/langflow/graph/vertex/types.py | 11 ++-- src/backend/base/langflow/utils/schemas.py | 54 ++++++++++++++++++- src/frontend/src/types/chat/index.ts | 1 + 10 files changed, 99 insertions(+), 15 deletions(-) diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 72575e99d..af69a9b7c 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -115,6 +115,7 @@ async def build_vertex( vertex_id: str, background_tasks: BackgroundTasks, inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None, + files: Optional[list[str]] = None, chat_service: "ChatService" = Depends(get_chat_service), current_user=Depends(get_current_active_user), ): @@ -139,7 +140,6 @@ async def build_vertex( start_time = time.perf_counter() next_runnable_vertices = [] top_level_vertices = [] - messages = [] try: start_time = time.perf_counter() cache = await chat_service.get_cache(flow_id) @@ -169,6 +169,7 @@ async def build_vertex( vertex_id=vertex_id, user_id=current_user.id, inputs_dict=inputs.model_dump() if inputs else {}, + files=files, ) result_data_response = ResultDataResponse(**result_dict.model_dump()) diff --git a/src/backend/base/langflow/base/data/utils.py b/src/backend/base/langflow/base/data/utils.py index 09169393b..a152f44a8 100644 --- a/src/backend/base/langflow/base/data/utils.py +++ b/src/backend/base/langflow/base/data/utils.py @@ -31,6 +31,13 @@ TEXT_FILE_TYPES = [ "tsx", ] +IMG_FILE_TYPES = [ + "jpg", + "jpeg", + "png", + "bmp", +] + def is_hidden(path: Path) -> bool: return path.name.startswith(".") diff --git a/src/backend/base/langflow/base/io/chat.py b/src/backend/base/langflow/base/io/chat.py index ea24dc968..43b609502 100644 --- a/src/backend/base/langflow/base/io/chat.py +++ b/src/backend/base/langflow/base/io/chat.py @@ -1,6 +1,7 @@ import warnings from typing import Optional, Union +from langflow.base.data.utils import IMG_FILE_TYPES, TEXT_FILE_TYPES from langflow.field_typing import Text from langflow.helpers.record import records_to_text from langflow.interface.custom.custom_component import CustomComponent @@ -41,6 +42,13 @@ class ChatComponent(CustomComponent): "info": "In case of Message being a Record, this template will be used to convert it to text.", "advanced": True, }, + "files": { + "field_type": "file", + "display_name": "Files", + "file_types": TEXT_FILE_TYPES + IMG_FILE_TYPES, + "info": "Files to be sent with the message.", + "advanced": True, + }, } def store_message( @@ -84,6 +92,7 @@ class ChatComponent(CustomComponent): sender: Optional[str] = "User", sender_name: Optional[str] = "User", input_value: Optional[Union[str, Record]] = None, + files: Optional[list[str]] = None, session_id: Optional[str] = None, return_record: Optional[bool] = False, record_template: str = "Text: {text}\nData: {data}", @@ -95,6 +104,7 @@ class ChatComponent(CustomComponent): input_value.data["sender"] = sender input_value.data["sender_name"] = sender_name input_value.data["session_id"] = session_id + input_value.data["files"] = files else: input_value_record = Record( text=input_value, @@ -102,6 +112,7 @@ class ChatComponent(CustomComponent): "sender": sender, "sender_name": sender_name, "session_id": session_id, + "files": files, }, ) elif isinstance(input_value, Record): @@ -122,17 +133,21 @@ class ChatComponent(CustomComponent): sender: Optional[str] = "User", sender_name: Optional[str] = "User", input_value: Optional[str] = None, + files: Optional[list[str]] = None, session_id: Optional[str] = None, return_record: Optional[bool] = False, record_template: str = "Text: {text}\nData: {data}", ) -> Union[Text, Record]: input_value_record: Optional[Record] = None + if files and not return_record: + raise ValueError("Files can only be provided when Return Record is enabled.") if return_record: if isinstance(input_value, Record): # Update the data of the record input_value.data["sender"] = sender input_value.data["sender_name"] = sender_name input_value.data["session_id"] = session_id + input_value.data["files"] = files else: input_value_record = Record( text=input_value, @@ -140,6 +155,7 @@ class ChatComponent(CustomComponent): "sender": sender, "sender_name": sender_name, "session_id": session_id, + "files": files, }, ) elif isinstance(input_value, Record): diff --git a/src/backend/base/langflow/components/inputs/ChatInput.py b/src/backend/base/langflow/components/inputs/ChatInput.py index 40c3267ab..7f9b6bb00 100644 --- a/src/backend/base/langflow/components/inputs/ChatInput.py +++ b/src/backend/base/langflow/components/inputs/ChatInput.py @@ -25,6 +25,7 @@ class ChatInput(ChatComponent): sender: Optional[str] = "User", sender_name: Optional[str] = "User", input_value: Optional[str] = None, + files: Optional[list[str]] = None, session_id: Optional[str] = None, return_record: Optional[bool] = False, ) -> Union[Text, Record]: @@ -32,6 +33,7 @@ class ChatInput(ChatComponent): sender=sender, sender_name=sender_name, input_value=input_value, + files=files, session_id=session_id, return_record=return_record, ) diff --git a/src/backend/base/langflow/components/outputs/ChatOutput.py b/src/backend/base/langflow/components/outputs/ChatOutput.py index 888b85025..cf1826c2d 100644 --- a/src/backend/base/langflow/components/outputs/ChatOutput.py +++ b/src/backend/base/langflow/components/outputs/ChatOutput.py @@ -18,6 +18,7 @@ class ChatOutput(ChatComponent): session_id: Optional[str] = None, return_record: Optional[bool] = False, record_template: Optional[str] = "{text}", + files: Optional[list[str]] = None, ) -> Union[Text, Record]: return super().build_with_record( sender=sender, @@ -26,4 +27,5 @@ class ChatOutput(ChatComponent): session_id=session_id, return_record=return_record, record_template=record_template or "", + files=files, ) diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index a0661f4f2..9e7421f43 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -666,6 +666,7 @@ class Graph: set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine], vertex_id: str, inputs_dict: Optional[Dict[str, str]] = None, + files: Optional[list[str]] = None, user_id: Optional[str] = None, ): """ @@ -688,7 +689,7 @@ class Graph: vertex = self.get_vertex(vertex_id) try: if not vertex.frozen or not vertex._built: - await vertex.build(user_id=user_id, inputs=inputs_dict) + await vertex.build(user_id=user_id, inputs=inputs_dict, files=files) if vertex.result is not None: params = vertex._built_object_repr() diff --git a/src/backend/base/langflow/graph/vertex/base.py b/src/backend/base/langflow/graph/vertex/base.py index f300b2dda..337ee3dfe 100644 --- a/src/backend/base/langflow/graph/vertex/base.py +++ b/src/backend/base/langflow/graph/vertex/base.py @@ -368,7 +368,7 @@ class Vertex: self.load_from_db_fields = load_from_db_fields self._raw_params = params.copy() - def update_raw_params(self, new_params: Dict[str, str], overwrite: bool = False): + def update_raw_params(self, new_params: Dict[str, str | list[str]], overwrite: bool = False): """ Update the raw parameters of the vertex with the given new parameters. @@ -419,6 +419,7 @@ class Vertex: sender=artifacts.get("sender"), sender_name=artifacts.get("sender_name"), session_id=artifacts.get("session_id"), + files=[{"path": file} if isinstance(file, str) else file for file in artifacts.get("files", [])], component_id=self.id, ).model_dump(exclude_none=True) ] @@ -673,6 +674,7 @@ class Vertex: self, user_id=None, inputs: Optional[Dict[str, Any]] = None, + files: Optional[list[str]] = None, requester: Optional["Vertex"] = None, **kwargs, ) -> Any: @@ -690,9 +692,14 @@ class Vertex: return await self.get_requester_result(requester) self._reset() - if self._is_chat_input() and inputs: - inputs = {"input_value": inputs.get(INPUT_FIELD_NAME, "")} - self.update_raw_params(inputs, overwrite=True) + if self._is_chat_input() and (inputs or files): + chat_input = {} + if inputs: + chat_input.update({"input_value": inputs.get(INPUT_FIELD_NAME, "")}) + if files: + chat_input.update({"files": files}) + + self.update_raw_params(chat_input, overwrite=True) # Run steps for step in self.steps: diff --git a/src/backend/base/langflow/graph/vertex/types.py b/src/backend/base/langflow/graph/vertex/types.py index cbebb14ea..930139056 100644 --- a/src/backend/base/langflow/graph/vertex/types.py +++ b/src/backend/base/langflow/graph/vertex/types.py @@ -342,14 +342,13 @@ class ChatVertex(Vertex): sender = self.params.get("sender", None) sender_name = self.params.get("sender_name", None) message = self.params.get(INPUT_FIELD_NAME, None) + files = [{"path": file} if isinstance(file, str) else file for file in self.params.get("files", [])] if isinstance(message, str): message = unescape_string(message) stream_url = None if isinstance(self._built_object, AIMessage): artifacts = ChatOutputResponse.from_message( - self._built_object, - sender=sender, - sender_name=sender_name, + self._built_object, sender=sender, sender_name=sender_name, files=files ) elif not isinstance(self._built_object, UnbuiltObject): if isinstance(self._built_object, dict): @@ -369,10 +368,7 @@ class ChatVertex(Vertex): message = self._built_object artifacts = ChatOutputResponse( - message=message, - sender=sender, - sender_name=sender_name, - stream_url=stream_url, + message=message, sender=sender, sender_name=sender_name, stream_url=stream_url, files=files ) self.will_stream = stream_url is not None @@ -410,6 +406,7 @@ class ChatVertex(Vertex): message=complete_message, sender=self.params.get("sender", ""), sender_name=self.params.get("sender_name", ""), + files=[{"path": file} if isinstance(file, str) else file for file in self.params.get("files", [])], ).model_dump() self.params[INPUT_FIELD_NAME] = complete_message self._built_object = Record(text=complete_message, data=self.artifacts) diff --git a/src/backend/base/langflow/utils/schemas.py b/src/backend/base/langflow/utils/schemas.py index 7d9535fff..da2e1a640 100644 --- a/src/backend/base/langflow/utils/schemas.py +++ b/src/backend/base/langflow/utils/schemas.py @@ -2,7 +2,18 @@ import enum from typing import Dict, List, Optional, Union from langchain_core.messages import BaseMessage -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, field_validator, model_validator +from typing_extensions import TypedDict + +from langflow.base.data.utils import IMG_FILE_TYPES, TEXT_FILE_TYPES + + +class File(TypedDict): + """File schema.""" + + path: str + name: str + type: str class ChatOutputResponse(BaseModel): @@ -14,7 +25,46 @@ class ChatOutputResponse(BaseModel): session_id: Optional[str] = None stream_url: Optional[str] = None component_id: Optional[str] = None - files: List[str] = [] + files: List[File] = [] + + @field_validator("files", mode="before") + def validate_files(cls, files): + """Validate files.""" + if not files: + return files + + for file in files: + if not isinstance(file, dict): + raise ValueError("Files must be a list of dictionaries.") + + if not all(key in file for key in ["path", "name", "type"]): + # If any of the keys are missing, we should extract the + # values from the file path + path = file.get("path") + if not path: + raise ValueError("File path is required.") + + name = file.get("name") + if not name: + name = path.split("/")[-1] + file["name"] = name + _type = file.get("type") + if not _type: + # get the file type from the path + extension = path.split(".")[-1] + file_types = set(TEXT_FILE_TYPES + IMG_FILE_TYPES) + if extension and extension in file_types: + _type = extension + else: + for file_type in file_types: + if file_type in path: + _type = file_type + break + if not _type: + raise ValueError("File type is required.") + file["type"] = _type + + return files @classmethod def from_message( diff --git a/src/frontend/src/types/chat/index.ts b/src/frontend/src/types/chat/index.ts index 5b462234d..12c7a2a20 100644 --- a/src/frontend/src/types/chat/index.ts +++ b/src/frontend/src/types/chat/index.ts @@ -19,6 +19,7 @@ export type ChatOutputType = { sender: string; sender_name: string; stream_url?: string; + files?: Array<{ path: string; type: string; name: string }>; }; export type chatInputType = {