fix: added model utils to get the model name (#4532)

* added model utils

model utils to find model name

* adding deafult

updated the logic

* Refactor get_model_name function to simplify logic

---------

Co-authored-by: anovazzi1 <otavio2204@gmail.com>
This commit is contained in:
Edwin Jose 2024-11-12 20:27:28 -05:00 committed by GitHub
commit da833a566a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 14 additions and 3 deletions

View file

@ -0,0 +1,8 @@
def get_model_name(llm, display_name: str | None = "Custom"):
attributes_to_check = ["model_name", "model", "model_id", "deployment_name"]
# Use a generator expression with next() to find the first matching attribute
model_name = next((getattr(llm, attr) for attr in attributes_to_check if hasattr(llm, attr)), None)
# If no matching attribute is found, return the class name as a fallback
return model_name if model_name is not None else display_name

View file

@ -2,6 +2,7 @@ from langchain_core.tools import StructuredTool
from langflow.base.agents.agent import LCToolsAgentComponent
from langflow.base.models.model_input_constants import ALL_PROVIDER_FIELDS, MODEL_PROVIDERS_DICT
from langflow.base.models.model_utils import get_model_name
from langflow.components.helpers import CurrentDateComponent
from langflow.components.langchain_utilities.tool_calling import ToolCallingAgentComponent
from langflow.components.memories.memory import MemoryComponent
@ -55,7 +56,8 @@ class AgentComponent(ToolCallingAgentComponent):
outputs = [Output(name="response", display_name="Response", method="message_response")]
async def message_response(self) -> Message:
llm_model = self.get_llm()
llm_model, display_name = self.get_llm()
self.model_name = get_model_name(llm_model, display_name=display_name)
if llm_model is None:
msg = "No language model selected"
raise ValueError(msg)
@ -98,13 +100,14 @@ class AgentComponent(ToolCallingAgentComponent):
provider_info = MODEL_PROVIDERS_DICT.get(self.agent_llm)
if provider_info:
component_class = provider_info.get("component_class")
display_name = component_class.display_name
inputs = provider_info.get("inputs")
prefix = provider_info.get("prefix", "")
return self._build_llm_model(component_class, inputs, prefix)
return self._build_llm_model(component_class, inputs, prefix), display_name
except Exception as e:
msg = f"Error building {self.agent_llm} language model"
raise ValueError(msg) from e
return self.agent_llm
return self.agent_llm, None
def _build_llm_model(self, component, inputs, prefix=""):
model_kwargs = {input_.name: getattr(self, f"{prefix}{input_.name}") for input_ in inputs}