Merge remote-tracking branch 'origin/streaming' into chatUpdate

This commit is contained in:
anovazzi1 2023-05-05 12:44:02 -03:00
commit cc5ac22b85
8 changed files with 43 additions and 32 deletions

View file

@ -1,6 +1,7 @@
import asyncio
from typing import Any
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langflow.api.schemas import ChatResponse

View file

@ -10,9 +10,5 @@ chat_manager = ChatManager()
@router.websocket("/chat/{client_id}")
async def websocket_endpoint(client_id: str, websocket: WebSocket):
"""Websocket endpoint for chat."""
try:
await chat_manager.handle_websocket(client_id, websocket)
except Exception as e:
# Log stack trace
logger.exception(e)
raise e
await chat_manager.handle_websocket(client_id, websocket)

View file

@ -5,6 +5,7 @@ from typing import Dict, List
from fastapi import WebSocket
from langflow.api.callback import StreamingLLMCallbackHandler
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.cache import cache_manager
from langflow.cache.manager import Subject
@ -143,7 +144,7 @@ class ChatManager:
break
response = ChatResponse(
message=result or "",
message="",
intermediate_steps=intermediate_steps.strip(),
type="end",
files=file_responses,
@ -175,14 +176,11 @@ class ChatManager:
# Handle any exceptions that might occur
logger.exception(e)
# send a message to the client
await self.active_connections[client_id].close(
code=1011, reason=str(e)
)
await self.active_connections[client_id].close(code=1000, reason=str(e))
finally:
await self.active_connections[client_id].close(
code=1000, reason="Client disconnected"
)
# await self.active_connections[client_id].close(
# code=1000, reason="Client disconnected"
# )
self.disconnect(client_id)
@ -205,8 +203,9 @@ async def process_graph(
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = get_result_and_steps(
langchain_object, chat_message.message or ""
stream_handler = StreamingLLMCallbackHandler(websocket)
result, intermediate_steps = await get_result_and_steps(
langchain_object, chat_message.message or "", callbacks=[stream_handler]
)
logger.debug("Generated result and intermediate_steps")
return result, intermediate_steps

View file

@ -47,7 +47,7 @@ def memoize_dict(maxsize=128):
def clear_cache():
cache.clear()
wrapper.clear_cache = clear_cache
wrapper.clear_cache = clear_cache # type: ignore
return wrapper
return decorator
@ -119,7 +119,8 @@ def save_binary_file(content: str, file_name: str, accepted_types: list[str]) ->
# Get the destination folder
cache_path = Path(tempfile.gettempdir()) / PREFIX
if content is None:
raise ValueError("Please, reload the file in the loader.")
data = content.split(",")[1]
decoded_bytes = base64.b64decode(data)

View file

@ -185,8 +185,11 @@ def fix_memory_inputs(langchain_object):
update_memory_keys(langchain_object, possible_new_mem_key)
def get_result_and_steps(langchain_object, message: str):
async def get_result_and_steps(langchain_object, message: str, callbacks=None):
"""Get result and thought from extracted json"""
if callbacks is None:
callbacks = []
try:
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
@ -206,17 +209,17 @@ def get_result_and_steps(langchain_object, message: str):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
langchain_object.return_intermediate_steps = True
fix_memory_inputs(langchain_object)
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
try:
output = langchain_object(chat_input)
output = await langchain_object.acall(chat_input, callbacks=callbacks)
except ValueError as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
output = langchain_object.run(chat_input)
output = langchain_object.run(chat_input, callbacks=callbacks)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []

View file

@ -4,13 +4,9 @@ import os
from io import BytesIO
import yaml
from langchain.callbacks.manager import AsyncCallbackManager
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.llms import AzureOpenAI, OpenAI
from langchain.base_language import BaseLanguageModel
from PIL.Image import Image
from langflow.api.callback import StreamingLLMCallbackHandler
def load_file_into_dict(file_path: str) -> dict:
if not os.path.exists(file_path):
@ -48,10 +44,7 @@ def try_setting_streaming_options(langchain_object, websocket):
langchain_object.llm_chain, "llm"
):
llm = langchain_object.llm_chain.llm
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
if isinstance(llm, BaseLanguageModel):
llm.streaming = bool(hasattr(llm, "streaming"))
stream_handler = StreamingLLMCallbackHandler(websocket)
stream_manager = AsyncCallbackManager([stream_handler])
llm.callback_manager = stream_manager
return langchain_object

View file

@ -150,6 +150,13 @@ class TimeTravelGuideChainNode(FrontendNode):
multiline=False,
name="llm",
),
TemplateField(
field_type="BaseChatMemory",
required=False,
show=True,
name="memory",
advanced=False,
),
],
)
description: str = "Time travel guide chain to be used in the flow."

View file

@ -435,5 +435,16 @@ def test_time_travel_guide_chain(client: TestClient):
"list": False,
"advanced": False,
}
assert template["memory"] == {
"required": False,
"placeholder": "",
"show": True,
"multiline": False,
"password": False,
"name": "memory",
"type": "BaseChatMemory",
"list": False,
"advanced": False,
}
assert chain["description"] == "Time travel guide chain to be used in the flow."