feat: improve LangWatch integration by introducing langchain callbacks on the tracing service, and component and workflow span types (#3094)

* Improve LangWatch integration by introducing langchain callbacks on the tracing service, and component and workflow span types. Bump LangWatch to v0.1.14

* [autofix.ci] apply automated fixes

* Fix type checks

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Rogério Chaves 2024-08-01 15:28:40 +02:00 committed by GitHub
commit 916fca4051
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 96 additions and 69 deletions

16
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "aenum"
@ -5263,13 +5263,13 @@ requests = ">=2,<3"
[[package]]
name = "langwatch"
version = "0.1.13"
version = "0.1.14"
description = "Python SDK for LangWatch for monitoring your LLMs"
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "langwatch-0.1.13-py3-none-any.whl", hash = "sha256:d326c9c2d1a164a54ae0dde66aef50a641af8ead593a983f4c7e2bf6cac74e6b"},
{file = "langwatch-0.1.13.tar.gz", hash = "sha256:0dca8ca7627468bfada9aeda94452a9924775e6fbb9e761f7c7f9e9d074691d5"},
{file = "langwatch-0.1.14-py3-none-any.whl", hash = "sha256:5b3994ce3ee06e20de999635ecaaa9e2c6393839eb90a8bf402511444edda8a8"},
{file = "langwatch-0.1.14.tar.gz", hash = "sha256:3f97e891a61dff43a95b6fcb12db5c5a5170ff411dd56a55eff9ef24f19be967"},
]
[package.dependencies]
@ -5988,7 +5988,7 @@ files = [
{file = "msgpack-1.0.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5fbb160554e319f7b22ecf530a80a3ff496d38e8e07ae763b9e82fadfe96f273"},
{file = "msgpack-1.0.8-cp39-cp39-win32.whl", hash = "sha256:f9af38a89b6a5c04b7d18c492c8ccf2aee7048aff1ce8437c4683bb5a1df893d"},
{file = "msgpack-1.0.8-cp39-cp39-win_amd64.whl", hash = "sha256:ed59dd52075f8fc91da6053b12e8c89e37aa043f8986efd89e61fae69dc1b011"},
{file = "msgpack-1.0.8.tar.gz", hash = "sha256:95c02b0e27e706e48d0e5426d1710ca78e0f0628d6e89d5b5a5b91a5f12274f3"},
{file = "msgpack-1.0.8-py3-none-any.whl", hash = "sha256:24f727df1e20b9876fa6e95f840a2a2651e34c0ad147676356f4bf5fbb0206ca"},
]
[[package]]
@ -6473,7 +6473,6 @@ description = "Nvidia JIT LTO Library"
optional = true
python-versions = ">=3"
files = [
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_aarch64.whl", hash = "sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27"},
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212"},
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-win_amd64.whl", hash = "sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697"},
]
@ -6913,7 +6912,6 @@ optional = false
python-versions = ">=3.9"
files = [
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
@ -6934,7 +6932,6 @@ files = [
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
@ -8647,7 +8644,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -11995,4 +11991,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "3d81751b63f97649e1b46a0d626c055303df6ab3bfce2e6b49941400aad08793"
content-hash = "360b7bde5bb4338b128efbec31990d5fbce816b6fc5d58875a488d314c86fe7b"

View file

@ -99,7 +99,7 @@ langchain-nvidia-ai-endpoints = "0.1.6"
langchain-google-calendar-tools = "^0.0.1"
langchain-milvus = "^0.1.1"
crewai = {extras = ["tools"], version = "^0.36.0"}
langwatch = "^0.1.10"
langwatch = "^0.1.14"
langsmith = "^0.1.86"
yfinance = "^0.2.40"
langchain-google-community = "1.0.7"

View file

@ -98,7 +98,9 @@ class LCAgentComponent(Component):
self.chat_history = self.get_chat_history_data()
if self.chat_history:
input_dict["chat_history"] = data_to_messages(self.chat_history)
result = await agent.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
result = await agent.ainvoke(
input_dict, config={"callbacks": [AgentAsyncHandler(self.log)] + self.get_langchain_callbacks()}
)
self.status = result
if "output" not in result:
raise ValueError("Output key not found in result. Tried 'output'.")
@ -141,7 +143,10 @@ class LCToolsAgentComponent(LCAgentComponent):
input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value}
if self.chat_history:
input_dict["chat_history"] = data_to_messages(self.chat_history)
result = await runnable.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
result = await runnable.ainvoke(
input_dict, config={"callbacks": [AgentAsyncHandler(self.log)] + self.get_langchain_callbacks()}
)
self.status = result
if "output" not in result:
raise ValueError("Output key not found in result. Tried 'output'.")

View file

@ -164,7 +164,11 @@ class LCModelComponent(Component):
inputs: Union[list, dict] = messages or {}
try:
runnable = runnable.with_config( # type: ignore
{"run_name": self.display_name, "project_name": self._tracing_service.project_name} # type: ignore
{
"run_name": self.display_name,
"project_name": self._tracing_service.project_name, # type: ignore
"callbacks": self.get_langchain_callbacks(),
}
)
if stream:
return runnable.stream(inputs) # type: ignore

View file

@ -28,7 +28,7 @@ class ConversationChainComponent(LCChainComponent):
else:
chain = ConversationChain(llm=self.llm, memory=self.memory)
result = chain.invoke({"input": self.input_value})
result = chain.invoke({"input": self.input_value}, config={"callbacks": self.get_langchain_callbacks()})
if isinstance(result, dict):
result = result.get(chain.output_key, "") # type: ignore

View file

@ -20,7 +20,9 @@ class LLMCheckerChainComponent(LCChainComponent):
def invoke_chain(self) -> Message:
chain = LLMCheckerChain.from_llm(llm=self.llm)
response = chain.invoke({chain.input_key: self.input_value})
response = chain.invoke(
{chain.input_key: self.input_value}, config={"callbacks": self.get_langchain_callbacks()}
)
result = response.get(chain.output_key, "")
result = str(result)
self.status = result

View file

@ -23,7 +23,9 @@ class LLMMathChainComponent(LCChainComponent):
def invoke_chain(self) -> Message:
chain = LLMMathChain.from_llm(llm=self.llm)
response = chain.invoke({chain.input_key: self.input_value})
response = chain.invoke(
{chain.input_key: self.input_value}, config={"callbacks": self.get_langchain_callbacks()}
)
result = response.get(chain.output_key, "")
result = str(result)
self.status = result

View file

@ -52,7 +52,7 @@ class RetrievalQAComponent(LCChainComponent):
return_source_documents=True,
)
result = runnable.invoke({"query": self.input_value})
result = runnable.invoke({"query": self.input_value}, config={"callbacks": self.get_langchain_callbacks()})
source_docs = self.to_data(result.get("source_documents", []))
result_str = str(result.get("result", ""))

View file

@ -43,7 +43,9 @@ class SQLGeneratorComponent(LCChainComponent):
raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.")
sql_query_chain = create_sql_query_chain(llm=self.llm, db=self.db, prompt=prompt_template, k=self.top_k)
query_writer: Runnable = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()}
response = query_writer.invoke({"question": self.input_value})
response = query_writer.invoke(
{"question": self.input_value}, config={"callbacks": self.get_langchain_callbacks()}
)
query = response.get("query")
self.status = query
return query

View file

@ -17,7 +17,10 @@ class ShouldRunNextComponent(CustomComponent):
chain = prompt | llm
error_message = ""
for i in range(retries):
result = chain.invoke(dict(question=question, context=context, error_message=error_message))
result = chain.invoke(
dict(question=question, context=context, error_message=error_message),
config={"callbacks": self.get_langchain_callbacks()},
)
if isinstance(result, BaseMessage):
content = result.content
elif isinstance(result, str):

View file

@ -46,7 +46,7 @@ class CohereRerankComponent(LCVectorStoreComponent):
async def search_documents(self) -> List[Data]: # type: ignore
retriever = self.build_base_retriever()
documents = await retriever.ainvoke(self.search_query)
documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
data = self.to_data(documents)
self.status = data
return data

View file

@ -58,7 +58,7 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
async def search_documents(self) -> List[Data]: # type: ignore
retriever = self.build_base_retriever()
documents = await retriever.ainvoke(self.search_query)
documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
data = self.to_data(documents)
self.status = data
return data

View file

@ -64,7 +64,7 @@ class SelfQueryRetrieverComponent(CustomComponent):
if not isinstance(query, str):
raise ValueError(f"Query type {type(query)} not supported.")
documents = self_query_retriever.invoke(input=input_text)
documents = self_query_retriever.invoke(input=input_text, config={"callbacks": self.get_langchain_callbacks()})
data = [Data.from_document(document) for document in documents]
self.status = data
return data

