fix: Corrected URI handling in SQLDatabaseComponent (#3291)

This commit is contained in:
Carlos Coelho 2024-08-12 19:45:13 -03:00 committed by GitHub
commit dc419d11fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,6 +1,7 @@
from langchain_experimental.sql.base import SQLDatabase
from langchain_community.utilities.sql_database import SQLDatabase
from langflow.custom import CustomComponent
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
class SQLDatabaseComponent(CustomComponent):
@ -14,10 +15,12 @@ class SQLDatabaseComponent(CustomComponent):
}
def clean_up_uri(self, uri: str) -> str:
if uri.startswith("postgresql://"):
uri = uri.replace("postgresql://", "postgres://")
if uri.startswith("postgres://"):
uri = uri.replace("postgres://", "postgresql://")
return uri.strip()
def build(self, uri: str) -> SQLDatabase:
uri = self.clean_up_uri(uri)
return SQLDatabase.from_uri(uri)
# Create an engine using SQLAlchemy with StaticPool
engine = create_engine(uri, poolclass=StaticPool)
return SQLDatabase(engine)