add text embedder component (#3663)

* add text embedder component

* [autofix.ci] apply automated fixes

* add embedding similarity component

* [autofix.ci] apply automated fixes

* change text embedder output type to data

*  (similarity.spec.ts): Add end-to-end test for checking similarity between embedding texts in the frontend application.

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: cristhianzl <cristhian.lousa@gmail.com>
This commit is contained in:
Rodrigo Nader 2024-09-03 17:11:18 -03:00 committed by GitHub
commit 978bdf5fec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 459 additions and 2 deletions

View file

@ -0,0 +1,70 @@
from typing import List
import numpy as np
from langflow.custom import Component
from langflow.io import DataInput, DropdownInput, Output
from langflow.schema import Data
class EmbeddingSimilarityComponent(Component):
display_name: str = "Embedding Similarity"
description: str = "Compute selected form of similarity between two embedding vectors."
icon = "equal"
inputs = [
DataInput(
name="embedding_vectors",
display_name="Embedding Vectors",
info="A list containing exactly two data objects with embedding vectors to compare.",
is_list=True,
),
DropdownInput(
name="similarity_metric",
display_name="Similarity Metric",
info="Select the similarity metric to use.",
options=["Cosine Similarity", "Euclidean Distance", "Manhattan Distance"],
value="Cosine Similarity",
),
]
outputs = [
Output(display_name="Similarity Data", name="similarity_data", method="compute_similarity"),
]
def compute_similarity(self) -> Data:
embedding_vectors: List[Data] = self.embedding_vectors
# Assert that the list contains exactly two Data objects
assert len(embedding_vectors) == 2, "Exactly two embedding vectors are required."
embedding_1 = np.array(embedding_vectors[0].data["embeddings"])
embedding_2 = np.array(embedding_vectors[1].data["embeddings"])
if embedding_1.shape != embedding_2.shape:
similarity_score = {"error": "Embeddings must have the same dimensions."}
else:
similarity_metric = self.similarity_metric
if similarity_metric == "Cosine Similarity":
score = np.dot(embedding_1, embedding_2) / (np.linalg.norm(embedding_1) * np.linalg.norm(embedding_2))
similarity_score = {"cosine_similarity": score}
elif similarity_metric == "Euclidean Distance":
score = np.linalg.norm(embedding_1 - embedding_2)
similarity_score = {"euclidean_distance": score}
elif similarity_metric == "Manhattan Distance":
score = np.sum(np.abs(embedding_1 - embedding_2))
similarity_score = {"manhattan_distance": score}
# Create a Data object to encapsulate the similarity score and additional information
similarity_data = Data(
data={
"embedding_1": embedding_vectors[0].data["embeddings"],
"embedding_2": embedding_vectors[1].data["embeddings"],
"similarity_score": similarity_score,
},
text_key="similarity_score",
)
self.status = similarity_data
return similarity_data

View file

@ -19,11 +19,11 @@ import numpy as np
class GoogleGenerativeAIEmbeddingsComponent(Component):
display_name = "Google GenerativeAI Embeddings"
display_name = "Google Generative AI Embeddings"
description = "Connect to Google's generative AI embeddings service using the GoogleGenerativeAIEmbeddings class, found in the langchain-google-genai package."
documentation: str = "https://python.langchain.com/v0.2/docs/integrations/text_embedding/google_generative_ai/"
icon = "Google"
name = "Google GenerativeAI Embeddings"
name = "Google Generative AI Embeddings"
inputs = [
SecretStrInput(name="api_key", display_name="API Key"),

View file

@ -0,0 +1,51 @@
from langflow.custom import Component
from langflow.io import HandleInput, MessageInput, Output
from langflow.field_typing import Embeddings
from langflow.schema.message import Message
from langflow.schema import Data
class TextEmbedderComponent(Component):
display_name: str = "Text Embedder"
description: str = "Generate embeddings for a given message using the specified embedding model."
icon = "binary"
inputs = [
HandleInput(
name="embedding_model",
display_name="Embedding Model",
info="The embedding model to use for generating embeddings.",
input_types=["Embeddings"],
),
MessageInput(
name="message",
display_name="Message",
info="The message to generate embeddings for.",
),
]
outputs = [
Output(display_name="Embedding Data", name="embeddings", method="generate_embeddings"),
]
def generate_embeddings(self) -> Data:
embedding_model: Embeddings = self.embedding_model
message: Message = self.message
# Extract the text content from the message
text_content = message.text
# Generate embeddings using the provided embedding model
embeddings = embedding_model.embed_documents([text_content])
# Assuming the embedding model returns a list of embeddings, we take the first one
if embeddings:
embedding_vector = embeddings[0]
else:
embedding_vector = []
# Create a Data object to encapsulate the results
result_data = Data(data={"text": text_content, "embeddings": embedding_vector})
self.status = {"text": text_content, "embeddings": embedding_vector}
return result_data

View file

@ -0,0 +1,336 @@
import { expect, test } from "@playwright/test";
test("user must be able to check similarity between embedding texts", async ({
page,
}) => {
test.skip(
!process?.env?.OPENAI_API_KEY,
"OPENAI_API_KEY required to run this test",
);
await page.goto("/");
// await page.waitForTimeout(2000);
let modalCount = 0;
try {
const modalTitleElement = await page?.getByTestId("modal-title");
if (modalTitleElement) {
modalCount = await modalTitleElement.count();
}
} catch (error) {
modalCount = 0;
}
while (modalCount === 0) {
await page.getByText("New Project", { exact: true }).click();
await page.waitForTimeout(3000);
modalCount = await page.getByTestId("modal-title")?.count();
}
await page.getByRole("heading", { name: "Blank Flow" }).click();
//first component
await page.getByTestId("extended-disclosure").click();
await page.getByPlaceholder("Search").click();
await page.getByPlaceholder("Search").fill("openai");
// await page.waitForTimeout(1000);
await page
.getByTestId("embeddingsOpenAI Embeddings")
.dragTo(page.locator('//*[@id="react-flow-id"]'));
await page.getByTitle("zoom out").click();
await page
.locator('//*[@id="react-flow-id"]')
.hover()
.then(async () => {
await page.mouse.down();
await page.mouse.move(-800, 300);
});
await page.mouse.up();
//second component
await page
.getByTestId("embeddingsOpenAI Embeddings")
.dragTo(page.locator('//*[@id="react-flow-id"]'));
await page.getByTitle("zoom out").click();
await page
.locator('//*[@id="react-flow-id"]')
.hover()
.then(async () => {
await page.mouse.down();
await page.mouse.move(-800, 300);
});
await page.mouse.up();
//third component
await page.getByTestId("extended-disclosure").click();
await page.getByPlaceholder("Search").click();
await page.getByPlaceholder("Search").fill("text embedder");
// await page.waitForTimeout(1000);
await page
.getByTestId("embeddingsText Embedder")
.dragTo(page.locator('//*[@id="react-flow-id"]'));
await page.getByTitle("zoom out").click();
await page
.locator('//*[@id="react-flow-id"]')
.hover()
.then(async () => {
await page.mouse.down();
await page.mouse.move(-800, 300);
});
await page.mouse.up();
//fourth component
await page
.getByTestId("embeddingsText Embedder")
.dragTo(page.locator('//*[@id="react-flow-id"]'));
await page.getByTitle("zoom out").click();
await page
.locator('//*[@id="react-flow-id"]')
.hover()
.then(async () => {
await page.mouse.down();
await page.mouse.move(-800, 300);
});
await page.mouse.up();
//fifth component
await page.getByTestId("extended-disclosure").click();
await page.getByPlaceholder("Search").click();
await page.getByPlaceholder("Search").fill("embedding similarity");
// await page.waitForTimeout(1000);
await page
.getByTestId("embeddingsEmbedding Similarity")
.dragTo(page.locator('//*[@id="react-flow-id"]'));
await page.getByTitle("zoom out").click();
await page
.locator('//*[@id="react-flow-id"]')
.hover()
.then(async () => {
await page.mouse.down();
await page.mouse.move(-800, 300);
});
await page.mouse.up();
//sisxth component
await page.getByTestId("extended-disclosure").click();
await page.getByPlaceholder("Search").click();
await page.getByPlaceholder("Search").fill("parse data");
// await page.waitForTimeout(1000);
await page
.getByTestId("helpersParse Data")
.dragTo(page.locator('//*[@id="react-flow-id"]'));
await page.getByTitle("zoom out").click();
await page
.locator('//*[@id="react-flow-id"]')
.hover()
.then(async () => {
await page.mouse.down();
await page.mouse.move(-800, 300);
});
await page.mouse.up();
//seventh component
await page.getByTestId("extended-disclosure").click();
await page.getByPlaceholder("Search").click();
await page.getByPlaceholder("Search").fill("text output");
// await page.waitForTimeout(1000);
await page
.getByTestId("outputsText Output")
.dragTo(page.locator('//*[@id="react-flow-id"]'));
await page.getByTitle("zoom out").click();
await page
.locator('//*[@id="react-flow-id"]')
.hover()
.then(async () => {
await page.mouse.down();
await page.mouse.move(-800, 300);
});
await page.mouse.up();
await page.getByTestId("extended-disclosure").click();
await page.getByPlaceholder("Search").click();
await page.getByPlaceholder("Search").fill("filter data");
// await page.waitForTimeout(1000);
await page
.getByTestId("helpersFilter Data")
.dragTo(page.locator('//*[@id="react-flow-id"]'));
await page.getByTitle("zoom out").click();
await page
.locator('//*[@id="react-flow-id"]')
.hover()
.then(async () => {
await page.mouse.down();
await page.mouse.move(-800, 300);
});
await page.mouse.up();
let outdatedComponents = await page.getByTestId("icon-AlertTriangle").count();
while (outdatedComponents > 0) {
await page.getByTestId("icon-AlertTriangle").first().click();
// await page.waitForTimeout(1000);
outdatedComponents = await page.getByTestId("icon-AlertTriangle").count();
}
await page.getByTitle("fit view").click();
await page
.getByTestId("textarea_str_template")
.last()
.fill("{similarity_score}");
await page
.getByTestId("popover-anchor-input-message")
.last()
.fill("datastax");
await page
.getByTestId("popover-anchor-input-message")
.first()
.fill("langflow");
await page
.getByTestId("popover-anchor-input-openai_api_key")
.nth(0)
.fill(process.env.OPENAI_API_KEY ?? "");
await page
.getByTestId("popover-anchor-input-openai_api_key")
.nth(1)
.fill(process.env.OPENAI_API_KEY ?? "");
await page
.getByTestId("inputlist_str_filter_criteria_0")
.nth(0)
.fill("similarity_score");
//connection 1
const openAiEmbeddingOutput_0 = await page
.getByTestId("handle-openaiembeddings-shownode-embeddings-right")
.nth(2);
await openAiEmbeddingOutput_0.hover();
await page.mouse.down();
const textEmbedderInput_0 = await page
.getByTestId("handle-textembeddercomponent-shownode-embedding model-left")
.nth(0);
await textEmbedderInput_0.hover();
await page.mouse.up();
//connection 2
const openAiEmbeddingOutput_1 = await page
.getByTestId("handle-openaiembeddings-shownode-embeddings-right")
.nth(0);
await openAiEmbeddingOutput_1.hover();
await page.mouse.down();
const textEmbedderInput_1 = await page
.getByTestId("handle-textembeddercomponent-shownode-embedding model-left")
.nth(1);
await textEmbedderInput_1.hover();
await page.mouse.up();
//connection 3
const textEmbedderOutput_0 = await page
.getByTestId("handle-textembeddercomponent-shownode-embedding data-right")
.nth(0);
await textEmbedderOutput_0.hover();
await page.mouse.down();
const embeddingSimilarityInput = await page
.getByTestId(
"handle-embeddingsimilaritycomponent-shownode-embedding vectors-left",
)
.nth(0);
await embeddingSimilarityInput.hover();
await page.mouse.up();
//connection 4
const textEmbedderOutput_1 = await page
.getByTestId("handle-textembeddercomponent-shownode-embedding data-right")
.nth(2);
await textEmbedderOutput_1.hover();
await page.mouse.down();
await embeddingSimilarityInput.hover();
await page.mouse.up();
//connection 5
const embeddingSimilarityOutput = await page
.getByTestId(
"handle-embeddingsimilaritycomponent-shownode-similarity data-right",
)
.nth(0);
await embeddingSimilarityOutput.hover();
await page.mouse.down();
const filterDataInput = await page
.getByTestId("handle-filterdata-shownode-data-left")
.nth(0);
await filterDataInput.hover();
await page.mouse.up();
//connection 6
const filterDataOutput = await page
.getByTestId("handle-filterdata-shownode-filtered data-right")
.nth(0);
await filterDataOutput.hover();
await page.mouse.down();
const parseDataInput = await page
.getByTestId("handle-parsedata-shownode-data-left")
.nth(0);
await parseDataInput.hover();
await page.mouse.up();
//connection 7
const parseDataOutput = await page
.getByTestId("handle-parsedata-shownode-text-right")
.nth(0);
await parseDataOutput.hover();
await page.mouse.down();
const textOutputInput = await page
.getByTestId("handle-textoutput-shownode-text-left")
.nth(0);
await textOutputInput.hover();
await page.mouse.up();
await page.getByTestId("button_run_text output").click();
await page.waitForSelector("text=built successfully", { timeout: 30000 });
await page.waitForTimeout(1000);
await page.getByText("Playground", { exact: true }).click();
await page.waitForTimeout(1000);
await page
.getByPlaceholder("Empty")
.waitFor({ state: "visible", timeout: 30000 });
const valueSimilarity = await page.getByPlaceholder("Empty").textContent();
expect(valueSimilarity).toContain("cosine_similarity");
const valueLength = valueSimilarity!.length;
expect(valueLength).toBeGreaterThan(20);
});