Add dependencies and fix typing in callback.py, ConversationChain.py, constants.py, and custom_component.py

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-28 21:06:25 -03:00
commit 4460c6e13c
4 changed files with 12 additions and 11 deletions

View file

@ -7,7 +7,7 @@ from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHand
from loguru import logger
from langflow.api.v1.schemas import ChatResponse, PromptResponse
from langflow.services.deps import get_chat_service
from langflow.services.deps import get_chat_service, get_socket_service
from langflow.utils.util import remove_ansi_escape_codes
if TYPE_CHECKING:
@ -21,7 +21,7 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
def __init__(self, session_id: str):
self.chat_service = get_chat_service()
self.client_id = session_id
self.socketio_service: "SocketIOService" = self.chat_service.socketio_service
self.socketio_service: "SocketIOService" = get_socket_service()
self.sid = session_id
# self.socketio_service = self.chat_service.active_connections[self.client_id]

View file

@ -34,12 +34,12 @@ class ConversationChainComponent(CustomComponent):
result = chain.invoke({chain.input_key: input_value})
# result is an AIMessage which is a subclass of BaseMessage
# We need to check if it is a string or a BaseMessage
result_str = ""
result_str: str = ""
if hasattr(result, "content") and isinstance(result.content, str):
self.status = "is message"
result_str = result.content
elif isinstance(result, str):
self.status = "is_string"
result_str = result
else:
# is dict

View file

@ -1,4 +1,4 @@
from typing import Callable, Dict, Union
from typing import Callable, Dict, NewType, Union
from langchain.agents.agent import AgentExecutor
from langchain.chains.base import Chain
@ -22,9 +22,10 @@ class Object:
pass
# Text = NewType("Text", str)
class Text(str):
pass
Text = NewType("Text", str)
# class Text(str):
# pass
#
class Data:

View file

@ -117,7 +117,7 @@ class CustomComponent(Component):
def to_records(
self, data: Any, text_key: str = "text", data_key: str = "data"
) -> List[dict]:
) -> List[Record]:
"""
Convert data into a list of records.
@ -312,7 +312,7 @@ class CustomComponent(Component):
) -> Any:
graph = await self.load_flow(flow_id, tweaks)
input_value_dict = {"input_value": input_value}
return await graph.run(input_value_dict)
return await graph.run(input_value_dict, stream=False)
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
if not self._user_id: