Skip to content

Commit

Permalink
Providing an option for the plugins to use Spring DI (#5012)
Browse files Browse the repository at this point in the history
providing an option for the plugins to use Spring DI

Signed-off-by: Santhosh Gandhe <[email protected]>

Update data-prepper-api/src/main/java/org/opensearch/dataprepper/model/annotations/DataPrepperPlugin.java

modified the comment line based on the suggession

Co-authored-by: David Venable <[email protected]>
Signed-off-by: Santhosh Gandhe <[email protected]>

Integration test to validate the DI context enabling in plugins


Signed-off-by: Santhosh Gandhe <[email protected]>
Co-authored-by: David Venable <[email protected]>
  • Loading branch information
san81 and dlvenable authored Oct 4, 2024
1 parent 72e2f27 commit 472c912
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,16 @@
* @since 1.2
*/
Class<?> pluginConfigurationType() default PluginSetting.class;

/**
* Optional Packages to scan for Data Prepper DI components.
* Plugins provide this list if they want to use Dependency Injection in its module.
* Providing this value, implicitly assumes and initiates plugin specific isolated ApplicationContext.
* <p>
* The package names that spring context scans will be picked up by these marker classes.
*
* @return Array of classes to use for package scan.
* @since 2.2
*/
Class[] packagesToScan() default {};
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import org.opensearch.dataprepper.model.configuration.PipelinesDataFlowModel;
import org.opensearch.dataprepper.model.configuration.PluginSetting;
import org.opensearch.dataprepper.model.plugin.InvalidPluginConfigurationException;
import org.opensearch.dataprepper.model.source.Source;
import org.opensearch.dataprepper.plugins.TestObjectPlugin;
import org.opensearch.dataprepper.plugins.test.TestComponent;
import org.opensearch.dataprepper.plugins.test.TestDISource;
import org.opensearch.dataprepper.plugins.test.TestPlugin;
import org.opensearch.dataprepper.validation.LoggingPluginErrorsHandler;
import org.opensearch.dataprepper.validation.PluginErrorCollector;
Expand All @@ -30,6 +33,8 @@
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;

/**
Expand Down Expand Up @@ -96,6 +101,23 @@ void loadPlugin_should_return_a_new_plugin_instance_with_the_expected_configurat
assertThat(configuration.getOptionalString(), equalTo(optionalStringValue));
}

@Test
void loadPlugin_should_return_a_new_plugin_instance_with_DI_context_initialized() {

final Map<String, Object> pluginSettingMap = new HashMap<>();
final PluginSetting pluginSetting = new PluginSetting("test_di_source", pluginSettingMap);
pluginSetting.setPipelineName(pipelineName);

final Source sourcePlugin = createObjectUnderTest().loadPlugin(Source.class, pluginSetting);

assertThat(sourcePlugin, instanceOf(TestDISource.class));
TestDISource plugin = (TestDISource) sourcePlugin;
// Testing the auto wired been with the Dependency Injection
assertNotNull(plugin.getTestComponent());
assertInstanceOf(TestComponent.class, plugin.getTestComponent());
assertThat(plugin.getTestComponent().getIdentifier(), equalTo("test-component"));
}

@Test
void loadPlugin_should_return_a_new_plugin_instance_with_the_expected_configuration_variable_args() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.dataprepper.model.sink.SinkContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.context.annotation.DependsOn;

import javax.inject.Inject;
Expand Down Expand Up @@ -115,13 +116,16 @@ private <T> ComponentPluginArgumentsContext getConstructionContext(final PluginS
final PluginConfigObservable pluginConfigObservable = pluginConfigurationObservableFactory
.createDefaultPluginConfigObservable(pluginConfigurationConverter, pluginConfigurationType, pluginSetting);

Class[] markersToScan = pluginAnnotation.packagesToScan();
BeanFactory beanFactory = pluginBeanFactoryProvider.createPluginSpecificContext(markersToScan);

return new ComponentPluginArgumentsContext.Builder()
.withPluginSetting(pluginSetting)
.withPipelineDescription(pluginSetting)
.withPluginConfiguration(configuration)
.withPluginFactory(this)
.withSinkContext(sinkContext)
.withBeanFactory(pluginBeanFactoryProvider.get())
.withBeanFactory(beanFactory)
.withPluginConfigurationObservable(pluginConfigObservable)
.withTypeArgumentSuppliers(applicationContextToTypedSuppliers.getArgumentsSuppliers())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import org.springframework.beans.factory.BeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.support.GenericApplicationContext;

import javax.inject.Inject;
import javax.inject.Named;
import javax.inject.Provider;
import java.util.Arrays;
import java.util.Objects;

/**
Expand All @@ -25,7 +26,7 @@
* <p><i>publicContext</i> is the root {@link ApplicationContext}</p>
*/
@Named
class PluginBeanFactoryProvider implements Provider<BeanFactory> {
class PluginBeanFactoryProvider {
private final GenericApplicationContext sharedPluginApplicationContext;
private final GenericApplicationContext coreApplicationContext;

Expand Down Expand Up @@ -57,8 +58,17 @@ GenericApplicationContext getCoreApplicationContext() {
* instead, a new isolated {@link ApplicationContext} should be created.
* @return BeanFactory A BeanFactory that inherits from {@link PluginBeanFactoryProvider#sharedPluginApplicationContext}
*/
public BeanFactory get() {
final GenericApplicationContext isolatedPluginApplicationContext = new GenericApplicationContext(sharedPluginApplicationContext);
public BeanFactory createPluginSpecificContext(Class[] markersToScan) {
AnnotationConfigApplicationContext isolatedPluginApplicationContext = new AnnotationConfigApplicationContext();
if(markersToScan !=null && markersToScan.length>0) {
// If packages to scan is provided in this plugin annotation, which indicates
// that this plugin is interested in using Dependency Injection isolated for its module
Arrays.stream(markersToScan)
.map(Class::getPackageName)
.forEach(isolatedPluginApplicationContext::scan);
isolatedPluginApplicationContext.refresh();
}
isolatedPluginApplicationContext.setParent(sharedPluginApplicationContext);
return isolatedPluginApplicationContext.getBeanFactory();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import org.opensearch.dataprepper.model.plugin.NoPluginFoundException;
import org.opensearch.dataprepper.model.plugin.PluginConfigObservable;
import org.opensearch.dataprepper.model.sink.Sink;
import org.opensearch.dataprepper.model.source.Source;
import org.opensearch.dataprepper.plugins.test.TestDISource;
import org.opensearch.dataprepper.plugins.test.TestSink;
import org.springframework.beans.factory.BeanFactory;

Expand Down Expand Up @@ -192,6 +194,25 @@ void setUp() {
.willReturn(Optional.of(expectedPluginClass));
}

@Test
void loadPlugin_should_create_a_new_instance_of_the_plugin_with_di_initialized() {

final TestDISource expectedInstance = mock(TestDISource.class);
final Object convertedConfiguration = mock(Object.class);
given(pluginConfigurationConverter.convert(PluginSetting.class, pluginSetting))
.willReturn(convertedConfiguration);
given(firstPluginProvider.findPluginClass(Source.class, pluginName))
.willReturn(Optional.of(TestDISource.class));
given(pluginCreator.newPluginInstance(eq(TestDISource.class), any(ComponentPluginArgumentsContext.class), eq(pluginName)))
.willReturn(expectedInstance);

assertThat(createObjectUnderTest().loadPlugin(Source.class, pluginSetting),
equalTo(expectedInstance));
verify(pluginConfigurationObservableFactory).createDefaultPluginConfigObservable(eq(pluginConfigurationConverter),
eq(PluginSetting.class), eq(pluginSetting));
verify(beanFactoryProvider).createPluginSpecificContext(new Class[]{TestDISource.class});
}

@Test
void loadPlugin_should_create_a_new_instance_of_the_first_plugin_found() {

Expand All @@ -206,7 +227,7 @@ void loadPlugin_should_create_a_new_instance_of_the_first_plugin_found() {
equalTo(expectedInstance));
verify(pluginConfigurationObservableFactory).createDefaultPluginConfigObservable(eq(pluginConfigurationConverter),
eq(PluginSetting.class), eq(pluginSetting));
verify(beanFactoryProvider).get();
verify(beanFactoryProvider).createPluginSpecificContext(new Class[]{});
}

@Test
Expand Down Expand Up @@ -240,7 +261,7 @@ void loadPlugins_should_return_an_empty_list_when_the_number_of_instances_is_0()
assertThat(plugins, notNullValue());
assertThat(plugins.size(), equalTo(0));

verify(beanFactoryProvider).get();
verify(beanFactoryProvider).createPluginSpecificContext(new Class[]{});
verifyNoInteractions(pluginCreator);
}

Expand All @@ -256,7 +277,7 @@ void loadPlugins_should_return_a_single_instance_when_the_the_numberOfInstances_
final List<?> plugins = createObjectUnderTest().loadPlugins(
baseClass, pluginSetting, c -> 1);

verify(beanFactoryProvider).get();
verify(beanFactoryProvider).createPluginSpecificContext(new Class[]{});
verify(pluginConfigurationObservableFactory).createDefaultPluginConfigObservable(eq(pluginConfigurationConverter),
eq(PluginSetting.class), eq(pluginSetting));
final ArgumentCaptor<ComponentPluginArgumentsContext> pluginArgumentsContextArgCapture = ArgumentCaptor.forClass(ComponentPluginArgumentsContext.class);
Expand Down Expand Up @@ -285,7 +306,7 @@ void loadPlugin_with_varargs_should_return_a_single_instance_when_the_the_number

final Object plugin = createObjectUnderTest().loadPlugin(baseClass, pluginSetting, object);

verify(beanFactoryProvider).get();
verify(beanFactoryProvider).createPluginSpecificContext(new Class[]{});
verify(pluginConfigurationObservableFactory).createDefaultPluginConfigObservable(eq(pluginConfigurationConverter),
eq(PluginSetting.class), eq(pluginSetting));
final ArgumentCaptor<ComponentPluginArgumentsContext> pluginArgumentsContextArgCapture = ArgumentCaptor.forClass(ComponentPluginArgumentsContext.class);
Expand Down Expand Up @@ -320,7 +341,7 @@ void loadPlugins_should_return_an_instance_for_the_total_count() {
final List<?> plugins = createObjectUnderTest().loadPlugins(
baseClass, pluginSetting, c -> 3);

verify(beanFactoryProvider).get();
verify(beanFactoryProvider).createPluginSpecificContext(new Class[]{});
final ArgumentCaptor<ComponentPluginArgumentsContext> pluginArgumentsContextArgCapture = ArgumentCaptor.forClass(ComponentPluginArgumentsContext.class);
verify(pluginCreator, times(3)).newPluginInstance(eq(expectedPluginClass), pluginArgumentsContextArgCapture.capture(), eq(pluginName));
final List<ComponentPluginArgumentsContext> actualPluginArgumentsContextList = pluginArgumentsContextArgCapture.getAllValues();
Expand Down Expand Up @@ -356,7 +377,7 @@ void loadPlugins_should_return_a_single_instance_with_values_from_ApplicationCon
final List<?> plugins = createObjectUnderTest().loadPlugins(
baseClass, pluginSetting, c -> 1);

verify(beanFactoryProvider).get();
verify(beanFactoryProvider).createPluginSpecificContext(new Class[]{});
final ArgumentCaptor<ComponentPluginArgumentsContext> pluginArgumentsContextArgCapture = ArgumentCaptor.forClass(ComponentPluginArgumentsContext.class);
verify(pluginCreator).newPluginInstance(eq(expectedPluginClass), pluginArgumentsContextArgCapture.capture(), eq(pluginName));
final ComponentPluginArgumentsContext actualPluginArgumentsContext = pluginArgumentsContextArgCapture.getValue();
Expand Down Expand Up @@ -398,7 +419,7 @@ void loadPlugin_should_create_a_new_instance_of_the_first_plugin_found_with_corr

assertThat(createObjectUnderTest().loadPlugin(baseClass, pluginSetting), equalTo(expectedInstance));
MatcherAssert.assertThat(expectedInstance.getClass().getAnnotation(DataPrepperPlugin.class).deprecatedName(), equalTo(TEST_SINK_DEPRECATED_NAME));
verify(beanFactoryProvider).get();
verify(beanFactoryProvider).createPluginSpecificContext(new Class[]{});
}
}

Expand Down Expand Up @@ -427,7 +448,7 @@ void loadPlugin_should_create_a_new_instance_of_the_first_plugin_found_with_corr

assertThat(createObjectUnderTest().loadPlugin(baseClass, pluginSetting), equalTo(expectedInstance));
MatcherAssert.assertThat(expectedInstance.getClass().getAnnotation(DataPrepperPlugin.class).alternateNames(), equalTo(new String[]{TEST_SINK_ALTERNATE_NAME}));
verify(beanFactoryProvider).get();
verify(beanFactoryProvider).createPluginSpecificContext(new Class[]{});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.opensearch.dataprepper.plugins.test.TestComponent;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.support.GenericApplicationContext;

import static org.hamcrest.CoreMatchers.equalTo;
Expand All @@ -21,6 +23,7 @@
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class PluginBeanFactoryProviderTest {

Expand Down Expand Up @@ -48,14 +51,14 @@ void testPluginBeanFactoryProviderUsesParentContext() {
@Test
void testPluginBeanFactoryProviderRequiresContext() {
context = null;
assertThrows(NullPointerException.class, () -> createObjectUnderTest());
assertThrows(NullPointerException.class, this::createObjectUnderTest);
}

@Test
void testPluginBeanFactoryProviderRequiresParentContext() {
context = mock(GenericApplicationContext.class);

assertThrows(NullPointerException.class, () -> createObjectUnderTest());
assertThrows(NullPointerException.class, this::createObjectUnderTest);
}

@Test
Expand All @@ -65,16 +68,16 @@ void testPluginBeanFactoryProviderGetReturnsBeanFactory() {
final PluginBeanFactoryProvider beanFactoryProvider = createObjectUnderTest();

verify(context).getParent();
assertThat(beanFactoryProvider.get(), is(instanceOf(BeanFactory.class)));
assertThat(beanFactoryProvider.createPluginSpecificContext(new Class[]{}), is(instanceOf(BeanFactory.class)));
}

@Test
void testPluginBeanFactoryProviderGetReturnsUniqueBeanFactory() {
doReturn(context).when(context).getParent();

final PluginBeanFactoryProvider beanFactoryProvider = createObjectUnderTest();
final BeanFactory isolatedBeanFactoryA = beanFactoryProvider.get();
final BeanFactory isolatedBeanFactoryB = beanFactoryProvider.get();
final BeanFactory isolatedBeanFactoryA = beanFactoryProvider.createPluginSpecificContext(new Class[]{});
final BeanFactory isolatedBeanFactoryB = beanFactoryProvider.createPluginSpecificContext(new Class[]{});

verify(context).getParent();
assertThat(isolatedBeanFactoryA, not(sameInstance(isolatedBeanFactoryB)));
Expand All @@ -95,4 +98,22 @@ void getSharedPluginApplicationContext_called_multiple_times_returns_same_instan
final PluginBeanFactoryProvider objectUnderTest = createObjectUnderTest();
assertThat(objectUnderTest.getSharedPluginApplicationContext(), sameInstance(objectUnderTest.getSharedPluginApplicationContext()));
}

@Test
void testCreatePluginSpecificContext() {
when(context.getParent()).thenReturn(context);
final PluginBeanFactoryProvider objectUnderTest = createObjectUnderTest();
BeanFactory beanFactory = objectUnderTest.createPluginSpecificContext(new Class[]{TestComponent.class});
assertThat(beanFactory, notNullValue());
assertThat(beanFactory.getBean(TestComponent.class), notNullValue());
}

@Test
void testCreatePluginSpecificContext_with_empty_array() {
when(context.getParent()).thenReturn(context);
final PluginBeanFactoryProvider objectUnderTest = createObjectUnderTest();
BeanFactory beanFactory = objectUnderTest.createPluginSpecificContext(new Class[]{});
assertThat(beanFactory, notNullValue());
assertThrows(NoSuchBeanDefinitionException.class, ()->beanFactory.getBean(TestComponent.class));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.opensearch.dataprepper.plugins.test;

import javax.inject.Named;

@Named
public class TestComponent {
public String getIdentifier() {
return "test-component";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.test;

import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.model.source.Source;
import org.opensearch.dataprepper.plugin.TestPluggableInterface;

@DataPrepperPlugin(name = "test_di_source",
alternateNames = { "test_source_alternate_name1", "test_source_alternate_name2" },
deprecatedName = "test_source_deprecated_name",
pluginType = Source.class,
packagesToScan = {TestDISource.class})
public class TestDISource implements Source<Record<String>>, TestPluggableInterface {

private final TestComponent testComponent;

@DataPrepperPluginConstructor
public TestDISource(TestComponent testComponent) {
this.testComponent = testComponent;
}

@Override
public void start(Buffer<Record<String>> buffer) {
}

public TestComponent getTestComponent() {
return testComponent;
}

@Override
public void stop() {}
}

0 comments on commit 472c912

Please sign in to comment.