View file

@ -120,7 +120,7 @@ class VectaraRagComponent(Component):
rerank_config=rerank_config,
)
rag = vectara.as_rag(config)
response = rag.invoke(self.search_query)
response = rag.invoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
text_output = response["answer"]

View file

@ -28,6 +28,7 @@ if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
from langflow.services.storage.service import StorageService
from langflow.services.tracing.service import TracingService
from langchain.callbacks.base import BaseCallbackHandler
class CustomComponent(BaseComponent):
@ -528,3 +529,8 @@ class CustomComponent(BaseComponent):
frontend_node=new_frontend_node, raw_frontend_node=current_frontend_node
)
return frontend_node
def get_langchain_callbacks(self) -> List["BaseCallbackHandler"]:
if self._tracing_service:
return self._tracing_service.get_langchain_callbacks()
return []

View file

@ -136,7 +136,9 @@ class Data(BaseModel):
contents = [{"type": "text", "text": text}]
for file_path in files:
image_template = ImagePromptTemplate()
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file_path}) # type: ignore
image_prompt_value: ImagePromptValue = image_template.invoke(
input={"path": file_path}, config={"callbacks": self.get_langchain_callbacks()}
) # type: ignore
contents.append({"type": "image_url", "image_url": image_prompt_value.image_url})
human_message = HumanMessage(content=contents) # type: ignore
else:

View file

@ -165,7 +165,9 @@ class Message(Data):
content_dicts.append(file.to_content_dict())
else:
image_template = ImagePromptTemplate()
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file})
image_prompt_value: ImagePromptValue = image_template.invoke(
input={"path": file}, config={"callbacks": self.get_langchain_callbacks()}
) # type: ignore
content_dicts.append({"type": "image_url", "image_url": image_prompt_value.image_url})
return content_dicts

View file

@ -6,6 +6,7 @@ from langflow.services.tracing.schema import Log
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
from langchain.callbacks.base import BaseCallbackHandler
class BaseTracer(ABC):
@ -49,3 +50,7 @@ class BaseTracer(ABC):
metadata: dict[str, Any] | None = None,
):
raise NotImplementedError
@abstractmethod
def get_langchain_callback(self) -> Optional["BaseCallbackHandler"]:
raise NotImplementedError

View file

@ -13,6 +13,7 @@ from langflow.services.tracing.schema import Log
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
from langchain.callbacks.base import BaseCallbackHandler
class LangSmithTracer(BaseTracer):
@ -158,3 +159,6 @@ class LangSmithTracer(BaseTracer):
self._run_tree.end(outputs=outputs, error=self._error_to_string(error))
self._run_tree.post()
self._run_link = self._run_tree.get_url()
def get_langchain_callback(self) -> Optional["BaseCallbackHandler"]:
return None

View file

