feat: Add Prompt component to prompts/__init__.py

This commit is contained in:
Rodrigo 2024-06-15 22:18:03 -03:00
commit 0d609b9376
6 changed files with 175 additions and 198 deletions

View file

@ -1,150 +1,115 @@
from typing import Dict, List, Optional
from langchain_openai.embeddings.base import OpenAIEmbeddings
from pydantic.v1 import SecretStr
from langflow.custom import CustomComponent
from langflow.field_typing import Embeddings, NestedDict
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Embeddings
from langflow.inputs import (
BoolInput,
DictInput,
FloatInput,
IntInput,
SecretStrInput,
StrInput,
DropdownInput
)
from langflow.template import Output
class OpenAIEmbeddingsComponent(CustomComponent):
class OpenAIEmbeddingsComponent(LCModelComponent):
display_name = "OpenAI Embeddings"
description = "Generate embeddings using OpenAI models."
icon = "OpenAI"
inputs = [
DictInput(
name="default_headers",
display_name="Default Headers",
advanced=True,
info="Default headers to use for the API request.",
),
DictInput(
name="default_query",
display_name="Default Query",
advanced=True,
info="Default query parameters to use for the API request.",
),
IntInput(name="chunk_size", display_name="Chunk Size", advanced=True, value=1000),
StrInput(name="client", display_name="Client", advanced=True),
StrInput(name="deployment", display_name="Deployment", advanced=True),
IntInput(
name="embedding_ctx_length",
display_name="Embedding Context Length",
advanced=True,
value=1536
),
IntInput(name="max_retries", display_name="Max Retries", value=3, advanced=True),
DropdownInput(
name="model",
display_name="Model",
advanced=False,
options=[
"text-embedding-3-small",
"text-embedding-3-large",
"text-embedding-ada-002",
],
value="text-embedding-3-small"
),
DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True),
SecretStrInput(
name="openai_api_base", display_name="OpenAI API Base", advanced=True
),
SecretStrInput(name="openai_api_key", display_name="OpenAI API Key"),
SecretStrInput(
name="openai_api_type", display_name="OpenAI API Type", advanced=True
),
StrInput(
name="openai_api_version", display_name="OpenAI API Version", advanced=True
),
StrInput(
name="openai_organization",
display_name="OpenAI Organization",
advanced=True,
),
StrInput(name="openai_proxy", display_name="OpenAI Proxy", advanced=True),
FloatInput(
name="request_timeout", display_name="Request Timeout", advanced=True
),
BoolInput(
name="show_progress_bar", display_name="Show Progress Bar", advanced=True
),
BoolInput(name="skip_empty", display_name="Skip Empty", advanced=True),
StrInput(
name="tiktoken_model_name",
display_name="TikToken Model Name",
advanced=True,
),
BoolInput(
name="tiktoken_enable", display_name="TikToken Enable", advanced=True
),
]
def build_config(self):
return {
"allowed_special": {
"display_name": "Allowed Special",
"advanced": True,
"field_type": "str",
"is_list": True,
},
"default_headers": {
"display_name": "Default Headers",
"advanced": True,
"field_type": "dict",
},
"default_query": {
"display_name": "Default Query",
"advanced": True,
"field_type": "NestedDict",
},
"disallowed_special": {
"display_name": "Disallowed Special",
"advanced": True,
"field_type": "str",
"is_list": True,
},
"chunk_size": {"display_name": "Chunk Size", "advanced": True},
"client": {"display_name": "Client", "advanced": True},
"deployment": {"display_name": "Deployment", "advanced": True},
"embedding_ctx_length": {
"display_name": "Embedding Context Length",
"advanced": True,
},
"max_retries": {"display_name": "Max Retries", "advanced": True},
"model": {
"display_name": "Model",
"advanced": False,
"options": [
"text-embedding-3-small",
"text-embedding-3-large",
"text-embedding-ada-002",
],
},
"model_kwargs": {"display_name": "Model Kwargs", "advanced": True},
"openai_api_base": {
"display_name": "OpenAI API Base",
"password": True,
"advanced": True,
},
"openai_api_key": {"display_name": "OpenAI API Key", "password": True},
"openai_api_type": {
"display_name": "OpenAI API Type",
"advanced": True,
"password": True,
},
"openai_api_version": {
"display_name": "OpenAI API Version",
"advanced": True,
},
"openai_organization": {
"display_name": "OpenAI Organization",
"advanced": True,
},
"openai_proxy": {"display_name": "OpenAI Proxy", "advanced": True},
"request_timeout": {"display_name": "Request Timeout", "advanced": True},
"show_progress_bar": {
"display_name": "Show Progress Bar",
"advanced": True,
},
"skip_empty": {"display_name": "Skip Empty", "advanced": True},
"tiktoken_model_name": {
"display_name": "TikToken Model Name",
"advanced": True,
},
"tiktoken_enable": {"display_name": "TikToken Enable", "advanced": True},
"dimensions": {
"display_name": "Dimensions",
"info": "The number of dimensions the resulting output embeddings should have. Only supported by certain models.",
"advanced": True,
},
}
def build(
self,
openai_api_key: str,
default_headers: Optional[Dict[str, str]] = None,
default_query: Optional[NestedDict] = {},
allowed_special: List[str] = [],
disallowed_special: List[str] = ["all"],
chunk_size: int = 1000,
deployment: str = "text-embedding-ada-002",
embedding_ctx_length: int = 8191,
max_retries: int = 6,
model: str = "text-embedding-ada-002",
model_kwargs: NestedDict = {},
openai_api_base: Optional[str] = None,
openai_api_type: Optional[str] = None,
openai_api_version: Optional[str] = None,
openai_organization: Optional[str] = None,
openai_proxy: Optional[str] = None,
request_timeout: Optional[float] = None,
show_progress_bar: bool = False,
skip_empty: bool = False,
tiktoken_enable: bool = True,
tiktoken_model_name: Optional[str] = None,
dimensions: Optional[int] = None,
) -> Embeddings:
# This is to avoid errors with Vector Stores (e.g Chroma)
if disallowed_special == ["all"]:
disallowed_special = "all" # type: ignore
if openai_api_key:
api_key = SecretStr(openai_api_key)
else:
api_key = None
outputs = [
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]
def build_embeddings(self) -> Embeddings:
return OpenAIEmbeddings(
tiktoken_enabled=tiktoken_enable,
default_headers=default_headers,
default_query=default_query,
allowed_special=set(allowed_special),
tiktoken_enabled=self.tiktoken_enable,
default_headers=self.default_headers,
default_query=self.default_query,
allowed_special="all",
disallowed_special="all",
chunk_size=chunk_size,
deployment=deployment,
embedding_ctx_length=embedding_ctx_length,
max_retries=max_retries,
model=model,
model_kwargs=model_kwargs,
base_url=openai_api_base,
api_key=api_key,
openai_api_type=openai_api_type,
api_version=openai_api_version,
organization=openai_organization,
openai_proxy=openai_proxy,
timeout=request_timeout,
show_progress_bar=show_progress_bar,
skip_empty=skip_empty,
tiktoken_model_name=tiktoken_model_name,
dimensions=dimensions,
chunk_size=self.chunk_size,
deployment=self.deployment,
embedding_ctx_length=self.embedding_ctx_length,
max_retries=self.max_retries,
model=self.model,
model_kwargs=self.model_kwargs,
base_url=self.openai_api_base,
api_key=self.openai_api_key,
openai_api_type=self.openai_api_type,
api_version=self.openai_api_version,
organization=self.openai_organization,
openai_proxy=self.openai_proxy,
timeout=self.request_timeout,
show_progress_bar=self.show_progress_bar,
skip_empty=self.skip_empty,
tiktoken_model_name=self.tiktoken_model_name,
)

View file

@ -12,7 +12,7 @@ class ParseDataComponent(Component):
inputs = [
HandleInput(
name="data", display_name="Data", info="The data to convert to text.", input_types=["Message", "Data"]
name="data", display_name="Data", info="The data to convert to text.", input_types=["Data"]
),
MultilineInput(
name="template",
@ -27,7 +27,7 @@ class ParseDataComponent(Component):
def parse_data_to_text(self) -> Text:
data = self.data if isinstance(self.data, list) else [self.data]
template = self.template or "Text: {text}"
template = self.template
result_string = data_to_text(template, data)
self.status = result_string

View file

@ -0,0 +1,3 @@
from .Prompt import Prompt
__all__ = ["Prompt"]

View file

@ -57,30 +57,30 @@ def __getattr__(name: str) -> Any:
__all__ = [
"NestedDict",
"Data",
"Tool",
"PromptTemplate",
"Chain",
"AgentExecutor",
"BaseChatMemory",
"BaseLLM",
"BaseLanguageModel",
"BaseLLM",
"BaseLoader",
"BaseMemory",
"BaseOutputParser",
"BaseRetriever",
"VectorStore",
"Embeddings",
"TextSplitter",
"Document",
"AgentExecutor",
"Text",
"Object",
"Callable",
"BasePromptTemplate",
"BaseRetriever",
"Callable",
"Chain",
"ChatPromptTemplate",
"Prompt",
"RangeSpec",
"Input",
"Code",
"Data",
"Document",
"Embeddings",
"Input",
"NestedDict",
"Object",
"Prompt",
"PromptTemplate",
"RangeSpec",
"Text",
"TextSplitter",
"Tool",
"VectorStore",
]

View file

@ -23,6 +23,11 @@ class HandleInput(BaseInputMixin, ListableInputMixin):
field_type: Optional[SerializableFieldTypes] = FieldTypes.OTHER
# class DataInput(HandleInput):
# input_types: list[str] = ["Data"]
# ! Let's add this?
class PromptInput(BaseInputMixin, ListableInputMixin):
field_type: Optional[SerializableFieldTypes] = FieldTypes.PROMPT
@ -38,29 +43,29 @@ class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin): # noqa:
class TextInput(StrInput):
input_types: list[str] = ["Data", "Message", "Text"]
@field_validator("value")
@classmethod
def validate_value(cls, v: Any, _info):
value = None
if isinstance(v, str):
value = v
elif isinstance(v, Message):
value = v.text
elif isinstance(v, Data):
if v.text_key in v.data:
value = v.data[v.text_key]
else:
keys = ", ".join(v.data.keys())
input_name = _info.data["name"]
raise ValueError(
f"The input to '{input_name}' must contain the key '{v.text_key}'."
f"You can set `text_key` to one of the following keys: {keys} or set the value using another Component."
)
else:
raise ValueError(f"Invalid input type {type(v)}")
if isinstance(value, str):
return value
raise ValueError(f"Invalid value type {type(value)}")
# @field_validator("value")
# @classmethod
# def validate_value(cls, v: Any, _info):
# value = None
# if isinstance(v, str):
# value = v
# elif isinstance(v, Message):
# value = v.text
# elif isinstance(v, Data):
# if v.text_key in v.data:
# value = v.data[v.text_key]
# else:
# keys = ", ".join(v.data.keys())
# input_name = _info.data["name"]
# raise ValueError(
# f"The input to '{input_name}' must contain the key '{v.text_key}'."
# f"You can set `text_key` to one of the following keys: {keys} or set the value using another Component."
# )
# else:
# raise ValueError(f"Invalid input type {type(v)}")
# if isinstance(value, str):
# return value
# raise ValueError(f"Invalid value type {type(value)}")
class MultilineInput(BaseInputMixin):
@ -108,16 +113,17 @@ class FileInput(BaseInputMixin, ListableInputMixin, FileMixin):
InputTypes = Union[
StrInput,
SecretStrInput,
IntInput,
FloatInput,
BoolInput,
NestedDictInput,
# DataInput, # ! Let's add this
DictInput,
DropdownInput,
FileInput,
PromptInput,
MultilineInput,
FloatInput,
HandleInput,
IntInput,
MultilineInput,
NestedDictInput,
PromptInput,
SecretStrInput,
StrInput,
]

View file

@ -247,7 +247,6 @@ export const nodeColors: { [char: string]: string } = {
models: "#ab11ab",
model_specs: "#6344BE",
chains: "#FE7500",
Document: "#7AAE42",
list: "#9AAE42",
agents: "#903BBE",
tools: "#FF3434",
@ -267,14 +266,18 @@ export const nodeColors: { [char: string]: string } = {
experimental: "#E6277A",
langchain_utilities: "#31A3CC",
output_parsers: "#E6A627",
str: "#4367BF",
Text: "#4367BF",
retrievers: "#e6b25a",
unknown: "#9CA3AF",
// custom_components: "#ab11ab",
Data: "#9CA3AF",
Message: "#4367BF",
BaseLanguageModel: "#ab11ab",
retrievers: "#e6b25a",
//
str: "#2563eb",
Text: "#2563eb",
unknown: "#9CA3AF",
Document: "#65a30d",
Data: "#dc2626",
Message: "#4f46e5",
Prompt: "#7c3aed",
Embeddings: "#10b981",
BaseLanguageModel: "#c026d3",
};
export const nodeNames: { [char: string]: string } = {