feat: improve LangWatch integration by introducing langchain callbacks on the tracing service, and component and workflow span types (#3094)

* Improve LangWatch integration by introducing langchain callbacks on the tracing service, and component and workflow span types. Bump LangWatch to v0.1.14

* [autofix.ci] apply automated fixes

* Fix type checks

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Rogério Chaves 2024-08-01 15:28:40 +02:00 committed by GitHub
commit 916fca4051
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 96 additions and 69 deletions

View file

@ -98,7 +98,9 @@ class LCAgentComponent(Component):
self.chat_history = self.get_chat_history_data()
if self.chat_history:
input_dict["chat_history"] = data_to_messages(self.chat_history)
result = await agent.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
result = await agent.ainvoke(
input_dict, config={"callbacks": [AgentAsyncHandler(self.log)] + self.get_langchain_callbacks()}
)
self.status = result
if "output" not in result:
raise ValueError("Output key not found in result. Tried 'output'.")
@ -141,7 +143,10 @@ class LCToolsAgentComponent(LCAgentComponent):
input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value}
if self.chat_history:
input_dict["chat_history"] = data_to_messages(self.chat_history)
result = await runnable.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
result = await runnable.ainvoke(
input_dict, config={"callbacks": [AgentAsyncHandler(self.log)] + self.get_langchain_callbacks()}
)
self.status = result
if "output" not in result:
raise ValueError("Output key not found in result. Tried 'output'.")

View file

@ -164,7 +164,11 @@ class LCModelComponent(Component):
inputs: Union[list, dict] = messages or {}
try:
runnable = runnable.with_config( # type: ignore
{"run_name": self.display_name, "project_name": self._tracing_service.project_name} # type: ignore
{
"run_name": self.display_name,
"project_name": self._tracing_service.project_name, # type: ignore
"callbacks": self.get_langchain_callbacks(),
}
)
if stream:
return runnable.stream(inputs) # type: ignore

View file

@ -28,7 +28,7 @@ class ConversationChainComponent(LCChainComponent):
else:
chain = ConversationChain(llm=self.llm, memory=self.memory)
result = chain.invoke({"input": self.input_value})
result = chain.invoke({"input": self.input_value}, config={"callbacks": self.get_langchain_callbacks()})
if isinstance(result, dict):
result = result.get(chain.output_key, "") # type: ignore

View file

@ -20,7 +20,9 @@ class LLMCheckerChainComponent(LCChainComponent):
def invoke_chain(self) -> Message:
chain = LLMCheckerChain.from_llm(llm=self.llm)
response = chain.invoke({chain.input_key: self.input_value})
response = chain.invoke(
{chain.input_key: self.input_value}, config={"callbacks": self.get_langchain_callbacks()}
)
result = response.get(chain.output_key, "")
result = str(result)
self.status = result

View file

@ -23,7 +23,9 @@ class LLMMathChainComponent(LCChainComponent):
def invoke_chain(self) -> Message:
chain = LLMMathChain.from_llm(llm=self.llm)
response = chain.invoke({chain.input_key: self.input_value})
response = chain.invoke(
{chain.input_key: self.input_value}, config={"callbacks": self.get_langchain_callbacks()}
)
result = response.get(chain.output_key, "")
result = str(result)
self.status = result

View file

@ -52,7 +52,7 @@ class RetrievalQAComponent(LCChainComponent):
return_source_documents=True,
)
result = runnable.invoke({"query": self.input_value})
result = runnable.invoke({"query": self.input_value}, config={"callbacks": self.get_langchain_callbacks()})
source_docs = self.to_data(result.get("source_documents", []))
result_str = str(result.get("result", ""))

View file

@ -43,7 +43,9 @@ class SQLGeneratorComponent(LCChainComponent):
raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.")
sql_query_chain = create_sql_query_chain(llm=self.llm, db=self.db, prompt=prompt_template, k=self.top_k)
query_writer: Runnable = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()}
response = query_writer.invoke({"question": self.input_value})
response = query_writer.invoke(
{"question": self.input_value}, config={"callbacks": self.get_langchain_callbacks()}
)
query = response.get("query")
self.status = query
return query

View file

@ -17,7 +17,10 @@ class ShouldRunNextComponent(CustomComponent):
chain = prompt | llm
error_message = ""
for i in range(retries):
result = chain.invoke(dict(question=question, context=context, error_message=error_message))
result = chain.invoke(
dict(question=question, context=context, error_message=error_message),
config={"callbacks": self.get_langchain_callbacks()},
)
if isinstance(result, BaseMessage):
content = result.content
elif isinstance(result, str):

View file

@ -46,7 +46,7 @@ class CohereRerankComponent(LCVectorStoreComponent):
async def search_documents(self) -> List[Data]: # type: ignore
retriever = self.build_base_retriever()
documents = await retriever.ainvoke(self.search_query)
documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
data = self.to_data(documents)
self.status = data
return data

View file

@ -58,7 +58,7 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
async def search_documents(self) -> List[Data]: # type: ignore
retriever = self.build_base_retriever()
documents = await retriever.ainvoke(self.search_query)
documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
data = self.to_data(documents)
self.status = data
return data

View file

@ -64,7 +64,7 @@ class SelfQueryRetrieverComponent(CustomComponent):
if not isinstance(query, str):
raise ValueError(f"Query type {type(query)} not supported.")
documents = self_query_retriever.invoke(input=input_text)
documents = self_query_retriever.invoke(input=input_text, config={"callbacks": self.get_langchain_callbacks()})
data = [Data.from_document(document) for document in documents]
self.status = data
return data

View file

@ -120,7 +120,7 @@ class VectaraRagComponent(Component):
rerank_config=rerank_config,
)
rag = vectara.as_rag(config)
response = rag.invoke(self.search_query)
response = rag.invoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
text_output = response["answer"]

View file

@ -28,6 +28,7 @@ if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
from langflow.services.storage.service import StorageService
from langflow.services.tracing.service import TracingService
from langchain.callbacks.base import BaseCallbackHandler
class CustomComponent(BaseComponent):
@ -528,3 +529,8 @@ class CustomComponent(BaseComponent):
frontend_node=new_frontend_node, raw_frontend_node=current_frontend_node
)
return frontend_node
def get_langchain_callbacks(self) -> List["BaseCallbackHandler"]:
if self._tracing_service:
return self._tracing_service.get_langchain_callbacks()
return []

View file

@ -136,7 +136,9 @@ class Data(BaseModel):
contents = [{"type": "text", "text": text}]
for file_path in files:
image_template = ImagePromptTemplate()
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file_path}) # type: ignore
image_prompt_value: ImagePromptValue = image_template.invoke(
input={"path": file_path}, config={"callbacks": self.get_langchain_callbacks()}
) # type: ignore
contents.append({"type": "image_url", "image_url": image_prompt_value.image_url})
human_message = HumanMessage(content=contents) # type: ignore
else:

View file

@ -165,7 +165,9 @@ class Message(Data):
content_dicts.append(file.to_content_dict())
else:
image_template = ImagePromptTemplate()
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file})
image_prompt_value: ImagePromptValue = image_template.invoke(
input={"path": file}, config={"callbacks": self.get_langchain_callbacks()}
) # type: ignore
content_dicts.append({"type": "image_url", "image_url": image_prompt_value.image_url})
return content_dicts

View file

@ -6,6 +6,7 @@ from langflow.services.tracing.schema import Log
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
from langchain.callbacks.base import BaseCallbackHandler
class BaseTracer(ABC):
@ -49,3 +50,7 @@ class BaseTracer(ABC):
metadata: dict[str, Any] | None = None,
):
raise NotImplementedError
@abstractmethod
def get_langchain_callback(self) -> Optional["BaseCallbackHandler"]:
raise NotImplementedError

View file

@ -13,6 +13,7 @@ from langflow.services.tracing.schema import Log
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
from langchain.callbacks.base import BaseCallbackHandler
class LangSmithTracer(BaseTracer):
@ -158,3 +159,6 @@ class LangSmithTracer(BaseTracer):
self._run_tree.end(outputs=outputs, error=self._error_to_string(error))
self._run_tree.post()
self._run_link = self._run_tree.get_url()
def get_langchain_callback(self) -> Optional["BaseCallbackHandler"]:
return None

View file

