From 759920b0d6c928471a0d4c2566281bdba244a842 Mon Sep 17 00:00:00 2001 From: Cezar Vasconcelos Date: Fri, 14 Jun 2024 21:27:28 +0000 Subject: [PATCH] upd: OpenAIModel, add seed and json mode --- .../langflow/components/models/OpenAIModel.py | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/backend/base/langflow/components/models/OpenAIModel.py b/src/backend/base/langflow/components/models/OpenAIModel.py index 0e4b2ae46..9d35db30a 100644 --- a/src/backend/base/langflow/components/models/OpenAIModel.py +++ b/src/backend/base/langflow/components/models/OpenAIModel.py @@ -5,11 +5,11 @@ from langflow.base.constants import STREAM_INFO_TEXT from langflow.base.models.model import LCModelComponent from langflow.base.models.openai_constants import MODEL_NAMES from langflow.field_typing import BaseLanguageModel, Text -from langflow.inputs import BoolInput, DictInput, DropdownInput, FloatInput, SecretStrInput, StrInput -from langflow.inputs.inputs import IntInput +from langflow.inputs import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput from langflow.template import Output + class OpenAIModelComponent(LCModelComponent): display_name = "OpenAI" description = "Generates text using OpenAI LLMs." @@ -31,7 +31,7 @@ class OpenAIModelComponent(LCModelComponent): name="openai_api_base", display_name="OpenAI API Base", advanced=True, - info="The base URL of the OpenAI API. Defaults to https://api.openai.com/v1.\n\nYou can change this to use other APIs like JinaChat, LocalAI and Prem.", + info="The base URL of the OpenAI API. Defaults to https://api.openai.com/v1. You can change this to use other APIs like JinaChat, LocalAI and Prem.", ), SecretStrInput( name="openai_api_key", @@ -48,6 +48,19 @@ class OpenAIModelComponent(LCModelComponent): info="System message to pass to the model.", advanced=True, ), + BoolInput( + name="json_mode", + display_name="JSON Mode", + info="Enable JSON mode for the model output.", + advanced=True, + ), + IntInput( + name="seed", + display_name="Seed", + info="The seed controls the reproducibility of the job.", + advanced=True, + value=1, + ), ] outputs = [ Output(display_name="Text", name="text_output", method="text_response"), @@ -70,12 +83,15 @@ class OpenAIModelComponent(LCModelComponent): max_tokens = self.max_tokens model_kwargs = self.model_kwargs openai_api_base = self.openai_api_base or "https://api.openai.com/v1" - + json_mode = self.json_mode + seed = self.seed if openai_api_key: api_key = SecretStr(openai_api_key) else: api_key = None - + response_format = None + if json_mode: + response_format = {"type": "json_object"} output = ChatOpenAI( max_tokens=max_tokens or None, model_kwargs=model_kwargs or {}, @@ -83,5 +99,8 @@ class OpenAIModelComponent(LCModelComponent): base_url=openai_api_base, api_key=api_key, temperature=temperature or 0.1, + response_format=response_format, + seed=seed, ) + return output