From 4430dcd9bcd880cbc841d3f2715e7ef716a90902 Mon Sep 17 00:00:00 2001 From: Daniel Manson Date: Mon, 5 Aug 2024 10:44:19 +0100 Subject: [PATCH] Add connect_fn kwarg to pool --- asyncpg/_testbase/__init__.py | 3 ++- asyncpg/pool.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py index 7aca834f..83101cea 100644 --- a/asyncpg/_testbase/__init__.py +++ b/asyncpg/_testbase/__init__.py @@ -268,6 +268,7 @@ def create_pool(dsn=None, *, pool_class=pg_pool.Pool, connection_class=pg_connection.Connection, record_class=asyncpg.Record, + connect_fn=pg_connection.connect, **connect_kwargs): return pool_class( dsn, @@ -275,7 +276,7 @@ def create_pool(dsn=None, *, max_queries=max_queries, loop=loop, setup=setup, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, connection_class=connection_class, - record_class=record_class, + record_class=record_class, connect_fn=connect_fn, **connect_kwargs) diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 8a00d64b..1e9a0457 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -313,7 +313,7 @@ class Pool: __slots__ = ( '_queue', '_loop', '_minsize', '_maxsize', - '_init', '_connect_args', '_connect_kwargs', + '_init', '_connect_fn', '_connect_args', '_connect_kwargs', '_holders', '_initialized', '_initializing', '_closing', '_closed', '_connection_class', '_record_class', '_generation', '_setup', '_max_queries', '_max_inactive_connection_lifetime' @@ -329,6 +329,7 @@ def __init__(self, *connect_args, loop, connection_class, record_class, + connect_fn, **connect_kwargs): if len(connect_args) > 1: @@ -388,6 +389,7 @@ def __init__(self, *connect_args, self._init = init self._connect_args = connect_args self._connect_kwargs = connect_kwargs + self._connect_fn = connect_fn self._setup = setup self._max_queries = max_queries @@ -503,7 +505,7 @@ def set_connect_args(self, dsn=None, **connect_kwargs): self._connect_kwargs = connect_kwargs async def _get_new_connection(self): - con = await connection.connect( + con = await self._connect_fn( *self._connect_args, loop=self._loop, connection_class=self._connection_class, @@ -1097,6 +1099,10 @@ def create_pool(dsn=None, *, or :meth:`Connection.set_type_codec() <\ asyncpg.connection.Connection.set_type_codec>`. + :param coroutine connect_fn: + A coroutine with signature identical to :func:`~asyncpg.connection.connect`. This can be used to add custom + authentication or ssl logic when creating a connection, as is required by GCP's cloud-sql-python-connector. + :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. @@ -1127,7 +1133,7 @@ def create_pool(dsn=None, *, return Pool( dsn, connection_class=connection_class, - record_class=record_class, + record_class=record_class, connect_fn=connection.connect, min_size=min_size, max_size=max_size, max_queries=max_queries, loop=loop, setup=setup, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime,