upd: cohere llm component
This commit is contained in:
parent
b376bb26be
commit
0ff4694a16
1 changed files with 54 additions and 57 deletions
|
|
@ -1,73 +1,70 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain_cohere import ChatCohere
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import Text
|
||||
|
||||
from langflow.field_typing import BaseLanguageModel, Text
|
||||
from langflow.inputs import BoolInput, FloatInput, IntInput, SecretStrInput, StrInput
|
||||
from langflow.template import Output
|
||||
|
||||
class CohereComponent(LCModelComponent):
|
||||
display_name = "Cohere"
|
||||
description = "Generate text using Cohere LLMs."
|
||||
documentation = "https://python.langchain.com/docs/modules/model_io/models/llms/integrations/cohere"
|
||||
|
||||
icon = "Cohere"
|
||||
|
||||
field_order = [
|
||||
"cohere_api_key",
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
inputs = [
|
||||
SecretStrInput(
|
||||
name="cohere_api_key",
|
||||
display_name="Cohere API Key",
|
||||
info="The Cohere API Key to use for the Cohere model.",
|
||||
advanced=False,
|
||||
value="COHERE_API_KEY",
|
||||
),
|
||||
IntInput(
|
||||
name="max_tokens",
|
||||
display_name="Max Tokens",
|
||||
advanced=True,
|
||||
info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
|
||||
),
|
||||
FloatInput(name="temperature", display_name="Temperature", value=0.75),
|
||||
StrInput(name="input_value", display_name="Input", input_types=["Text", "Data", "Prompt"]),
|
||||
BoolInput(name="stream", display_name="Stream", info=STREAM_INFO_TEXT, advanced=True),
|
||||
StrInput(
|
||||
name="system_message",
|
||||
display_name="System Message",
|
||||
info="System message to pass to the model.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
outputs = [
|
||||
Output(display_name="Text", name="text_output", method="text_response"),
|
||||
Output(display_name="Language Model", name="model_output", method="build_model"),
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"cohere_api_key": {
|
||||
"display_name": "Cohere API Key",
|
||||
"type": "password",
|
||||
"password": True,
|
||||
"required": True,
|
||||
},
|
||||
"max_tokens": {
|
||||
"display_name": "Max Tokens",
|
||||
"advanced": True,
|
||||
"info": "The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
|
||||
},
|
||||
"temperature": {
|
||||
"display_name": "Temperature",
|
||||
"default": 0.75,
|
||||
"type": "float",
|
||||
"show": True,
|
||||
},
|
||||
"input_value": {"display_name": "Input", "input_types": ["Text", "Data", "Prompt"]},
|
||||
"stream": {
|
||||
"display_name": "Stream",
|
||||
"info": STREAM_INFO_TEXT,
|
||||
"advanced": True,
|
||||
},
|
||||
"system_message": {
|
||||
"display_name": "System Message",
|
||||
"info": "System message to pass to the model.",
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
def text_response(self) -> Text:
|
||||
input_value = self.input_value
|
||||
stream = self.stream
|
||||
system_message = self.system_message
|
||||
output = self.build_model()
|
||||
result = self.get_chat_result(output, stream, input_value, system_message)
|
||||
self.status = result
|
||||
return result
|
||||
|
||||
def build(
|
||||
self,
|
||||
cohere_api_key: str,
|
||||
input_value: Text,
|
||||
temperature: float = 0.75,
|
||||
stream: bool = False,
|
||||
system_message: Optional[str] = None,
|
||||
) -> Text:
|
||||
api_key = SecretStr(cohere_api_key)
|
||||
output = ChatCohere( # type: ignore
|
||||
def build_model(self) -> BaseLanguageModel:
|
||||
cohere_api_key = self.cohere_api_key
|
||||
temperature = self.temperature
|
||||
max_tokens = self.max_tokens
|
||||
|
||||
if cohere_api_key:
|
||||
api_key = SecretStr(cohere_api_key)
|
||||
else:
|
||||
api_key = None
|
||||
|
||||
output = ChatCohere(
|
||||
max_tokens=max_tokens or None,
|
||||
temperature=temperature or 0.75,
|
||||
cohere_api_key=api_key,
|
||||
temperature=temperature,
|
||||
)
|
||||
return self.get_chat_result(output, stream, input_value, system_message)
|
||||
return self.get_chat_result(output, stream, input_value, system_message)
|
||||
|
||||
return output
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue