feat: add google generative ai models to agent component (#5853)

* add google generative ai to agents

* [autofix.ci] apply automated fixes

* format error fixed

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Edwin Jose 2025-01-21 17:08:21 -05:00 committed by GitHub
commit 5ce4f514e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 101 additions and 4 deletions

View file

@ -4,6 +4,7 @@ from langflow.base.models.model import LCModelComponent
from langflow.components.models.amazon_bedrock import AmazonBedrockComponent
from langflow.components.models.anthropic import AnthropicModelComponent
from langflow.components.models.azure_openai import AzureChatOpenAIComponent
from langflow.components.models.google_generative_ai import GoogleGenerativeAIComponent
from langflow.components.models.groq import GroqModel
from langflow.components.models.nvidia import NVIDIAModelComponent
from langflow.components.models.openai import OpenAIModelComponent
@ -66,6 +67,20 @@ def create_input_fields_dict(inputs: list[Input], prefix: str) -> dict[str, Inpu
return {f"{prefix}{input_.name}": input_.to_dict() for input_ in inputs}
def _get_google_generative_ai_inputs_and_fields():
try:
from langflow.components.models.google_generative_ai import GoogleGenerativeAIComponent
google_generative_ai_inputs = get_filtered_inputs(GoogleGenerativeAIComponent)
except ImportError as e:
msg = (
"Google Generative AI is not installed. Please install it with "
"`pip install langchain-google-generative-ai`."
)
raise ImportError(msg) from e
return google_generative_ai_inputs, create_input_fields_dict(google_generative_ai_inputs, "")
def _get_openai_inputs_and_fields():
try:
from langflow.components.models.openai import OpenAIModelComponent
@ -201,6 +216,16 @@ try:
except ImportError:
pass
try:
google_generative_ai_inputs, google_generative_ai_fields = _get_google_generative_ai_inputs_and_fields()
MODEL_PROVIDERS_DICT["Google Generative AI"] = {
"fields": google_generative_ai_fields,
"inputs": google_generative_ai_inputs,
"prefix": "",
"component_class": GoogleGenerativeAIComponent(),
}
except ImportError:
pass
MODEL_PROVIDERS = list(MODEL_PROVIDERS_DICT.keys())
ALL_PROVIDER_FIELDS: list[str] = [field for provider in MODEL_PROVIDERS_DICT.values() for field in provider["fields"]]

View file

@ -1,3 +1,7 @@
from typing import Any
import requests
from loguru import logger
from pydantic.v1 import SecretStr
from langflow.base.models.google_generative_ai_constants import GOOGLE_GENERATIVE_AI_MODELS
@ -5,6 +9,8 @@ from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.field_typing.range_spec import RangeSpec
from langflow.inputs import DropdownInput, FloatInput, IntInput, SecretStrInput, SliderInput
from langflow.inputs.inputs import BoolInput
from langflow.schema import dotdict
class GoogleGenerativeAIComponent(LCModelComponent):
@ -19,17 +25,20 @@ class GoogleGenerativeAIComponent(LCModelComponent):
name="max_output_tokens", display_name="Max Output Tokens", info="The maximum number of tokens to generate."
),
DropdownInput(
name="model",
name="model_name",
display_name="Model",
info="The name of the model to use.",
options=GOOGLE_GENERATIVE_AI_MODELS,
value="gemini-1.5-pro",
refresh_button=True,
combobox=True,
),
SecretStrInput(
name="google_api_key",
name="api_key",
display_name="Google API Key",
info="The Google API Key to use for the Google Generative AI.",
required=True,
real_time_refresh=True,
),
FloatInput(
name="top_p",
@ -57,6 +66,11 @@ class GoogleGenerativeAIComponent(LCModelComponent):
info="Decode using top-k sampling: consider the set of top_k most probable tokens. Must be positive.",
advanced=True,
),
BoolInput(
name="tool_model_enabled",
display_name="Tool Model Enabled",
info="Whether to use the tool model.",
),
]
def build_model(self) -> LanguageModel: # type: ignore[type-var]
@ -66,8 +80,8 @@ class GoogleGenerativeAIComponent(LCModelComponent):
msg = "The 'langchain_google_genai' package is required to use the Google Generative AI model."
raise ImportError(msg) from e
google_api_key = self.google_api_key
model = self.model
google_api_key = self.api_key
model = self.model_name
max_output_tokens = self.max_output_tokens
temperature = self.temperature
top_k = self.top_k
@ -83,3 +97,50 @@ class GoogleGenerativeAIComponent(LCModelComponent):
n=n or 1,
google_api_key=SecretStr(google_api_key).get_secret_value(),
)
def get_models(self, tool_model_enabled: bool | None = None) -> list[str]:
try:
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model_ids = [
model.name.replace("models/", "")
for model in genai.list_models()
if "generateContent" in model.supported_generation_methods
]
model_ids.sort(reverse=True)
except (ImportError, ValueError) as e:
logger.exception(f"Error getting model names: {e}")
model_ids = GOOGLE_GENERATIVE_AI_MODELS
if tool_model_enabled:
try:
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
except ImportError as e:
msg = "langchain_google_genai is not installed."
raise ImportError(msg) from e
for model in model_ids:
model_with_tool = ChatGoogleGenerativeAI(
model=self.model_name,
google_api_key=self.api_key,
)
if not self.supports_tool_calling(model_with_tool):
model_ids.remove(model)
return model_ids
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name in ("base_url", "model_name", "tool_model_enabled", "api_key") and field_value:
try:
if len(self.api_key) == 0:
ids = GOOGLE_GENERATIVE_AI_MODELS
else:
try:
ids = self.get_models(tool_model_enabled=self.tool_model_enabled)
except (ImportError, ValueError, requests.exceptions.RequestException) as e:
logger.exception(f"Error getting model names: {e}")
ids = GOOGLE_GENERATIVE_AI_MODELS
build_config["model_name"]["options"] = ids
build_config["model_name"]["value"] = ids[0]
except Exception as e:
msg = f"Error getting model names: {e}"
raise ValueError(msg) from e
return build_config

View file

@ -1714,6 +1714,7 @@
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Google Generative AI",
"Groq",
"NVIDIA",
"OpenAI",

View file

@ -1731,6 +1731,7 @@
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Google Generative AI",
"Groq",
"NVIDIA",
"OpenAI",

View file

@ -1974,6 +1974,7 @@
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Google Generative AI",
"Groq",
"NVIDIA",
"OpenAI",

View file

@ -765,6 +765,7 @@
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Google Generative AI",
"Groq",
"NVIDIA",
"OpenAI",

View file

@ -699,6 +699,7 @@
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Google Generative AI",
"Groq",
"NVIDIA",
"OpenAI",
@ -1280,6 +1281,7 @@
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Google Generative AI",
"Groq",
"NVIDIA",
"OpenAI",
@ -2678,6 +2680,7 @@
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Google Generative AI",
"Groq",
"NVIDIA",
"OpenAI",

View file

@ -229,6 +229,7 @@
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Google Generative AI",
"Groq",
"NVIDIA",
"OpenAI",

View file

@ -1360,6 +1360,7 @@
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Google Generative AI",
"Groq",
"NVIDIA",
"OpenAI",
@ -1941,6 +1942,7 @@
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Google Generative AI",
"Groq",
"NVIDIA",
"OpenAI",
@ -2522,6 +2524,7 @@
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Google Generative AI",
"Groq",
"NVIDIA",
"OpenAI",