From 3cd53515d093a6e224f5a9b4fd7d07f902146c73 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 19 Mar 2024 15:41:51 +0100 Subject: [PATCH] update test dependencies and add option to control downloading (#55) --- CMakeLists.txt | 7 + tests/CMakeLists.txt | 93 ++++----- tests/gtest_mpi.cpp | 216 +++++++++++++++++++++ tests/gtest_mpi.hpp | 28 +++ tests/local_tests/test_local_transform.cpp | 4 +- tests/mpi_tests/test_multi_transform.cpp | 4 + tests/mpi_tests/test_transform.cpp | 15 +- tests/mpi_tests/test_transpose.cpp | 6 + tests/mpi_tests/test_transpose_gpu.cpp | 6 + tests/programs/benchmark.cpp | 17 +- tests/run_mpi_tests.cpp | 25 +-- 11 files changed, 338 insertions(+), 83 deletions(-) create mode 100644 tests/gtest_mpi.cpp create mode 100644 tests/gtest_mpi.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ff9e7b..7730330 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,8 @@ set(CMAKE_HIP_STANDARD 17) #add local module path set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${PROJECT_SOURCE_DIR}/cmake/modules) +include(CMakeDependentOption) + # Options option(SPFFT_STATIC "Compile as static library" OFF) option(SPFFT_OMP "Compile with OpenMP support" ON) @@ -38,6 +40,11 @@ option(SPFFT_BUILD_TESTS "Build tests" OFF) option(SPFFT_SINGLE_PRECISION "Enable single precision support" OFF) option(SPFFT_INSTALL "Enable CMake install commands" ON) option(SPFFT_FORTRAN "Compile fortran module" OFF) +option(SPFFT_BUNDLED_LIBS "Use bundled libraries for building tests" ON) + +cmake_dependent_option(SPFFT_BUNDLED_GOOGLETEST "Use bundled googletest lib" ON "SPFFT_BUNDLED_LIBS" OFF) +cmake_dependent_option(SPFFT_BUNDLED_JSON "Use bundled json lib" ON "SPFFT_BUNDLED_LIBS" OFF) +cmake_dependent_option(SPFFT_BUNDLED_CLI11 "Use bundled CLI11 lib" ON "SPFFT_BUNDLED_LIBS" OFF) set(SPFFT_GPU_BACKEND "OFF" CACHE STRING "GPU backend") set_property(CACHE SPFFT_GPU_BACKEND PROPERTY STRINGS diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 445632a..f292dff 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,72 +1,60 @@ if(SPFFT_BUILD_TESTS) - cmake_minimum_required(VERSION 3.11 FATAL_ERROR) # git fetch module requires at least 3.11 + cmake_minimum_required(VERSION 3.14 FATAL_ERROR) # FetchContent_MakeAvailable requires at least 3.14 + + # update time stamps when using FetchContent + if(POLICY CMP0135) + cmake_policy(SET CMP0135 NEW) + endif() + set(BUILD_GMOCK OFF CACHE BOOL "") set(INSTALL_GTEST OFF CACHE BOOL "") mark_as_advanced(BUILD_GMOCK INSTALL_GTEST) include(FetchContent) # add googletest - FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG release-1.8.1 - ) - FetchContent_GetProperties(googletest) - if(NOT googletest_POPULATED) - message(STATUS "Downloading Google Test repository...") - FetchContent_Populate(googletest) - endif() - add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR}) - - # add gtest_mpi - FetchContent_Declare( - gtest_mpi - GIT_REPOSITORY https://github.com/AdhocMan/gtest_mpi.git - GIT_TAG v1.0.0 - ) - FetchContent_GetProperties(gtest_mpi) - if(NOT gtest_mpi_POPULATED) - message(STATUS "Downloading Google Test MPI extension repository...") - FetchContent_Populate(gtest_mpi) + if(SPFFT_BUNDLED_GOOGLETEST) + FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/refs/tags/v1.14.0.tar.gz + URL_MD5 c8340a482851ef6a3fe618a082304cfc + ) + FetchContent_MakeAvailable(googletest) + else() + find_package(googletest CONFIG REQUIRED) endif() - add_subdirectory(${gtest_mpi_SOURCE_DIR} ${gtest_mpi_BINARY_DIR}) + list(APPEND SPFFT_TEST_LIBRARIES gtest_main) # add command line parser - FetchContent_Declare( - cli11 - GIT_REPOSITORY https://github.com/CLIUtils/CLI11.git - GIT_TAG v1.7.1 - ) - FetchContent_GetProperties(cli11) - if(NOT cli11_POPULATED) - message(STATUS "Downloading CLI11 command line parser repository...") - FetchContent_Populate(cli11) + if(SPFFT_BUNDLED_CLI11) + FetchContent_Declare( + cli11 + URL https://github.com/CLIUtils/CLI11/archive/refs/tags/v2.3.2.tar.gz + URL_MD5 b80cb645dee25982110b068b426363ff + ) + FetchContent_MakeAvailable(cli11) + else() + find_package(CLI11 CONFIG REQUIRED) endif() - list(APPEND SPFFT_EXTERNAL_INCLUDE_DIRS ${cli11_SOURCE_DIR}/include) + list(APPEND SPFFT_TEST_LIBRARIES CLI11::CLI11) - # add json parser - set(JSON_Install OFF CACHE BOOL "") - FetchContent_Declare( - json - GIT_REPOSITORY https://github.com/nlohmann/json.git - GIT_TAG v3.6.1 - ) - FetchContent_GetProperties(json) - if(NOT json_POPULATED) - message(STATUS "Downloading json repository...") - FetchContent_Populate(json) + # add json parser + if(SPFFT_BUNDLED_JSON) + FetchContent_Declare( + json + URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz + URL_MD5 e8d56bc54621037842ee9f0aeae27746 + ) + FetchContent_MakeAvailable(json) + else() + find_package(nlohmann_json CONFIG REQUIRED) endif() - set(JSON_BuildTests OFF CACHE INTERNAL "") - add_subdirectory(${json_SOURCE_DIR} ${json_BINARY_DIR}) - list(APPEND SPFFT_EXTERNAL_LIBS nlohmann_json::nlohmann_json) - list(APPEND SPFFT_INCLUDE_DIRS ${PROJECT_SOURCE_DIR}/tests) # benchmark executable add_executable(benchmark programs/benchmark.cpp) - target_link_libraries(benchmark PRIVATE spfft_test ${SPFFT_EXTERNAL_LIBS}) + target_link_libraries(benchmark PRIVATE spfft_test ${SPFFT_EXTERNAL_LIBS} CLI11::CLI11 nlohmann_json::nlohmann_json) target_include_directories(benchmark PRIVATE ${SPFFT_INCLUDE_DIRS} ${SPFFT_EXTERNAL_INCLUDE_DIRS}) # test executables @@ -77,19 +65,20 @@ if(SPFFT_BUILD_TESTS) local_tests/test_fftw_prop_hash.cpp local_tests/test_local_transform.cpp ) - target_link_libraries(run_local_tests PRIVATE gtest_main gtest_mpi) + target_link_libraries(run_local_tests PRIVATE gtest_main) target_link_libraries(run_local_tests PRIVATE spfft_test ${SPFFT_EXTERNAL_LIBS}) target_include_directories(run_local_tests PRIVATE ${SPFFT_INCLUDE_DIRS} ${SPFFT_EXTERNAL_INCLUDE_DIRS}) if(SPFFT_MPI) add_executable(run_mpi_tests run_mpi_tests.cpp + gtest_mpi.cpp mpi_tests/test_transform.cpp mpi_tests/test_multi_transform.cpp mpi_tests/test_transpose.cpp mpi_tests/test_transpose_gpu.cpp ) - target_link_libraries(run_mpi_tests PRIVATE gtest_main gtest_mpi) + target_link_libraries(run_mpi_tests PRIVATE gtest_main) target_link_libraries(run_mpi_tests PRIVATE spfft_test ${SPFFT_EXTERNAL_LIBS}) target_include_directories(run_mpi_tests PRIVATE ${SPFFT_INCLUDE_DIRS} ${SPFFT_EXTERNAL_INCLUDE_DIRS}) endif() diff --git a/tests/gtest_mpi.cpp b/tests/gtest_mpi.cpp new file mode 100644 index 0000000..e7aeeac --- /dev/null +++ b/tests/gtest_mpi.cpp @@ -0,0 +1,216 @@ +#include "gtest_mpi.hpp" +#include +#include +#include +#include +#include +#include + +namespace gtest_mpi { + +namespace { + +class MPIListener : public testing::EmptyTestEventListener { +public: + using UnitTest = testing::UnitTest; + using TestCase = testing::TestCase; + using TestInfo = testing::TestInfo; + using TestPartResult = testing::TestPartResult; + using TestSuite = testing::TestSuite; + + MPIListener(testing::TestEventListener *listener) + : listener_(listener), comm_(MPI_COMM_WORLD), gather_called_(false) { + MPI_Comm_dup(MPI_COMM_WORLD, &comm_); + int rank; + MPI_Comm_rank(comm_, &rank); + if (rank != 0) + listener_.reset(); + } + + void OnTestProgramStart(const UnitTest &u) override { + if (listener_) + listener_->OnTestProgramStart(u); + } + + void OnTestProgramEnd(const UnitTest &u) override { + if (listener_) + listener_->OnTestProgramEnd(u); + } + + void OnTestStart(const TestInfo &test_info) override { + gather_called_ = false; + if (listener_) + listener_->OnTestStart(test_info); + } + + void OnTestPartResult(const TestPartResult &test_part_result) override { + if (listener_) { + listener_->OnTestPartResult(test_part_result); + } else if (test_part_result.type() == TestPartResult::Type::kFatalFailure || + test_part_result.type() == + TestPartResult::Type::kNonFatalFailure) { + std::size_t fileIndex = strings_.size(); + strings_ += test_part_result.file_name(); + strings_ += '\0'; + + std::size_t messageIndex = strings_.size(); + strings_ += test_part_result.message(); + strings_ += '\0'; + + infos_.emplace_back(ResultInfo{test_part_result.type(), fileIndex, + test_part_result.line_number(), + messageIndex}); + } + } + + void OnTestEnd(const TestInfo &test_info) override { + if(!gather_called_){ + std::cerr << "Missing GTEST_MPI_GUARD in test case!" << std::endl; + throw std::runtime_error("Missing GTEST_MPI_GUARD in test case!"); + } + + if (listener_) + listener_->OnTestEnd(test_info); + } + + void OnTestIterationStart(const UnitTest &u, int it) override { + if (listener_) + listener_->OnTestIterationStart(u, it); + } + + void OnEnvironmentsSetUpStart(const UnitTest &u) override { + if (listener_) + listener_->OnEnvironmentsSetUpStart(u); + } + + void OnEnvironmentsSetUpEnd(const UnitTest &u) override { + if (listener_) + listener_->OnEnvironmentsSetUpEnd(u); + } + + void OnTestSuiteStart(const TestSuite &t) override { + if (listener_) + listener_->OnTestSuiteStart(t); + } + + void OnTestDisabled(const TestInfo &t) override { + if (listener_) + listener_->OnTestDisabled(t); + } + void OnTestSuiteEnd(const TestSuite &t) override { + if (listener_) + listener_->OnTestSuiteEnd(t); + } + + void OnEnvironmentsTearDownStart(const UnitTest &u) override { + if (listener_) + listener_->OnEnvironmentsTearDownStart(u); + } + + void OnEnvironmentsTearDownEnd(const UnitTest &u) override { + if (listener_) + listener_->OnEnvironmentsTearDownEnd(u); + } + + void OnTestIterationEnd(const UnitTest &u, int it) override { + if (listener_) + listener_->OnTestIterationEnd(u, it); + } + + void GatherPartResults() { + gather_called_ = true; + int rank, n_proc; + MPI_Comm_rank(comm_, &rank); + MPI_Comm_size(comm_, &n_proc); + + if (rank == 0) { + decltype(infos_) remoteInfos; + decltype(strings_) remoteStrings; + for (int r = 1; r < n_proc; ++r) { + MPI_Status status; + int count; + + // Result infos + MPI_Probe(r, 0, comm_, &status); + MPI_Get_count(&status, MPI_CHAR, &count); + auto numResults = static_cast(count) / + sizeof(decltype(remoteInfos)::value_type); + remoteInfos.resize(numResults); + MPI_Recv(remoteInfos.data(), count, MPI_BYTE, r, 0, comm_, + MPI_STATUS_IGNORE); + + // Only continue if any results + if (numResults) { + // Get strings + MPI_Probe(r, 0, comm_, &status); + MPI_Get_count(&status, MPI_CHAR, &count); + auto stringSize = static_cast(count) / + sizeof(decltype(remoteStrings)::value_type); + remoteStrings.resize(stringSize); + MPI_Recv(&remoteStrings[0], count, MPI_BYTE, r, 0, comm_, + MPI_STATUS_IGNORE); + + // Create error for every remote fail + for (const auto &info : remoteInfos) { + if (info.type == TestPartResult::Type::kFatalFailure || + info.type == TestPartResult::Type::kNonFatalFailure) { + ADD_FAILURE_AT(&remoteStrings[info.fileIndex], info.lineNumber) + << "Rank " << r << ": " << &remoteStrings[info.messageIndex]; + } + } + } + } + } else { + MPI_Send(infos_.data(), + infos_.size() * sizeof(decltype(infos_)::value_type), MPI_BYTE, + 0, 0, comm_); + + // Only send string if results exist + if (infos_.size()) { + MPI_Send(strings_.data(), + strings_.size() * sizeof(decltype(strings_)::value_type), + MPI_BYTE, 0, 0, comm_); + } + } + + infos_.clear(); + strings_.clear(); + } + +private: + struct ResultInfo { + TestPartResult::Type type; + std::size_t fileIndex; + int lineNumber; + std::size_t messageIndex; + }; + + std::unique_ptr listener_; + MPI_Comm comm_; + bool gather_called_; + + std::vector infos_; + std::string strings_; +}; + +MPIListener *globalMPIListener = nullptr; + +} // namespace + +void InitGoogleTestMPI(int *argc, char **argv) { + + ::testing::InitGoogleTest(argc, argv); + + auto &test_listeners = ::testing::UnitTest::GetInstance()->listeners(); + + globalMPIListener = new MPIListener( + test_listeners.Release(test_listeners.default_result_printer())); + + test_listeners.Append(globalMPIListener); +} + +TestGuard CreateTestGuard() { + return TestGuard{[]() { globalMPIListener->GatherPartResults(); }}; +} + +} // namespace gtest_mpi diff --git a/tests/gtest_mpi.hpp b/tests/gtest_mpi.hpp new file mode 100644 index 0000000..7b905fe --- /dev/null +++ b/tests/gtest_mpi.hpp @@ -0,0 +1,28 @@ +#ifndef GTEST_MPI_HPP +#define GTEST_MPI_HPP + +#include + +namespace gtest_mpi { +// Internal helper struct +struct TestGuard { + void (*func)() = nullptr; + + ~TestGuard() { + if (func) + func(); + } +}; + +// Initialize GoogleTest and MPI functionality. MPI_Init has to called before. +void InitGoogleTestMPI(int *argc, char **argv); + +// Create a test guard, which has to be placed in all test cases. +TestGuard CreateTestGuard(); + +} // namespace gtest_mpi + +// Helper macro for creating a test guard within test cases. +#define GTEST_MPI_GUARD auto gtest_mpi_guard__LINE__ = ::gtest_mpi::CreateTestGuard(); + +#endif diff --git a/tests/local_tests/test_local_transform.cpp b/tests/local_tests/test_local_transform.cpp index a13a1b4..2b8a7e8 100644 --- a/tests/local_tests/test_local_transform.cpp +++ b/tests/local_tests/test_local_transform.cpp @@ -92,7 +92,7 @@ static auto param_type_names( #define TEST_PROCESSING_UNITS SpfftProcessingUnitType::SPFFT_PU_HOST #endif -INSTANTIATE_TEST_CASE_P(FullTest, TestLocalTransform, +INSTANTIATE_TEST_SUITE_P(FullTest, TestLocalTransform, ::testing::Combine(::testing::Values(SpfftExchangeType::SPFFT_EXCH_DEFAULT), ::testing::Values(TEST_PROCESSING_UNITS), ::testing::Values(1, 2, 11, 12, 13, 100), @@ -101,7 +101,7 @@ INSTANTIATE_TEST_CASE_P(FullTest, TestLocalTransform, ::testing::Values(false)), param_type_names); -INSTANTIATE_TEST_CASE_P(CenteredIndicesTest, TestLocalTransform, +INSTANTIATE_TEST_SUITE_P(CenteredIndicesTest, TestLocalTransform, ::testing::Combine(::testing::Values(SpfftExchangeType::SPFFT_EXCH_DEFAULT), ::testing::Values(TEST_PROCESSING_UNITS), ::testing::Values(1, 2, 11, 100), diff --git a/tests/mpi_tests/test_multi_transform.cpp b/tests/mpi_tests/test_multi_transform.cpp index a446148..073490e 100644 --- a/tests/mpi_tests/test_multi_transform.cpp +++ b/tests/mpi_tests/test_multi_transform.cpp @@ -1,11 +1,14 @@ #include + #include #include #include #include #include #include + #include "gtest/gtest.h" +#include "gtest_mpi.hpp" #include "memory/array_view_utility.hpp" #include "memory/host_array.hpp" #include "memory/host_array_view.hpp" @@ -17,6 +20,7 @@ #include "util/common_types.hpp" TEST(MPIMultiTransformTest, BackwardsForwards) { + GTEST_MPI_GUARD try { MPICommunicatorHandle comm(MPI_COMM_WORLD); const std::vector zStickDistribution(comm.size(), 1.0); diff --git a/tests/mpi_tests/test_transform.cpp b/tests/mpi_tests/test_transform.cpp index fa2516b..c3cb0c6 100644 --- a/tests/mpi_tests/test_transform.cpp +++ b/tests/mpi_tests/test_transform.cpp @@ -1,12 +1,16 @@ #include "test_util/test_transform.hpp" + #include + #include #include #include #include #include #include + #include "gtest/gtest.h" +#include "gtest_mpi.hpp" #include "memory/array_view_utility.hpp" #include "memory/host_array.hpp" #include "memory/host_array_view.hpp" @@ -36,6 +40,7 @@ class MPITransformTest : public TransformTest { Grid grid_; }; TEST_P(MPITransformTest, ForwardUniformDistribution) { + GTEST_MPI_GUARD try { std::vector zStickDistribution(comm_size(), 1.0); std::vector xyPlaneDistribution(comm_size(), 1.0); @@ -47,6 +52,7 @@ TEST_P(MPITransformTest, ForwardUniformDistribution) { } TEST_P(MPITransformTest, BackwardAllOneRank) { + GTEST_MPI_GUARD try { std::vector zStickDistribution(comm_size(), 0.0); zStickDistribution[0] = 1.0; @@ -61,6 +67,7 @@ TEST_P(MPITransformTest, BackwardAllOneRank) { } TEST_P(MPITransformTest, ForwardAllOneRank) { + GTEST_MPI_GUARD try { std::vector zStickDistribution(comm_size(), 0.0); zStickDistribution[0] = 1.0; @@ -75,6 +82,7 @@ TEST_P(MPITransformTest, ForwardAllOneRank) { } TEST_P(MPITransformTest, BackwardAllOneRankPerSide) { + GTEST_MPI_GUARD try { std::vector zStickDistribution(comm_size(), 0.0); zStickDistribution[0] = 1.0; @@ -89,6 +97,7 @@ TEST_P(MPITransformTest, BackwardAllOneRankPerSide) { } TEST_P(MPITransformTest, ForwardAllOneRankPerSide) { + GTEST_MPI_GUARD try { std::vector zStickDistribution(comm_size(), 0.0); zStickDistribution[0] = 1.0; @@ -103,6 +112,7 @@ TEST_P(MPITransformTest, ForwardAllOneRankPerSide) { } TEST_P(MPITransformTest, R2CUniformDistribution) { + GTEST_MPI_GUARD try { std::vector xyPlaneDistribution(comm_size(), 1.0); test_r2c(xyPlaneDistribution); @@ -113,6 +123,7 @@ TEST_P(MPITransformTest, R2CUniformDistribution) { } TEST_P(MPITransformTest, R2COneRankAllPlanes) { + GTEST_MPI_GUARD try { std::vector xyPlaneDistribution(comm_size(), 0.0); xyPlaneDistribution[0] = 1.0; @@ -170,7 +181,7 @@ static auto param_type_names( #define TEST_PROCESSING_UNITS SpfftProcessingUnitType::SPFFT_PU_HOST #endif -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( FullTest, MPITransformTest, ::testing::Combine(::testing::Values(SpfftExchangeType::SPFFT_EXCH_BUFFERED, SpfftExchangeType::SPFFT_EXCH_COMPACT_BUFFERED, @@ -182,7 +193,7 @@ INSTANTIATE_TEST_CASE_P( ::testing::Values(1, 2, 11, 12, 13, 100), ::testing::Values(false)), param_type_names); -INSTANTIATE_TEST_CASE_P(CenteredIndicesTest, MPITransformTest, +INSTANTIATE_TEST_SUITE_P(CenteredIndicesTest, MPITransformTest, ::testing::Combine(::testing::Values(SpfftExchangeType::SPFFT_EXCH_DEFAULT), ::testing::Values(TEST_PROCESSING_UNITS), ::testing::Values(1, 2, 11, 100), diff --git a/tests/mpi_tests/test_transpose.cpp b/tests/mpi_tests/test_transpose.cpp index d14134f..1c9c5ac 100644 --- a/tests/mpi_tests/test_transpose.cpp +++ b/tests/mpi_tests/test_transpose.cpp @@ -1,9 +1,12 @@ #include + #include #include #include #include + #include "gtest/gtest.h" +#include "gtest_mpi.hpp" #include "memory/array_view_utility.hpp" #include "memory/host_array.hpp" #include "memory/host_array_view.hpp" @@ -120,6 +123,7 @@ static void check_freq_domain(const HostArrayView2D>& freqV } TEST_F(TransposeTest, Unbuffered) { + GTEST_MPI_GUARD auto freqXYView = create_3d_view(array2_, 0, paramPtr_->num_xy_planes(comm_.rank()), paramPtr_->dim_x(), paramPtr_->dim_y()); auto freqView = @@ -138,6 +142,7 @@ TEST_F(TransposeTest, Unbuffered) { } TEST_F(TransposeTest, CompactBuffered) { + GTEST_MPI_GUARD auto freqXYView = create_3d_view(array2_, 0, paramPtr_->num_xy_planes(comm_.rank()), paramPtr_->dim_x(), paramPtr_->dim_y()); auto freqView = @@ -161,6 +166,7 @@ TEST_F(TransposeTest, CompactBuffered) { } TEST_F(TransposeTest, Buffered) { + GTEST_MPI_GUARD auto freqXYView = create_3d_view(array2_, 0, paramPtr_->num_xy_planes(comm_.rank()), paramPtr_->dim_x(), paramPtr_->dim_y()); auto freqView = diff --git a/tests/mpi_tests/test_transpose_gpu.cpp b/tests/mpi_tests/test_transpose_gpu.cpp index c67ba1d..8e90355 100644 --- a/tests/mpi_tests/test_transpose_gpu.cpp +++ b/tests/mpi_tests/test_transpose_gpu.cpp @@ -1,9 +1,12 @@ #include + #include #include #include #include + #include "gtest/gtest.h" +#include "gtest_mpi.hpp" #include "memory/array_view_utility.hpp" #include "memory/host_array.hpp" #include "memory/host_array_view.hpp" @@ -130,6 +133,7 @@ static void check_freq_domain(const HostArrayView2D>& freqV } TEST_F(TransposeGPUTest, Buffered) { + GTEST_MPI_GUARD auto freqXYView = create_3d_view(array2_, 0, paramPtr_->num_xy_planes(comm_.rank()), paramPtr_->dim_y(), paramPtr_->dim_x()); auto freqXYViewGPU = create_3d_view(gpuArray2_, 0, paramPtr_->num_xy_planes(comm_.rank()), @@ -170,6 +174,7 @@ TEST_F(TransposeGPUTest, Buffered) { } TEST_F(TransposeGPUTest, CompactBuffered) { + GTEST_MPI_GUARD auto freqXYView = create_3d_view(array2_, 0, paramPtr_->num_xy_planes(comm_.rank()), paramPtr_->dim_y(), paramPtr_->dim_x()); auto freqXYViewGPU = create_3d_view(gpuArray2_, 0, paramPtr_->num_xy_planes(comm_.rank()), @@ -212,6 +217,7 @@ TEST_F(TransposeGPUTest, CompactBuffered) { } TEST_F(TransposeGPUTest, Unbuffered) { + GTEST_MPI_GUARD auto freqXYView = create_3d_view(array2_, 0, paramPtr_->num_xy_planes(comm_.rank()), paramPtr_->dim_y(), paramPtr_->dim_x()); auto freqXYViewGPU = create_3d_view(gpuArray2_, 0, paramPtr_->num_xy_planes(comm_.rank()), diff --git a/tests/programs/benchmark.cpp b/tests/programs/benchmark.cpp index 3002d6e..032613b 100644 --- a/tests/programs/benchmark.cpp +++ b/tests/programs/benchmark.cpp @@ -141,17 +141,16 @@ int main(int argc, char** argv) { app.add_option("-o", outputFileName, "Output file name")->required(); app.add_option("-m", numTransforms, "Multiple transform number")->default_val("1"); app.add_option("-s", sparsity, "Sparsity"); - app.add_set("-t", transformTypeName, - std::set{"c2c", "r2c"}, - "Transform type") + app.add_option("-t", transformTypeName, "Transform type") + ->check(CLI::IsMember({"c2c", "r2c"})) ->default_val("c2c"); - app.add_set("-e", exchName, - std::set{"all", "compact", "compactFloat", "buffered", "bufferedFloat", - "unbuffered"}, - "Exchange type") + app.add_option("-e", exchName, "Exchange type") + ->check(CLI::IsMember( + {"all", "compact", "compactFloat", "buffered", "bufferedFloat", "unbuffered"})) ->required(); - app.add_set("-p", procName, std::set{"cpu", "gpu", "gpu-gpu"}, - "Processing unit. With gpu-gpu, device memory is used as input and output.") + app.add_option("-p", procName, + "Processing unit. With gpu-gpu, device memory is used as input and output.") + ->check(CLI::IsMember({"cpu", "gpu", "gpu-gpu"})) ->required(); CLI11_PARSE(app, argc, argv); diff --git a/tests/run_mpi_tests.cpp b/tests/run_mpi_tests.cpp index e7ac17a..d2d5410 100644 --- a/tests/run_mpi_tests.cpp +++ b/tests/run_mpi_tests.cpp @@ -1,29 +1,18 @@ #include + #include "gtest/gtest.h" -#include "gtest_mpi/gtest_mpi.hpp" +#include "gtest_mpi.hpp" int main(int argc, char* argv[]) { // Initialize MPI before any call to gtest_mpi - MPI_Init(&argc, &argv); - - // Intialize google test - ::testing::InitGoogleTest(&argc, argv); - - // Add a test envirnment, which will initialize a test communicator - // (a duplicate of MPI_COMM_WORLD) - ::testing::AddGlobalTestEnvironment(new gtest_mpi::MPITestEnvironment()); - - auto& test_listeners = ::testing::UnitTest::GetInstance()->listeners(); + int provided; + MPI_Init_thread(&argc, &argv, MPI_THREAD_FUNNELED, &provided); - // Remove default listener and replace with the custom MPI listener - delete test_listeners.Release(test_listeners.default_result_printer()); - test_listeners.Append(new gtest_mpi::PrettyMPIUnitTestResultPrinter()); + gtest_mpi::InitGoogleTestMPI(&argc, argv); - // run tests - auto exit_code = RUN_ALL_TESTS(); + auto status = RUN_ALL_TESTS(); - // Finalize MPI before exiting MPI_Finalize(); - return exit_code; + return status; }