add text embedder component (#3663)
* add text embedder component
* [autofix.ci] apply automated fixes
* add embedding similarity component
* [autofix.ci] apply automated fixes
* change text embedder output type to data
* ✨ (similarity.spec.ts): Add end-to-end test for checking similarity between embedding texts in the frontend application.
---------
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: cristhianzl <cristhian.lousa@gmail.com>
This commit is contained in:
parent
61e5bbb482
commit
978bdf5fec
4 changed files with 459 additions and 2 deletions
|
|
@ -0,0 +1,70 @@
|
|||
from typing import List
|
||||
import numpy as np
|
||||
from langflow.custom import Component
|
||||
from langflow.io import DataInput, DropdownInput, Output
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
class EmbeddingSimilarityComponent(Component):
|
||||
display_name: str = "Embedding Similarity"
|
||||
description: str = "Compute selected form of similarity between two embedding vectors."
|
||||
icon = "equal"
|
||||
|
||||
inputs = [
|
||||
DataInput(
|
||||
name="embedding_vectors",
|
||||
display_name="Embedding Vectors",
|
||||
info="A list containing exactly two data objects with embedding vectors to compare.",
|
||||
is_list=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="similarity_metric",
|
||||
display_name="Similarity Metric",
|
||||
info="Select the similarity metric to use.",
|
||||
options=["Cosine Similarity", "Euclidean Distance", "Manhattan Distance"],
|
||||
value="Cosine Similarity",
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Similarity Data", name="similarity_data", method="compute_similarity"),
|
||||
]
|
||||
|
||||
def compute_similarity(self) -> Data:
|
||||
embedding_vectors: List[Data] = self.embedding_vectors
|
||||
|
||||
# Assert that the list contains exactly two Data objects
|
||||
assert len(embedding_vectors) == 2, "Exactly two embedding vectors are required."
|
||||
|
||||
embedding_1 = np.array(embedding_vectors[0].data["embeddings"])
|
||||
embedding_2 = np.array(embedding_vectors[1].data["embeddings"])
|
||||
|
||||
if embedding_1.shape != embedding_2.shape:
|
||||
similarity_score = {"error": "Embeddings must have the same dimensions."}
|
||||
else:
|
||||
similarity_metric = self.similarity_metric
|
||||
|
||||
if similarity_metric == "Cosine Similarity":
|
||||
score = np.dot(embedding_1, embedding_2) / (np.linalg.norm(embedding_1) * np.linalg.norm(embedding_2))
|
||||
similarity_score = {"cosine_similarity": score}
|
||||
|
||||
elif similarity_metric == "Euclidean Distance":
|
||||
score = np.linalg.norm(embedding_1 - embedding_2)
|
||||
similarity_score = {"euclidean_distance": score}
|
||||
|
||||
elif similarity_metric == "Manhattan Distance":
|
||||
score = np.sum(np.abs(embedding_1 - embedding_2))
|
||||
similarity_score = {"manhattan_distance": score}
|
||||
|
||||
# Create a Data object to encapsulate the similarity score and additional information
|
||||
similarity_data = Data(
|
||||
data={
|
||||
"embedding_1": embedding_vectors[0].data["embeddings"],
|
||||
"embedding_2": embedding_vectors[1].data["embeddings"],
|
||||
"similarity_score": similarity_score,
|
||||
},
|
||||
text_key="similarity_score",
|
||||
)
|
||||
|
||||
self.status = similarity_data
|
||||
return similarity_data
|
||||
|
|
@ -19,11 +19,11 @@ import numpy as np
|
|||
|
||||
|
||||
class GoogleGenerativeAIEmbeddingsComponent(Component):
|
||||
display_name = "Google GenerativeAI Embeddings"
|
||||
display_name = "Google Generative AI Embeddings"
|
||||
description = "Connect to Google's generative AI embeddings service using the GoogleGenerativeAIEmbeddings class, found in the langchain-google-genai package."
|
||||
documentation: str = "https://python.langchain.com/v0.2/docs/integrations/text_embedding/google_generative_ai/"
|
||||
icon = "Google"
|
||||
name = "Google GenerativeAI Embeddings"
|
||||
name = "Google Generative AI Embeddings"
|
||||
|
||||
inputs = [
|
||||
SecretStrInput(name="api_key", display_name="API Key"),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
from langflow.custom import Component
|
||||
from langflow.io import HandleInput, MessageInput, Output
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.schema.message import Message
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
class TextEmbedderComponent(Component):
|
||||
display_name: str = "Text Embedder"
|
||||
description: str = "Generate embeddings for a given message using the specified embedding model."
|
||||
icon = "binary"
|
||||
|
||||
inputs = [
|
||||
HandleInput(
|
||||
name="embedding_model",
|
||||
display_name="Embedding Model",
|
||||
info="The embedding model to use for generating embeddings.",
|
||||
input_types=["Embeddings"],
|
||||
),
|
||||
MessageInput(
|
||||
name="message",
|
||||
display_name="Message",
|
||||
info="The message to generate embeddings for.",
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Embedding Data", name="embeddings", method="generate_embeddings"),
|
||||
]
|
||||
|
||||
def generate_embeddings(self) -> Data:
|
||||
embedding_model: Embeddings = self.embedding_model
|
||||
message: Message = self.message
|
||||
|
||||
# Extract the text content from the message
|
||||
text_content = message.text
|
||||
|
||||
# Generate embeddings using the provided embedding model
|
||||
embeddings = embedding_model.embed_documents([text_content])
|
||||
|
||||
# Assuming the embedding model returns a list of embeddings, we take the first one
|
||||
if embeddings:
|
||||
embedding_vector = embeddings[0]
|
||||
else:
|
||||
embedding_vector = []
|
||||
|
||||
# Create a Data object to encapsulate the results
|
||||
result_data = Data(data={"text": text_content, "embeddings": embedding_vector})
|
||||
|
||||
self.status = {"text": text_content, "embeddings": embedding_vector}
|
||||
return result_data
|
||||
336
src/frontend/tests/end-to-end/similarity.spec.ts
Normal file
336
src/frontend/tests/end-to-end/similarity.spec.ts
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
import { expect, test } from "@playwright/test";
|
||||
|
||||
test("user must be able to check similarity between embedding texts", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.skip(
|
||||
!process?.env?.OPENAI_API_KEY,
|
||||
"OPENAI_API_KEY required to run this test",
|
||||
);
|
||||
|
||||
await page.goto("/");
|
||||
// await page.waitForTimeout(2000);
|
||||
|
||||
let modalCount = 0;
|
||||
try {
|
||||
const modalTitleElement = await page?.getByTestId("modal-title");
|
||||
if (modalTitleElement) {
|
||||
modalCount = await modalTitleElement.count();
|
||||
}
|
||||
} catch (error) {
|
||||
modalCount = 0;
|
||||
}
|
||||
|
||||
while (modalCount === 0) {
|
||||
await page.getByText("New Project", { exact: true }).click();
|
||||
await page.waitForTimeout(3000);
|
||||
modalCount = await page.getByTestId("modal-title")?.count();
|
||||
}
|
||||
|
||||
await page.getByRole("heading", { name: "Blank Flow" }).click();
|
||||
|
||||
//first component
|
||||
|
||||
await page.getByTestId("extended-disclosure").click();
|
||||
await page.getByPlaceholder("Search").click();
|
||||
await page.getByPlaceholder("Search").fill("openai");
|
||||
// await page.waitForTimeout(1000);
|
||||
|
||||
await page
|
||||
.getByTestId("embeddingsOpenAI Embeddings")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'));
|
||||
|
||||
await page.getByTitle("zoom out").click();
|
||||
await page
|
||||
.locator('//*[@id="react-flow-id"]')
|
||||
.hover()
|
||||
.then(async () => {
|
||||
await page.mouse.down();
|
||||
await page.mouse.move(-800, 300);
|
||||
});
|
||||
|
||||
await page.mouse.up();
|
||||
|
||||
//second component
|
||||
|
||||
await page
|
||||
.getByTestId("embeddingsOpenAI Embeddings")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'));
|
||||
|
||||
await page.getByTitle("zoom out").click();
|
||||
await page
|
||||
.locator('//*[@id="react-flow-id"]')
|
||||
.hover()
|
||||
.then(async () => {
|
||||
await page.mouse.down();
|
||||
await page.mouse.move(-800, 300);
|
||||
});
|
||||
|
||||
await page.mouse.up();
|
||||
|
||||
//third component
|
||||
|
||||
await page.getByTestId("extended-disclosure").click();
|
||||
await page.getByPlaceholder("Search").click();
|
||||
await page.getByPlaceholder("Search").fill("text embedder");
|
||||
// await page.waitForTimeout(1000);
|
||||
|
||||
await page
|
||||
.getByTestId("embeddingsText Embedder")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'));
|
||||
|
||||
await page.getByTitle("zoom out").click();
|
||||
await page
|
||||
.locator('//*[@id="react-flow-id"]')
|
||||
.hover()
|
||||
.then(async () => {
|
||||
await page.mouse.down();
|
||||
await page.mouse.move(-800, 300);
|
||||
});
|
||||
|
||||
await page.mouse.up();
|
||||
|
||||
//fourth component
|
||||
|
||||
await page
|
||||
.getByTestId("embeddingsText Embedder")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'));
|
||||
|
||||
await page.getByTitle("zoom out").click();
|
||||
await page
|
||||
.locator('//*[@id="react-flow-id"]')
|
||||
.hover()
|
||||
.then(async () => {
|
||||
await page.mouse.down();
|
||||
await page.mouse.move(-800, 300);
|
||||
});
|
||||
|
||||
await page.mouse.up();
|
||||
|
||||
//fifth component
|
||||
|
||||
await page.getByTestId("extended-disclosure").click();
|
||||
await page.getByPlaceholder("Search").click();
|
||||
await page.getByPlaceholder("Search").fill("embedding similarity");
|
||||
// await page.waitForTimeout(1000);
|
||||
|
||||
await page
|
||||
.getByTestId("embeddingsEmbedding Similarity")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'));
|
||||
|
||||
await page.getByTitle("zoom out").click();
|
||||
await page
|
||||
.locator('//*[@id="react-flow-id"]')
|
||||
.hover()
|
||||
.then(async () => {
|
||||
await page.mouse.down();
|
||||
await page.mouse.move(-800, 300);
|
||||
});
|
||||
|
||||
await page.mouse.up();
|
||||
|
||||
//sisxth component
|
||||
|
||||
await page.getByTestId("extended-disclosure").click();
|
||||
await page.getByPlaceholder("Search").click();
|
||||
await page.getByPlaceholder("Search").fill("parse data");
|
||||
// await page.waitForTimeout(1000);
|
||||
|
||||
await page
|
||||
.getByTestId("helpersParse Data")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'));
|
||||
|
||||
await page.getByTitle("zoom out").click();
|
||||
await page
|
||||
.locator('//*[@id="react-flow-id"]')
|
||||
.hover()
|
||||
.then(async () => {
|
||||
await page.mouse.down();
|
||||
await page.mouse.move(-800, 300);
|
||||
});
|
||||
|
||||
await page.mouse.up();
|
||||
|
||||
//seventh component
|
||||
|
||||
await page.getByTestId("extended-disclosure").click();
|
||||
await page.getByPlaceholder("Search").click();
|
||||
await page.getByPlaceholder("Search").fill("text output");
|
||||
// await page.waitForTimeout(1000);
|
||||
|
||||
await page
|
||||
.getByTestId("outputsText Output")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'));
|
||||
|
||||
await page.getByTitle("zoom out").click();
|
||||
await page
|
||||
.locator('//*[@id="react-flow-id"]')
|
||||
.hover()
|
||||
.then(async () => {
|
||||
await page.mouse.down();
|
||||
await page.mouse.move(-800, 300);
|
||||
});
|
||||
|
||||
await page.mouse.up();
|
||||
|
||||
await page.getByTestId("extended-disclosure").click();
|
||||
await page.getByPlaceholder("Search").click();
|
||||
await page.getByPlaceholder("Search").fill("filter data");
|
||||
// await page.waitForTimeout(1000);
|
||||
|
||||
await page
|
||||
.getByTestId("helpersFilter Data")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'));
|
||||
|
||||
await page.getByTitle("zoom out").click();
|
||||
await page
|
||||
.locator('//*[@id="react-flow-id"]')
|
||||
.hover()
|
||||
.then(async () => {
|
||||
await page.mouse.down();
|
||||
await page.mouse.move(-800, 300);
|
||||
});
|
||||
|
||||
await page.mouse.up();
|
||||
|
||||
let outdatedComponents = await page.getByTestId("icon-AlertTriangle").count();
|
||||
|
||||
while (outdatedComponents > 0) {
|
||||
await page.getByTestId("icon-AlertTriangle").first().click();
|
||||
// await page.waitForTimeout(1000);
|
||||
outdatedComponents = await page.getByTestId("icon-AlertTriangle").count();
|
||||
}
|
||||
|
||||
await page.getByTitle("fit view").click();
|
||||
|
||||
await page
|
||||
.getByTestId("textarea_str_template")
|
||||
.last()
|
||||
.fill("{similarity_score}");
|
||||
|
||||
await page
|
||||
.getByTestId("popover-anchor-input-message")
|
||||
.last()
|
||||
.fill("datastax");
|
||||
await page
|
||||
.getByTestId("popover-anchor-input-message")
|
||||
.first()
|
||||
.fill("langflow");
|
||||
|
||||
await page
|
||||
.getByTestId("popover-anchor-input-openai_api_key")
|
||||
.nth(0)
|
||||
.fill(process.env.OPENAI_API_KEY ?? "");
|
||||
|
||||
await page
|
||||
.getByTestId("popover-anchor-input-openai_api_key")
|
||||
.nth(1)
|
||||
.fill(process.env.OPENAI_API_KEY ?? "");
|
||||
|
||||
await page
|
||||
.getByTestId("inputlist_str_filter_criteria_0")
|
||||
.nth(0)
|
||||
.fill("similarity_score");
|
||||
|
||||
//connection 1
|
||||
const openAiEmbeddingOutput_0 = await page
|
||||
.getByTestId("handle-openaiembeddings-shownode-embeddings-right")
|
||||
.nth(2);
|
||||
await openAiEmbeddingOutput_0.hover();
|
||||
await page.mouse.down();
|
||||
const textEmbedderInput_0 = await page
|
||||
.getByTestId("handle-textembeddercomponent-shownode-embedding model-left")
|
||||
.nth(0);
|
||||
await textEmbedderInput_0.hover();
|
||||
await page.mouse.up();
|
||||
|
||||
//connection 2
|
||||
const openAiEmbeddingOutput_1 = await page
|
||||
.getByTestId("handle-openaiembeddings-shownode-embeddings-right")
|
||||
.nth(0);
|
||||
await openAiEmbeddingOutput_1.hover();
|
||||
await page.mouse.down();
|
||||
const textEmbedderInput_1 = await page
|
||||
.getByTestId("handle-textembeddercomponent-shownode-embedding model-left")
|
||||
.nth(1);
|
||||
await textEmbedderInput_1.hover();
|
||||
await page.mouse.up();
|
||||
|
||||
//connection 3
|
||||
const textEmbedderOutput_0 = await page
|
||||
.getByTestId("handle-textembeddercomponent-shownode-embedding data-right")
|
||||
.nth(0);
|
||||
await textEmbedderOutput_0.hover();
|
||||
await page.mouse.down();
|
||||
const embeddingSimilarityInput = await page
|
||||
.getByTestId(
|
||||
"handle-embeddingsimilaritycomponent-shownode-embedding vectors-left",
|
||||
)
|
||||
.nth(0);
|
||||
await embeddingSimilarityInput.hover();
|
||||
await page.mouse.up();
|
||||
|
||||
//connection 4
|
||||
const textEmbedderOutput_1 = await page
|
||||
.getByTestId("handle-textembeddercomponent-shownode-embedding data-right")
|
||||
.nth(2);
|
||||
await textEmbedderOutput_1.hover();
|
||||
await page.mouse.down();
|
||||
await embeddingSimilarityInput.hover();
|
||||
await page.mouse.up();
|
||||
|
||||
//connection 5
|
||||
const embeddingSimilarityOutput = await page
|
||||
.getByTestId(
|
||||
"handle-embeddingsimilaritycomponent-shownode-similarity data-right",
|
||||
)
|
||||
.nth(0);
|
||||
await embeddingSimilarityOutput.hover();
|
||||
await page.mouse.down();
|
||||
const filterDataInput = await page
|
||||
.getByTestId("handle-filterdata-shownode-data-left")
|
||||
.nth(0);
|
||||
await filterDataInput.hover();
|
||||
await page.mouse.up();
|
||||
|
||||
//connection 6
|
||||
const filterDataOutput = await page
|
||||
.getByTestId("handle-filterdata-shownode-filtered data-right")
|
||||
.nth(0);
|
||||
await filterDataOutput.hover();
|
||||
await page.mouse.down();
|
||||
const parseDataInput = await page
|
||||
.getByTestId("handle-parsedata-shownode-data-left")
|
||||
.nth(0);
|
||||
await parseDataInput.hover();
|
||||
await page.mouse.up();
|
||||
|
||||
//connection 7
|
||||
const parseDataOutput = await page
|
||||
.getByTestId("handle-parsedata-shownode-text-right")
|
||||
.nth(0);
|
||||
await parseDataOutput.hover();
|
||||
await page.mouse.down();
|
||||
const textOutputInput = await page
|
||||
.getByTestId("handle-textoutput-shownode-text-left")
|
||||
.nth(0);
|
||||
await textOutputInput.hover();
|
||||
await page.mouse.up();
|
||||
|
||||
await page.getByTestId("button_run_text output").click();
|
||||
|
||||
await page.waitForSelector("text=built successfully", { timeout: 30000 });
|
||||
|
||||
await page.waitForTimeout(1000);
|
||||
await page.getByText("Playground", { exact: true }).click();
|
||||
await page.waitForTimeout(1000);
|
||||
|
||||
await page
|
||||
.getByPlaceholder("Empty")
|
||||
.waitFor({ state: "visible", timeout: 30000 });
|
||||
|
||||
const valueSimilarity = await page.getByPlaceholder("Empty").textContent();
|
||||
expect(valueSimilarity).toContain("cosine_similarity");
|
||||
const valueLength = valueSimilarity!.length;
|
||||
expect(valueLength).toBeGreaterThan(20);
|
||||
});
|
||||
Loading…
Add table
Add a link
Reference in a new issue