Refactor file handling in ChatInput and ChatOutput components

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-04-20 10:44:15 -03:00
commit d1bc88715e
10 changed files with 99 additions and 15 deletions

View file

@ -115,6 +115,7 @@ async def build_vertex(
vertex_id: str,
background_tasks: BackgroundTasks,
inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None,
files: Optional[list[str]] = None,
chat_service: "ChatService" = Depends(get_chat_service),
current_user=Depends(get_current_active_user),
):
@ -139,7 +140,6 @@ async def build_vertex(
start_time = time.perf_counter()
next_runnable_vertices = []
top_level_vertices = []
messages = []
try:
start_time = time.perf_counter()
cache = await chat_service.get_cache(flow_id)
@ -169,6 +169,7 @@ async def build_vertex(
vertex_id=vertex_id,
user_id=current_user.id,
inputs_dict=inputs.model_dump() if inputs else {},
files=files,
)
result_data_response = ResultDataResponse(**result_dict.model_dump())

View file

@ -31,6 +31,13 @@ TEXT_FILE_TYPES = [
"tsx",
]
IMG_FILE_TYPES = [
"jpg",
"jpeg",
"png",
"bmp",
]
def is_hidden(path: Path) -> bool:
return path.name.startswith(".")

View file

@ -1,6 +1,7 @@
import warnings
from typing import Optional, Union
from langflow.base.data.utils import IMG_FILE_TYPES, TEXT_FILE_TYPES
from langflow.field_typing import Text
from langflow.helpers.record import records_to_text
from langflow.interface.custom.custom_component import CustomComponent
@ -41,6 +42,13 @@ class ChatComponent(CustomComponent):
"info": "In case of Message being a Record, this template will be used to convert it to text.",
"advanced": True,
},
"files": {
"field_type": "file",
"display_name": "Files",
"file_types": TEXT_FILE_TYPES + IMG_FILE_TYPES,
"info": "Files to be sent with the message.",
"advanced": True,
},
}
def store_message(
@ -84,6 +92,7 @@ class ChatComponent(CustomComponent):
sender: Optional[str] = "User",
sender_name: Optional[str] = "User",
input_value: Optional[Union[str, Record]] = None,
files: Optional[list[str]] = None,
session_id: Optional[str] = None,
return_record: Optional[bool] = False,
record_template: str = "Text: {text}\nData: {data}",
@ -95,6 +104,7 @@ class ChatComponent(CustomComponent):
input_value.data["sender"] = sender
input_value.data["sender_name"] = sender_name
input_value.data["session_id"] = session_id
input_value.data["files"] = files
else:
input_value_record = Record(
text=input_value,
@ -102,6 +112,7 @@ class ChatComponent(CustomComponent):
"sender": sender,
"sender_name": sender_name,
"session_id": session_id,
"files": files,
},
)
elif isinstance(input_value, Record):
@ -122,17 +133,21 @@ class ChatComponent(CustomComponent):
sender: Optional[str] = "User",
sender_name: Optional[str] = "User",
input_value: Optional[str] = None,
files: Optional[list[str]] = None,
session_id: Optional[str] = None,
return_record: Optional[bool] = False,
record_template: str = "Text: {text}\nData: {data}",
) -> Union[Text, Record]:
input_value_record: Optional[Record] = None
if files and not return_record:
raise ValueError("Files can only be provided when Return Record is enabled.")
if return_record:
if isinstance(input_value, Record):
# Update the data of the record
input_value.data["sender"] = sender
input_value.data["sender_name"] = sender_name
input_value.data["session_id"] = session_id
input_value.data["files"] = files
else:
input_value_record = Record(
text=input_value,
@ -140,6 +155,7 @@ class ChatComponent(CustomComponent):
"sender": sender,
"sender_name": sender_name,
"session_id": session_id,
"files": files,
},
)
elif isinstance(input_value, Record):

View file

@ -25,6 +25,7 @@ class ChatInput(ChatComponent):
sender: Optional[str] = "User",
sender_name: Optional[str] = "User",
input_value: Optional[str] = None,
files: Optional[list[str]] = None,
session_id: Optional[str] = None,
return_record: Optional[bool] = False,
) -> Union[Text, Record]:
@ -32,6 +33,7 @@ class ChatInput(ChatComponent):
sender=sender,
sender_name=sender_name,
input_value=input_value,
files=files,
session_id=session_id,
return_record=return_record,
)

View file

@ -18,6 +18,7 @@ class ChatOutput(ChatComponent):
session_id: Optional[str] = None,
return_record: Optional[bool] = False,
record_template: Optional[str] = "{text}",
files: Optional[list[str]] = None,
) -> Union[Text, Record]:
return super().build_with_record(
sender=sender,
@ -26,4 +27,5 @@ class ChatOutput(ChatComponent):
session_id=session_id,
return_record=return_record,
record_template=record_template or "",
files=files,
)

View file

