fix: stream option is not working with tracing enabled (#2602)

* fix: cannot create 'generator' instances

* fix: cannot create 'generator' instances

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Nicolò Boschi 2024-07-09 15:37:47 +02:00 committed by GitHub
commit 21adbd5531
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 40 additions and 23 deletions

View file

@ -86,24 +86,21 @@ class Message(Data):
# they are: "text", "sender"
if self.text is None or not self.sender:
logger.warning("Missing required keys ('text', 'sender') in Message, defaulting to HumanMessage.")
if not isinstance(self.text, str):
text = ""
else:
text = self.text
if self.sender == "User" or not self.sender:
if self.files:
contents = [{"type": "text", "text": self.text}]
contents = [{"type": "text", "text": text}]
contents.extend(self.get_file_content_dicts())
human_message = HumanMessage(content=contents) # type: ignore
else:
if not isinstance(self.text, str):
text = ""
else:
text = self.text
human_message = HumanMessage(
content=text,
)
human_message = HumanMessage(content=text)
return human_message
return AIMessage(content=self.text) # type: ignore
return AIMessage(content=text) # type: ignore
@classmethod
def from_data(cls, data: "Data") -> "Message":

View file

@ -1,6 +1,7 @@
import asyncio
import os
import traceback
import types
from collections import defaultdict
from contextlib import asynccontextmanager
from datetime import datetime, timezone
@ -214,11 +215,8 @@ class LangSmithTracer(BaseTracer):
):
if not self._ready:
return
raw_inputs = {}
processed_inputs = {}
if inputs:
raw_inputs = inputs.copy()
raw_inputs |= metadata or {}
processed_inputs = self._convert_to_langchain_types(inputs)
child = self._run_tree.create_child(
name=trace_name,
@ -226,7 +224,7 @@ class LangSmithTracer(BaseTracer):
inputs=processed_inputs,
)
if metadata:
child.add_metadata(metadata)
child.add_metadata(self._convert_to_langchain_types(metadata))
self._children[trace_name] = child
self._child_link: dict[str, str] = {}
@ -254,6 +252,9 @@ class LangSmithTracer(BaseTracer):
value = value.to_lc_document()
elif isinstance(value, Data):
value = value.to_lc_document()
elif isinstance(value, types.GeneratorType):
# generator is not serializable, also we can't consume it
value = str(value)
return value
def end_trace(
@ -270,8 +271,8 @@ class LangSmithTracer(BaseTracer):
raw_outputs = outputs
processed_outputs = self._convert_to_langchain_types(outputs)
if logs:
child.add_metadata({"logs": {log.get("name"): log for log in logs}})
child.add_metadata({"outputs": raw_outputs})
child.add_metadata(self._convert_to_langchain_types({"logs": {log.get("name"): log for log in logs}}))
child.add_metadata(self._convert_to_langchain_types({"outputs": raw_outputs}))
child.end(outputs=processed_outputs, error=error)
if error:
child.patch()

View file

@ -73,13 +73,6 @@ def get_text():
assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}"
@pytest.fixture(autouse=True)
def check_openai_api_key_in_environment_variables():
import os
assert os.environ.get("OPENAI_API_KEY") is not None, "OPENAI_API_KEY is not set in environment variables"
@pytest.fixture()
async def async_client() -> AsyncGenerator:
from langflow.main import create_app

View file

@ -7,6 +7,7 @@ from langflow.schema.message import Message
from langflow.services.database.models.message import MessageCreate, MessageRead
from langflow.services.database.models.message.model import MessageTable
from langflow.services.deps import session_scope
from langflow.services.tracing.utils import convert_to_langchain_type
@pytest.fixture()
@ -74,3 +75,28 @@ def test_store_message():
stored_messages = store_message(message)
assert len(stored_messages) == 1
assert stored_messages[0].text == "Stored message"
@pytest.mark.parametrize("method_name", ["message", "convert_to_langchain_type"])
def test_convert_to_langchain(method_name):
def convert(value):
if method_name == "message":
return value.to_lc_message()
elif method_name == "convert_to_langchain_type":
return convert_to_langchain_type(value)
else:
raise ValueError(f"Invalid method: {method_name}")
lc_message = convert(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"))
assert lc_message.content == "Test message 1"
assert lc_message.type == "human"
lc_message = convert(Message(text="Test message 2", sender="AI", session_id="session_id2"))
assert lc_message.content == "Test message 2"
assert lc_message.type == "ai"
iterator = iter(["stream", "message"])
lc_message = convert(Message(text=iterator, sender="AI", session_id="session_id2"))
assert lc_message.content == ""
assert lc_message.type == "ai"
assert len(list(iterator)) == 2