@ -1,4 +1,3 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
from uuid import UUID
@ -11,9 +10,9 @@ from langflow.services.tracing.schema import Log
if TYPE_CHECKING:
from langwatch.tracer import ContextSpan
from langwatch.types import SpanTypes
from langflow.graph.vertex.base import Vertex
from langchain.callbacks.base import BaseCallbackHandler
class LangWatchTracer(BaseTracer):
@ -40,7 +39,7 @@ class LangWatchTracer(BaseTracer):
self.trace.root_span.update(
span_id=f"{self.flow_id}-{nanoid.generate(size=6)}", # nanoid to make the span_id globally unique, which is required for LangWatch for now
name=name_without_id,
type=self._convert_trace_type(trace_type),
type="workflow",
)
except Exception as e:
logger.debug(f"Error setting up LangWatch tracer: {e}")
@ -51,8 +50,6 @@ class LangWatchTracer(BaseTracer):
return self._ready
def setup_langwatch(self):
if os.getenv("LANGWATCH_API_KEY") is None:
return False
try:
import langwatch
@ -62,14 +59,6 @@ class LangWatchTracer(BaseTracer):
return False
return True
def _convert_trace_type(self, trace_type: str):
trace_type_: "SpanTypes" = (
cast("SpanTypes", trace_type)
if trace_type in ["span", "llm", "chain", "tool", "agent", "guardrail", "rag"]
else "span"
)
return trace_type_
def add_trace(
self,
trace_id: str,
@ -88,23 +77,21 @@ class LangWatchTracer(BaseTracer):
name_without_id = " (".join(trace_name.split(" (")[0:-1])
trace_type_ = self._convert_trace_type(trace_type)
self.spans[trace_id] = self.trace.span(
span_id=f"{trace_id}-{nanoid.generate(size=6)}", # Add a nanoid to make the span_id globally unique, which is required for LangWatch for now
name=name_without_id,
type=trace_type_,
parent=(
[span for key, span in self.spans.items() for edge in vertex.incoming_edges if key == edge.source_id][
-1
]
if vertex and len(vertex.incoming_edges) > 0
else self.trace.root_span
),
input=self._convert_to_langwatch_types(inputs),
previous_nodes = (
[span for key, span in self.spans.items() for edge in vertex.incoming_edges if key == edge.source_id]
if vertex and len(vertex.incoming_edges) > 0
else []
)
if trace_type_ == "llm" and "model_name" in inputs:
self.spans[trace_id].update(model=inputs["model_name"])
span = self.trace.span(
span_id=f"{trace_id}-{nanoid.generate(size=6)}", # Add a nanoid to make the span_id globally unique, which is required for LangWatch for now
name=name_without_id,
type="component",
parent=(previous_nodes[-1] if len(previous_nodes) > 0 else self.trace.root_span),
input=self._convert_to_langwatch_types(inputs),
)
self.trace.set_current_span(span)
self.spans[trace_id] = span
def end_trace(
self,
@ -117,16 +104,6 @@ class LangWatchTracer(BaseTracer):
if not self._ready:
return
if self.spans.get(trace_id):
# Workaround for when model is used just as a component not actually called as an LLM,
# to prevent LangWatch from calculating the cost based on it when it was in fact never called
if (
self.spans[trace_id].type == "llm"
and outputs
and "model_output" in outputs
and "text_output" not in outputs
):
self.spans[trace_id].update(metrics={"prompt_tokens": 0, "completion_tokens": 0})
self.spans[trace_id].end(output=self._convert_to_langwatch_types(outputs), error=error)
def end(
@ -146,7 +123,9 @@ class LangWatchTracer(BaseTracer):
if metadata and "flow_name" in metadata:
self.trace.update(metadata=(self.trace.metadata or {}) | {"labels": [f"Flow: {metadata['flow_name']}"]})
self.trace.deferred_send_spans()
if self.trace.api_key or self._client.api_key:
self.trace.deferred_send_spans()
def _convert_to_langwatch_types(self, io_dict: Optional[Dict[str, Any]]):
from langwatch.utils import autoconvert_typed_values
@ -183,3 +162,9 @@ class LangWatchTracer(BaseTracer):
elif isinstance(value, Data):
value = cast(dict, value.to_lc_document())
return value
def get_langchain_callback(self) -> Optional["BaseCallbackHandler"]:
if self.trace is None:
return None
return self.trace.get_langchain_callback()

View file

@ -2,7 +2,7 @@ import asyncio
import os
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional, List
from uuid import UUID
from loguru import logger
@ -16,6 +16,7 @@ if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
from langflow.services.monitor.service import MonitorService
from langflow.services.settings.service import SettingsService
from langchain.callbacks.base import BaseCallbackHandler
def _get_langsmith_tracer():
@ -115,9 +116,7 @@ class TracingService(Service):
def _initialize_langwatch_tracer(self):
if (
os.getenv("LANGWATCH_API_KEY")
and "langwatch" not in self._tracers
or self._tracers["langwatch"].trace_id != self.run_id # type: ignore
"langwatch" not in self._tracers or self._tracers["langwatch"].trace_id != self.run_id # type: ignore
):
langwatch_tracer = _get_langwatch_tracer()
self._tracers["langwatch"] = langwatch_tracer(
@ -229,3 +228,13 @@ class TracingService(Service):
if "api_key" in key:
inputs[key] = "*****" # avoid logging api_keys for security reasons
return inputs
def get_langchain_callbacks(self) -> List["BaseCallbackHandler"]:
callbacks = []
for tracer in self._tracers.values():
if not tracer.ready: # type: ignore
continue
langchain_callback = tracer.get_langchain_callback()
if langchain_callback:
callbacks.append(langchain_callback)
return callbacks