Skip to content

Commit

Permalink
"Thread per test class" via MultiEnvTestEngine
Browse files Browse the repository at this point in the history
See projectnessie#9441 for a complete description of the underlying problem. TL;DR is: `ThreadLocal`s from various 3rd party libraries leak into the single `Test worker` thread that runs all the tests, resulting in TL objects/suppliers from the various Quarkus test class loaders, eventually leading to nasty OOMs.

This change updates the `MultiEnvTestEngine` by using the new `ThreadPerTestClassExecutionExecutorService` and also "assimilate" really all tests, even the non-multi-env tests, so that those also run on a thread per test-class. The logic to distinguish multi-env from non-multi-env tests via `MultiEnvExtensionRegistry.registerExtension()` via test discovery is not perfect (but good enough), it can add multi-env tests to the non-multi-env tests, so an additional check is needed there.

Since each test class runs on "its own thread", the `ThreadLocal`s are registered on that thread. Once the test class finishes, the thread is disposed and its thread locals become eligible for garbage collection, which is what is needed.

The bump of the max-heap size for test workers is also reduced back to 2g (was changed in projectnessie#9433).

Fixes projectnessie#9441
  • Loading branch information
snazy committed Aug 30, 2024
1 parent 0fc02c8 commit 1e246c5
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 48 deletions.
2 changes: 1 addition & 1 deletion build-logic/src/main/kotlin/nessie-testing.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ tasks.withType<Test>().configureEach {
)

minHeapSize = if (testHeapSize != null) testHeapSize as String else "768m"
maxHeapSize = if (testHeapSize != null) testHeapSize as String else "3g"
maxHeapSize = if (testHeapSize != null) testHeapSize as String else "2g"
} else if (testHeapSize != null) {
minHeapSize = testHeapSize!!
maxHeapSize = testHeapSize!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Stream;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.Extension;
import org.junit.jupiter.engine.config.DefaultJupiterConfiguration;
import org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor;
import org.junit.jupiter.engine.extension.MutableExtensionRegistry;
import org.junit.platform.commons.util.AnnotationUtils;

Expand All @@ -33,17 +37,40 @@
public class MultiEnvExtensionRegistry {
private final MutableExtensionRegistry registry;

private final Set<ClassBasedTestDescriptor> probablyNotMultiEnv = new LinkedHashSet<>();

public MultiEnvExtensionRegistry() {
this.registry =
MutableExtensionRegistry.createRegistryWithDefaultExtensions(
new DefaultJupiterConfiguration(new EmptyConfigurationParameters()));
}

public void registerExtensions(Class<?> testClass) {
AnnotationUtils.findRepeatableAnnotations(testClass, ExtendWith.class).stream()
.flatMap(e -> Arrays.stream(e.value()))
.filter(MultiEnvTestExtension.class::isAssignableFrom)
public void registerExtensions(ClassBasedTestDescriptor descriptor) {
AtomicBoolean multiEnv = new AtomicBoolean(false);

findMultiEnvExtensions(descriptor)
.peek(x -> multiEnv.set(true))
.forEach(registry::registerExtension);

if (!multiEnv.get()) {
probablyNotMultiEnv.add(descriptor);
}
}

public boolean isMultiEnvClass(ClassBasedTestDescriptor descriptor) {
return findMultiEnvExtensions(descriptor).findFirst().isPresent();
}

private Stream<Class<? extends Extension>> findMultiEnvExtensions(
ClassBasedTestDescriptor descriptor) {
Class<?> testClass = descriptor.getTestClass();
return AnnotationUtils.findRepeatableAnnotations(testClass, ExtendWith.class).stream()
.flatMap(e -> Arrays.stream(e.value()))
.filter(MultiEnvTestExtension.class::isAssignableFrom);
}

public Stream<ClassBasedTestDescriptor> probablyNotMultiEnv() {
return probablyNotMultiEnv.stream();
}

public Stream<MultiEnvTestExtension> stream() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor;
import org.junit.jupiter.engine.descriptor.JupiterEngineDescriptor;
import org.junit.jupiter.engine.discovery.DiscoverySelectorResolver;
import org.junit.platform.engine.ConfigurationParameters;
import org.junit.platform.engine.EngineDiscoveryRequest;
import org.junit.platform.engine.ExecutionRequest;
import org.junit.platform.engine.TestDescriptor;
Expand All @@ -51,7 +52,8 @@ public class MultiEnvTestEngine implements TestEngine {

private static final MultiEnvExtensionRegistry registry = new MultiEnvExtensionRegistry();

private final JupiterTestEngine delegate = new JupiterTestEngine();
private final ThreadPerTestClassExecutionTestEngine delegate =
new ThreadPerTestClassExecutionTestEngine();

static MultiEnvExtensionRegistry registry() {
return registry;
Expand All @@ -76,36 +78,36 @@ public TestDescriptor discover(EngineDiscoveryRequest discoveryRequest, UniqueId
preliminaryResult.accept(
descriptor -> {
if (descriptor instanceof ClassBasedTestDescriptor) {
Class<?> testClass = ((ClassBasedTestDescriptor) descriptor).getTestClass();
registry.registerExtensions(testClass);
registry.registerExtensions(((ClassBasedTestDescriptor) descriptor));
}
});

ConfigurationParameters configurationParameters =
discoveryRequest.getConfigurationParameters();

// JupiterEngineDescriptor must be the root, that's what the JUnit Jupiter engine
// implementation expects.
JupiterEngineDescriptor multiEnvDescriptor =
JupiterEngineDescriptor multiEnvRootDescriptor =
new JupiterEngineDescriptor(
uniqueId,
new DefaultJupiterConfiguration(discoveryRequest.getConfigurationParameters()));
uniqueId, new DefaultJupiterConfiguration(configurationParameters));

// Handle the "multi-env" tests.
List<String> extensions = new ArrayList<>();
AtomicBoolean envDiscovered = new AtomicBoolean();
AtomicBoolean multiEnvDiscovered = new AtomicBoolean();
registry.stream()
.forEach(
ext -> {
extensions.add(ext.getClass().getSimpleName());
for (String envId :
ext.allEnvironmentIds(discoveryRequest.getConfigurationParameters())) {
envDiscovered.set(true);
for (String envId : ext.allEnvironmentIds(configurationParameters)) {
multiEnvDiscovered.set(true);
UniqueId segment = uniqueId.append(ext.segmentType(), envId);

MultiEnvTestDescriptor envRoot = new MultiEnvTestDescriptor(segment, envId);
multiEnvDescriptor.addChild(envRoot);
multiEnvRootDescriptor.addChild(envRoot);

JupiterConfiguration envRootConfiguration =
new CachingJupiterConfiguration(
new MultiEnvJupiterConfiguration(
discoveryRequest.getConfigurationParameters(), envId));
new MultiEnvJupiterConfiguration(configurationParameters, envId));
JupiterEngineDescriptor discoverResult =
new JupiterEngineDescriptor(segment, envRootConfiguration);
new DiscoverySelectorResolver()
Expand All @@ -116,18 +118,53 @@ public TestDescriptor discover(EngineDiscoveryRequest discoveryRequest, UniqueId
for (TestDescriptor child : children) {
// Note: this also removes the reference to parent from the child
discoverResult.removeChild(child);
envRoot.addChild(child);

// Must check whether the test class is a multi-env test, because discovery
// returns all test classes.
ClassBasedTestDescriptor classBased = (ClassBasedTestDescriptor) child;
boolean multi = registry().isMultiEnvClass(classBased);
if (multi) {
envRoot.addChild(child);
}
}
}
});

// Also execute all other tests via the MultiEnv test engine to get the "thread per
// test-class" behavior.
registry()
.probablyNotMultiEnv()
.forEach(
clazz -> {
JupiterConfiguration jupiterConfiguration =
new CachingJupiterConfiguration(
new DefaultJupiterConfiguration(configurationParameters));

JupiterEngineDescriptor discoverResult =
new JupiterEngineDescriptor(uniqueId, jupiterConfiguration);
new DiscoverySelectorResolver().resolveSelectors(discoveryRequest, discoverResult);

List<? extends TestDescriptor> children =
new ArrayList<>(discoverResult.getChildren());
for (TestDescriptor child : children) {
// Must check whether the test class is not a multi-env test here, as the
// `multiEnvNotDetected` contains some actual multi-env tests.
ClassBasedTestDescriptor classBased = (ClassBasedTestDescriptor) child;
boolean multi = registry().isMultiEnvClass(classBased);
if (!multi) {
discoverResult.removeChild(child);
multiEnvRootDescriptor.addChild(child);
}
}
});

if (!extensions.isEmpty() && !envDiscovered.get() && FAIL_ON_MISSING_ENVIRONMENTS) {
if (!extensions.isEmpty() && !multiEnvDiscovered.get() && FAIL_ON_MISSING_ENVIRONMENTS) {
throw new IllegalStateException(
String.format(
"%s was enabled, but test extensions did not discover any environment IDs: %s",
getClass().getSimpleName(), extensions));
}
return multiEnvDescriptor;
return multiEnvRootDescriptor;
} catch (Exception e) {
LOGGER.error("Failed to discover tests", e);
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
*/
package org.projectnessie.junit.engine;

import static org.projectnessie.junit.engine.MultiEnvTestEngine.registry;

import java.util.Optional;
import org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor;
import org.junit.platform.engine.FilterResult;
Expand Down Expand Up @@ -46,39 +44,27 @@ private Optional<Class<?>> classFor(TestDescriptor object) {
return Optional.empty();
}

private FilterResult filter(Class<?> testClass, UniqueId id) {
// Use the static extension data collected during the discovery phase.
// It is possible to reload extensions based of class objects from test descriptors,
// however that would add unnecessary overhead.
MultiEnvExtensionRegistry registry = registry();

/**
* This filter effectively routes all tests via the {@link MultiEnvTestEngine}, both actual
* multi-env tests but also non-multi-env tests to achieve the needed thread-per-test-class
* behavior.
*
* <p>"Thread-per-test-class behavior" is needed to prevent the class/class-loader leak via {@link
* ThreadLocal}s as described in <a
* href="https://github.com/projectnessie/nessie/issues/9441">#9441</a>.
*/
private FilterResult filter(UniqueId id) {
if (id.getEngineId().map("junit-jupiter"::equals).orElse(false)) {
if (registry.stream(testClass).findAny().isPresent()) {
return FilterResult.excluded("Excluding multi-env test from Jupiter Engine: " + id);
} else {
return FilterResult.included(null);
}
return FilterResult.excluded("Excluding multi-env test from Jupiter Engine: " + id);
} else {
// check whether any of the extensions declared by the test recognize the version segment
boolean matched =
registry.stream(testClass)
.anyMatch(
ext ->
id.getSegments().stream()
.anyMatch(s -> ext.segmentType().equals(s.getType())));

if (matched) {
return FilterResult.included(null);
} else {
return FilterResult.excluded("Excluding unmatched multi-env test: " + id);
}
return FilterResult.included(null);
}
}

@Override
public FilterResult apply(TestDescriptor test) {
return classFor(test)
.map(testClass -> filter(testClass, test.getUniqueId()))
.map(testClass -> filter(test.getUniqueId()))
.orElseGet(() -> FilterResult.included(null)); // fallback for non-class descriptors
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright (C) 2024 Dremio
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.projectnessie.junit.engine;

import static java.util.concurrent.CompletableFuture.completedFuture;

import java.lang.reflect.Field;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.platform.engine.TestDescriptor;
import org.junit.platform.engine.UniqueId;
import org.junit.platform.engine.support.hierarchical.HierarchicalTestExecutorService;

/**
* Implements a JUnit test executor that provides thread-per-test-class behavior.
*
* <p>"Thread-per-test-class behavior" is needed to prevent the class/class-loader leak via {@link
* ThreadLocal}s as described in <a
* href="https://github.com/projectnessie/nessie/issues/9441">#9441</a>.
*/
public class ThreadPerTestClassExecutionExecutorService implements HierarchicalTestExecutorService {

private static final Class<?> CLASS_NODE_TEST_TASK;
private static final Field FIELD_TEST_DESCRIPTOR;

static {
try {
CLASS_NODE_TEST_TASK =
Class.forName("org.junit.platform.engine.support.hierarchical.NodeTestTask");
FIELD_TEST_DESCRIPTOR = CLASS_NODE_TEST_TASK.getDeclaredField("testDescriptor");
FIELD_TEST_DESCRIPTOR.setAccessible(true);
} catch (Exception e) {
throw new RuntimeException(
"ThreadPerExecutionExecutorService is probably not compatible with the current JUnit version",
e);
}
}

protected TestDescriptor getTestDescriptor(TestTask testTask) {
if (!CLASS_NODE_TEST_TASK.isAssignableFrom(testTask.getClass())) {
throw new IllegalArgumentException(
testTask.getClass().getName() + " is not of type " + CLASS_NODE_TEST_TASK.getName());
}
try {
return (TestDescriptor) FIELD_TEST_DESCRIPTOR.get(testTask);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}

public ThreadPerTestClassExecutionExecutorService() {}

@Override
public Future<Void> submit(TestTask testTask) {
executeTask(testTask);
return completedFuture(null);
}

@Override
public void invokeAll(List<? extends TestTask> tasks) {
tasks.forEach(this::executeTask);
}

protected void executeTask(TestTask testTask) {
TestDescriptor testDescriptor = getTestDescriptor(testTask);
UniqueId.Segment lastSegment = testDescriptor.getUniqueId().getLastSegment();
String type = lastSegment.getType();
if ("class".equals(type)) {
AtomicReference<Exception> failure = new AtomicReference<>();
Thread threadPerClass =
new Thread(
() -> {
try {
testTask.execute();
} catch (Exception e) {
failure.set(e);
}
},
"TEST THREAD FOR " + lastSegment.getValue());
threadPerClass.setDaemon(true);
threadPerClass.start();
try {
threadPerClass.join();
} catch (InterruptedException e) {
// delegate a thread-interrupt
threadPerClass.interrupt();
}
Exception ex = failure.get();
if (ex instanceof RuntimeException) {
throw (RuntimeException) ex;
} else if (ex != null) {
throw new RuntimeException(ex);
}
} else {
testTask.execute();
}
}

@Override
public void close() {
// nothing to do
}
}
Loading

0 comments on commit 1e246c5

Please sign in to comment.