diff --git a/src/backend/base/langflow/components/vectorstores/pgvector.py b/src/backend/base/langflow/components/vectorstores/pgvector.py index 5b0825d82..0d2be230d 100644 --- a/src/backend/base/langflow/components/vectorstores/pgvector.py +++ b/src/backend/base/langflow/components/vectorstores/pgvector.py @@ -6,6 +6,7 @@ from langflow.base.vectorstores.model import LCVectorStoreComponent from langflow.helpers.data import docs_to_data from langflow.io import HandleInput, IntInput, StrInput, SecretStrInput, DataInput, MultilineInput from langflow.schema import Data +from langflow.utils.connection_string_parser import transform_connection_string class PGVectorStoreComponent(LCVectorStoreComponent): @@ -46,18 +47,20 @@ class PGVectorStoreComponent(LCVectorStoreComponent): else: documents.append(_input) + connection_string_parsed = transform_connection_string(self.pg_server_url) + if documents: pgvector = PGVector.from_documents( embedding=self.embedding, documents=documents, collection_name=self.collection_name, - connection_string=self.pg_server_url, + connection_string=connection_string_parsed, ) else: pgvector = PGVector.from_existing_index( embedding=self.embedding, collection_name=self.collection_name, - connection_string=self.pg_server_url, + connection_string=connection_string_parsed, ) return pgvector diff --git a/src/backend/base/langflow/utils/connection_string_parser.py b/src/backend/base/langflow/utils/connection_string_parser.py new file mode 100644 index 000000000..f83bdb9ab --- /dev/null +++ b/src/backend/base/langflow/utils/connection_string_parser.py @@ -0,0 +1,11 @@ +from urllib.parse import quote + + +def transform_connection_string(connection_string): + db_url_name = connection_string.split("@")[-1] + password_url = connection_string.split(":")[-1] + password_string = password_url.replace(f"@{db_url_name}", "") + encoded_password = quote(password_string) + protocol_user = connection_string.split(":")[:-1] + transformed_connection_string = f'{":".join(protocol_user)}:{encoded_password}@{db_url_name}' + return transformed_connection_string diff --git a/src/backend/tests/unit/utils/test_connection_string_parser.py b/src/backend/tests/unit/utils/test_connection_string_parser.py new file mode 100644 index 000000000..1ab82279e --- /dev/null +++ b/src/backend/tests/unit/utils/test_connection_string_parser.py @@ -0,0 +1,25 @@ +import pytest +from langflow.utils.connection_string_parser import transform_connection_string + + +@pytest.fixture +def client(): + pass + + +@pytest.mark.parametrize( + "connection_string, expected", + [ + ("protocol:user:password@host", "protocol:user:password@host"), + ("protocol:user@host", "protocol:user@host"), + ("protocol:user:pass@word@host", "protocol:user:pass%40word@host"), + ("protocol:user:pa:ss:word@host", "protocol:user:pa:ss:word@host"), + ("user:password@host", "user:password@host"), + ("protocol::password@host", "protocol::password@host"), + ("protocol:user:password@", "protocol:user:password@"), + ("protocol:user:pa@ss@word@host", "protocol:user:pa%40ss%40word@host"), + ], +) +def test_transform_connection_string(connection_string, expected): + result = transform_connection_string(connection_string) + assert result == expected