bugfix: parse password on db connection string when it has @ on It (#3173)

*  (test_connection_string.py): add unit test for the transform_connection_string function to ensure correct transformation of connection strings

* 📝 (pgvector.py): add support for parsing and transforming connection string in PGVectorStoreComponent to improve security and maintainability
📝 (connection_string_parser.py): create utility function to transform connection string by encoding password to improve security
📝 (test_connection_string_parser.py): add unit tests for transform_connection_string function to ensure correct transformation of connection string

* test: fix import

---------

Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
Cristhian Zanforlin Lousa 2024-08-05 09:15:33 -03:00 committed by GitHub
commit 5dc3ff2b1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 41 additions and 2 deletions

View file

@ -6,6 +6,7 @@ from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.helpers.data import docs_to_data
from langflow.io import HandleInput, IntInput, StrInput, SecretStrInput, DataInput, MultilineInput
from langflow.schema import Data
from langflow.utils.connection_string_parser import transform_connection_string
class PGVectorStoreComponent(LCVectorStoreComponent):
@ -46,18 +47,20 @@ class PGVectorStoreComponent(LCVectorStoreComponent):
else:
documents.append(_input)
connection_string_parsed = transform_connection_string(self.pg_server_url)
if documents:
pgvector = PGVector.from_documents(
embedding=self.embedding,
documents=documents,
collection_name=self.collection_name,
connection_string=self.pg_server_url,
connection_string=connection_string_parsed,
)
else:
pgvector = PGVector.from_existing_index(
embedding=self.embedding,
collection_name=self.collection_name,
connection_string=self.pg_server_url,
connection_string=connection_string_parsed,
)
return pgvector

View file

@ -0,0 +1,11 @@
from urllib.parse import quote
def transform_connection_string(connection_string):
db_url_name = connection_string.split("@")[-1]
password_url = connection_string.split(":")[-1]
password_string = password_url.replace(f"@{db_url_name}", "")
encoded_password = quote(password_string)
protocol_user = connection_string.split(":")[:-1]
transformed_connection_string = f'{":".join(protocol_user)}:{encoded_password}@{db_url_name}'
return transformed_connection_string

View file

@ -0,0 +1,25 @@
import pytest
from langflow.utils.connection_string_parser import transform_connection_string
@pytest.fixture
def client():
pass
@pytest.mark.parametrize(
"connection_string, expected",
[
("protocol:user:password@host", "protocol:user:password@host"),
("protocol:user@host", "protocol:user@host"),
("protocol:user:pass@word@host", "protocol:user:pass%40word@host"),
("protocol:user:pa:ss:word@host", "protocol:user:pa:ss:word@host"),
("user:password@host", "user:password@host"),
("protocol::password@host", "protocol::password@host"),
("protocol:user:password@", "protocol:user:password@"),
("protocol:user:pa@ss@word@host", "protocol:user:pa%40ss%40word@host"),
],
)
def test_transform_connection_string(connection_string, expected):
result = transform_connection_string(connection_string)
assert result == expected