Skip to content

Commit

Permalink
[Codegen] Drop TransformStrategies (#18820)
Browse files Browse the repository at this point in the history
The strategies in TransformStrategies have been off by default for some
time and are unmaintained. Drop all related code and tests. Some
previous pipeline tests are combined into existing pipeline tests and
simplified so that they are no longer change detectors.
  • Loading branch information
qedawkins authored Oct 18, 2024
1 parent 4f33005 commit 012f8a6
Show file tree
Hide file tree
Showing 67 changed files with 762 additions and 8,575 deletions.
1 change: 0 additions & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
/compiler/src/iree/compiler/Codegen/LLVMCPU/ @hanhanW @MaheshRavishankar
/compiler/src/iree/compiler/Codegen/LLVMGPU/ @MaheshRavishankar @qedawkins @kuhar @Groverkss
/compiler/src/iree/compiler/Codegen/SPIRV/ @antiagainst @MaheshRavishankar @kuhar
/compiler/src/iree/compiler/Codegen/TransformStrategies/ @qedawkins @MaheshRavishankar
/compiler/src/iree/compiler/ConstEval/ @hanhanW @stellaraccident
/compiler/src/iree/compiler/Dialect/Encoding/ @bjacob @hanhanW
/compiler/src/iree/compiler/Dialect/Flow/ @hanhanW @MaheshRavishankar @IanWood1
Expand Down
2 changes: 0 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,6 @@ iree_compiler_cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:DialectUtils",
# TransformStrategies
"//compiler/src/iree/compiler/Codegen/TransformStrategies/Common:TransformStrategies",
# TransformExtensions (needed for registration in the pass)
"//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
"//compiler/src/iree/compiler/Codegen/Common/TransformExtensions:CommonExtensions",
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ iree_cc_library(
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
iree::compiler::Codegen::LLVMCPU::TransformExtensions::LLVMCPUExtensions
iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
iree::compiler::Codegen::TransformStrategies::Common::TransformStrategies
iree::compiler::Dialect::Encoding::IR
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::TransformExtensions::FlowExtensions
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Interfaces:PartitionableLoopsInterface",
"//compiler/src/iree/compiler/Codegen/TransformStrategies/CPU",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ iree_cc_library(
iree::compiler::Codegen::Common::TransformDialectInterpreterPass
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Interfaces::PartitionableLoopsInterface
iree::compiler::Codegen::TransformStrategies::CPU
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Flow::IR
Expand Down
34 changes: 0 additions & 34 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h"
#include "iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h"
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
#include "iree/compiler/Codegen/TransformStrategies/CPU/Common.h"
#include "iree/compiler/Codegen/Utils/CPUUtils.h"
#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
Expand Down Expand Up @@ -100,12 +99,6 @@ static llvm::cl::opt<bool> clDisableArmSMETiling(
"target (i.e., when the +sme feature flag is present)"),
llvm::cl::init(false));

// Non-static options are used in other places.
llvm::cl::opt<bool> clEnableTransformDialectJit(
"iree-llvmcpu-enable-transform-dialect-jit",
llvm::cl::desc("enable the usage of the transform dialect JIT"),
llvm::cl::init(false));

using IREE::Codegen::DispatchLoweringPassPipeline;

// Encodes the pre-processing strategy to be applied on a Linalg operation
Expand Down Expand Up @@ -2007,28 +2000,6 @@ setDefaultGenericOpRootConfig(mlir::FunctionOpInterface entryPointFn,
/*subgroupSize=*/{}, pipelineConfig);
}

/// Set lowering info to be used by the transform dialect jitter.
static LogicalResult
setTransformStrategyRootConfig(mlir::FunctionOpInterface entryPointFn,
linalg::GenericOp genericOp,
const LinalgOpInfo &linalgOpInfo,
const TargetMLTransformInfo &targetMLTransInfo) {
assert(!getLoweringConfig(genericOp) &&
"expected lowering_config is not set");
if (!clEnableTransformDialectJit)
return failure();
cpu::CPUModel cpuModel;
if (failed(
cpu::matchAndSetReductionStrategy(entryPointFn, genericOp, cpuModel)))
return failure();
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
entryPointFn->getContext(),
IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen);
if (failed(setTranslationInfo(entryPointFn, translationInfo)))
return failure();
return success();
}

/// Utility to return the transpose vector `sizes` for X86. Empty `sizes` on
/// return indicates failure.
static void getTransposeX86VectorSizes(
Expand Down Expand Up @@ -2284,11 +2255,6 @@ setRootConfig(mlir::FunctionOpInterface entryPointFn,
const TargetMLTransformInfo &targetMLTransInfo) {
assert(!getLoweringConfig(genericOp) &&
"expected lowering_config is not set");
// First, try to apply the transform dialect strategy, if defined.
if (succeeded(setTransformStrategyRootConfig(
entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo))) {
return success();
}

if (succeeded(setTransposeLikeOpRootConfig(
entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@
#include "iree/compiler/Codegen/LLVMCPU/KernelDispatch.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
Expand All @@ -38,22 +28,7 @@ class LLVMCPUSelectLoweringStrategyPass
LLVMCPUSelectLoweringStrategyPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
// TODO(qedawkins): Once TransformStrategies is deprecated, drop the
// unnecessary dialect registrations.
// clang-format off
registry.insert<IREE::Codegen::IREECodegenDialect,
IREE::HAL::HALDialect,
IREE::LinalgExt::IREELinalgExtDialect,
bufferization::BufferizationDialect,
linalg::LinalgDialect,
LLVM::LLVMDialect,
pdl::PDLDialect,
pdl_interp::PDLInterpDialect,
scf::SCFDialect,
tensor::TensorDialect,
transform::TransformDialect,
vector::VectorDialect>();
// clang-format on
registry.insert<IREE::Codegen::IREECodegenDialect>();
}

void runOnOperation() override;
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface",
"//compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions:LLVMGPUExtensions",
"//compiler/src/iree/compiler/Codegen/LLVMGPU/Utils",
"//compiler/src/iree/compiler/Codegen/TransformStrategies/GPU",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ iree_cc_library(
iree::compiler::Codegen::Interfaces::UKernelOpInterface
iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
iree::compiler::Codegen::LLVMGPU::Utils
iree::compiler::Codegen::TransformStrategies::GPU
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Codegen::Utils::VectorOpUtils
Expand Down
62 changes: 0 additions & 62 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h"
#include "iree/compiler/Codegen/Interfaces/UKernelOpInterface.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
Expand Down Expand Up @@ -63,11 +62,6 @@ llvm::cl::opt<bool> clGPUEnableVectorDistribution(
llvm::cl::desc("enable the usage of the vector distribution pipeline"),
llvm::cl::init(true));

llvm::cl::opt<bool> clGPUEnableTransformDialectJit(
"iree-codegen-llvmgpu-enable-transform-dialect-jit",
llvm::cl::desc("enable the usage of the transform dialect JIT"),
llvm::cl::init(false));

/// Flag to force using WMMA tensorcore operations.
llvm::cl::opt<bool>
clGPUUseWMMA("iree-codegen-llvmgpu-use-wmma",
Expand Down Expand Up @@ -1392,57 +1386,6 @@ static LogicalResult setRootDefaultConfig(IREE::GPU::TargetAttr target,
preferredSubgroupSize);
}

//====---------------------------------------------------------------------===//
// Transform Dialect Pipeline Configuration
//====---------------------------------------------------------------------===//

/// Set configuration for transform dialect based strategies.
static LogicalResult
setTransformDialectConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint, Operation *op) {
if (!clGPUEnableTransformDialectJit) {
return failure();
}

auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
entryPoint.getContext(), CodeGenPipeline::TransformDialectCodegen);

// TODO: unify the target informations into one structure.
iree_compiler::gpu::GPUModel gpuModel;
gpuModel.hasWarpShuffle = target.supportsSubgroupShuffle();
gpuModel.hasTF32TensorCore = target.supportsTF32InputMMAOps();
gpuModel.hasMmaSync = target.supportsSyncMMAOps();

// Populates a subset of the fragment combinations supported in MLIR lowerings
// to NVVM (which is itself a subset of what LLVM supports) based on what the
// pipeline currently supports.
// TODO: avoid hard coding this and populate based on hardware capabilities.
// TODO: add missing supported configs once the pipeline supports it.
MLIRContext *context = entryPoint.getContext();
Type f32Type = Float32Type::get(context);
Type f16Type = Float16Type::get(context);

iree_compiler::gpu::MMAConfig f16f32AccConfig = {
/*m=*/16, /*n=*/16, /*k=*/16,
/*aType=*/f16Type, /*bType=*/f16Type, /*cType=*/f32Type};
iree_compiler::gpu::MMAConfig f16f16AccConfig = {
/*m=*/16, /*n=*/16, /*k=*/16,
/*aType=*/f16Type, /*bType=*/f16Type, /*cType=*/f16Type};
gpuModel.supportedWMMAConfigs = {f16f32AccConfig, f16f16AccConfig};

if (target.supportsTF32InputMMAOps()) {
iree_compiler::gpu::MMAConfig tf32WmmaConfig = {
/*m=*/16, /*n=*/16, /*k=*/8,
/*aType=*/f32Type, /*bType=*/f32Type, /*cType=*/f32Type};
gpuModel.supportedWMMAConfigs.push_back(tf32WmmaConfig);
}

if (failed(iree_compiler::gpu::matchAndSetTransformStrategy(entryPoint, op,
gpuModel)))
return failure();
return setTranslationInfo(entryPoint, translationInfo);
}

static bool isMatvecLike(linalg::LinalgOp linalgOp) {
if (linalgOp.getNumParallelLoops() != 2)
return false;
Expand Down Expand Up @@ -2015,11 +1958,6 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
computeOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
llvm::dbgs() << "\n";
});
// First try to see if there is a transform dialect configuration existing.
if (succeeded(setTransformDialectConfig(target, entryPointFn, computeOp))) {
LDBG("Transform Dialect Config");
return success();
}
if (succeeded(setDataTiledMultiMmaLoweringConfig(target, entryPointFn,
computeOp))) {
LDBG("Tile and fuse data tiled multi_mma config");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,6 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/LLVMGPU/KernelConfig.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
Expand All @@ -41,24 +29,8 @@ class LLVMGPUSelectLoweringStrategyPass final
LLVMGPUSelectLoweringStrategyPass>::LLVMGPUSelectLoweringStrategyPassBase;

void getDependentDialects(DialectRegistry &registry) const override {
// TODO(qedawkins): Once TransformStrategies is deprecated, drop the
// unnecessary dialect registrations.
// clang-format off
registry
.insert<IREE::Codegen::IREECodegenDialect,
IREE::GPU::IREEGPUDialect,
IREE::HAL::HALDialect,
IREE::LinalgExt::IREELinalgExtDialect,
linalg::LinalgDialect,
gpu::GPUDialect,
nvgpu::NVGPUDialect,
pdl::PDLDialect,
pdl_interp::PDLInterpDialect,
scf::SCFDialect,
tensor::TensorDialect,
transform::TransformDialect,
vector::VectorDialect>();
// clang-format on
.insert<IREE::Codegen::IREECodegenDialect, IREE::GPU::IREEGPUDialect>();
}

void runOnOperation() override;
Expand Down
7 changes: 1 addition & 6 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,8 @@ iree_lit_test_suite(
"nvvm_mma_sync_pipeline_test.mlir",
"reduction_pipeline_cuda.mlir",
"reduction_pipeline_rocm.mlir",
"reduction_pipeline_transform_cuda.mlir",
"reduction_pipeline_transform_rocm.mlir",
"reduction_pipeline_softmax_rocm.mlir",
"rocdl_pipeline_test.mlir",
"set_transform_strategy_batch_matmul.mlir",
"set_transform_strategy_convolution.mlir",
"set_transform_strategy_matmul.mlir",
"set_transform_strategy_pad.mlir",
"illegal_configuration.mlir",
"legalize.mlir",
"linalg_transform.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,8 @@ iree_lit_test_suite(
"promote_matmul_to_fit_mma.mlir"
"reduction_pipeline_cuda.mlir"
"reduction_pipeline_rocm.mlir"
"reduction_pipeline_transform_cuda.mlir"
"reduction_pipeline_transform_rocm.mlir"
"reduction_pipeline_softmax_rocm.mlir"
"rocdl_pipeline_test.mlir"
"set_transform_strategy_batch_matmul.mlir"
"set_transform_strategy_convolution.mlir"
"set_transform_strategy_matmul.mlir"
"set_transform_strategy_pad.mlir"
"tensor_pad.mlir"
"tensorcore_vectorization.mlir"
"transform_dialect_bufferize.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false --iree-gpu-test-target=sm_60 %s | FileCheck %s
// RUN: --iree-gpu-test-target=sm_60 %s | FileCheck %s
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false --iree-gpu-test-target=sm_80 %s | FileCheck %s --check-prefix=SM80
// RUN: --iree-gpu-test-target=sm_80 %s | FileCheck %s --check-prefix=SM80

// Transform dialect attributes are tested separately.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
// RUN: iree-opt %s --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline, func.func(iree-llvmgpu-lower-executable-target))" \
// RUN: --iree-gpu-test-target=sm_60 \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
// RUN: --iree-codegen-transform-dialect-library=%p/transform_dialect_codegen_bufferize_spec.mlir@__transform_main | \
// RUN: FileCheck %s

// RUN: iree-opt %s --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline, func.func(iree-llvmgpu-lower-executable-target))" \
// RUN: --iree-gpu-test-target=sm_60 \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
// RUN: --iree-codegen-transform-dialect-library=%p/transform_dialect_codegen_foreach_to_gpu_spec.mlir@__transform_main | \
// RUN: FileCheck %s --check-prefix=FOREACH-TO-GPU

Expand Down
Loading

0 comments on commit 012f8a6

Please sign in to comment.