From 0d609b9376ac0ed07f95adf837a7b10d4131da99 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 15 Jun 2024 22:18:03 -0300 Subject: [PATCH] feat: Add Prompt component to prompts/__init__.py --- .../components/embeddings/OpenAIEmbeddings.py | 245 ++++++++---------- .../langflow/components/helpers/ParseData.py | 4 +- .../langflow/components/prompts/__init__.py | 3 + .../base/langflow/field_typing/__init__.py | 36 +-- src/backend/base/langflow/inputs/inputs.py | 66 ++--- src/frontend/src/utils/styleUtils.ts | 19 +- 6 files changed, 175 insertions(+), 198 deletions(-) create mode 100644 src/backend/base/langflow/components/prompts/__init__.py diff --git a/src/backend/base/langflow/components/embeddings/OpenAIEmbeddings.py b/src/backend/base/langflow/components/embeddings/OpenAIEmbeddings.py index 4813fc793..1f22e2a21 100644 --- a/src/backend/base/langflow/components/embeddings/OpenAIEmbeddings.py +++ b/src/backend/base/langflow/components/embeddings/OpenAIEmbeddings.py @@ -1,150 +1,115 @@ -from typing import Dict, List, Optional - from langchain_openai.embeddings.base import OpenAIEmbeddings -from pydantic.v1 import SecretStr - -from langflow.custom import CustomComponent -from langflow.field_typing import Embeddings, NestedDict +from langflow.base.models.model import LCModelComponent +from langflow.field_typing import Embeddings +from langflow.inputs import ( + BoolInput, + DictInput, + FloatInput, + IntInput, + SecretStrInput, + StrInput, + DropdownInput +) +from langflow.template import Output -class OpenAIEmbeddingsComponent(CustomComponent): +class OpenAIEmbeddingsComponent(LCModelComponent): display_name = "OpenAI Embeddings" description = "Generate embeddings using OpenAI models." + icon = "OpenAI" + inputs = [ + DictInput( + name="default_headers", + display_name="Default Headers", + advanced=True, + info="Default headers to use for the API request.", + ), + DictInput( + name="default_query", + display_name="Default Query", + advanced=True, + info="Default query parameters to use for the API request.", + ), + IntInput(name="chunk_size", display_name="Chunk Size", advanced=True, value=1000), + StrInput(name="client", display_name="Client", advanced=True), + StrInput(name="deployment", display_name="Deployment", advanced=True), + IntInput( + name="embedding_ctx_length", + display_name="Embedding Context Length", + advanced=True, + value=1536 + ), + IntInput(name="max_retries", display_name="Max Retries", value=3, advanced=True), + DropdownInput( + name="model", + display_name="Model", + advanced=False, + options=[ + "text-embedding-3-small", + "text-embedding-3-large", + "text-embedding-ada-002", + ], + value="text-embedding-3-small" + ), + DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True), + SecretStrInput( + name="openai_api_base", display_name="OpenAI API Base", advanced=True + ), + SecretStrInput(name="openai_api_key", display_name="OpenAI API Key"), + SecretStrInput( + name="openai_api_type", display_name="OpenAI API Type", advanced=True + ), + StrInput( + name="openai_api_version", display_name="OpenAI API Version", advanced=True + ), + StrInput( + name="openai_organization", + display_name="OpenAI Organization", + advanced=True, + ), + StrInput(name="openai_proxy", display_name="OpenAI Proxy", advanced=True), + FloatInput( + name="request_timeout", display_name="Request Timeout", advanced=True + ), + BoolInput( + name="show_progress_bar", display_name="Show Progress Bar", advanced=True + ), + BoolInput(name="skip_empty", display_name="Skip Empty", advanced=True), + StrInput( + name="tiktoken_model_name", + display_name="TikToken Model Name", + advanced=True, + ), + BoolInput( + name="tiktoken_enable", display_name="TikToken Enable", advanced=True + ), + ] - def build_config(self): - return { - "allowed_special": { - "display_name": "Allowed Special", - "advanced": True, - "field_type": "str", - "is_list": True, - }, - "default_headers": { - "display_name": "Default Headers", - "advanced": True, - "field_type": "dict", - }, - "default_query": { - "display_name": "Default Query", - "advanced": True, - "field_type": "NestedDict", - }, - "disallowed_special": { - "display_name": "Disallowed Special", - "advanced": True, - "field_type": "str", - "is_list": True, - }, - "chunk_size": {"display_name": "Chunk Size", "advanced": True}, - "client": {"display_name": "Client", "advanced": True}, - "deployment": {"display_name": "Deployment", "advanced": True}, - "embedding_ctx_length": { - "display_name": "Embedding Context Length", - "advanced": True, - }, - "max_retries": {"display_name": "Max Retries", "advanced": True}, - "model": { - "display_name": "Model", - "advanced": False, - "options": [ - "text-embedding-3-small", - "text-embedding-3-large", - "text-embedding-ada-002", - ], - }, - "model_kwargs": {"display_name": "Model Kwargs", "advanced": True}, - "openai_api_base": { - "display_name": "OpenAI API Base", - "password": True, - "advanced": True, - }, - "openai_api_key": {"display_name": "OpenAI API Key", "password": True}, - "openai_api_type": { - "display_name": "OpenAI API Type", - "advanced": True, - "password": True, - }, - "openai_api_version": { - "display_name": "OpenAI API Version", - "advanced": True, - }, - "openai_organization": { - "display_name": "OpenAI Organization", - "advanced": True, - }, - "openai_proxy": {"display_name": "OpenAI Proxy", "advanced": True}, - "request_timeout": {"display_name": "Request Timeout", "advanced": True}, - "show_progress_bar": { - "display_name": "Show Progress Bar", - "advanced": True, - }, - "skip_empty": {"display_name": "Skip Empty", "advanced": True}, - "tiktoken_model_name": { - "display_name": "TikToken Model Name", - "advanced": True, - }, - "tiktoken_enable": {"display_name": "TikToken Enable", "advanced": True}, - "dimensions": { - "display_name": "Dimensions", - "info": "The number of dimensions the resulting output embeddings should have. Only supported by certain models.", - "advanced": True, - }, - } - - def build( - self, - openai_api_key: str, - default_headers: Optional[Dict[str, str]] = None, - default_query: Optional[NestedDict] = {}, - allowed_special: List[str] = [], - disallowed_special: List[str] = ["all"], - chunk_size: int = 1000, - deployment: str = "text-embedding-ada-002", - embedding_ctx_length: int = 8191, - max_retries: int = 6, - model: str = "text-embedding-ada-002", - model_kwargs: NestedDict = {}, - openai_api_base: Optional[str] = None, - openai_api_type: Optional[str] = None, - openai_api_version: Optional[str] = None, - openai_organization: Optional[str] = None, - openai_proxy: Optional[str] = None, - request_timeout: Optional[float] = None, - show_progress_bar: bool = False, - skip_empty: bool = False, - tiktoken_enable: bool = True, - tiktoken_model_name: Optional[str] = None, - dimensions: Optional[int] = None, - ) -> Embeddings: - # This is to avoid errors with Vector Stores (e.g Chroma) - if disallowed_special == ["all"]: - disallowed_special = "all" # type: ignore - if openai_api_key: - api_key = SecretStr(openai_api_key) - else: - api_key = None + outputs = [ + Output(display_name="Embeddings", name="embeddings", method="build_embeddings"), + ] + def build_embeddings(self) -> Embeddings: return OpenAIEmbeddings( - tiktoken_enabled=tiktoken_enable, - default_headers=default_headers, - default_query=default_query, - allowed_special=set(allowed_special), + tiktoken_enabled=self.tiktoken_enable, + default_headers=self.default_headers, + default_query=self.default_query, + allowed_special="all", disallowed_special="all", - chunk_size=chunk_size, - deployment=deployment, - embedding_ctx_length=embedding_ctx_length, - max_retries=max_retries, - model=model, - model_kwargs=model_kwargs, - base_url=openai_api_base, - api_key=api_key, - openai_api_type=openai_api_type, - api_version=openai_api_version, - organization=openai_organization, - openai_proxy=openai_proxy, - timeout=request_timeout, - show_progress_bar=show_progress_bar, - skip_empty=skip_empty, - tiktoken_model_name=tiktoken_model_name, - dimensions=dimensions, + chunk_size=self.chunk_size, + deployment=self.deployment, + embedding_ctx_length=self.embedding_ctx_length, + max_retries=self.max_retries, + model=self.model, + model_kwargs=self.model_kwargs, + base_url=self.openai_api_base, + api_key=self.openai_api_key, + openai_api_type=self.openai_api_type, + api_version=self.openai_api_version, + organization=self.openai_organization, + openai_proxy=self.openai_proxy, + timeout=self.request_timeout, + show_progress_bar=self.show_progress_bar, + skip_empty=self.skip_empty, + tiktoken_model_name=self.tiktoken_model_name, ) diff --git a/src/backend/base/langflow/components/helpers/ParseData.py b/src/backend/base/langflow/components/helpers/ParseData.py index 41cd669fe..4dfd36889 100644 --- a/src/backend/base/langflow/components/helpers/ParseData.py +++ b/src/backend/base/langflow/components/helpers/ParseData.py @@ -12,7 +12,7 @@ class ParseDataComponent(Component): inputs = [ HandleInput( - name="data", display_name="Data", info="The data to convert to text.", input_types=["Message", "Data"] + name="data", display_name="Data", info="The data to convert to text.", input_types=["Data"] ), MultilineInput( name="template", @@ -27,7 +27,7 @@ class ParseDataComponent(Component): def parse_data_to_text(self) -> Text: data = self.data if isinstance(self.data, list) else [self.data] - template = self.template or "Text: {text}" + template = self.template result_string = data_to_text(template, data) self.status = result_string diff --git a/src/backend/base/langflow/components/prompts/__init__.py b/src/backend/base/langflow/components/prompts/__init__.py new file mode 100644 index 000000000..92d7b8eed --- /dev/null +++ b/src/backend/base/langflow/components/prompts/__init__.py @@ -0,0 +1,3 @@ +from .Prompt import Prompt + +__all__ = ["Prompt"] diff --git a/src/backend/base/langflow/field_typing/__init__.py b/src/backend/base/langflow/field_typing/__init__.py index 5c925daf4..e818a5b55 100644 --- a/src/backend/base/langflow/field_typing/__init__.py +++ b/src/backend/base/langflow/field_typing/__init__.py @@ -57,30 +57,30 @@ def __getattr__(name: str) -> Any: __all__ = [ - "NestedDict", - "Data", - "Tool", - "PromptTemplate", - "Chain", + "AgentExecutor", "BaseChatMemory", - "BaseLLM", "BaseLanguageModel", + "BaseLLM", "BaseLoader", "BaseMemory", "BaseOutputParser", - "BaseRetriever", - "VectorStore", - "Embeddings", - "TextSplitter", - "Document", - "AgentExecutor", - "Text", - "Object", - "Callable", "BasePromptTemplate", + "BaseRetriever", + "Callable", + "Chain", "ChatPromptTemplate", - "Prompt", - "RangeSpec", - "Input", "Code", + "Data", + "Document", + "Embeddings", + "Input", + "NestedDict", + "Object", + "Prompt", + "PromptTemplate", + "RangeSpec", + "Text", + "TextSplitter", + "Tool", + "VectorStore", ] diff --git a/src/backend/base/langflow/inputs/inputs.py b/src/backend/base/langflow/inputs/inputs.py index 39920974c..16db05f4a 100644 --- a/src/backend/base/langflow/inputs/inputs.py +++ b/src/backend/base/langflow/inputs/inputs.py @@ -23,6 +23,11 @@ class HandleInput(BaseInputMixin, ListableInputMixin): field_type: Optional[SerializableFieldTypes] = FieldTypes.OTHER +# class DataInput(HandleInput): +# input_types: list[str] = ["Data"] +# ! Let's add this? + + class PromptInput(BaseInputMixin, ListableInputMixin): field_type: Optional[SerializableFieldTypes] = FieldTypes.PROMPT @@ -38,29 +43,29 @@ class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin): # noqa: class TextInput(StrInput): input_types: list[str] = ["Data", "Message", "Text"] - @field_validator("value") - @classmethod - def validate_value(cls, v: Any, _info): - value = None - if isinstance(v, str): - value = v - elif isinstance(v, Message): - value = v.text - elif isinstance(v, Data): - if v.text_key in v.data: - value = v.data[v.text_key] - else: - keys = ", ".join(v.data.keys()) - input_name = _info.data["name"] - raise ValueError( - f"The input to '{input_name}' must contain the key '{v.text_key}'." - f"You can set `text_key` to one of the following keys: {keys} or set the value using another Component." - ) - else: - raise ValueError(f"Invalid input type {type(v)}") - if isinstance(value, str): - return value - raise ValueError(f"Invalid value type {type(value)}") + # @field_validator("value") + # @classmethod + # def validate_value(cls, v: Any, _info): + # value = None + # if isinstance(v, str): + # value = v + # elif isinstance(v, Message): + # value = v.text + # elif isinstance(v, Data): + # if v.text_key in v.data: + # value = v.data[v.text_key] + # else: + # keys = ", ".join(v.data.keys()) + # input_name = _info.data["name"] + # raise ValueError( + # f"The input to '{input_name}' must contain the key '{v.text_key}'." + # f"You can set `text_key` to one of the following keys: {keys} or set the value using another Component." + # ) + # else: + # raise ValueError(f"Invalid input type {type(v)}") + # if isinstance(value, str): + # return value + # raise ValueError(f"Invalid value type {type(value)}") class MultilineInput(BaseInputMixin): @@ -108,16 +113,17 @@ class FileInput(BaseInputMixin, ListableInputMixin, FileMixin): InputTypes = Union[ - StrInput, - SecretStrInput, - IntInput, - FloatInput, BoolInput, - NestedDictInput, + # DataInput, # ! Let's add this DictInput, DropdownInput, FileInput, - PromptInput, - MultilineInput, + FloatInput, HandleInput, + IntInput, + MultilineInput, + NestedDictInput, + PromptInput, + SecretStrInput, + StrInput, ] diff --git a/src/frontend/src/utils/styleUtils.ts b/src/frontend/src/utils/styleUtils.ts index fff15809c..adead59c4 100644 --- a/src/frontend/src/utils/styleUtils.ts +++ b/src/frontend/src/utils/styleUtils.ts @@ -247,7 +247,6 @@ export const nodeColors: { [char: string]: string } = { models: "#ab11ab", model_specs: "#6344BE", chains: "#FE7500", - Document: "#7AAE42", list: "#9AAE42", agents: "#903BBE", tools: "#FF3434", @@ -267,14 +266,18 @@ export const nodeColors: { [char: string]: string } = { experimental: "#E6277A", langchain_utilities: "#31A3CC", output_parsers: "#E6A627", - str: "#4367BF", - Text: "#4367BF", - retrievers: "#e6b25a", - unknown: "#9CA3AF", // custom_components: "#ab11ab", - Data: "#9CA3AF", - Message: "#4367BF", - BaseLanguageModel: "#ab11ab", + retrievers: "#e6b25a", + // + str: "#2563eb", + Text: "#2563eb", + unknown: "#9CA3AF", + Document: "#65a30d", + Data: "#dc2626", + Message: "#4f46e5", + Prompt: "#7c3aed", + Embeddings: "#10b981", + BaseLanguageModel: "#c026d3", }; export const nodeNames: { [char: string]: string } = {