Skip to content

Commit

Permalink
Fix bug where the shader cache is not used properly.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 16, 2024
1 parent d53469e commit 67887cd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
2 changes: 2 additions & 0 deletions lib/nnc/mfa/v2/AttentionDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ bool AttentionDescriptor::operator==(const AttentionDescriptor& rhs) const {
batchDimension == rhs.batchDimension &&
Hq == rhs.Hq &&
Hk == rhs.Hk &&
scale == rhs.scale &&
type == rhs.type &&
(lowPrecisionInputs == rhs.lowPrecisionInputs) &&
(lowPrecisionIntermediates == rhs.lowPrecisionIntermediates) &&
simd_all(leadingDimensions.value_or(simd::uint4(UINT32_MAX)) == rhs.leadingDimensions.value_or(simd::uint4(UINT32_MAX))) &&
Expand Down
3 changes: 2 additions & 1 deletion lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ bool AttentionKernelDescriptor::operator==(const AttentionKernelDescriptor& rhs)
registerPrecisions == rhs.registerPrecisions &&
transposeState == rhs.transposeState &&
leadingDimensions == rhs.leadingDimensions &&
type == rhs.type;
type == rhs.type &&;
scale == rhs.scale;
}

std::size_t std::hash<AttentionKernelDescriptor>::operator()(const AttentionKernelDescriptor& hash) const noexcept {
Expand Down
50 changes: 46 additions & 4 deletions lib/nnc/mfa/v2/AttentionOperand.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,52 @@ struct AttentionOperands {
constexpr AttentionOperands() : bitmap(0) {}

constexpr bool operator==(const AttentionOperands<Value>& rhs) const {
return Q == rhs.Q && K == rhs.K && S == rhs.S && P == rhs.P && V == rhs.V && O == rhs.O &&
L == rhs.L && D == rhs.D &&
dO == rhs.dO && dV == rhs.dV && dP == rhs.dP && dS == rhs.dS && dK == rhs.dK && dQ == rhs.dQ &&
bitmap == bitmap;
if (bitmap != rhs.bitmap) {
return false;
}
if (bitmap & (1 << (AttentionOperand::Q)) && Q != rhs.Q) {
return false;
}
if (bitmap & (1 << (AttentionOperand::K)) && K != rhs.K) {
return false;
}
if (bitmap & (1 << (AttentionOperand::S)) && S != rhs.S) {
return false;
}
if (bitmap & (1 << (AttentionOperand::P)) && P != rhs.P) {
return false;
}
if (bitmap & (1 << (AttentionOperand::V)) && V != rhs.V) {
return false;
}
if (bitmap & (1 << (AttentionOperand::O)) && O != rhs.O) {
return false;
}
if (bitmap & (1 << (AttentionOperand::L)) && L != rhs.L) {
return false;
}
if (bitmap & (1 << (AttentionOperand::D)) && D != rhs.D) {
return false;
}
if (bitmap & (1 << (AttentionOperand::dO)) && dO != rhs.dO) {
return false;
}
if (bitmap & (1 << (AttentionOperand::dV)) && dV != rhs.dV) {
return false;
}
if (bitmap & (1 << (AttentionOperand::dP)) && dP != rhs.dP) {
return false;
}
if (bitmap & (1 << (AttentionOperand::dS)) && dS != rhs.dS) {
return false;
}
if (bitmap & (1 << (AttentionOperand::dK)) && dK != rhs.dK) {
return false;
}
if (bitmap & (1 << (AttentionOperand::dQ)) && dQ != rhs.dQ) {
return false;
}
return true;
}

class Reference {
Expand Down

0 comments on commit 67887cd

Please sign in to comment.