@ -1,4 +1,3 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
from uuid import UUID
@ -11,9 +10,9 @@ from langflow.services.tracing.schema import Log
if TYPE_CHECKING:
from langwatch.tracer import ContextSpan
from langwatch.types import SpanTypes
from langflow.graph.vertex.base import Vertex
from langchain.callbacks.base import BaseCallbackHandler
class LangWatchTracer(BaseTracer):
@ -40,7 +39,7 @@ class LangWatchTracer(BaseTracer):
self.trace.root_span.update(
span_id=f"{self.flow_id}-{nanoid.generate(size=6)}", # nanoid to make the span_id globally unique, which is required for LangWatch for now
name=name_without_id,
type=self._convert_trace_type(trace_type),
type="workflow",
)
except Exception as e:
logger.debug(f"Error setting up LangWatch tracer: {e}")
@ -51,8 +50,6 @@ class LangWatchTracer(BaseTracer):
return self._ready
def setup_langwatch(self):
if os.getenv("LANGWATCH_API_KEY") is None:
return False
try:
import langwatch
@ -62,14 +59,6 @@ class LangWatchTracer(BaseTracer):
return False
return True
def _convert_trace_type(self, trace_type: str):
trace_type_: "SpanTypes" = (
cast("SpanTypes", trace_type)
if trace_type in ["span", "llm", "chain", "tool", "agent", "guardrail", "rag"]
else "span"
)
return trace_type_
def add_trace(
self,
trace_id: str,
@ -88,23 +77,21 @@ class LangWatchTracer(BaseTracer):
name_without_id = " (".join(trace_name.split(" (")[0:-1])
trace_type_ = self._convert_trace_type(trace_type)
self.spans[trace_id] = self.trace.span(
span_id=f"{trace_id}-{nanoid.generate(size=6)}", # Add a nanoid to make the span_id globally unique, which is required for LangWatch for now
name=name_without_id,
type=trace_type_,
parent=(
[span for key, span in self.spans.items() for edge in vertex.incoming_edges if key == edge.source_id][
-1
]
if vertex and len(vertex.incoming_edges) > 0
else self.trace.root_span
),
input=self._convert_to_langwatch_types(inputs),
previous_nodes = (
[span for key, span in self.spans.items() for edge in vertex.incoming_edges if key == edge.source_id]
if vertex and len(vertex.incoming_edges) > 0
else []
)
if trace_type_ == "llm" and "model_name" in inputs:
self.spans[trace_id].update(model=inputs["model_name"])
span = self.trace.span(
span_id=f"{trace_id}-{nanoid.generate(size=6)}", # Add a nanoid to make the span_id globally unique, which is required for LangWatch for now
name=name_without_id,
type="component",
parent=(previous_nodes[-1] if len(previous_nodes) > 0 else self.trace.root_span),
input=self._convert_to_langwatch_types(inputs),
)
self.trace.set_current_span(span)
self.spans[trace_id] = span
def end_trace(
self,
@ -117,16 +104,6 @@ class LangWatchTracer(BaseTracer):
if not self._ready:
return
if self.spans.get(trace_id):
# Workaround for when model is used just as a component not actually called as an LLM,
# to prevent LangWatch from calculating the cost based on it when it was in fact never called
if (
self.spans[trace_id].type == "llm"
and outputs
and "model_output" in outputs
and "text_output" not in outputs
):
self.spans[trace_id].update(metrics={"prompt_tokens": 0, "completion_tokens": 0})
self.spans[trace_id].end(output=self._convert_to_langwatch_types(outputs), error=error)
def end(
@ -146,7 +123,9 @@ class LangWatchTracer(BaseTracer):
if metadata and "flow_name" in metadata:
self.trace.update(metadata=(self.trace.metadata or {}) | {"labels": [f"Flow: {metadata['flow_name']}"]})
self.trace.deferred_send_spans()
if self.trace.api_key or self._client.api_key:
self.trace.deferred_send_spans()
def _convert_to_langwatch_types(self, io_dict: Optional[Dict[str, Any]]):
from langwatch.utils import autoconvert_typed_values
@ -183,3 +162,9 @@ class LangWatchTracer(BaseTracer):
elif isinstance(value, Data):
value = cast(dict, value.to_lc_document())
return value
def get_langchain_callback(self) -> Optional["BaseCallbackHandler"]:
if self.trace is None:
return None
return self.trace.get_langchain_callback()

View file

@ -2,7 +2,7 @@ import asyncio
import os
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional, List
from uuid import UUID
from loguru import logger
@ -16,6 +16,7 @@ if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
from langflow.services.monitor.service import MonitorService
from langflow.services.settings.service import SettingsService
from langchain.callbacks.base import BaseCallbackHandler
def _get_langsmith_tracer():
@ -115,9 +116,7 @@ class TracingService(Service):
def _initialize_langwatch_tracer(self):
if (
os.getenv("LANGWATCH_API_KEY")
and "langwatch" not in self._tracers
or self._tracers["langwatch"].trace_id != self.run_id # type: ignore
"langwatch" not in self._tracers or self._tracers["langwatch"].trace_id != self.run_id # type: ignore
):
langwatch_tracer = _get_langwatch_tracer()
self._tracers["langwatch"] = langwatch_tracer(
@ -229,3 +228,13 @@ class TracingService(Service):
if "api_key" in key:
inputs[key] = "*****" # avoid logging api_keys for security reasons
return inputs
def get_langchain_callbacks(self) -> List["BaseCallbackHandler"]:
callbacks = []
for tracer in self._tracers.values():
if not tracer.ready: # type: ignore
continue
langchain_callback = tracer.get_langchain_callback()
if langchain_callback:
callbacks.append(langchain_callback)
return callbacks