From 978bdf5fec1126ea0c493a72df9ede6eb94fa34a Mon Sep 17 00:00:00 2001 From: Rodrigo Nader Date: Tue, 3 Sep 2024 17:11:18 -0300 Subject: [PATCH] add text embedder component (#3663) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add text embedder component * [autofix.ci] apply automated fixes * add embedding similarity component * [autofix.ci] apply automated fixes * change text embedder output type to data * ✨ (similarity.spec.ts): Add end-to-end test for checking similarity between embedding texts in the frontend application. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: cristhianzl --- .../embeddings/EmbeddingSimilarity.py | 70 ++++ .../GoogleGenerativeAIEmbeddings.py | 4 +- .../components/embeddings/TextEmbedder.py | 51 +++ .../tests/end-to-end/similarity.spec.ts | 336 ++++++++++++++++++ 4 files changed, 459 insertions(+), 2 deletions(-) create mode 100644 src/backend/base/langflow/components/embeddings/EmbeddingSimilarity.py create mode 100644 src/backend/base/langflow/components/embeddings/TextEmbedder.py create mode 100644 src/frontend/tests/end-to-end/similarity.spec.ts diff --git a/src/backend/base/langflow/components/embeddings/EmbeddingSimilarity.py b/src/backend/base/langflow/components/embeddings/EmbeddingSimilarity.py new file mode 100644 index 000000000..9bbdac8dc --- /dev/null +++ b/src/backend/base/langflow/components/embeddings/EmbeddingSimilarity.py @@ -0,0 +1,70 @@ +from typing import List +import numpy as np +from langflow.custom import Component +from langflow.io import DataInput, DropdownInput, Output +from langflow.schema import Data + + +class EmbeddingSimilarityComponent(Component): + display_name: str = "Embedding Similarity" + description: str = "Compute selected form of similarity between two embedding vectors." + icon = "equal" + + inputs = [ + DataInput( + name="embedding_vectors", + display_name="Embedding Vectors", + info="A list containing exactly two data objects with embedding vectors to compare.", + is_list=True, + ), + DropdownInput( + name="similarity_metric", + display_name="Similarity Metric", + info="Select the similarity metric to use.", + options=["Cosine Similarity", "Euclidean Distance", "Manhattan Distance"], + value="Cosine Similarity", + ), + ] + + outputs = [ + Output(display_name="Similarity Data", name="similarity_data", method="compute_similarity"), + ] + + def compute_similarity(self) -> Data: + embedding_vectors: List[Data] = self.embedding_vectors + + # Assert that the list contains exactly two Data objects + assert len(embedding_vectors) == 2, "Exactly two embedding vectors are required." + + embedding_1 = np.array(embedding_vectors[0].data["embeddings"]) + embedding_2 = np.array(embedding_vectors[1].data["embeddings"]) + + if embedding_1.shape != embedding_2.shape: + similarity_score = {"error": "Embeddings must have the same dimensions."} + else: + similarity_metric = self.similarity_metric + + if similarity_metric == "Cosine Similarity": + score = np.dot(embedding_1, embedding_2) / (np.linalg.norm(embedding_1) * np.linalg.norm(embedding_2)) + similarity_score = {"cosine_similarity": score} + + elif similarity_metric == "Euclidean Distance": + score = np.linalg.norm(embedding_1 - embedding_2) + similarity_score = {"euclidean_distance": score} + + elif similarity_metric == "Manhattan Distance": + score = np.sum(np.abs(embedding_1 - embedding_2)) + similarity_score = {"manhattan_distance": score} + + # Create a Data object to encapsulate the similarity score and additional information + similarity_data = Data( + data={ + "embedding_1": embedding_vectors[0].data["embeddings"], + "embedding_2": embedding_vectors[1].data["embeddings"], + "similarity_score": similarity_score, + }, + text_key="similarity_score", + ) + + self.status = similarity_data + return similarity_data diff --git a/src/backend/base/langflow/components/embeddings/GoogleGenerativeAIEmbeddings.py b/src/backend/base/langflow/components/embeddings/GoogleGenerativeAIEmbeddings.py index c4150cae1..72245003c 100644 --- a/src/backend/base/langflow/components/embeddings/GoogleGenerativeAIEmbeddings.py +++ b/src/backend/base/langflow/components/embeddings/GoogleGenerativeAIEmbeddings.py @@ -19,11 +19,11 @@ import numpy as np class GoogleGenerativeAIEmbeddingsComponent(Component): - display_name = "Google GenerativeAI Embeddings" + display_name = "Google Generative AI Embeddings" description = "Connect to Google's generative AI embeddings service using the GoogleGenerativeAIEmbeddings class, found in the langchain-google-genai package." documentation: str = "https://python.langchain.com/v0.2/docs/integrations/text_embedding/google_generative_ai/" icon = "Google" - name = "Google GenerativeAI Embeddings" + name = "Google Generative AI Embeddings" inputs = [ SecretStrInput(name="api_key", display_name="API Key"), diff --git a/src/backend/base/langflow/components/embeddings/TextEmbedder.py b/src/backend/base/langflow/components/embeddings/TextEmbedder.py new file mode 100644 index 000000000..2fc1ab631 --- /dev/null +++ b/src/backend/base/langflow/components/embeddings/TextEmbedder.py @@ -0,0 +1,51 @@ +from langflow.custom import Component +from langflow.io import HandleInput, MessageInput, Output +from langflow.field_typing import Embeddings +from langflow.schema.message import Message +from langflow.schema import Data + + +class TextEmbedderComponent(Component): + display_name: str = "Text Embedder" + description: str = "Generate embeddings for a given message using the specified embedding model." + icon = "binary" + + inputs = [ + HandleInput( + name="embedding_model", + display_name="Embedding Model", + info="The embedding model to use for generating embeddings.", + input_types=["Embeddings"], + ), + MessageInput( + name="message", + display_name="Message", + info="The message to generate embeddings for.", + ), + ] + + outputs = [ + Output(display_name="Embedding Data", name="embeddings", method="generate_embeddings"), + ] + + def generate_embeddings(self) -> Data: + embedding_model: Embeddings = self.embedding_model + message: Message = self.message + + # Extract the text content from the message + text_content = message.text + + # Generate embeddings using the provided embedding model + embeddings = embedding_model.embed_documents([text_content]) + + # Assuming the embedding model returns a list of embeddings, we take the first one + if embeddings: + embedding_vector = embeddings[0] + else: + embedding_vector = [] + + # Create a Data object to encapsulate the results + result_data = Data(data={"text": text_content, "embeddings": embedding_vector}) + + self.status = {"text": text_content, "embeddings": embedding_vector} + return result_data diff --git a/src/frontend/tests/end-to-end/similarity.spec.ts b/src/frontend/tests/end-to-end/similarity.spec.ts new file mode 100644 index 000000000..0c18f4e36 --- /dev/null +++ b/src/frontend/tests/end-to-end/similarity.spec.ts @@ -0,0 +1,336 @@ +import { expect, test } from "@playwright/test"; + +test("user must be able to check similarity between embedding texts", async ({ + page, +}) => { + test.skip( + !process?.env?.OPENAI_API_KEY, + "OPENAI_API_KEY required to run this test", + ); + + await page.goto("/"); + // await page.waitForTimeout(2000); + + let modalCount = 0; + try { + const modalTitleElement = await page?.getByTestId("modal-title"); + if (modalTitleElement) { + modalCount = await modalTitleElement.count(); + } + } catch (error) { + modalCount = 0; + } + + while (modalCount === 0) { + await page.getByText("New Project", { exact: true }).click(); + await page.waitForTimeout(3000); + modalCount = await page.getByTestId("modal-title")?.count(); + } + + await page.getByRole("heading", { name: "Blank Flow" }).click(); + + //first component + + await page.getByTestId("extended-disclosure").click(); + await page.getByPlaceholder("Search").click(); + await page.getByPlaceholder("Search").fill("openai"); + // await page.waitForTimeout(1000); + + await page + .getByTestId("embeddingsOpenAI Embeddings") + .dragTo(page.locator('//*[@id="react-flow-id"]')); + + await page.getByTitle("zoom out").click(); + await page + .locator('//*[@id="react-flow-id"]') + .hover() + .then(async () => { + await page.mouse.down(); + await page.mouse.move(-800, 300); + }); + + await page.mouse.up(); + + //second component + + await page + .getByTestId("embeddingsOpenAI Embeddings") + .dragTo(page.locator('//*[@id="react-flow-id"]')); + + await page.getByTitle("zoom out").click(); + await page + .locator('//*[@id="react-flow-id"]') + .hover() + .then(async () => { + await page.mouse.down(); + await page.mouse.move(-800, 300); + }); + + await page.mouse.up(); + + //third component + + await page.getByTestId("extended-disclosure").click(); + await page.getByPlaceholder("Search").click(); + await page.getByPlaceholder("Search").fill("text embedder"); + // await page.waitForTimeout(1000); + + await page + .getByTestId("embeddingsText Embedder") + .dragTo(page.locator('//*[@id="react-flow-id"]')); + + await page.getByTitle("zoom out").click(); + await page + .locator('//*[@id="react-flow-id"]') + .hover() + .then(async () => { + await page.mouse.down(); + await page.mouse.move(-800, 300); + }); + + await page.mouse.up(); + + //fourth component + + await page + .getByTestId("embeddingsText Embedder") + .dragTo(page.locator('//*[@id="react-flow-id"]')); + + await page.getByTitle("zoom out").click(); + await page + .locator('//*[@id="react-flow-id"]') + .hover() + .then(async () => { + await page.mouse.down(); + await page.mouse.move(-800, 300); + }); + + await page.mouse.up(); + + //fifth component + + await page.getByTestId("extended-disclosure").click(); + await page.getByPlaceholder("Search").click(); + await page.getByPlaceholder("Search").fill("embedding similarity"); + // await page.waitForTimeout(1000); + + await page + .getByTestId("embeddingsEmbedding Similarity") + .dragTo(page.locator('//*[@id="react-flow-id"]')); + + await page.getByTitle("zoom out").click(); + await page + .locator('//*[@id="react-flow-id"]') + .hover() + .then(async () => { + await page.mouse.down(); + await page.mouse.move(-800, 300); + }); + + await page.mouse.up(); + + //sisxth component + + await page.getByTestId("extended-disclosure").click(); + await page.getByPlaceholder("Search").click(); + await page.getByPlaceholder("Search").fill("parse data"); + // await page.waitForTimeout(1000); + + await page + .getByTestId("helpersParse Data") + .dragTo(page.locator('//*[@id="react-flow-id"]')); + + await page.getByTitle("zoom out").click(); + await page + .locator('//*[@id="react-flow-id"]') + .hover() + .then(async () => { + await page.mouse.down(); + await page.mouse.move(-800, 300); + }); + + await page.mouse.up(); + + //seventh component + + await page.getByTestId("extended-disclosure").click(); + await page.getByPlaceholder("Search").click(); + await page.getByPlaceholder("Search").fill("text output"); + // await page.waitForTimeout(1000); + + await page + .getByTestId("outputsText Output") + .dragTo(page.locator('//*[@id="react-flow-id"]')); + + await page.getByTitle("zoom out").click(); + await page + .locator('//*[@id="react-flow-id"]') + .hover() + .then(async () => { + await page.mouse.down(); + await page.mouse.move(-800, 300); + }); + + await page.mouse.up(); + + await page.getByTestId("extended-disclosure").click(); + await page.getByPlaceholder("Search").click(); + await page.getByPlaceholder("Search").fill("filter data"); + // await page.waitForTimeout(1000); + + await page + .getByTestId("helpersFilter Data") + .dragTo(page.locator('//*[@id="react-flow-id"]')); + + await page.getByTitle("zoom out").click(); + await page + .locator('//*[@id="react-flow-id"]') + .hover() + .then(async () => { + await page.mouse.down(); + await page.mouse.move(-800, 300); + }); + + await page.mouse.up(); + + let outdatedComponents = await page.getByTestId("icon-AlertTriangle").count(); + + while (outdatedComponents > 0) { + await page.getByTestId("icon-AlertTriangle").first().click(); + // await page.waitForTimeout(1000); + outdatedComponents = await page.getByTestId("icon-AlertTriangle").count(); + } + + await page.getByTitle("fit view").click(); + + await page + .getByTestId("textarea_str_template") + .last() + .fill("{similarity_score}"); + + await page + .getByTestId("popover-anchor-input-message") + .last() + .fill("datastax"); + await page + .getByTestId("popover-anchor-input-message") + .first() + .fill("langflow"); + + await page + .getByTestId("popover-anchor-input-openai_api_key") + .nth(0) + .fill(process.env.OPENAI_API_KEY ?? ""); + + await page + .getByTestId("popover-anchor-input-openai_api_key") + .nth(1) + .fill(process.env.OPENAI_API_KEY ?? ""); + + await page + .getByTestId("inputlist_str_filter_criteria_0") + .nth(0) + .fill("similarity_score"); + + //connection 1 + const openAiEmbeddingOutput_0 = await page + .getByTestId("handle-openaiembeddings-shownode-embeddings-right") + .nth(2); + await openAiEmbeddingOutput_0.hover(); + await page.mouse.down(); + const textEmbedderInput_0 = await page + .getByTestId("handle-textembeddercomponent-shownode-embedding model-left") + .nth(0); + await textEmbedderInput_0.hover(); + await page.mouse.up(); + + //connection 2 + const openAiEmbeddingOutput_1 = await page + .getByTestId("handle-openaiembeddings-shownode-embeddings-right") + .nth(0); + await openAiEmbeddingOutput_1.hover(); + await page.mouse.down(); + const textEmbedderInput_1 = await page + .getByTestId("handle-textembeddercomponent-shownode-embedding model-left") + .nth(1); + await textEmbedderInput_1.hover(); + await page.mouse.up(); + + //connection 3 + const textEmbedderOutput_0 = await page + .getByTestId("handle-textembeddercomponent-shownode-embedding data-right") + .nth(0); + await textEmbedderOutput_0.hover(); + await page.mouse.down(); + const embeddingSimilarityInput = await page + .getByTestId( + "handle-embeddingsimilaritycomponent-shownode-embedding vectors-left", + ) + .nth(0); + await embeddingSimilarityInput.hover(); + await page.mouse.up(); + + //connection 4 + const textEmbedderOutput_1 = await page + .getByTestId("handle-textembeddercomponent-shownode-embedding data-right") + .nth(2); + await textEmbedderOutput_1.hover(); + await page.mouse.down(); + await embeddingSimilarityInput.hover(); + await page.mouse.up(); + + //connection 5 + const embeddingSimilarityOutput = await page + .getByTestId( + "handle-embeddingsimilaritycomponent-shownode-similarity data-right", + ) + .nth(0); + await embeddingSimilarityOutput.hover(); + await page.mouse.down(); + const filterDataInput = await page + .getByTestId("handle-filterdata-shownode-data-left") + .nth(0); + await filterDataInput.hover(); + await page.mouse.up(); + + //connection 6 + const filterDataOutput = await page + .getByTestId("handle-filterdata-shownode-filtered data-right") + .nth(0); + await filterDataOutput.hover(); + await page.mouse.down(); + const parseDataInput = await page + .getByTestId("handle-parsedata-shownode-data-left") + .nth(0); + await parseDataInput.hover(); + await page.mouse.up(); + + //connection 7 + const parseDataOutput = await page + .getByTestId("handle-parsedata-shownode-text-right") + .nth(0); + await parseDataOutput.hover(); + await page.mouse.down(); + const textOutputInput = await page + .getByTestId("handle-textoutput-shownode-text-left") + .nth(0); + await textOutputInput.hover(); + await page.mouse.up(); + + await page.getByTestId("button_run_text output").click(); + + await page.waitForSelector("text=built successfully", { timeout: 30000 }); + + await page.waitForTimeout(1000); + await page.getByText("Playground", { exact: true }).click(); + await page.waitForTimeout(1000); + + await page + .getByPlaceholder("Empty") + .waitFor({ state: "visible", timeout: 30000 }); + + const valueSimilarity = await page.getByPlaceholder("Empty").textContent(); + expect(valueSimilarity).toContain("cosine_similarity"); + const valueLength = valueSimilarity!.length; + expect(valueLength).toBeGreaterThan(20); +});