Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support to generate OpenMP parallel construct clauses, at this time for num_threads and proc_bind #2944

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion docs/Dialects/krnl.md
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,35 @@ Typically it is used for optional arguments used in KrnlCallop.
| :----: | ----------- |
| `none_val` | none type

### `krnl.parallel_clause` (KrnlParallelClauseOp)

_Attach OpenMP clauses to an index varialbe_


Syntax:

```
operation ::= `krnl.parallel_clause` `(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)?
attr-dict `:` type($parallel_loop_index)
```

Attach OpenMP clauses to an index variable. That index variable
is used to uniquely associate a parallel loop with its clauses.

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>proc_bind</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `parallel_loop_index` | index
| `num_threads` | 32-bit signless integer

### `krnl.parallel` (KrnlParallelOp)

_Mark Krnl loops as parallel loops_
Expand All @@ -937,23 +966,38 @@ _Mark Krnl loops as parallel loops_
Syntax:

```
operation ::= `krnl.parallel` `(` $loops `)` attr-dict `:` type($loops)
operation ::= `krnl.parallel` `(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops)
```

Parallelize the specified loops. When multiple loop specifiers are passed
as parameters, there loops can be parallelized as a collapsed loop.
krnl.parallel should be placed as the last operator before krnl.iterate,
Since we do not want to parallelize the loop until we interpret krnl.block,
krnl.permute and krnl.unroll.

Optionally, a value may specifiy the number of threads requested for the
parallel loop. A proc_bind string may also be specified; valid values are
"primary", "close", or "spread". Default values are used when not specified.

```
krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop
```

Traits: `AttrSizedOperandSegments`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>proc_bind</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `loops` | variadic of any type
| `num_threads` | 32-bit signless integer

### `krnl.permute` (KrnlPermuteOp)

Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ void addKrnlToLLVMPasses(
// The alloca_scope ops are somewhat fragile; canonicalize remove them when
// redundant, which helps reliability of the compilation of these ops.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(onnx_mlir::createProcessKrnlParallelClausePass());
}

// The pass below is needed for subview and collapseShape.. Unfortunately,
Expand Down
22 changes: 22 additions & 0 deletions src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,10 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
<< parallelOp << "\n");
// ToFix handle multiple parallel loop
ValueRange loopRefs = parallelOp.getLoops();
Value numThreads = parallelOp.getNumThreads();
StringAttr procBind = parallelOp.getProcBindAttr();
bool needParallelClause =
numThreads || (procBind && procBind.getValue().size() > 0);

