fix: Do a better job of mapping Langchain to LiteLLM (#5233)

This commit is contained in:
Eric Hare 2024-12-12 10:56:39 -08:00 committed by GitHub
commit a17335e802
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,6 +1,7 @@
from collections.abc import Callable
from typing import Any, cast
import litellm
from crewai import LLM, Agent, Crew, Process, Task
from crewai.task import TaskOutput
from crewai.tools.base_tool import Tool
@ -70,9 +71,12 @@ def convert_llm(llm: Any, excluded_keys=None) -> LLM:
msg = "Could not find model name in the LLM object"
raise ValueError(msg)
# Normalize Ollama with prefix TODO: Handle all litellm supported models
if llm.dict().get("_type") == "chat-ollama":
model_name = f"ollama/{model_name}"
# Normalize to the LLM model name
# Remove langchain_ prefix if present
provider = llm.get_lc_namespace()[0]
if provider.startswith("langchain_"):
provider = provider[10:]
model_name = f"{provider}/{model_name}"
# Retrieve the API Key from the LLM
if excluded_keys is None:
@ -188,8 +192,13 @@ class BaseCrewComponent(Component):
return step_callback
async def build_output(self) -> Message:
crew = self.build_crew()
result = await crew.kickoff_async()
message = Message(text=result.raw, sender=MESSAGE_SENDER_AI)
try:
crew = self.build_crew()
result = await crew.kickoff_async()
message = Message(text=result.raw, sender=MESSAGE_SENDER_AI)
except litellm.exceptions.BadRequestError as e:
raise ValueError(e) from e
self.status = message
return message