refactor: Update OpenAIModelComponent to support JSON mode with schema
This commit is contained in:
parent
84c1320c82
commit
f49ebe5f9b
2 changed files with 17 additions and 7 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import warnings
|
||||
from typing import Optional, Union
|
||||
|
||||
|
|
@ -108,10 +109,13 @@ class LCModelComponent(Component):
|
|||
return runnable.stream(inputs)
|
||||
else:
|
||||
message = runnable.invoke(inputs)
|
||||
result = message.content
|
||||
result = message.content if hasattr(message, "content") else message
|
||||
if isinstance(message, AIMessage):
|
||||
status_message = self.build_status_message(message)
|
||||
self.status = status_message
|
||||
elif isinstance(result, dict):
|
||||
result = json.dumps(message, indent=4)
|
||||
self.status = result
|
||||
else:
|
||||
self.status = result
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from langchain_openai import ChatOpenAI
|
||||
from langflow.schema.message import Message
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
|
|
@ -32,6 +33,13 @@ class OpenAIModelComponent(LCModelComponent):
|
|||
info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
|
||||
),
|
||||
DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True),
|
||||
DictInput(
|
||||
name="schema",
|
||||
is_list=True,
|
||||
display_name="Schema",
|
||||
advanced=True,
|
||||
info="The schema for the Output of the model. You must pass the word JSON in the prompt. If left blank, JSON mode will be disabled.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="model_name", display_name="Model Name", advanced=False, options=MODEL_NAMES, value=MODEL_NAMES[0]
|
||||
),
|
||||
|
|
@ -75,7 +83,7 @@ class OpenAIModelComponent(LCModelComponent):
|
|||
Output(display_name="Language Model", name="model_output", method="build_model"),
|
||||
]
|
||||
|
||||
def text_response(self) -> Text:
|
||||
def text_response(self) -> Message:
|
||||
input_value = self.input_value
|
||||
stream = self.stream
|
||||
system_message = self.system_message
|
||||
|
|
@ -91,15 +99,12 @@ class OpenAIModelComponent(LCModelComponent):
|
|||
max_tokens = self.max_tokens
|
||||
model_kwargs = self.model_kwargs
|
||||
openai_api_base = self.openai_api_base or "https://api.openai.com/v1"
|
||||
json_mode = self.json_mode
|
||||
json_mode = bool(self.schema)
|
||||
seed = self.seed
|
||||
if openai_api_key:
|
||||
api_key = SecretStr(openai_api_key)
|
||||
else:
|
||||
api_key = None
|
||||
response_format = None
|
||||
if json_mode:
|
||||
response_format = {"type": "json_object"}
|
||||
output = ChatOpenAI(
|
||||
max_tokens=max_tokens or None,
|
||||
model_kwargs=model_kwargs or {},
|
||||
|
|
@ -107,8 +112,9 @@ class OpenAIModelComponent(LCModelComponent):
|
|||
base_url=openai_api_base,
|
||||
api_key=api_key,
|
||||
temperature=temperature or 0.1,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
)
|
||||
if json_mode:
|
||||
output = output.with_structured_output(schema=self.schema, method="json_mode")
|
||||
|
||||
return output
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue