feat: add LLM output for the AI ML component to use ChatOpenAI base model (#3230)
* fix: add LLM output for the AI ML component, using ChatOpenAI base model from langchain * [autofix.ci] apply automated fixes * fix: docs * fix: unused variables --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
0ac7845a9a
commit
fcd8d379a6
1 changed files with 99 additions and 119 deletions
|
|
@ -1,148 +1,128 @@
|
|||
import json
|
||||
import httpx
|
||||
from langflow.base.models.aiml_constants import AIML_CHAT_MODELS
|
||||
from langflow.custom.custom_component.component import Component
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
from langflow.inputs.inputs import FloatInput, IntInput, MessageInput, SecretStrInput
|
||||
from langflow.schema.message import Message
|
||||
from langflow.template.field.base import Output
|
||||
from loguru import logger
|
||||
from langflow.field_typing.range_spec import RangeSpec
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langflow.base.models.aiml_constants import AIML_CHAT_MODELS
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.inputs import (
|
||||
BoolInput,
|
||||
DictInput,
|
||||
DropdownInput,
|
||||
FloatInput,
|
||||
IntInput,
|
||||
SecretStrInput,
|
||||
StrInput,
|
||||
)
|
||||
|
||||
|
||||
class AIMLModelComponent(Component):
|
||||
display_name = "AI/ML API"
|
||||
description = "Generates text using the AI/ML API"
|
||||
icon = "AI/ML"
|
||||
chat_completion_url = "https://api.aimlapi.com/v1/chat/completions"
|
||||
class AIMLModelComponent(LCModelComponent):
|
||||
display_name = "AIML"
|
||||
description = "Generates text using AIML LLMs."
|
||||
icon = "AIML"
|
||||
name = "AIMLModel"
|
||||
documentation = "https://docs.aimlapi.com/api-reference"
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Text", name="text_output", method="make_request"),
|
||||
]
|
||||
|
||||
inputs = [
|
||||
DropdownInput(
|
||||
name="model_name",
|
||||
display_name="Model Name",
|
||||
options=AIML_CHAT_MODELS,
|
||||
required=True,
|
||||
),
|
||||
SecretStrInput(
|
||||
name="aiml_api_key",
|
||||
display_name="AI/ML API Key",
|
||||
value="AIML_API_KEY",
|
||||
),
|
||||
MessageInput(name="input_value", display_name="Input", required=True),
|
||||
inputs = LCModelComponent._base_inputs + [
|
||||
IntInput(
|
||||
name="max_tokens",
|
||||
display_name="Max Tokens",
|
||||
advanced=True,
|
||||
info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
|
||||
range_spec=RangeSpec(min=0, max=128000),
|
||||
),
|
||||
DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True),
|
||||
BoolInput(
|
||||
name="json_mode",
|
||||
display_name="JSON Mode",
|
||||
advanced=True,
|
||||
info="If True, it will output JSON regardless of passing a schema.",
|
||||
),
|
||||
DictInput(
|
||||
name="output_schema",
|
||||
is_list=True,
|
||||
display_name="Schema",
|
||||
advanced=True,
|
||||
info="The schema for the Output of the model. You must pass the word JSON in the prompt. If left blank, JSON mode will be disabled.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="model_name",
|
||||
display_name="Model Name",
|
||||
advanced=False,
|
||||
options=AIML_CHAT_MODELS,
|
||||
value=AIML_CHAT_MODELS[0],
|
||||
),
|
||||
StrInput(
|
||||
name="stop_tokens",
|
||||
display_name="Stop Tokens",
|
||||
info="Comma-separated list of tokens to signal the model to stop generating text.",
|
||||
name="aiml_api_base",
|
||||
display_name="AIML API Base",
|
||||
advanced=True,
|
||||
info="The base URL of the OpenAI API. Defaults to https://api.aimlapi.com . You can change this to use other APIs like JinaChat, LocalAI e Prem.",
|
||||
),
|
||||
SecretStrInput(
|
||||
name="api_key",
|
||||
display_name="AIML API Key",
|
||||
info="The AIML API Key to use for the OpenAI model.",
|
||||
advanced=False,
|
||||
value="AIML_API_KEY",
|
||||
),
|
||||
FloatInput(name="temperature", display_name="Temperature", value=0.1),
|
||||
IntInput(
|
||||
name="top_k",
|
||||
display_name="Top K",
|
||||
info="Limits token selection to top K. (Default: 40)",
|
||||
advanced=True,
|
||||
),
|
||||
FloatInput(
|
||||
name="top_p",
|
||||
display_name="Top P",
|
||||
info="Works together with top-k. (Default: 0.9)",
|
||||
advanced=True,
|
||||
),
|
||||
FloatInput(
|
||||
name="repeat_penalty",
|
||||
display_name="Repeat Penalty",
|
||||
info="Penalty for repetitions in generated text. (Default: 1.1)",
|
||||
advanced=True,
|
||||
),
|
||||
FloatInput(
|
||||
name="temperature",
|
||||
display_name="Temperature",
|
||||
value=0.2,
|
||||
info="Controls the creativity of model responses.",
|
||||
),
|
||||
StrInput(
|
||||
name="system_message",
|
||||
display_name="System Message",
|
||||
info="System message to pass to the model.",
|
||||
name="seed",
|
||||
display_name="Seed",
|
||||
info="The seed controls the reproducibility of the job.",
|
||||
advanced=True,
|
||||
value=1,
|
||||
),
|
||||
]
|
||||
|
||||
def make_request(self) -> Message:
|
||||
api_key = SecretStr(self.aiml_api_key) if self.aiml_api_key else None
|
||||
def build_model(self) -> LanguageModel: # type: ignore[type-var]
|
||||
output_schema_dict: dict[str, str] = reduce(operator.ior, self.output_schema or {}, {})
|
||||
aiml_api_key = self.api_key
|
||||
temperature = self.temperature
|
||||
model_name: str = self.model_name
|
||||
max_tokens = self.max_tokens
|
||||
model_kwargs = self.model_kwargs or {}
|
||||
aiml_api_base = self.aiml_api_base or "https://api.aimlapi.com"
|
||||
json_mode = bool(output_schema_dict) or self.json_mode
|
||||
seed = self.seed
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key.get_secret_value()}" if api_key else "",
|
||||
}
|
||||
|
||||
messages = []
|
||||
if self.system_message:
|
||||
messages.append({"role": "system", "content": self.system_message})
|
||||
|
||||
if self.input_value:
|
||||
if isinstance(self.input_value, Message):
|
||||
# Though we aren't using langchain here, the helper method is useful
|
||||
message = self.input_value.to_lc_message()
|
||||
if message.type == "human":
|
||||
messages.append({"role": "user", "content": message.content})
|
||||
else:
|
||||
raise ValueError(f"Expected user message, saw: {message.type}")
|
||||
else:
|
||||
raise TypeError(f"Expected Message type, saw: {type(self.input_value)}")
|
||||
if isinstance(aiml_api_key, SecretStr):
|
||||
openai_api_key = aiml_api_key.get_secret_value()
|
||||
else:
|
||||
raise ValueError("Please provide an input value")
|
||||
openai_api_key = aiml_api_key
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens or None,
|
||||
"temperature": self.temperature or 0.2,
|
||||
"top_k": self.top_k or 40,
|
||||
"top_p": self.top_p or 0.9,
|
||||
"repeat_penalty": self.repeat_penalty or 1.1,
|
||||
"stop_tokens": self.stop_tokens or None,
|
||||
}
|
||||
model = ChatOpenAI(
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
api_key=openai_api_key,
|
||||
base_url=aiml_api_base,
|
||||
max_tokens=max_tokens or None,
|
||||
seed=seed,
|
||||
json_mode=json_mode,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return model # type: ignore
|
||||
|
||||
def _get_exception_message(self, e: Exception):
|
||||
"""
|
||||
Get a message from an OpenAI exception.
|
||||
|
||||
Args:
|
||||
exception (Exception): The exception to get the message from.
|
||||
|
||||
Returns:
|
||||
str: The message from the exception.
|
||||
"""
|
||||
try:
|
||||
response = httpx.post(self.chat_completion_url, headers=headers, json=payload)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
result_data = response.json()
|
||||
choice = result_data["choices"][0]
|
||||
result = choice["message"]["content"]
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
logger.error(f"HTTP error occurred: {http_err}")
|
||||
raise http_err
|
||||
except httpx.RequestError as req_err:
|
||||
logger.error(f"Request error occurred: {req_err}")
|
||||
raise req_err
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to decode JSON, response text: {response.text}")
|
||||
result = response.text
|
||||
except KeyError as key_err:
|
||||
logger.warning(f"Key error: {key_err}, response content: {result_data}")
|
||||
raise key_err
|
||||
|
||||
self.status = result
|
||||
except httpx.TimeoutException:
|
||||
return Message(text="Request timed out.")
|
||||
except Exception as exc:
|
||||
logger.error(f"Error: {exc}")
|
||||
raise
|
||||
|
||||
return Message(text=result)
|
||||
from openai.error import BadRequestError
|
||||
except ImportError:
|
||||
return None
|
||||
if isinstance(e, BadRequestError):
|
||||
message = e.json_body.get("error", {}).get("message", "") # type: ignore
|
||||
if message:
|
||||
return message
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue