feat: improve LangWatch integration by introducing langchain callbacks on the tracing service, and component and workflow span types (#3094)
* Improve LangWatch integration by introducing langchain callbacks on the tracing service, and component and workflow span types. Bump LangWatch to v0.1.14 * [autofix.ci] apply automated fixes * Fix type checks * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1dd840c526
commit
916fca4051
21 changed files with 96 additions and 69 deletions
|
|
@ -98,7 +98,9 @@ class LCAgentComponent(Component):
|
|||
self.chat_history = self.get_chat_history_data()
|
||||
if self.chat_history:
|
||||
input_dict["chat_history"] = data_to_messages(self.chat_history)
|
||||
result = await agent.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
|
||||
result = await agent.ainvoke(
|
||||
input_dict, config={"callbacks": [AgentAsyncHandler(self.log)] + self.get_langchain_callbacks()}
|
||||
)
|
||||
self.status = result
|
||||
if "output" not in result:
|
||||
raise ValueError("Output key not found in result. Tried 'output'.")
|
||||
|
|
@ -141,7 +143,10 @@ class LCToolsAgentComponent(LCAgentComponent):
|
|||
input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value}
|
||||
if self.chat_history:
|
||||
input_dict["chat_history"] = data_to_messages(self.chat_history)
|
||||
result = await runnable.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
|
||||
|
||||
result = await runnable.ainvoke(
|
||||
input_dict, config={"callbacks": [AgentAsyncHandler(self.log)] + self.get_langchain_callbacks()}
|
||||
)
|
||||
self.status = result
|
||||
if "output" not in result:
|
||||
raise ValueError("Output key not found in result. Tried 'output'.")
|
||||
|
|
|
|||
|
|
@ -164,7 +164,11 @@ class LCModelComponent(Component):
|
|||
inputs: Union[list, dict] = messages or {}
|
||||
try:
|
||||
runnable = runnable.with_config( # type: ignore
|
||||
{"run_name": self.display_name, "project_name": self._tracing_service.project_name} # type: ignore
|
||||
{
|
||||
"run_name": self.display_name,
|
||||
"project_name": self._tracing_service.project_name, # type: ignore
|
||||
"callbacks": self.get_langchain_callbacks(),
|
||||
}
|
||||
)
|
||||
if stream:
|
||||
return runnable.stream(inputs) # type: ignore
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class ConversationChainComponent(LCChainComponent):
|
|||
else:
|
||||
chain = ConversationChain(llm=self.llm, memory=self.memory)
|
||||
|
||||
result = chain.invoke({"input": self.input_value})
|
||||
result = chain.invoke({"input": self.input_value}, config={"callbacks": self.get_langchain_callbacks()})
|
||||
if isinstance(result, dict):
|
||||
result = result.get(chain.output_key, "") # type: ignore
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ class LLMCheckerChainComponent(LCChainComponent):
|
|||
|
||||
def invoke_chain(self) -> Message:
|
||||
chain = LLMCheckerChain.from_llm(llm=self.llm)
|
||||
response = chain.invoke({chain.input_key: self.input_value})
|
||||
response = chain.invoke(
|
||||
{chain.input_key: self.input_value}, config={"callbacks": self.get_langchain_callbacks()}
|
||||
)
|
||||
result = response.get(chain.output_key, "")
|
||||
result = str(result)
|
||||
self.status = result
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ class LLMMathChainComponent(LCChainComponent):
|
|||
|
||||
def invoke_chain(self) -> Message:
|
||||
chain = LLMMathChain.from_llm(llm=self.llm)
|
||||
response = chain.invoke({chain.input_key: self.input_value})
|
||||
response = chain.invoke(
|
||||
{chain.input_key: self.input_value}, config={"callbacks": self.get_langchain_callbacks()}
|
||||
)
|
||||
result = response.get(chain.output_key, "")
|
||||
result = str(result)
|
||||
self.status = result
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class RetrievalQAComponent(LCChainComponent):
|
|||
return_source_documents=True,
|
||||
)
|
||||
|
||||
result = runnable.invoke({"query": self.input_value})
|
||||
result = runnable.invoke({"query": self.input_value}, config={"callbacks": self.get_langchain_callbacks()})
|
||||
|
||||
source_docs = self.to_data(result.get("source_documents", []))
|
||||
result_str = str(result.get("result", ""))
|
||||
|
|
|
|||
|
|
@ -43,7 +43,9 @@ class SQLGeneratorComponent(LCChainComponent):
|
|||
raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.")
|
||||
sql_query_chain = create_sql_query_chain(llm=self.llm, db=self.db, prompt=prompt_template, k=self.top_k)
|
||||
query_writer: Runnable = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()}
|
||||
response = query_writer.invoke({"question": self.input_value})
|
||||
response = query_writer.invoke(
|
||||
{"question": self.input_value}, config={"callbacks": self.get_langchain_callbacks()}
|
||||
)
|
||||
query = response.get("query")
|
||||
self.status = query
|
||||
return query
|
||||
|
|
|
|||
|
|
@ -17,7 +17,10 @@ class ShouldRunNextComponent(CustomComponent):
|
|||
chain = prompt | llm
|
||||
error_message = ""
|
||||
for i in range(retries):
|
||||
result = chain.invoke(dict(question=question, context=context, error_message=error_message))
|
||||
result = chain.invoke(
|
||||
dict(question=question, context=context, error_message=error_message),
|
||||
config={"callbacks": self.get_langchain_callbacks()},
|
||||
)
|
||||
if isinstance(result, BaseMessage):
|
||||
content = result.content
|
||||
elif isinstance(result, str):
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class CohereRerankComponent(LCVectorStoreComponent):
|
|||
|
||||
async def search_documents(self) -> List[Data]: # type: ignore
|
||||
retriever = self.build_base_retriever()
|
||||
documents = await retriever.ainvoke(self.search_query)
|
||||
documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
|
||||
data = self.to_data(documents)
|
||||
self.status = data
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
|
|||
|
||||
async def search_documents(self) -> List[Data]: # type: ignore
|
||||
retriever = self.build_base_retriever()
|
||||
documents = await retriever.ainvoke(self.search_query)
|
||||
documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
|
||||
data = self.to_data(documents)
|
||||
self.status = data
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class SelfQueryRetrieverComponent(CustomComponent):
|
|||
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(f"Query type {type(query)} not supported.")
|
||||
documents = self_query_retriever.invoke(input=input_text)
|
||||
documents = self_query_retriever.invoke(input=input_text, config={"callbacks": self.get_langchain_callbacks()})
|
||||
data = [Data.from_document(document) for document in documents]
|
||||
self.status = data
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ class VectaraRagComponent(Component):
|
|||
rerank_config=rerank_config,
|
||||
)
|
||||
rag = vectara.as_rag(config)
|
||||
response = rag.invoke(self.search_query)
|
||||
response = rag.invoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
|
||||
|
||||
text_output = response["answer"]
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ if TYPE_CHECKING:
|
|||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.services.tracing.service import TracingService
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
class CustomComponent(BaseComponent):
|
||||
|
|
@ -528,3 +529,8 @@ class CustomComponent(BaseComponent):
|
|||
frontend_node=new_frontend_node, raw_frontend_node=current_frontend_node
|
||||
)
|
||||
return frontend_node
|
||||
|
||||
def get_langchain_callbacks(self) -> List["BaseCallbackHandler"]:
|
||||
if self._tracing_service:
|
||||
return self._tracing_service.get_langchain_callbacks()
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -136,7 +136,9 @@ class Data(BaseModel):
|
|||
contents = [{"type": "text", "text": text}]
|
||||
for file_path in files:
|
||||
image_template = ImagePromptTemplate()
|
||||
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file_path}) # type: ignore
|
||||
image_prompt_value: ImagePromptValue = image_template.invoke(
|
||||
input={"path": file_path}, config={"callbacks": self.get_langchain_callbacks()}
|
||||
) # type: ignore
|
||||
contents.append({"type": "image_url", "image_url": image_prompt_value.image_url})
|
||||
human_message = HumanMessage(content=contents) # type: ignore
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -165,7 +165,9 @@ class Message(Data):
|
|||
content_dicts.append(file.to_content_dict())
|
||||
else:
|
||||
image_template = ImagePromptTemplate()
|
||||
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file})
|
||||
image_prompt_value: ImagePromptValue = image_template.invoke(
|
||||
input={"path": file}, config={"callbacks": self.get_langchain_callbacks()}
|
||||
) # type: ignore
|
||||
content_dicts.append({"type": "image_url", "image_url": image_prompt_value.image_url})
|
||||
return content_dicts
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from langflow.services.tracing.schema import Log
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
class BaseTracer(ABC):
|
||||
|
|
@ -49,3 +50,7 @@ class BaseTracer(ABC):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_langchain_callback(self) -> Optional["BaseCallbackHandler"]:
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from langflow.services.tracing.schema import Log
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
class LangSmithTracer(BaseTracer):
|
||||
|
|
@ -158,3 +159,6 @@ class LangSmithTracer(BaseTracer):
|
|||
self._run_tree.end(outputs=outputs, error=self._error_to_string(error))
|
||||
self._run_tree.post()
|
||||
self._run_link = self._run_tree.get_url()
|
||||
|
||||
def get_langchain_callback(self) -> Optional["BaseCallbackHandler"]:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -11,9 +10,9 @@ from langflow.services.tracing.schema import Log
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from langwatch.tracer import ContextSpan
|
||||
from langwatch.types import SpanTypes
|
||||
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
class LangWatchTracer(BaseTracer):
|
||||
|
|
@ -40,7 +39,7 @@ class LangWatchTracer(BaseTracer):
|
|||
self.trace.root_span.update(
|
||||
span_id=f"{self.flow_id}-{nanoid.generate(size=6)}", # nanoid to make the span_id globally unique, which is required for LangWatch for now
|
||||
name=name_without_id,
|
||||
type=self._convert_trace_type(trace_type),
|
||||
type="workflow",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error setting up LangWatch tracer: {e}")
|
||||
|
|
@ -51,8 +50,6 @@ class LangWatchTracer(BaseTracer):
|
|||
return self._ready
|
||||
|
||||
def setup_langwatch(self):
|
||||
if os.getenv("LANGWATCH_API_KEY") is None:
|
||||
return False
|
||||
try:
|
||||
import langwatch
|
||||
|
||||
|
|
@ -62,14 +59,6 @@ class LangWatchTracer(BaseTracer):
|
|||
return False
|
||||
return True
|
||||
|
||||
def _convert_trace_type(self, trace_type: str):
|
||||
trace_type_: "SpanTypes" = (
|
||||
cast("SpanTypes", trace_type)
|
||||
if trace_type in ["span", "llm", "chain", "tool", "agent", "guardrail", "rag"]
|
||||
else "span"
|
||||
)
|
||||
return trace_type_
|
||||
|
||||
def add_trace(
|
||||
self,
|
||||
trace_id: str,
|
||||
|
|
@ -88,23 +77,21 @@ class LangWatchTracer(BaseTracer):
|
|||
|
||||
name_without_id = " (".join(trace_name.split(" (")[0:-1])
|
||||
|
||||
trace_type_ = self._convert_trace_type(trace_type)
|
||||
self.spans[trace_id] = self.trace.span(
|
||||
span_id=f"{trace_id}-{nanoid.generate(size=6)}", # Add a nanoid to make the span_id globally unique, which is required for LangWatch for now
|
||||
name=name_without_id,
|
||||
type=trace_type_,
|
||||
parent=(
|
||||
[span for key, span in self.spans.items() for edge in vertex.incoming_edges if key == edge.source_id][
|
||||
-1
|
||||
]
|
||||
if vertex and len(vertex.incoming_edges) > 0
|
||||
else self.trace.root_span
|
||||
),
|
||||
input=self._convert_to_langwatch_types(inputs),
|
||||
previous_nodes = (
|
||||
[span for key, span in self.spans.items() for edge in vertex.incoming_edges if key == edge.source_id]
|
||||
if vertex and len(vertex.incoming_edges) > 0
|
||||
else []
|
||||
)
|
||||
|
||||
if trace_type_ == "llm" and "model_name" in inputs:
|
||||
self.spans[trace_id].update(model=inputs["model_name"])
|
||||
span = self.trace.span(
|
||||
span_id=f"{trace_id}-{nanoid.generate(size=6)}", # Add a nanoid to make the span_id globally unique, which is required for LangWatch for now
|
||||
name=name_without_id,
|
||||
type="component",
|
||||
parent=(previous_nodes[-1] if len(previous_nodes) > 0 else self.trace.root_span),
|
||||
input=self._convert_to_langwatch_types(inputs),
|
||||
)
|
||||
self.trace.set_current_span(span)
|
||||
self.spans[trace_id] = span
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
|
|
@ -117,16 +104,6 @@ class LangWatchTracer(BaseTracer):
|
|||
if not self._ready:
|
||||
return
|
||||
if self.spans.get(trace_id):
|
||||
# Workaround for when model is used just as a component not actually called as an LLM,
|
||||
# to prevent LangWatch from calculating the cost based on it when it was in fact never called
|
||||
if (
|
||||
self.spans[trace_id].type == "llm"
|
||||
and outputs
|
||||
and "model_output" in outputs
|
||||
and "text_output" not in outputs
|
||||
):
|
||||
self.spans[trace_id].update(metrics={"prompt_tokens": 0, "completion_tokens": 0})
|
||||
|
||||
self.spans[trace_id].end(output=self._convert_to_langwatch_types(outputs), error=error)
|
||||
|
||||
def end(
|
||||
|
|
@ -146,7 +123,9 @@ class LangWatchTracer(BaseTracer):
|
|||
|
||||
if metadata and "flow_name" in metadata:
|
||||
self.trace.update(metadata=(self.trace.metadata or {}) | {"labels": [f"Flow: {metadata['flow_name']}"]})
|
||||
self.trace.deferred_send_spans()
|
||||
|
||||
if self.trace.api_key or self._client.api_key:
|
||||
self.trace.deferred_send_spans()
|
||||
|
||||
def _convert_to_langwatch_types(self, io_dict: Optional[Dict[str, Any]]):
|
||||
from langwatch.utils import autoconvert_typed_values
|
||||
|
|
@ -183,3 +162,9 @@ class LangWatchTracer(BaseTracer):
|
|||
elif isinstance(value, Data):
|
||||
value = cast(dict, value.to_lc_document())
|
||||
return value
|
||||
|
||||
def get_langchain_callback(self) -> Optional["BaseCallbackHandler"]:
|
||||
if self.trace is None:
|
||||
return None
|
||||
|
||||
return self.trace.get_langchain_callback()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -16,6 +16,7 @@ if TYPE_CHECKING:
|
|||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.services.monitor.service import MonitorService
|
||||
from langflow.services.settings.service import SettingsService
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
def _get_langsmith_tracer():
|
||||
|
|
@ -115,9 +116,7 @@ class TracingService(Service):
|
|||
|
||||
def _initialize_langwatch_tracer(self):
|
||||
if (
|
||||
os.getenv("LANGWATCH_API_KEY")
|
||||
and "langwatch" not in self._tracers
|
||||
or self._tracers["langwatch"].trace_id != self.run_id # type: ignore
|
||||
"langwatch" not in self._tracers or self._tracers["langwatch"].trace_id != self.run_id # type: ignore
|
||||
):
|
||||
langwatch_tracer = _get_langwatch_tracer()
|
||||
self._tracers["langwatch"] = langwatch_tracer(
|
||||
|
|
@ -229,3 +228,13 @@ class TracingService(Service):
|
|||
if "api_key" in key:
|
||||
inputs[key] = "*****" # avoid logging api_keys for security reasons
|
||||
return inputs
|
||||
|
||||
def get_langchain_callbacks(self) -> List["BaseCallbackHandler"]:
|
||||
callbacks = []
|
||||
for tracer in self._tracers.values():
|
||||
if not tracer.ready: # type: ignore
|
||||
continue
|
||||
langchain_callback = tracer.get_langchain_callback()
|
||||
if langchain_callback:
|
||||
callbacks.append(langchain_callback)
|
||||
return callbacks
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue