feat: Add Prompt component to prompts/__init__.py
This commit is contained in:
parent
2996cb726d
commit
0d609b9376
6 changed files with 175 additions and 198 deletions
|
|
@ -1,150 +1,115 @@
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_openai.embeddings.base import OpenAIEmbeddings
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing import Embeddings, NestedDict
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.inputs import (
|
||||
BoolInput,
|
||||
DictInput,
|
||||
FloatInput,
|
||||
IntInput,
|
||||
SecretStrInput,
|
||||
StrInput,
|
||||
DropdownInput
|
||||
)
|
||||
from langflow.template import Output
|
||||
|
||||
|
||||
class OpenAIEmbeddingsComponent(CustomComponent):
|
||||
class OpenAIEmbeddingsComponent(LCModelComponent):
|
||||
display_name = "OpenAI Embeddings"
|
||||
description = "Generate embeddings using OpenAI models."
|
||||
icon = "OpenAI"
|
||||
inputs = [
|
||||
DictInput(
|
||||
name="default_headers",
|
||||
display_name="Default Headers",
|
||||
advanced=True,
|
||||
info="Default headers to use for the API request.",
|
||||
),
|
||||
DictInput(
|
||||
name="default_query",
|
||||
display_name="Default Query",
|
||||
advanced=True,
|
||||
info="Default query parameters to use for the API request.",
|
||||
),
|
||||
IntInput(name="chunk_size", display_name="Chunk Size", advanced=True, value=1000),
|
||||
StrInput(name="client", display_name="Client", advanced=True),
|
||||
StrInput(name="deployment", display_name="Deployment", advanced=True),
|
||||
IntInput(
|
||||
name="embedding_ctx_length",
|
||||
display_name="Embedding Context Length",
|
||||
advanced=True,
|
||||
value=1536
|
||||
),
|
||||
IntInput(name="max_retries", display_name="Max Retries", value=3, advanced=True),
|
||||
DropdownInput(
|
||||
name="model",
|
||||
display_name="Model",
|
||||
advanced=False,
|
||||
options=[
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large",
|
||||
"text-embedding-ada-002",
|
||||
],
|
||||
value="text-embedding-3-small"
|
||||
),
|
||||
DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True),
|
||||
SecretStrInput(
|
||||
name="openai_api_base", display_name="OpenAI API Base", advanced=True
|
||||
),
|
||||
SecretStrInput(name="openai_api_key", display_name="OpenAI API Key"),
|
||||
SecretStrInput(
|
||||
name="openai_api_type", display_name="OpenAI API Type", advanced=True
|
||||
),
|
||||
StrInput(
|
||||
name="openai_api_version", display_name="OpenAI API Version", advanced=True
|
||||
),
|
||||
StrInput(
|
||||
name="openai_organization",
|
||||
display_name="OpenAI Organization",
|
||||
advanced=True,
|
||||
),
|
||||
StrInput(name="openai_proxy", display_name="OpenAI Proxy", advanced=True),
|
||||
FloatInput(
|
||||
name="request_timeout", display_name="Request Timeout", advanced=True
|
||||
),
|
||||
BoolInput(
|
||||
name="show_progress_bar", display_name="Show Progress Bar", advanced=True
|
||||
),
|
||||
BoolInput(name="skip_empty", display_name="Skip Empty", advanced=True),
|
||||
StrInput(
|
||||
name="tiktoken_model_name",
|
||||
display_name="TikToken Model Name",
|
||||
advanced=True,
|
||||
),
|
||||
BoolInput(
|
||||
name="tiktoken_enable", display_name="TikToken Enable", advanced=True
|
||||
),
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"allowed_special": {
|
||||
"display_name": "Allowed Special",
|
||||
"advanced": True,
|
||||
"field_type": "str",
|
||||
"is_list": True,
|
||||
},
|
||||
"default_headers": {
|
||||
"display_name": "Default Headers",
|
||||
"advanced": True,
|
||||
"field_type": "dict",
|
||||
},
|
||||
"default_query": {
|
||||
"display_name": "Default Query",
|
||||
"advanced": True,
|
||||
"field_type": "NestedDict",
|
||||
},
|
||||
"disallowed_special": {
|
||||
"display_name": "Disallowed Special",
|
||||
"advanced": True,
|
||||
"field_type": "str",
|
||||
"is_list": True,
|
||||
},
|
||||
"chunk_size": {"display_name": "Chunk Size", "advanced": True},
|
||||
"client": {"display_name": "Client", "advanced": True},
|
||||
"deployment": {"display_name": "Deployment", "advanced": True},
|
||||
"embedding_ctx_length": {
|
||||
"display_name": "Embedding Context Length",
|
||||
"advanced": True,
|
||||
},
|
||||
"max_retries": {"display_name": "Max Retries", "advanced": True},
|
||||
"model": {
|
||||
"display_name": "Model",
|
||||
"advanced": False,
|
||||
"options": [
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large",
|
||||
"text-embedding-ada-002",
|
||||
],
|
||||
},
|
||||
"model_kwargs": {"display_name": "Model Kwargs", "advanced": True},
|
||||
"openai_api_base": {
|
||||
"display_name": "OpenAI API Base",
|
||||
"password": True,
|
||||
"advanced": True,
|
||||
},
|
||||
"openai_api_key": {"display_name": "OpenAI API Key", "password": True},
|
||||
"openai_api_type": {
|
||||
"display_name": "OpenAI API Type",
|
||||
"advanced": True,
|
||||
"password": True,
|
||||
},
|
||||
"openai_api_version": {
|
||||
"display_name": "OpenAI API Version",
|
||||
"advanced": True,
|
||||
},
|
||||
"openai_organization": {
|
||||
"display_name": "OpenAI Organization",
|
||||
"advanced": True,
|
||||
},
|
||||
"openai_proxy": {"display_name": "OpenAI Proxy", "advanced": True},
|
||||
"request_timeout": {"display_name": "Request Timeout", "advanced": True},
|
||||
"show_progress_bar": {
|
||||
"display_name": "Show Progress Bar",
|
||||
"advanced": True,
|
||||
},
|
||||
"skip_empty": {"display_name": "Skip Empty", "advanced": True},
|
||||
"tiktoken_model_name": {
|
||||
"display_name": "TikToken Model Name",
|
||||
"advanced": True,
|
||||
},
|
||||
"tiktoken_enable": {"display_name": "TikToken Enable", "advanced": True},
|
||||
"dimensions": {
|
||||
"display_name": "Dimensions",
|
||||
"info": "The number of dimensions the resulting output embeddings should have. Only supported by certain models.",
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
openai_api_key: str,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
default_query: Optional[NestedDict] = {},
|
||||
allowed_special: List[str] = [],
|
||||
disallowed_special: List[str] = ["all"],
|
||||
chunk_size: int = 1000,
|
||||
deployment: str = "text-embedding-ada-002",
|
||||
embedding_ctx_length: int = 8191,
|
||||
max_retries: int = 6,
|
||||
model: str = "text-embedding-ada-002",
|
||||
model_kwargs: NestedDict = {},
|
||||
openai_api_base: Optional[str] = None,
|
||||
openai_api_type: Optional[str] = None,
|
||||
openai_api_version: Optional[str] = None,
|
||||
openai_organization: Optional[str] = None,
|
||||
openai_proxy: Optional[str] = None,
|
||||
request_timeout: Optional[float] = None,
|
||||
show_progress_bar: bool = False,
|
||||
skip_empty: bool = False,
|
||||
tiktoken_enable: bool = True,
|
||||
tiktoken_model_name: Optional[str] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> Embeddings:
|
||||
# This is to avoid errors with Vector Stores (e.g Chroma)
|
||||
if disallowed_special == ["all"]:
|
||||
disallowed_special = "all" # type: ignore
|
||||
if openai_api_key:
|
||||
api_key = SecretStr(openai_api_key)
|
||||
else:
|
||||
api_key = None
|
||||
outputs = [
|
||||
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
|
||||
]
|
||||
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
return OpenAIEmbeddings(
|
||||
tiktoken_enabled=tiktoken_enable,
|
||||
default_headers=default_headers,
|
||||
default_query=default_query,
|
||||
allowed_special=set(allowed_special),
|
||||
tiktoken_enabled=self.tiktoken_enable,
|
||||
default_headers=self.default_headers,
|
||||
default_query=self.default_query,
|
||||
allowed_special="all",
|
||||
disallowed_special="all",
|
||||
chunk_size=chunk_size,
|
||||
deployment=deployment,
|
||||
embedding_ctx_length=embedding_ctx_length,
|
||||
max_retries=max_retries,
|
||||
model=model,
|
||||
model_kwargs=model_kwargs,
|
||||
base_url=openai_api_base,
|
||||
api_key=api_key,
|
||||
openai_api_type=openai_api_type,
|
||||
api_version=openai_api_version,
|
||||
organization=openai_organization,
|
||||
openai_proxy=openai_proxy,
|
||||
timeout=request_timeout,
|
||||
show_progress_bar=show_progress_bar,
|
||||
skip_empty=skip_empty,
|
||||
tiktoken_model_name=tiktoken_model_name,
|
||||
dimensions=dimensions,
|
||||
chunk_size=self.chunk_size,
|
||||
deployment=self.deployment,
|
||||
embedding_ctx_length=self.embedding_ctx_length,
|
||||
max_retries=self.max_retries,
|
||||
model=self.model,
|
||||
model_kwargs=self.model_kwargs,
|
||||
base_url=self.openai_api_base,
|
||||
api_key=self.openai_api_key,
|
||||
openai_api_type=self.openai_api_type,
|
||||
api_version=self.openai_api_version,
|
||||
organization=self.openai_organization,
|
||||
openai_proxy=self.openai_proxy,
|
||||
timeout=self.request_timeout,
|
||||
show_progress_bar=self.show_progress_bar,
|
||||
skip_empty=self.skip_empty,
|
||||
tiktoken_model_name=self.tiktoken_model_name,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class ParseDataComponent(Component):
|
|||
|
||||
inputs = [
|
||||
HandleInput(
|
||||
name="data", display_name="Data", info="The data to convert to text.", input_types=["Message", "Data"]
|
||||
name="data", display_name="Data", info="The data to convert to text.", input_types=["Data"]
|
||||
),
|
||||
MultilineInput(
|
||||
name="template",
|
||||
|
|
@ -27,7 +27,7 @@ class ParseDataComponent(Component):
|
|||
|
||||
def parse_data_to_text(self) -> Text:
|
||||
data = self.data if isinstance(self.data, list) else [self.data]
|
||||
template = self.template or "Text: {text}"
|
||||
template = self.template
|
||||
|
||||
result_string = data_to_text(template, data)
|
||||
self.status = result_string
|
||||
|
|
|
|||
3
src/backend/base/langflow/components/prompts/__init__.py
Normal file
3
src/backend/base/langflow/components/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .Prompt import Prompt
|
||||
|
||||
__all__ = ["Prompt"]
|
||||
|
|
@ -57,30 +57,30 @@ def __getattr__(name: str) -> Any:
|
|||
|
||||
|
||||
__all__ = [
|
||||
"NestedDict",
|
||||
"Data",
|
||||
"Tool",
|
||||
"PromptTemplate",
|
||||
"Chain",
|
||||
"AgentExecutor",
|
||||
"BaseChatMemory",
|
||||
"BaseLLM",
|
||||
"BaseLanguageModel",
|
||||
"BaseLLM",
|
||||
"BaseLoader",
|
||||
"BaseMemory",
|
||||
"BaseOutputParser",
|
||||
"BaseRetriever",
|
||||
"VectorStore",
|
||||
"Embeddings",
|
||||
"TextSplitter",
|
||||
"Document",
|
||||
"AgentExecutor",
|
||||
"Text",
|
||||
"Object",
|
||||
"Callable",
|
||||
"BasePromptTemplate",
|
||||
"BaseRetriever",
|
||||
"Callable",
|
||||
"Chain",
|
||||
"ChatPromptTemplate",
|
||||
"Prompt",
|
||||
"RangeSpec",
|
||||
"Input",
|
||||
"Code",
|
||||
"Data",
|
||||
"Document",
|
||||
"Embeddings",
|
||||
"Input",
|
||||
"NestedDict",
|
||||
"Object",
|
||||
"Prompt",
|
||||
"PromptTemplate",
|
||||
"RangeSpec",
|
||||
"Text",
|
||||
"TextSplitter",
|
||||
"Tool",
|
||||
"VectorStore",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -23,6 +23,11 @@ class HandleInput(BaseInputMixin, ListableInputMixin):
|
|||
field_type: Optional[SerializableFieldTypes] = FieldTypes.OTHER
|
||||
|
||||
|
||||
# class DataInput(HandleInput):
|
||||
# input_types: list[str] = ["Data"]
|
||||
# ! Let's add this?
|
||||
|
||||
|
||||
class PromptInput(BaseInputMixin, ListableInputMixin):
|
||||
field_type: Optional[SerializableFieldTypes] = FieldTypes.PROMPT
|
||||
|
||||
|
|
@ -38,29 +43,29 @@ class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin): # noqa:
|
|||
class TextInput(StrInput):
|
||||
input_types: list[str] = ["Data", "Message", "Text"]
|
||||
|
||||
@field_validator("value")
|
||||
@classmethod
|
||||
def validate_value(cls, v: Any, _info):
|
||||
value = None
|
||||
if isinstance(v, str):
|
||||
value = v
|
||||
elif isinstance(v, Message):
|
||||
value = v.text
|
||||
elif isinstance(v, Data):
|
||||
if v.text_key in v.data:
|
||||
value = v.data[v.text_key]
|
||||
else:
|
||||
keys = ", ".join(v.data.keys())
|
||||
input_name = _info.data["name"]
|
||||
raise ValueError(
|
||||
f"The input to '{input_name}' must contain the key '{v.text_key}'."
|
||||
f"You can set `text_key` to one of the following keys: {keys} or set the value using another Component."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid input type {type(v)}")
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
raise ValueError(f"Invalid value type {type(value)}")
|
||||
# @field_validator("value")
|
||||
# @classmethod
|
||||
# def validate_value(cls, v: Any, _info):
|
||||
# value = None
|
||||
# if isinstance(v, str):
|
||||
# value = v
|
||||
# elif isinstance(v, Message):
|
||||
# value = v.text
|
||||
# elif isinstance(v, Data):
|
||||
# if v.text_key in v.data:
|
||||
# value = v.data[v.text_key]
|
||||
# else:
|
||||
# keys = ", ".join(v.data.keys())
|
||||
# input_name = _info.data["name"]
|
||||
# raise ValueError(
|
||||
# f"The input to '{input_name}' must contain the key '{v.text_key}'."
|
||||
# f"You can set `text_key` to one of the following keys: {keys} or set the value using another Component."
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError(f"Invalid input type {type(v)}")
|
||||
# if isinstance(value, str):
|
||||
# return value
|
||||
# raise ValueError(f"Invalid value type {type(value)}")
|
||||
|
||||
|
||||
class MultilineInput(BaseInputMixin):
|
||||
|
|
@ -108,16 +113,17 @@ class FileInput(BaseInputMixin, ListableInputMixin, FileMixin):
|
|||
|
||||
|
||||
InputTypes = Union[
|
||||
StrInput,
|
||||
SecretStrInput,
|
||||
IntInput,
|
||||
FloatInput,
|
||||
BoolInput,
|
||||
NestedDictInput,
|
||||
# DataInput, # ! Let's add this
|
||||
DictInput,
|
||||
DropdownInput,
|
||||
FileInput,
|
||||
PromptInput,
|
||||
MultilineInput,
|
||||
FloatInput,
|
||||
HandleInput,
|
||||
IntInput,
|
||||
MultilineInput,
|
||||
NestedDictInput,
|
||||
PromptInput,
|
||||
SecretStrInput,
|
||||
StrInput,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -247,7 +247,6 @@ export const nodeColors: { [char: string]: string } = {
|
|||
models: "#ab11ab",
|
||||
model_specs: "#6344BE",
|
||||
chains: "#FE7500",
|
||||
Document: "#7AAE42",
|
||||
list: "#9AAE42",
|
||||
agents: "#903BBE",
|
||||
tools: "#FF3434",
|
||||
|
|
@ -267,14 +266,18 @@ export const nodeColors: { [char: string]: string } = {
|
|||
experimental: "#E6277A",
|
||||
langchain_utilities: "#31A3CC",
|
||||
output_parsers: "#E6A627",
|
||||
str: "#4367BF",
|
||||
Text: "#4367BF",
|
||||
retrievers: "#e6b25a",
|
||||
unknown: "#9CA3AF",
|
||||
// custom_components: "#ab11ab",
|
||||
Data: "#9CA3AF",
|
||||
Message: "#4367BF",
|
||||
BaseLanguageModel: "#ab11ab",
|
||||
retrievers: "#e6b25a",
|
||||
//
|
||||
str: "#2563eb",
|
||||
Text: "#2563eb",
|
||||
unknown: "#9CA3AF",
|
||||
Document: "#65a30d",
|
||||
Data: "#dc2626",
|
||||
Message: "#4f46e5",
|
||||
Prompt: "#7c3aed",
|
||||
Embeddings: "#10b981",
|
||||
BaseLanguageModel: "#c026d3",
|
||||
};
|
||||
|
||||
export const nodeNames: { [char: string]: string } = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue