fix: stream option is not working with tracing enabled (#2602)
* fix: cannot create 'generator' instances * fix: cannot create 'generator' instances * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ce9b4b09e5
commit
21adbd5531
4 changed files with 40 additions and 23 deletions
|
|
@ -86,24 +86,21 @@ class Message(Data):
|
|||
# they are: "text", "sender"
|
||||
if self.text is None or not self.sender:
|
||||
logger.warning("Missing required keys ('text', 'sender') in Message, defaulting to HumanMessage.")
|
||||
if not isinstance(self.text, str):
|
||||
text = ""
|
||||
else:
|
||||
text = self.text
|
||||
|
||||
if self.sender == "User" or not self.sender:
|
||||
if self.files:
|
||||
contents = [{"type": "text", "text": self.text}]
|
||||
contents = [{"type": "text", "text": text}]
|
||||
contents.extend(self.get_file_content_dicts())
|
||||
human_message = HumanMessage(content=contents) # type: ignore
|
||||
else:
|
||||
if not isinstance(self.text, str):
|
||||
text = ""
|
||||
else:
|
||||
text = self.text
|
||||
human_message = HumanMessage(
|
||||
content=text,
|
||||
)
|
||||
|
||||
human_message = HumanMessage(content=text)
|
||||
return human_message
|
||||
|
||||
return AIMessage(content=self.text) # type: ignore
|
||||
return AIMessage(content=text) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: "Data") -> "Message":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import os
|
||||
import traceback
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -214,11 +215,8 @@ class LangSmithTracer(BaseTracer):
|
|||
):
|
||||
if not self._ready:
|
||||
return
|
||||
raw_inputs = {}
|
||||
processed_inputs = {}
|
||||
if inputs:
|
||||
raw_inputs = inputs.copy()
|
||||
raw_inputs |= metadata or {}
|
||||
processed_inputs = self._convert_to_langchain_types(inputs)
|
||||
child = self._run_tree.create_child(
|
||||
name=trace_name,
|
||||
|
|
@ -226,7 +224,7 @@ class LangSmithTracer(BaseTracer):
|
|||
inputs=processed_inputs,
|
||||
)
|
||||
if metadata:
|
||||
child.add_metadata(metadata)
|
||||
child.add_metadata(self._convert_to_langchain_types(metadata))
|
||||
self._children[trace_name] = child
|
||||
self._child_link: dict[str, str] = {}
|
||||
|
||||
|
|
@ -254,6 +252,9 @@ class LangSmithTracer(BaseTracer):
|
|||
value = value.to_lc_document()
|
||||
elif isinstance(value, Data):
|
||||
value = value.to_lc_document()
|
||||
elif isinstance(value, types.GeneratorType):
|
||||
# generator is not serializable, also we can't consume it
|
||||
value = str(value)
|
||||
return value
|
||||
|
||||
def end_trace(
|
||||
|
|
@ -270,8 +271,8 @@ class LangSmithTracer(BaseTracer):
|
|||
raw_outputs = outputs
|
||||
processed_outputs = self._convert_to_langchain_types(outputs)
|
||||
if logs:
|
||||
child.add_metadata({"logs": {log.get("name"): log for log in logs}})
|
||||
child.add_metadata({"outputs": raw_outputs})
|
||||
child.add_metadata(self._convert_to_langchain_types({"logs": {log.get("name"): log for log in logs}}))
|
||||
child.add_metadata(self._convert_to_langchain_types({"outputs": raw_outputs}))
|
||||
child.end(outputs=processed_outputs, error=error)
|
||||
if error:
|
||||
child.patch()
|
||||
|
|
|
|||
|
|
@ -73,13 +73,6 @@ def get_text():
|
|||
assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def check_openai_api_key_in_environment_variables():
|
||||
import os
|
||||
|
||||
assert os.environ.get("OPENAI_API_KEY") is not None, "OPENAI_API_KEY is not set in environment variables"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def async_client() -> AsyncGenerator:
|
||||
from langflow.main import create_app
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from langflow.schema.message import Message
|
|||
from langflow.services.database.models.message import MessageCreate, MessageRead
|
||||
from langflow.services.database.models.message.model import MessageTable
|
||||
from langflow.services.deps import session_scope
|
||||
from langflow.services.tracing.utils import convert_to_langchain_type
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
@ -74,3 +75,28 @@ def test_store_message():
|
|||
stored_messages = store_message(message)
|
||||
assert len(stored_messages) == 1
|
||||
assert stored_messages[0].text == "Stored message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["message", "convert_to_langchain_type"])
|
||||
def test_convert_to_langchain(method_name):
|
||||
def convert(value):
|
||||
if method_name == "message":
|
||||
return value.to_lc_message()
|
||||
elif method_name == "convert_to_langchain_type":
|
||||
return convert_to_langchain_type(value)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method_name}")
|
||||
|
||||
lc_message = convert(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"))
|
||||
assert lc_message.content == "Test message 1"
|
||||
assert lc_message.type == "human"
|
||||
|
||||
lc_message = convert(Message(text="Test message 2", sender="AI", session_id="session_id2"))
|
||||
assert lc_message.content == "Test message 2"
|
||||
assert lc_message.type == "ai"
|
||||
|
||||
iterator = iter(["stream", "message"])
|
||||
lc_message = convert(Message(text=iterator, sender="AI", session_id="session_id2"))
|
||||
assert lc_message.content == ""
|
||||
assert lc_message.type == "ai"
|
||||
assert len(list(iterator)) == 2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue