Skip to content

Commit

Permalink
Fix N+1 problem for one-to-many and many-to-many relationships (#254)
Browse files Browse the repository at this point in the history
This optimization batches what used to be multiple SQL statements into a single SQL statement. For now, you'll have to enable the optimization via the `SQLAlchemyObjectType.Meta.connection_field_factory` (see `test_batching.py`).
  • Loading branch information
jnak authored Jan 22, 2020
1 parent 98e6fe7 commit d90de4a
Show file tree
Hide file tree
Showing 8 changed files with 458 additions and 120 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ var/
*.egg-info/
.installed.cfg
*.egg
.python-version

# PyInstaller
# Usually these files are written by a python script from a template
Expand All @@ -47,6 +48,7 @@ nosetests.xml
coverage.xml
*,cover
.pytest_cache/
.benchmarks/

# Translations
*.mo
Expand Down
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .fields import SQLAlchemyConnectionField
from .utils import get_query, get_session

__version__ = "2.2.2"
__version__ = "2.3.0.dev0"

__all__ = [
"__version__",
Expand Down
69 changes: 69 additions & 0 deletions graphene_sqlalchemy/batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import sqlalchemy
from promise import dataloader, promise
from sqlalchemy.orm import Session, strategies
from sqlalchemy.orm.query import QueryContext


def get_batch_resolver(relationship_prop):
class RelationshipLoader(dataloader.DataLoader):
cache = False

def batch_load_fn(self, parents): # pylint: disable=method-hidden
"""
Batch loads the relationships of all the parents as one SQL statement.
There is no way to do this out-of-the-box with SQLAlchemy but
we can piggyback on some internal APIs of the `selectin`
eager loading strategy. It's a bit hacky but it's preferable
than re-implementing and maintainnig a big chunk of the `selectin`
loader logic ourselves.
The approach here is to build a regular query that
selects the parent and `selectin` load the relationship.
But instead of having the query emits 2 `SELECT` statements
when callling `all()`, we skip the first `SELECT` statement
and jump right before the `selectin` loader is called.
To accomplish this, we have to construct objects that are
normally built in the first part of the query in order
to call directly `SelectInLoader._load_for_path`.
TODO Move this logic to a util in the SQLAlchemy repo as per
SQLAlchemy's main maitainer suggestion.
See https://git.io/JewQ7
"""
child_mapper = relationship_prop.mapper
parent_mapper = relationship_prop.parent
session = Session.object_session(parents[0])

# These issues are very unlikely to happen in practice...
for parent in parents:
# assert parent.__mapper__ is parent_mapper
# All instances must share the same session
assert session is Session.object_session(parent)
# The behavior of `selectin` is undefined if the parent is dirty
assert parent not in session.dirty

loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),))

# Should the boolean be set to False? Does it matter for our purposes?
states = [(sqlalchemy.inspect(parent), True) for parent in parents]

# For our purposes, the query_context will only used to get the session
query_context = QueryContext(session.query(parent_mapper.entity))

loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
)

return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents])

loader = RelationshipLoader()

def resolve(root, info, **args):
return loader.load(root)

return resolve
38 changes: 30 additions & 8 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from graphene.relay.connection import PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice

from .batching import get_batch_resolver
from .utils import get_query


Expand All @@ -33,14 +34,8 @@ def model(self):
return self.type._meta.node._meta.model

@classmethod
def get_query(cls, model, info, sort=None, **args):
query = get_query(model, info.context)
if sort is not None:
if isinstance(sort, six.string_types):
query = query.order_by(sort.value)
else:
query = query.order_by(*(col.value for col in sort))
return query
def get_query(cls, model, info, **args):
return get_query(model, info.context)

@classmethod
def resolve_connection(cls, connection_type, model, info, args, resolved):
Expand Down Expand Up @@ -78,6 +73,7 @@ def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)


# TODO Rename this to SortableSQLAlchemyConnectionField
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
def __init__(self, type, *args, **kwargs):
if "sort" not in kwargs and issubclass(type, Connection):
Expand All @@ -95,6 +91,32 @@ def __init__(self, type, *args, **kwargs):
del kwargs["sort"]
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)

@classmethod
def get_query(cls, model, info, sort=None, **args):
query = get_query(model, info.context)
if sort is not None:
if isinstance(sort, six.string_types):
query = query.order_by(sort.value)
else:
query = query.order_by(*(col.value for col in sort))
return query


class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
"""
This is currently experimental.
The API and behavior may change in future versions.
Use at your own risk.
"""
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, self.resolver, self.type, self.model)

@classmethod
def from_relationship(cls, relationship, registry, **field_kwargs):
model = relationship.mapper.entity
model_type = registry.get_type_for_model(model)
return cls(model_type._meta.connection, resolver=get_batch_resolver(relationship), **field_kwargs)


def default_connection_field_factory(relationship, registry, **field_kwargs):
model = relationship.mapper.entity
Expand Down
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Reporter(Base):
last_name = Column(String(30), doc="Last name")
email = Column(String(), doc="Email")
favorite_pet_kind = Column(PetKind)
pets = relationship("Pet", secondary=association_table, backref="reporters")
pets = relationship("Pet", secondary=association_table, backref="reporters", order_by="Pet.id")
articles = relationship("Article", backref="reporter")
favorite_article = relationship("Article", uselist=False)

Expand Down
Loading

0 comments on commit d90de4a

Please sign in to comment.