diff --git a/django_multitenant/utils.py b/django_multitenant/utils.py index 6ba0947..c2e4934 100644 --- a/django_multitenant/utils.py +++ b/django_multitenant/utils.py @@ -1,4 +1,5 @@ import inspect +from functools import lru_cache from django.apps import apps from .settings import TENANT_USE_ASGIREF @@ -57,20 +58,32 @@ def get_tenant_column(model_class_or_instance): ) from not_a_tenant_model -def get_tenant_field(model_class_or_instance): +@lru_cache(None) +def get_field_matching_column(model_class, column): """ - Gets the tenant field object from the model + Gets a field object from the model class, matching the column. """ - tenant_column = get_tenant_column(model_class_or_instance) - all_fields = model_class_or_instance._meta.fields + all_fields = model_class._meta.fields try: - return next(field for field in all_fields if field.column == tenant_column) + return next(field for field in all_fields if field.column == column) except StopIteration as no_field_found: raise ValueError( - f'No field found in {type(model_class_or_instance).name} with column name "{tenant_column}"' + f'No field found in {model_class.__name__} with column name "{column}"' ) from no_field_found +def get_tenant_field(model_class_or_instance): + """ + Gets the tenant field object from the model + """ + tenant_column = get_tenant_column(model_class_or_instance) + + if not inspect.isclass(model_class_or_instance): + model_class_or_instance = model_class_or_instance.__class__ + + return get_field_matching_column(model_class_or_instance, tenant_column) + + def get_object_tenant(instance): """ Gets the tenant value from the object. If the object itself is a tenant, it will return the same object