diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index 5ad7d1922..28839cf98 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -86,24 +86,21 @@ class Message(Data): # they are: "text", "sender" if self.text is None or not self.sender: logger.warning("Missing required keys ('text', 'sender') in Message, defaulting to HumanMessage.") + if not isinstance(self.text, str): + text = "" + else: + text = self.text if self.sender == "User" or not self.sender: if self.files: - contents = [{"type": "text", "text": self.text}] + contents = [{"type": "text", "text": text}] contents.extend(self.get_file_content_dicts()) human_message = HumanMessage(content=contents) # type: ignore else: - if not isinstance(self.text, str): - text = "" - else: - text = self.text - human_message = HumanMessage( - content=text, - ) - + human_message = HumanMessage(content=text) return human_message - return AIMessage(content=self.text) # type: ignore + return AIMessage(content=text) # type: ignore @classmethod def from_data(cls, data: "Data") -> "Message": diff --git a/src/backend/base/langflow/services/tracing/service.py b/src/backend/base/langflow/services/tracing/service.py index d3307e8b3..72e73a50c 100644 --- a/src/backend/base/langflow/services/tracing/service.py +++ b/src/backend/base/langflow/services/tracing/service.py @@ -1,6 +1,7 @@ import asyncio import os import traceback +import types from collections import defaultdict from contextlib import asynccontextmanager from datetime import datetime, timezone @@ -214,11 +215,8 @@ class LangSmithTracer(BaseTracer): ): if not self._ready: return - raw_inputs = {} processed_inputs = {} if inputs: - raw_inputs = inputs.copy() - raw_inputs |= metadata or {} processed_inputs = self._convert_to_langchain_types(inputs) child = self._run_tree.create_child( name=trace_name, @@ -226,7 +224,7 @@ class LangSmithTracer(BaseTracer): inputs=processed_inputs, ) if metadata: - child.add_metadata(metadata) + child.add_metadata(self._convert_to_langchain_types(metadata)) self._children[trace_name] = child self._child_link: dict[str, str] = {} @@ -254,6 +252,9 @@ class LangSmithTracer(BaseTracer): value = value.to_lc_document() elif isinstance(value, Data): value = value.to_lc_document() + elif isinstance(value, types.GeneratorType): + # generator is not serializable, also we can't consume it + value = str(value) return value def end_trace( @@ -270,8 +271,8 @@ class LangSmithTracer(BaseTracer): raw_outputs = outputs processed_outputs = self._convert_to_langchain_types(outputs) if logs: - child.add_metadata({"logs": {log.get("name"): log for log in logs}}) - child.add_metadata({"outputs": raw_outputs}) + child.add_metadata(self._convert_to_langchain_types({"logs": {log.get("name"): log for log in logs}})) + child.add_metadata(self._convert_to_langchain_types({"outputs": raw_outputs})) child.end(outputs=processed_outputs, error=error) if error: child.patch() diff --git a/tests/conftest.py b/tests/conftest.py index 0033b0652..43910c860 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,13 +73,6 @@ def get_text(): assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}" -@pytest.fixture(autouse=True) -def check_openai_api_key_in_environment_variables(): - import os - - assert os.environ.get("OPENAI_API_KEY") is not None, "OPENAI_API_KEY is not set in environment variables" - - @pytest.fixture() async def async_client() -> AsyncGenerator: from langflow.main import create_app diff --git a/tests/unit/test_messages.py b/tests/unit/test_messages.py index 059d82b61..5ae53bb34 100644 --- a/tests/unit/test_messages.py +++ b/tests/unit/test_messages.py @@ -7,6 +7,7 @@ from langflow.schema.message import Message from langflow.services.database.models.message import MessageCreate, MessageRead from langflow.services.database.models.message.model import MessageTable from langflow.services.deps import session_scope +from langflow.services.tracing.utils import convert_to_langchain_type @pytest.fixture() @@ -74,3 +75,28 @@ def test_store_message(): stored_messages = store_message(message) assert len(stored_messages) == 1 assert stored_messages[0].text == "Stored message" + + +@pytest.mark.parametrize("method_name", ["message", "convert_to_langchain_type"]) +def test_convert_to_langchain(method_name): + def convert(value): + if method_name == "message": + return value.to_lc_message() + elif method_name == "convert_to_langchain_type": + return convert_to_langchain_type(value) + else: + raise ValueError(f"Invalid method: {method_name}") + + lc_message = convert(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2")) + assert lc_message.content == "Test message 1" + assert lc_message.type == "human" + + lc_message = convert(Message(text="Test message 2", sender="AI", session_id="session_id2")) + assert lc_message.content == "Test message 2" + assert lc_message.type == "ai" + + iterator = iter(["stream", "message"]) + lc_message = convert(Message(text=iterator, sender="AI", session_id="session_id2")) + assert lc_message.content == "" + assert lc_message.type == "ai" + assert len(list(iterator)) == 2