diff --git a/lib/activerecord-multi-tenant/model_extensions.rb b/lib/activerecord-multi-tenant/model_extensions.rb index 0ec84f9..2056651 100644 --- a/lib/activerecord-multi-tenant/model_extensions.rb +++ b/lib/activerecord-multi-tenant/model_extensions.rb @@ -67,7 +67,7 @@ def inherited(subclass) partition_key = @partition_key # Create an implicit belongs_to association only if tenant class exists - if MultiTenant.tenant_klass_defined?(tenant_name) + if MultiTenant.tenant_klass_defined?(tenant_name, options) belongs_to tenant_name, **options.slice(:class_name, :inverse_of, :optional) .merge(foreign_key: options[:partition_key]) end @@ -103,7 +103,7 @@ def inherited(subclass) tenant_id end - if MultiTenant.tenant_klass_defined?(tenant_name) + if MultiTenant.tenant_klass_defined?(tenant_name, options) define_method "#{tenant_name}=" do |model| super(model) if send("#{partition_key}_changed?") && persisted? && !send("#{partition_key}_was").nil? diff --git a/lib/activerecord-multi-tenant/multi_tenant.rb b/lib/activerecord-multi-tenant/multi_tenant.rb index b2edb59..b541e33 100644 --- a/lib/activerecord-multi-tenant/multi_tenant.rb +++ b/lib/activerecord-multi-tenant/multi_tenant.rb @@ -5,8 +5,13 @@ class Current < ::ActiveSupport::CurrentAttributes attribute :tenant end - def self.tenant_klass_defined?(tenant_name) - !!tenant_name.to_s.classify.safe_constantize + def self.tenant_klass_defined?(tenant_name, options = {}) + class_name = if options[:class_name].present? + options[:class_name] + else + tenant_name.to_s.classify + end + !!class_name.safe_constantize end def self.partition_key(tenant_name) diff --git a/spec/activerecord-multi-tenant/multi_tenant_spec.rb b/spec/activerecord-multi-tenant/multi_tenant_spec.rb index 0cae722..cedd0bb 100644 --- a/spec/activerecord-multi-tenant/multi_tenant_spec.rb +++ b/spec/activerecord-multi-tenant/multi_tenant_spec.rb @@ -64,4 +64,58 @@ end end end + + describe '.tenant_klass_defined?' do + context 'without options' do + before(:all) do + class SampleTenant < ActiveRecord::Base + multi_tenant :sample_tenant + end + end + + it 'return true with valid tenant_name' do + expect(MultiTenant.tenant_klass_defined?(:sample_tenant)).to eq(true) + end + + it 'return false with invalid_tenant_name' do + invalid_tenant_name = :tenant + expect(MultiTenant.tenant_klass_defined?(invalid_tenant_name)).to eq(false) + end + end + + context 'with options' do + context 'and valid class_name' do + it 'return true' do + class SampleTenant < ActiveRecord::Base + multi_tenant :tenant + end + + tenant_name = :tenant + options = { + class_name: 'SampleTenant' + } + expect(MultiTenant.tenant_klass_defined?(tenant_name, options)).to eq(true) + end + + it 'return true when tenant class is nested' do + module SampleModule + class SampleNestedTenant < ActiveRecord::Base + multi_tenant :tenant + end + # rubocop:disable Layout/TrailingWhitespace + # Trailing whitespace is intentionally left here + + class AnotherTenant < ActiveRecord::Base + end + # rubocop:enable Layout/TrailingWhitespace + end + tenant_name = :tenant + options = { + class_name: 'SampleModule::SampleNestedTenant' + } + expect(MultiTenant.tenant_klass_defined?(tenant_name, options)).to eq(true) + end + end + end + end end