Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Substitute zdnn calls for stick/unstick late, after most ZLow optimizations are performed #2812

Merged
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
939c53e
first version, can generate an op but it vanishes in the canonicalize…
AlexandreEichenberger Apr 4, 2024
4df21b9
update
AlexandreEichenberger Apr 4, 2024
7f239bd
prefetch are now guarded
AlexandreEichenberger Apr 4, 2024
b6016fb
version that works
AlexandreEichenberger Apr 5, 2024
01ba9a6
merge
AlexandreEichenberger Apr 5, 2024
727db44
remove test code
AlexandreEichenberger Apr 5, 2024
5a1c8cf
remove test code
AlexandreEichenberger Apr 5, 2024
475bf26
format
AlexandreEichenberger Apr 5, 2024
b8414cc
initial
AlexandreEichenberger Apr 9, 2024
0dab0c1
update
AlexandreEichenberger Apr 9, 2024
581c5bf
first version
AlexandreEichenberger Apr 9, 2024
56bf221
try prefetch before simd conversion
AlexandreEichenberger Apr 9, 2024
768c6c5
added prefetch of input
AlexandreEichenberger Apr 9, 2024
768a4ec
added prefetch of input as read
AlexandreEichenberger Apr 9, 2024
ed3fa42
prefetch fix
AlexandreEichenberger Apr 9, 2024
ee72930
loop ahead
AlexandreEichenberger Apr 9, 2024
1761c5e
version with older prefetch scheme
AlexandreEichenberger Apr 10, 2024
ba0d183
gen prefetch like zDNN
AlexandreEichenberger Apr 11, 2024
1b18ecf
settled on prefetch like zDNN with N=M=1, dist=0, locality=1
AlexandreEichenberger Apr 11, 2024
8e1f487
prefetch without dist for unstick
AlexandreEichenberger Apr 11, 2024
66ee7bc
added stick without buffer
AlexandreEichenberger Apr 12, 2024
b46f7d6
added prefetch to no buffer
AlexandreEichenberger Apr 12, 2024
e9dcc4e
manually unrolled
AlexandreEichenberger Apr 12, 2024
06bd764
fix prefetch error
AlexandreEichenberger Apr 12, 2024
a993510
multi prefetch
AlexandreEichenberger Apr 12, 2024
b5546e7
update
AlexandreEichenberger Apr 23, 2024
f847139
update
AlexandreEichenberger Apr 26, 2024
bca7aa5
removed trivial reshape
AlexandreEichenberger Apr 26, 2024
bd1b392
no buffer unstick, partial impl
AlexandreEichenberger Apr 29, 2024
c05c98f
handling partial blocks
AlexandreEichenberger Apr 29, 2024
6d1c3ed
unrolled
AlexandreEichenberger Apr 29, 2024
83c6d3c
added prefetch
AlexandreEichenberger Apr 29, 2024
35d642f
fixed issue with dyn shape
AlexandreEichenberger Apr 29, 2024
9e7dc95
reverted simple prefetch computation
AlexandreEichenberger Apr 30, 2024
35f1511
cleaned up version with no buffers
AlexandreEichenberger Apr 30, 2024
3dfc49e
update
AlexandreEichenberger May 1, 2024
15022f0
initial
AlexandreEichenberger May 1, 2024
7a65395
first try to gen code
AlexandreEichenberger May 1, 2024
234fc65
disable parallel
AlexandreEichenberger May 1, 2024
df11f6a
added parallel back
AlexandreEichenberger May 1, 2024
c2692cf
removed all but the late expansion of stick/unsick
AlexandreEichenberger May 2, 2024
4335049
format
AlexandreEichenberger May 2, 2024
945f873
spelling
AlexandreEichenberger May 6, 2024
ed37a4e
response to comments
AlexandreEichenberger May 7, 2024
adedee8
Merge branch 'main' into opt-su-v5-3-late
AlexandreEichenberger May 9, 2024
3ae8ac5
Merge branch 'main' into opt-su-v5-3-late
AlexandreEichenberger May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
Expand Down Expand Up @@ -215,6 +216,10 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
addKrnlToAffinePasses(pm);
// Optimizations at ZLow that needs affine map in MemRef.
pm.addPass(zlow::createZLowRewritePass());
// Late generation of code for stick/unstick, needed to be after a
// ZLowRewrite pass.
if (nnpaEnableCompilerStickUnstick)
pm.addPass(zlow::createZLowStickExpansionPass(enableParallel));
pm.addPass(mlir::createCanonicalizerPass());
// Normalize MemRefs.
normalizeMemRefsPasses(pm);
Expand All @@ -223,6 +228,11 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
addKrnlToAffinePasses(pm);
// Optimizations at ZLow after normalizing MemRefs.
pm.addPass(zlow::createZLowRewritePass());
// The createZLowStickExpansion pass may create parallel constructs,
// they need to be handled here.
if (nnpaEnableCompilerStickUnstick && enableParallel)
pm.addPass(mlir::createConvertSCFToOpenMPPass());

pm.addPass(mlir::createCanonicalizerPass());
// Constant folding for std.alloc.
pm.addNestedPass<func::FuncOp>(onnx_mlir::createFoldStdAllocPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -981,3 +981,11 @@ bool isSuitableForZDNN<ONNXBatchNormalizationInferenceModeOp>(

return true;
}

/// Check legality for ONNXReshapeOp.
template <>
bool isSuitableForZDNN<ONNXReshapeOp>(
ONNXReshapeOp op, const DimAnalysis *dimAnalysis) {
// Noop Reshape is suitable for zAIU as this pass removes such reshape ops.
return isIdentityReshape(op, dimAnalysis);
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "src/Dialect/ONNX/ElementsAttr/WideNum.hpp"
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp"
#include "src/Support/TypeUtilities.hpp"
Expand Down Expand Up @@ -467,6 +468,31 @@ class AddSubWithRHSZeroExpandPattern : public OpRewritePattern<OP_TYPE> {
}
};

class RemoveReshapeWithIdentityPattern
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's clever to do it here. Thanks!

: public OpRewritePattern<ONNXReshapeOp> {
public:
using OpRewritePattern<ONNXReshapeOp>::OpRewritePattern;

DimAnalysis *dimAnalysis;

RemoveReshapeWithIdentityPattern(
MLIRContext *context, DimAnalysis *dimAnalysis)
: OpRewritePattern<ONNXReshapeOp>(context, 1001),
dimAnalysis(dimAnalysis) {}

LogicalResult matchAndRewrite(
ONNXReshapeOp reshapeOp, PatternRewriter &rewriter) const override {
if (!isIdentityReshape(reshapeOp, dimAnalysis))
return failure();

// Rewrite
Operation *op = reshapeOp.getOperation();
Value data = reshapeOp.getData();
rewriter.replaceOp(op, data);
return success();
}
};

//===----------------------------------------------------------------------===//
// Rewrite ONNX ops to ZHigh ops and ONNX ops for ZHigh.
//===----------------------------------------------------------------------===//
Expand All @@ -482,6 +508,8 @@ void getRewriteONNXForZHighPatterns(
patterns.getContext(), dimAnalysis);
patterns.insert<AddSubWithRHSZeroExpandPattern<ONNXSubOp>>(
patterns.getContext(), dimAnalysis);
patterns.insert<RemoveReshapeWithIdentityPattern>(
patterns.getContext(), dimAnalysis);
}

void getRewriteONNXForZHighDynamicallyLegal(
Expand Down Expand Up @@ -643,6 +671,15 @@ void getRewriteONNXForZHighDynamicallyLegal(
return isSuitableForZDNN<ONNXConvOp>(op) ||
!canInferencePadsForNNPAConv(op);
});
#if 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you will remove this in the final version.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tx

addDynamicallyLegalOpFor<ONNXReshapeOp>(target, dimAnalysis,
[](ONNXReshapeOp op, const DimAnalysis *dimAnalysis) {
// Get rid of identity reshape here, as it impacts stick/unstick.
// So all reshape are legal, unless it is an identity reshape, in which
// case there is a rule here to remove it.
return !isIdentityReshape(op, dimAnalysis);
});
#endif
}

struct RewriteONNXForZHighPass
Expand Down
Loading
Loading