From fcd8d379a6c2c55cad5077d45fd1335296a20428 Mon Sep 17 00:00:00 2001 From: Cezar Vasconcelos <97035956+vasconceloscezar@users.noreply.github.com> Date: Wed, 7 Aug 2024 17:32:49 -0300 Subject: [PATCH] feat: add LLM output for the AI ML component to use ChatOpenAI base model (#3230) * fix: add LLM output for the AI ML component, using ChatOpenAI base model from langchain * [autofix.ci] apply automated fixes * fix: docs * fix: unused variables --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../langflow/components/models/AIMLModel.py | 218 ++++++++---------- 1 file changed, 99 insertions(+), 119 deletions(-) diff --git a/src/backend/base/langflow/components/models/AIMLModel.py b/src/backend/base/langflow/components/models/AIMLModel.py index df0904590..bc866f6b1 100644 --- a/src/backend/base/langflow/components/models/AIMLModel.py +++ b/src/backend/base/langflow/components/models/AIMLModel.py @@ -1,148 +1,128 @@ -import json -import httpx -from langflow.base.models.aiml_constants import AIML_CHAT_MODELS -from langflow.custom.custom_component.component import Component +import operator +from functools import reduce -from langflow.inputs.inputs import FloatInput, IntInput, MessageInput, SecretStrInput -from langflow.schema.message import Message -from langflow.template.field.base import Output -from loguru import logger +from langflow.field_typing.range_spec import RangeSpec +from langchain_openai import ChatOpenAI from pydantic.v1 import SecretStr +from langflow.base.models.aiml_constants import AIML_CHAT_MODELS +from langflow.base.models.model import LCModelComponent +from langflow.field_typing import LanguageModel from langflow.inputs import ( + BoolInput, + DictInput, DropdownInput, + FloatInput, + IntInput, + SecretStrInput, StrInput, ) -class AIMLModelComponent(Component): - display_name = "AI/ML API" - description = "Generates text using the AI/ML API" - icon = "AI/ML" - chat_completion_url = "https://api.aimlapi.com/v1/chat/completions" +class AIMLModelComponent(LCModelComponent): + display_name = "AIML" + description = "Generates text using AIML LLMs." + icon = "AIML" + name = "AIMLModel" + documentation = "https://docs.aimlapi.com/api-reference" - outputs = [ - Output(display_name="Text", name="text_output", method="make_request"), - ] - - inputs = [ - DropdownInput( - name="model_name", - display_name="Model Name", - options=AIML_CHAT_MODELS, - required=True, - ), - SecretStrInput( - name="aiml_api_key", - display_name="AI/ML API Key", - value="AIML_API_KEY", - ), - MessageInput(name="input_value", display_name="Input", required=True), + inputs = LCModelComponent._base_inputs + [ IntInput( name="max_tokens", display_name="Max Tokens", advanced=True, info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.", + range_spec=RangeSpec(min=0, max=128000), + ), + DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True), + BoolInput( + name="json_mode", + display_name="JSON Mode", + advanced=True, + info="If True, it will output JSON regardless of passing a schema.", + ), + DictInput( + name="output_schema", + is_list=True, + display_name="Schema", + advanced=True, + info="The schema for the Output of the model. You must pass the word JSON in the prompt. If left blank, JSON mode will be disabled.", + ), + DropdownInput( + name="model_name", + display_name="Model Name", + advanced=False, + options=AIML_CHAT_MODELS, + value=AIML_CHAT_MODELS[0], ), StrInput( - name="stop_tokens", - display_name="Stop Tokens", - info="Comma-separated list of tokens to signal the model to stop generating text.", + name="aiml_api_base", + display_name="AIML API Base", advanced=True, + info="The base URL of the OpenAI API. Defaults to https://api.aimlapi.com . You can change this to use other APIs like JinaChat, LocalAI e Prem.", ), + SecretStrInput( + name="api_key", + display_name="AIML API Key", + info="The AIML API Key to use for the OpenAI model.", + advanced=False, + value="AIML_API_KEY", + ), + FloatInput(name="temperature", display_name="Temperature", value=0.1), IntInput( - name="top_k", - display_name="Top K", - info="Limits token selection to top K. (Default: 40)", - advanced=True, - ), - FloatInput( - name="top_p", - display_name="Top P", - info="Works together with top-k. (Default: 0.9)", - advanced=True, - ), - FloatInput( - name="repeat_penalty", - display_name="Repeat Penalty", - info="Penalty for repetitions in generated text. (Default: 1.1)", - advanced=True, - ), - FloatInput( - name="temperature", - display_name="Temperature", - value=0.2, - info="Controls the creativity of model responses.", - ), - StrInput( - name="system_message", - display_name="System Message", - info="System message to pass to the model.", + name="seed", + display_name="Seed", + info="The seed controls the reproducibility of the job.", advanced=True, + value=1, ), ] - def make_request(self) -> Message: - api_key = SecretStr(self.aiml_api_key) if self.aiml_api_key else None + def build_model(self) -> LanguageModel: # type: ignore[type-var] + output_schema_dict: dict[str, str] = reduce(operator.ior, self.output_schema or {}, {}) + aiml_api_key = self.api_key + temperature = self.temperature + model_name: str = self.model_name + max_tokens = self.max_tokens + model_kwargs = self.model_kwargs or {} + aiml_api_base = self.aiml_api_base or "https://api.aimlapi.com" + json_mode = bool(output_schema_dict) or self.json_mode + seed = self.seed - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key.get_secret_value()}" if api_key else "", - } - - messages = [] - if self.system_message: - messages.append({"role": "system", "content": self.system_message}) - - if self.input_value: - if isinstance(self.input_value, Message): - # Though we aren't using langchain here, the helper method is useful - message = self.input_value.to_lc_message() - if message.type == "human": - messages.append({"role": "user", "content": message.content}) - else: - raise ValueError(f"Expected user message, saw: {message.type}") - else: - raise TypeError(f"Expected Message type, saw: {type(self.input_value)}") + if isinstance(aiml_api_key, SecretStr): + openai_api_key = aiml_api_key.get_secret_value() else: - raise ValueError("Please provide an input value") + openai_api_key = aiml_api_key - payload = { - "model": self.model_name, - "messages": messages, - "max_tokens": self.max_tokens or None, - "temperature": self.temperature or 0.2, - "top_k": self.top_k or 40, - "top_p": self.top_p or 0.9, - "repeat_penalty": self.repeat_penalty or 1.1, - "stop_tokens": self.stop_tokens or None, - } + model = ChatOpenAI( + model=model_name, + temperature=temperature, + api_key=openai_api_key, + base_url=aiml_api_base, + max_tokens=max_tokens or None, + seed=seed, + json_mode=json_mode, + **model_kwargs, + ) + return model # type: ignore + + def _get_exception_message(self, e: Exception): + """ + Get a message from an OpenAI exception. + + Args: + exception (Exception): The exception to get the message from. + + Returns: + str: The message from the exception. + """ try: - response = httpx.post(self.chat_completion_url, headers=headers, json=payload) - try: - response.raise_for_status() - result_data = response.json() - choice = result_data["choices"][0] - result = choice["message"]["content"] - except httpx.HTTPStatusError as http_err: - logger.error(f"HTTP error occurred: {http_err}") - raise http_err - except httpx.RequestError as req_err: - logger.error(f"Request error occurred: {req_err}") - raise req_err - except json.JSONDecodeError: - logger.warning("Failed to decode JSON, response text: {response.text}") - result = response.text - except KeyError as key_err: - logger.warning(f"Key error: {key_err}, response content: {result_data}") - raise key_err - - self.status = result - except httpx.TimeoutException: - return Message(text="Request timed out.") - except Exception as exc: - logger.error(f"Error: {exc}") - raise - - return Message(text=result) + from openai.error import BadRequestError + except ImportError: + return None + if isinstance(e, BadRequestError): + message = e.json_body.get("error", {}).get("message", "") # type: ignore + if message: + return message + return None