Add dependencies and fix typing in callback.py, ConversationChain.py, constants.py, and custom_component.py
This commit is contained in:
parent
0c24390d9a
commit
4460c6e13c
4 changed files with 12 additions and 11 deletions
|
|
@ -7,7 +7,7 @@ from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHand
|
|||
from loguru import logger
|
||||
|
||||
from langflow.api.v1.schemas import ChatResponse, PromptResponse
|
||||
from langflow.services.deps import get_chat_service
|
||||
from langflow.services.deps import get_chat_service, get_socket_service
|
||||
from langflow.utils.util import remove_ansi_escape_codes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -21,7 +21,7 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
|
|||
def __init__(self, session_id: str):
|
||||
self.chat_service = get_chat_service()
|
||||
self.client_id = session_id
|
||||
self.socketio_service: "SocketIOService" = self.chat_service.socketio_service
|
||||
self.socketio_service: "SocketIOService" = get_socket_service()
|
||||
self.sid = session_id
|
||||
# self.socketio_service = self.chat_service.active_connections[self.client_id]
|
||||
|
||||
|
|
|
|||
|
|
@ -34,12 +34,12 @@ class ConversationChainComponent(CustomComponent):
|
|||
result = chain.invoke({chain.input_key: input_value})
|
||||
# result is an AIMessage which is a subclass of BaseMessage
|
||||
# We need to check if it is a string or a BaseMessage
|
||||
result_str = ""
|
||||
result_str: str = ""
|
||||
if hasattr(result, "content") and isinstance(result.content, str):
|
||||
self.status = "is message"
|
||||
|
||||
result_str = result.content
|
||||
elif isinstance(result, str):
|
||||
self.status = "is_string"
|
||||
|
||||
result_str = result
|
||||
else:
|
||||
# is dict
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Dict, Union
|
||||
from typing import Callable, Dict, NewType, Union
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.chains.base import Chain
|
||||
|
|
@ -22,9 +22,10 @@ class Object:
|
|||
pass
|
||||
|
||||
|
||||
# Text = NewType("Text", str)
|
||||
class Text(str):
|
||||
pass
|
||||
Text = NewType("Text", str)
|
||||
# class Text(str):
|
||||
# pass
|
||||
#
|
||||
|
||||
|
||||
class Data:
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class CustomComponent(Component):
|
|||
|
||||
def to_records(
|
||||
self, data: Any, text_key: str = "text", data_key: str = "data"
|
||||
) -> List[dict]:
|
||||
) -> List[Record]:
|
||||
"""
|
||||
Convert data into a list of records.
|
||||
|
||||
|
|
@ -312,7 +312,7 @@ class CustomComponent(Component):
|
|||
) -> Any:
|
||||
graph = await self.load_flow(flow_id, tweaks)
|
||||
input_value_dict = {"input_value": input_value}
|
||||
return await graph.run(input_value_dict)
|
||||
return await graph.run(input_value_dict, stream=False)
|
||||
|
||||
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
|
||||
if not self._user_id:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue