From d90de4ae8547e6725a1ec7bf4914be55d1fe32de Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Wed, 22 Jan 2020 17:58:55 -0500 Subject: [PATCH] Fix N+1 problem for one-to-many and many-to-many relationships (#254) This optimization batches what used to be multiple SQL statements into a single SQL statement. For now, you'll have to enable the optimization via the `SQLAlchemyObjectType.Meta.connection_field_factory` (see `test_batching.py`). --- .gitignore | 2 + graphene_sqlalchemy/__init__.py | 2 +- graphene_sqlalchemy/batching.py | 69 ++++ graphene_sqlalchemy/fields.py | 38 ++- graphene_sqlalchemy/tests/models.py | 2 +- graphene_sqlalchemy/tests/test_batching.py | 356 +++++++++++++++++++-- graphene_sqlalchemy/tests/test_fields.py | 40 ++- graphene_sqlalchemy/types.py | 69 +--- 8 files changed, 458 insertions(+), 120 deletions(-) create mode 100644 graphene_sqlalchemy/batching.py diff --git a/.gitignore b/.gitignore index d4f71e35..a97b8c21 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ var/ *.egg-info/ .installed.cfg *.egg +.python-version # PyInstaller # Usually these files are written by a python script from a template @@ -47,6 +48,7 @@ nosetests.xml coverage.xml *,cover .pytest_cache/ +.benchmarks/ # Translations *.mo diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 9ed4b0f6..ba71f614 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .fields import SQLAlchemyConnectionField from .utils import get_query, get_session -__version__ = "2.2.2" +__version__ = "2.3.0.dev0" __all__ = [ "__version__", diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py new file mode 100644 index 00000000..0665248f --- /dev/null +++ b/graphene_sqlalchemy/batching.py @@ -0,0 +1,69 @@ +import sqlalchemy +from promise import dataloader, promise +from sqlalchemy.orm import Session, strategies +from sqlalchemy.orm.query import QueryContext + + +def get_batch_resolver(relationship_prop): + class RelationshipLoader(dataloader.DataLoader): + cache = False + + def batch_load_fn(self, parents): # pylint: disable=method-hidden + """ + Batch loads the relationships of all the parents as one SQL statement. + + There is no way to do this out-of-the-box with SQLAlchemy but + we can piggyback on some internal APIs of the `selectin` + eager loading strategy. It's a bit hacky but it's preferable + than re-implementing and maintainnig a big chunk of the `selectin` + loader logic ourselves. + + The approach here is to build a regular query that + selects the parent and `selectin` load the relationship. + But instead of having the query emits 2 `SELECT` statements + when callling `all()`, we skip the first `SELECT` statement + and jump right before the `selectin` loader is called. + To accomplish this, we have to construct objects that are + normally built in the first part of the query in order + to call directly `SelectInLoader._load_for_path`. + + TODO Move this logic to a util in the SQLAlchemy repo as per + SQLAlchemy's main maitainer suggestion. + See https://git.io/JewQ7 + """ + child_mapper = relationship_prop.mapper + parent_mapper = relationship_prop.parent + session = Session.object_session(parents[0]) + + # These issues are very unlikely to happen in practice... + for parent in parents: + # assert parent.__mapper__ is parent_mapper + # All instances must share the same session + assert session is Session.object_session(parent) + # The behavior of `selectin` is undefined if the parent is dirty + assert parent not in session.dirty + + loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) + + # Should the boolean be set to False? Does it matter for our purposes? + states = [(sqlalchemy.inspect(parent), True) for parent in parents] + + # For our purposes, the query_context will only used to get the session + query_context = QueryContext(session.query(parent_mapper.entity)) + + loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + ) + + return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents]) + + loader = RelationshipLoader() + + def resolve(root, info, **args): + return loader.load(root) + + return resolve diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index a9f514ba..840204ae 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -9,6 +9,7 @@ from graphene.relay.connection import PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice +from .batching import get_batch_resolver from .utils import get_query @@ -33,14 +34,8 @@ def model(self): return self.type._meta.node._meta.model @classmethod - def get_query(cls, model, info, sort=None, **args): - query = get_query(model, info.context) - if sort is not None: - if isinstance(sort, six.string_types): - query = query.order_by(sort.value) - else: - query = query.order_by(*(col.value for col in sort)) - return query + def get_query(cls, model, info, **args): + return get_query(model, info.context) @classmethod def resolve_connection(cls, connection_type, model, info, args, resolved): @@ -78,6 +73,7 @@ def get_resolver(self, parent_resolver): return partial(self.connection_resolver, parent_resolver, self.type, self.model) +# TODO Rename this to SortableSQLAlchemyConnectionField class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): def __init__(self, type, *args, **kwargs): if "sort" not in kwargs and issubclass(type, Connection): @@ -95,6 +91,32 @@ def __init__(self, type, *args, **kwargs): del kwargs["sort"] super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) + @classmethod + def get_query(cls, model, info, sort=None, **args): + query = get_query(model, info.context) + if sort is not None: + if isinstance(sort, six.string_types): + query = query.order_by(sort.value) + else: + query = query.order_by(*(col.value for col in sort)) + return query + + +class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): + """ + This is currently experimental. + The API and behavior may change in future versions. + Use at your own risk. + """ + def get_resolver(self, parent_resolver): + return partial(self.connection_resolver, self.resolver, self.type, self.model) + + @classmethod + def from_relationship(cls, relationship, registry, **field_kwargs): + model = relationship.mapper.entity + model_type = registry.get_type_for_model(model) + return cls(model_type._meta.connection, resolver=get_batch_resolver(relationship), **field_kwargs) + def default_connection_field_factory(relationship, registry, **field_kwargs): model = relationship.mapper.entity diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 1df28333..88e992b9 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -61,7 +61,7 @@ class Reporter(Base): last_name = Column(String(30), doc="Last name") email = Column(String(), doc="Email") favorite_pet_kind = Column(PetKind) - pets = relationship("Pet", secondary=association_table, backref="reporters") + pets = relationship("Pet", secondary=association_table, backref="reporters", order_by="Pet.id") articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 0881f71e..77681069 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -5,9 +5,11 @@ import pytest import graphene +from graphene import relay +from ..fields import BatchSQLAlchemyConnectionField from ..types import SQLAlchemyObjectType -from .models import Article, Reporter +from .models import Article, HairKind, Pet, Reporter from .utils import to_std_dicts @@ -37,46 +39,34 @@ def mock_sqlalchemy_logging_handler(): sql_logger.setLevel(previous_level) -def make_fixture(session): - reporter_1 = Reporter( - first_name='Reporter_1', - ) - session.add(reporter_1) - reporter_2 = Reporter( - first_name='Reporter_2', - ) - session.add(reporter_2) - - article_1 = Article(headline='Article_1') - article_1.reporter = reporter_1 - session.add(article_1) - - article_2 = Article(headline='Article_2') - article_2.reporter = reporter_2 - session.add(article_2) - - session.commit() - session.close() - - -def get_schema(session): +def get_schema(): class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship class ArticleType(SQLAlchemyObjectType): class Meta: model = Article + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship class Query(graphene.ObjectType): articles = graphene.Field(graphene.List(ArticleType)) reporters = graphene.Field(graphene.List(ReporterType)) - def resolve_articles(self, _info): - return session.query(Article).all() + def resolve_articles(self, info): + return info.context.get('session').query(Article).all() - def resolve_reporters(self, _info): - return session.query(Reporter).all() + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() return graphene.Schema(query=Query) @@ -91,8 +81,28 @@ def is_sqlalchemy_version_less_than(version_string): def test_many_to_one(session_factory): session = session_factory() - make_fixture(session) - schema = get_schema(session) + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + schema = get_schema() with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level @@ -115,6 +125,8 @@ def test_many_to_one(session_factory): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN reporters' in message] + assert len(sql_statements) == 1 return assert messages == [ @@ -160,8 +172,28 @@ def test_many_to_one(session_factory): def test_one_to_one(session_factory): session = session_factory() - make_fixture(session) - schema = get_schema(session) + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + schema = get_schema() with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level @@ -184,6 +216,8 @@ def test_one_to_one(session_factory): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + assert len(sql_statements) == 1 return assert messages == [ @@ -226,3 +260,261 @@ def test_one_to_one(session_factory): }, ], } + + +def test_one_to_many(session_factory): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_1 + session.add(article_2) + + article_3 = Article(headline='Article_3') + article_3.reporter = reporter_2 + session.add(article_3) + + article_4 = Article(headline='Article_4') + article_4.reporter = reporter_2 + session.add(article_4) + + session.commit() + session.close() + + schema = get_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + reporters { + firstName + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + assert len(sql_statements) == 1 + return + + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.id AS reporters_id, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters', + '()', + + 'SELECT articles.reporter_id AS articles_reporter_id, ' + 'articles.id AS articles_id, ' + 'articles.headline AS articles_headline, ' + 'articles.pub_date AS articles_pub_date \n' + 'FROM articles \n' + 'WHERE articles.reporter_id IN (?, ?) ' + 'ORDER BY articles.reporter_id', + '(1, 2)' + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_1", + }, + }, + { + "node": { + "headline": "Article_2", + }, + }, + ], + }, + }, + { + "firstName": "Reporter_2", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_3", + }, + }, + { + "node": { + "headline": "Article_4", + }, + }, + ], + }, + }, + ], + } + + +def test_many_to_many(session_factory): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_1) + + pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_2) + + reporter_1.pets.append(pet_1) + reporter_1.pets.append(pet_2) + + pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_3) + + pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_4) + + reporter_2.pets.append(pet_3) + reporter_2.pets.append(pet_4) + + session.commit() + session.close() + + schema = get_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + reporters { + firstName + pets(first: 2) { + edges { + node { + name + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN pets' in message] + assert len(sql_statements) == 1 + return + + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.id AS reporters_id, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters', + '()', + + 'SELECT reporters_1.id AS reporters_1_id, ' + 'pets.id AS pets_id, ' + 'pets.name AS pets_name, ' + 'pets.pet_kind AS pets_pet_kind, ' + 'pets.hair_kind AS pets_hair_kind, ' + 'pets.reporter_id AS pets_reporter_id \n' + 'FROM reporters AS reporters_1 ' + 'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id ' + 'JOIN pets ON pets.id = association_1.pet_id \n' + 'WHERE reporters_1.id IN (?, ?) ' + 'ORDER BY reporters_1.id, pets.id', + '(1, 2)' + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_1", + }, + }, + { + "node": { + "name": "Pet_2", + }, + }, + ], + }, + }, + { + "firstName": "Reporter_2", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_3", + }, + }, + { + "node": { + "name": "Pet_4", + }, + }, + ], + }, + }, + ], + } diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 875b729d..557ff114 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -1,9 +1,11 @@ import pytest from promise import Promise -from graphene.relay import Connection +from graphene import ObjectType +from graphene.relay import Connection, Node -from ..fields import SQLAlchemyConnectionField +from ..fields import (SQLAlchemyConnectionField, + UnsortedSQLAlchemyConnectionField) from ..types import SQLAlchemyObjectType from .models import Editor as EditorModel from .models import Pet as PetModel @@ -12,44 +14,58 @@ class Pet(SQLAlchemyObjectType): class Meta: model = PetModel + interfaces = (Node,) class Editor(SQLAlchemyObjectType): class Meta: model = EditorModel - -class PetConnection(Connection): - class Meta: - node = Pet +## +# SQLAlchemyConnectionField +## def test_promise_connection_resolver(): def resolver(_obj, _info): return Promise.resolve([]) - result = SQLAlchemyConnectionField.connection_resolver( - resolver, PetConnection, Pet, None, None + result = UnsortedSQLAlchemyConnectionField.connection_resolver( + resolver, Pet._meta.connection, Pet, None, None ) assert isinstance(result, Promise) +def test_type_assert_sqlalchemy_object_type(): + with pytest.raises(AssertionError, match="only accepts SQLAlchemyObjectType"): + SQLAlchemyConnectionField(ObjectType).type + + +def test_type_assert_object_has_connection(): + with pytest.raises(AssertionError, match="doesn't have a connection"): + SQLAlchemyConnectionField(Editor).type + +## +# UnsortedSQLAlchemyConnectionField +## + + def test_sort_added_by_default(): - field = SQLAlchemyConnectionField(PetConnection) + field = SQLAlchemyConnectionField(Pet._meta.connection) assert "sort" in field.args assert field.args["sort"] == Pet.sort_argument() def test_sort_can_be_removed(): - field = SQLAlchemyConnectionField(PetConnection, sort=None) + field = SQLAlchemyConnectionField(Pet._meta.connection, sort=None) assert "sort" not in field.args def test_custom_sort(): - field = SQLAlchemyConnectionField(PetConnection, sort=Editor.sort_argument()) + field = SQLAlchemyConnectionField(Pet._meta.connection, sort=Editor.sort_argument()) assert field.args["sort"] == Editor.sort_argument() -def test_init_raises(): +def test_sort_init_raises(): with pytest.raises(TypeError, match="Cannot create sort"): SQLAlchemyConnectionField(Connection) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 23c8288e..2ed5110e 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,12 +1,10 @@ from collections import OrderedDict import sqlalchemy -from promise import dataloader, promise from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import (ColumnProperty, CompositeProperty, - RelationshipProperty, Session, strategies) + RelationshipProperty, strategies) from sqlalchemy.orm.exc import NoResultFound -from sqlalchemy.orm.query import QueryContext from graphene import Field from graphene.relay import Connection, Node @@ -15,6 +13,7 @@ from graphene.utils.get_unbound_function import get_unbound_function from graphene.utils.orderedtype import OrderedType +from .batching import get_batch_resolver from .converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, @@ -220,73 +219,11 @@ def _get_relationship_resolver(obj_type, relationship_prop, model_attr): :param str model_attr: the name of the SQLAlchemy attribute :rtype: Callable """ - child_mapper = relationship_prop.mapper - parent_mapper = relationship_prop.parent - if not getattr(strategies, 'SelectInLoader', None) or relationship_prop.uselist: # TODO Batch many-to-many and one-to-many relationships return _get_attr_resolver(obj_type, model_attr, model_attr) - class NonListRelationshipLoader(dataloader.DataLoader): - cache = False - - def batch_load_fn(self, parents): # pylint: disable=method-hidden - """ - Batch loads the relationship of all the parents as one SQL statement. - - There is no way to do this out-of-the-box with SQLAlchemy but - we can piggyback on some internal APIs of the `selectin` - eager loading strategy. It's a bit hacky but it's preferable - than re-implementing and maintainnig a big chunk of the `selectin` - loader logic ourselves. - - The approach here is to build a regular query that - selects the parent and `selectin` load the relationship. - But instead of having the query emits 2 `SELECT` statements - when callling `all()`, we skip the first `SELECT` statement - and jump right before the `selectin` loader is called. - To accomplish this, we have to construct objects that are - normally built in the first part of the query in order - to call directly `SelectInLoader._load_for_path`. - - TODO Move this logic to a util in the SQLAlchemy repo as per - SQLAlchemy's main maitainer suggestion. - See https://git.io/JewQ7 - """ - session = Session.object_session(parents[0]) - - # These issues are very unlikely to happen in practice... - for parent in parents: - assert parent.__mapper__ is parent_mapper - # All instances must share the same session - assert session is Session.object_session(parent) - # The behavior of `selectin` is undefined if the parent is dirty - assert parent not in session.dirty - - loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) - - # Should the boolean be set to False? Does it matter for our purposes? - states = [(sqlalchemy.inspect(parent), True) for parent in parents] - - # For our purposes, the query_context will only used to get the session - query_context = QueryContext(session.query(parent_mapper.entity)) - - loader._load_for_path( - query_context, - parent_mapper._path_registry, - states, - None, - child_mapper, - ) - - return promise.Promise.resolve([getattr(parent, model_attr) for parent in parents]) - - loader = NonListRelationshipLoader() - - def resolve(root, info): - return loader.load(root) - - return resolve + return get_batch_resolver(relationship_prop) def _get_attr_resolver(obj_type, orm_field_name, model_attr):