// Obtain the the reference the loop that needs to be parallelized
for (Value loopRef : loopRefs) {
Expand Down Expand Up @@ -778,6 +782,23 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
parallelLoop.getRegion().takeBody(loopToParallel.getRegion());
Operation *yieldOp = &parallelLoop.getBody()->back();
yieldOp->setOperands(reducedValues);
if (needParallelClause) {
// Use clause only for the first one (expected the outermost one).
// Ideally, we would generate here a single, multi-dimensional
// AffineParallelOp, and we would not need to reset the flag.
needParallelClause = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this condition used afterwards?

Copy link
Collaborator Author

@AlexandreEichenberger AlexandreEichenberger Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when we need the parallel clause, then only the first loop iteration in the for(Value loopRef : loopRefs) will execute the addition of the KrnlParallelClauseOp

// Currently approach: insert after yield and then move before it.
PatternRewriter::InsertionGuard insertGuard(builder);
builder.setInsertionPointAfter(yieldOp);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't setInsertionPoint(yieldOp) work for inserting just before yieldOp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reasons, if I don't have the "moveBefore", it get's me this error

flt_orig_model.mlir:18:3: error: operand #0 does not dominate this use
  krnl.iterate(%loop_block) with (%0 -> %arg1 = 0 to 16384){
  ^
flt_orig_model.mlir:18:3: note: see current operation: "krnl.parallel_clause"(%arg1, %0) {proc_bind = "spread"} : (index, i32) -> ()
flt_orig_model.mlir:18:3: note: operand defined as a block argument (block #0 in a child region)

Strangely, with the moveBefore(yieldOp), I get the same result with the setInsertionPointAfter or setInsertionPoint.
There is something fragile about the lowering of Krnl to Affine with respect to "movable".

Since it works as is, I prefer to leave it that way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conversion pass traverses the IR by ourselves. We manipulate graph directly. That might be the reason why it is fragile.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, then it's fine with the current way.

// Get induction variable.
ValueRange optionalLoopIndices = parallelLoop.getIVs();
assert(optionalLoopIndices.size() >= 1 &&
"expected at least one loop index");
Value parallelLoopIndex = optionalLoopIndices[0];
Operation *newOp = opBuilder.create<KrnlParallelClauseOp>(
loc, parallelLoopIndex, numThreads, procBind);
newOp->moveBefore(yieldOp);
}
// Replace the affine.forOp with affine.parallelOp in loopRefToTop
loopRefToOp[loopRef] = parallelLoop;
loopToParallel.erase();
Expand Down Expand Up @@ -975,6 +996,7 @@ void ConvertKrnlToAffinePass::runOnOperation() {
target.addIllegalOp<KrnlCopyToBufferOp>();
target.addIllegalOp<KrnlCopyFromBufferOp>();
target.addIllegalOp<KrnlPrefetchOp>();
target.addLegalOp<KrnlParallelClauseOp>();
target.addLegalOp<AffineYieldOp>();
target.addLegalOp<AffineLoadOp>();
target.addLegalOp<AffineStoreOp>();
Expand Down
21 changes: 20 additions & 1 deletion src/Dialect/Krnl/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,26 @@ ValueRange KrnlBuilder::getInductionVarValue(ValueRange loops) const {
}

void KrnlBuilder::parallel(ValueRange loops) const {
b().template create<KrnlParallelOp>(loc(), loops);
Value noneValue;
StringAttr noneStrAttr;
b().template create<KrnlParallelOp>(loc(), loops, noneValue, noneStrAttr);
}

void KrnlBuilder::parallel(
ValueRange loops, Value numThreads, StringAttr procBind) const {
if (procBind.getValue().size() > 0) {
std::string str = procBind.getValue().str();
assert((str == "primary" || str == "close" || str == "spread") &&
"expected primary, close, or spread for proc_bind");
}
b().template create<KrnlParallelOp>(loc(), loops, numThreads, procBind);
}

void KrnlBuilder::parallelClause(
Value parallelLoopIndex, Value numThreads, StringAttr procBind) const {
// No need to check procBind as its value are derived from parallel(...).
b().template create<KrnlParallelClauseOp>(
loc(), parallelLoopIndex, numThreads, procBind);
}

void KrnlBuilder::iterate(ValueRange originalLoops, ValueRange optimizedLoops,
Expand Down
4 changes: 4 additions & 0 deletions src/Dialect/Krnl/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ struct KrnlBuilder : public DialectBuilder {
void permute(mlir::ValueRange loops, mlir::ArrayRef<int64_t> map) const;
mlir::ValueRange getInductionVarValue(mlir::ValueRange loops) const;
void parallel(mlir::ValueRange loops) const;
void parallel(mlir::ValueRange loops, mlir::Value numThreads,
mlir::StringAttr procBind) const;
void parallelClause(mlir::Value parallelLoopIndex, mlir::Value numThreads,
mlir::StringAttr procBind) const;

// Iterate over optimized loops given the original loops, lbs and ubs. Lambda
// function implement the body of the loop, and receive a KRNL builder and the
Expand Down
30 changes: 27 additions & 3 deletions src/Dialect/Krnl/Krnl.td
Original file line number Diff line number Diff line change
Expand Up @@ -514,23 +514,47 @@ def KrnlUnrollOp : Op<Krnl_Dialect, "unroll"> {
}];
}

def KrnlParallelOp : Op<Krnl_Dialect, "parallel"> {
def KrnlParallelOp : Op<Krnl_Dialect, "parallel", [AttrSizedOperandSegments]> {
let summary = "Mark Krnl loops as parallel loops";
let description = [{
Parallelize the specified loops. When multiple loop specifiers are passed
as parameters, there loops can be parallelized as a collapsed loop.
krnl.parallel should be placed as the last operator before krnl.iterate,
Since we do not want to parallelize the loop until we interpret krnl.block,
krnl.permute and krnl.unroll.

Optionally, a value may specifiy the number of threads requested for the
parallel loop. A proc_bind string may also be specified; valid values are
"primary", "close", or "spread". Default values are used when not specified.

```
krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop
```
}];

let arguments = (ins Variadic<AnyType>:$loops);
let arguments = (ins Variadic<AnyType>:$loops,
Optional<I32>:$num_threads,
OptionalAttr<StrAttr>:$proc_bind);

let assemblyFormat = [{
`(` $loops `)` attr-dict `:` type($loops)
`(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops)
}];
}

def KrnlParallelClauseOp : Op<Krnl_Dialect, "parallel_clause"> {
let summary = "Attach OpenMP clauses to an index varialbe";
let description = [{
Attach OpenMP clauses to an index variable. That index variable
is used to uniquely associate a parallel loop with its clauses.
}];

let arguments = (ins Index: $parallel_loop_index,
Optional<I32>:$num_threads,
OptionalAttr<StrAttr>:$proc_bind);

let assemblyFormat = [{
`(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)?
attr-dict `:` type($parallel_loop_index)
}];
}

Expand Down
1 change: 1 addition & 0 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void configureOnnxToKrnlLoweringPass(bool reportOnParallel,
bool parallelIsEnabled, std::string specificParallelOps, bool reportOnSimd,
bool simdIsEnabled);
std::unique_ptr<mlir::Pass> createProcessScfParallelPrivatePass();
std::unique_ptr<mlir::Pass> createProcessKrnlParallelClausePass();

#ifdef ONNX_MLIR_ENABLE_STABLEHLO
/// Add pass for lowering to Stablehlo IR.
Expand Down
4 changes: 4 additions & 0 deletions src/Tools/onnx-mlir-opt/RegisterPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ void registerOMPasses(int optLevel) {
return createProcessScfParallelPrivatePass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return createProcessKrnlParallelClausePass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return krnl::createConvertSeqToMemrefPass();
});
Expand Down
4 changes: 3 additions & 1 deletion src/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ add_onnx_mlir_library(OMLowerKrnlRegion
MLIRTransformUtils
)

add_onnx_mlir_library(OMScfParallelPrivateRegion
add_onnx_mlir_library(OMScfParallelPrivateRegion
ProcessScfParallelPrivate.cpp
ProcessKrnlParallelClause.cpp

LINK_LIBS PUBLIC
OMSupport
MLIRTransformUtils
MLIROpenMPToLLVM
)

add_onnx_mlir_library(OMInstrument
Expand Down
Loading
Loading