@ -666,6 +666,7 @@ class Graph:
set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine],
vertex_id: str,
inputs_dict: Optional[Dict[str, str]] = None,
files: Optional[list[str]] = None,
user_id: Optional[str] = None,
):
"""
@ -688,7 +689,7 @@ class Graph:
vertex = self.get_vertex(vertex_id)
try:
if not vertex.frozen or not vertex._built:
await vertex.build(user_id=user_id, inputs=inputs_dict)
await vertex.build(user_id=user_id, inputs=inputs_dict, files=files)
if vertex.result is not None:
params = vertex._built_object_repr()

View file

@ -368,7 +368,7 @@ class Vertex:
self.load_from_db_fields = load_from_db_fields
self._raw_params = params.copy()
def update_raw_params(self, new_params: Dict[str, str], overwrite: bool = False):
def update_raw_params(self, new_params: Dict[str, str | list[str]], overwrite: bool = False):
"""
Update the raw parameters of the vertex with the given new parameters.
@ -419,6 +419,7 @@ class Vertex:
sender=artifacts.get("sender"),
sender_name=artifacts.get("sender_name"),
session_id=artifacts.get("session_id"),
files=[{"path": file} if isinstance(file, str) else file for file in artifacts.get("files", [])],
component_id=self.id,
).model_dump(exclude_none=True)
]
@ -673,6 +674,7 @@ class Vertex:
self,
user_id=None,
inputs: Optional[Dict[str, Any]] = None,
files: Optional[list[str]] = None,
requester: Optional["Vertex"] = None,
**kwargs,
) -> Any:
@ -690,9 +692,14 @@ class Vertex:
return await self.get_requester_result(requester)
self._reset()
if self._is_chat_input() and inputs:
inputs = {"input_value": inputs.get(INPUT_FIELD_NAME, "")}
self.update_raw_params(inputs, overwrite=True)
if self._is_chat_input() and (inputs or files):
chat_input = {}
if inputs:
chat_input.update({"input_value": inputs.get(INPUT_FIELD_NAME, "")})
if files:
chat_input.update({"files": files})
self.update_raw_params(chat_input, overwrite=True)
# Run steps
for step in self.steps:

View file

@ -342,14 +342,13 @@ class ChatVertex(Vertex):
sender = self.params.get("sender", None)
sender_name = self.params.get("sender_name", None)
message = self.params.get(INPUT_FIELD_NAME, None)
files = [{"path": file} if isinstance(file, str) else file for file in self.params.get("files", [])]
if isinstance(message, str):
message = unescape_string(message)
stream_url = None
if isinstance(self._built_object, AIMessage):
artifacts = ChatOutputResponse.from_message(
self._built_object,
sender=sender,
sender_name=sender_name,
self._built_object, sender=sender, sender_name=sender_name, files=files
)
elif not isinstance(self._built_object, UnbuiltObject):
if isinstance(self._built_object, dict):
@ -369,10 +368,7 @@ class ChatVertex(Vertex):
message = self._built_object
artifacts = ChatOutputResponse(
message=message,
sender=sender,
sender_name=sender_name,
stream_url=stream_url,
message=message, sender=sender, sender_name=sender_name, stream_url=stream_url, files=files
)
self.will_stream = stream_url is not None
@ -410,6 +406,7 @@ class ChatVertex(Vertex):
message=complete_message,
sender=self.params.get("sender", ""),
sender_name=self.params.get("sender_name", ""),
files=[{"path": file} if isinstance(file, str) else file for file in self.params.get("files", [])],
).model_dump()
self.params[INPUT_FIELD_NAME] = complete_message
self._built_object = Record(text=complete_message, data=self.artifacts)

View file

@ -2,7 +2,18 @@ import enum
from typing import Dict, List, Optional, Union
from langchain_core.messages import BaseMessage
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, field_validator, model_validator
from typing_extensions import TypedDict
from langflow.base.data.utils import IMG_FILE_TYPES, TEXT_FILE_TYPES
class File(TypedDict):
"""File schema."""
path: str
name: str
type: str
class ChatOutputResponse(BaseModel):
@ -14,7 +25,46 @@ class ChatOutputResponse(BaseModel):
session_id: Optional[str] = None
stream_url: Optional[str] = None
component_id: Optional[str] = None
files: List[str] = []
files: List[File] = []
@field_validator("files", mode="before")
def validate_files(cls, files):
"""Validate files."""
if not files:
return files
for file in files:
if not isinstance(file, dict):
raise ValueError("Files must be a list of dictionaries.")
if not all(key in file for key in ["path", "name", "type"]):
# If any of the keys are missing, we should extract the
# values from the file path
path = file.get("path")
if not path:
raise ValueError("File path is required.")
name = file.get("name")
if not name:
name = path.split("/")[-1]
file["name"] = name
_type = file.get("type")
if not _type:
# get the file type from the path
extension = path.split(".")[-1]
file_types = set(TEXT_FILE_TYPES + IMG_FILE_TYPES)
if extension and extension in file_types:
_type = extension
else:
for file_type in file_types:
if file_type in path:
_type = file_type
break
if not _type:
raise ValueError("File type is required.")
file["type"] = _type
return files
@classmethod
def from_message(

View file

@ -19,6 +19,7 @@ export type ChatOutputType = {
sender: string;
sender_name: string;
stream_url?: string;
files?: Array<{ path: string; type: string; name: string }>;
};
export type chatInputType = {