From 5dc3ff2b1ff5a5d0076ae6e8d043dadcfacdf3c8 Mon Sep 17 00:00:00 2001 From: Cristhian Zanforlin Lousa <72977554+Cristhianzl@users.noreply.github.com> Date: Mon, 5 Aug 2024 09:15:33 -0300 Subject: [PATCH] bugfix: parse password on db connection string when it has @ on It (#3173) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ (test_connection_string.py): add unit test for the transform_connection_string function to ensure correct transformation of connection strings * 📝 (pgvector.py): add support for parsing and transforming connection string in PGVectorStoreComponent to improve security and maintainability 📝 (connection_string_parser.py): create utility function to transform connection string by encoding password to improve security 📝 (test_connection_string_parser.py): add unit tests for transform_connection_string function to ensure correct transformation of connection string * test: fix import --------- Co-authored-by: italojohnny --- .../components/vectorstores/pgvector.py | 7 ++++-- .../utils/connection_string_parser.py | 11 ++++++++ .../utils/test_connection_string_parser.py | 25 +++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 src/backend/base/langflow/utils/connection_string_parser.py create mode 100644 src/backend/tests/unit/utils/test_connection_string_parser.py diff --git a/src/backend/base/langflow/components/vectorstores/pgvector.py b/src/backend/base/langflow/components/vectorstores/pgvector.py index 5b0825d82..0d2be230d 100644 --- a/src/backend/base/langflow/components/vectorstores/pgvector.py +++ b/src/backend/base/langflow/components/vectorstores/pgvector.py @@ -6,6 +6,7 @@ from langflow.base.vectorstores.model import LCVectorStoreComponent from langflow.helpers.data import docs_to_data from langflow.io import HandleInput, IntInput, StrInput, SecretStrInput, DataInput, MultilineInput from langflow.schema import Data +from langflow.utils.connection_string_parser import transform_connection_string class PGVectorStoreComponent(LCVectorStoreComponent): @@ -46,18 +47,20 @@ class PGVectorStoreComponent(LCVectorStoreComponent): else: documents.append(_input) + connection_string_parsed = transform_connection_string(self.pg_server_url) + if documents: pgvector = PGVector.from_documents( embedding=self.embedding, documents=documents, collection_name=self.collection_name, - connection_string=self.pg_server_url, + connection_string=connection_string_parsed, ) else: pgvector = PGVector.from_existing_index( embedding=self.embedding, collection_name=self.collection_name, - connection_string=self.pg_server_url, + connection_string=connection_string_parsed, ) return pgvector diff --git a/src/backend/base/langflow/utils/connection_string_parser.py b/src/backend/base/langflow/utils/connection_string_parser.py new file mode 100644 index 000000000..f83bdb9ab --- /dev/null +++ b/src/backend/base/langflow/utils/connection_string_parser.py @@ -0,0 +1,11 @@ +from urllib.parse import quote + + +def transform_connection_string(connection_string): + db_url_name = connection_string.split("@")[-1] + password_url = connection_string.split(":")[-1] + password_string = password_url.replace(f"@{db_url_name}", "") + encoded_password = quote(password_string) + protocol_user = connection_string.split(":")[:-1] + transformed_connection_string = f'{":".join(protocol_user)}:{encoded_password}@{db_url_name}' + return transformed_connection_string diff --git a/src/backend/tests/unit/utils/test_connection_string_parser.py b/src/backend/tests/unit/utils/test_connection_string_parser.py new file mode 100644 index 000000000..1ab82279e --- /dev/null +++ b/src/backend/tests/unit/utils/test_connection_string_parser.py @@ -0,0 +1,25 @@ +import pytest +from langflow.utils.connection_string_parser import transform_connection_string + + +@pytest.fixture +def client(): + pass + + +@pytest.mark.parametrize( + "connection_string, expected", + [ + ("protocol:user:password@host", "protocol:user:password@host"), + ("protocol:user@host", "protocol:user@host"), + ("protocol:user:pass@word@host", "protocol:user:pass%40word@host"), + ("protocol:user:pa:ss:word@host", "protocol:user:pa:ss:word@host"), + ("user:password@host", "user:password@host"), + ("protocol::password@host", "protocol::password@host"), + ("protocol:user:password@", "protocol:user:password@"), + ("protocol:user:pa@ss@word@host", "protocol:user:pa%40ss%40word@host"), + ], +) +def test_transform_connection_string(connection_string, expected): + result = transform_connection_string(connection_string) + assert result == expected