Skip to content

Commit

Permalink
Integrated into mfa call flow.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 13, 2024
1 parent 290be87 commit 6737683
Show file tree
Hide file tree
Showing 9 changed files with 328 additions and 49 deletions.
14 changes: 10 additions & 4 deletions bin/nnc/square_attention_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,21 +347,26 @@ void validateProblemSize(int sequenceDimension, int headDimension)
NS::SharedPtr<MTL::Buffer> bufferQ = NS::TransferPtr(device->newBuffer(network.Q.data(), sizeof(float) * sequenceDimension * headDimension, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked));
NS::SharedPtr<MTL::Buffer> bufferK = NS::TransferPtr(device->newBuffer(network.K.data(), sizeof(float) * sequenceDimension * headDimension, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked));
NS::SharedPtr<MTL::Buffer> bufferV = NS::TransferPtr(device->newBuffer(network.V.data(), sizeof(float) * sequenceDimension * headDimension, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked));
float* resultL = (float*)ccmalloc(sizeof(float) * sequenceDimension);
resultL[0] = NAN;
NS::SharedPtr<MTL::Buffer> bufferL = NS::TransferPtr(device->newBuffer(resultL, sizeof(float) * sequenceDimension, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked));
float* resultO = (float*)ccmalloc(sizeof(float) * sequenceDimension * headDimension);
resultO[0] = NAN;
NS::SharedPtr<MTL::Buffer> bufferO = NS::TransferPtr(device->newBuffer(resultO, sizeof(float) * sequenceDimension * headDimension, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked));
NS::SharedPtr<MTL::CommandBuffer> commandBuffer = NS::TransferPtr(queue->commandBuffer());
NS::SharedPtr<MTL::ComputeCommandEncoder> encoder = NS::TransferPtr(commandBuffer->computeCommandEncoder());
encoder->setComputePipelineState(pipelineValue->pipeline.get());
encoder->setThreadgroupMemoryLength(pipelineValue->kernel->threadgroupMemoryAllocation, 0);
encoder->setBuffer(bufferQ.get(), 0, 0);
encoder->setBuffer(bufferK.get(), 0, 1);
encoder->setBuffer(bufferV.get(), 0, 2);
encoder->setBuffer(bufferO.get(), 0, 3);
encoder->setBuffer(bufferQ.get(), 0, AttentionOperand(AttentionOperand::Q).bufferIndex());
encoder->setBuffer(bufferK.get(), 0, AttentionOperand(AttentionOperand::K).bufferIndex());
encoder->setBuffer(bufferV.get(), 0, AttentionOperand(AttentionOperand::V).bufferIndex());
encoder->setBuffer(bufferO.get(), 0, AttentionOperand(AttentionOperand::O).bufferIndex());
encoder->setBuffer(bufferL.get(), 0, AttentionOperand(AttentionOperand::L).bufferIndex());
encoder->useResource(bufferQ.get(), MTL::ResourceUsageRead);
encoder->useResource(bufferK.get(), MTL::ResourceUsageRead);
encoder->useResource(bufferV.get(), MTL::ResourceUsageRead);
encoder->useResource(bufferO.get(), MTL::ResourceUsageWrite);
encoder->useResource(bufferL.get(), MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
auto ceilDivide =
[=](int64_t target, uint16_t granularity) -> int64_t {
return (target + int64_t(granularity) - 1) / int64_t(granularity);
Expand Down Expand Up @@ -412,6 +417,7 @@ void validateProblemSize(int sequenceDimension, int headDimension)
} else {
check(O, resultO, 2e-5);
}
ccfree(resultL);
ccfree(resultO);
}
}
Expand Down
151 changes: 149 additions & 2 deletions lib/nnc/mfa/ccv_nnc_mfa_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ using namespace ccv::nnc;

#include <string>

#include "v2/ShaderCache.hpp"
#include "v2/AttentionKernel.hpp"
#include "v2/AttentionKernelDescriptor.hpp"
#include "v2/AttentionDescriptor.hpp"

// MARK: - C

void ccv_nnc_mfa_prepare_attention(mfa::context* context, ccv_nnc_mfa_attention_params_t params)
Expand Down Expand Up @@ -43,6 +48,148 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
CCV_NNC_MFA_PRECONDITION(false);
break;
}
if (!params.masked && params.Hq == params.Hk) {
simd::ushort2 num_batch_dims(0);
simd::uint2 batch_sizes(1);
if (params.batched) {
for (uint16_t operand = 0; operand < 2; ++operand) {
uint32_t* batch_dims;
if (operand == 0) {
batch_dims = params.batch_dims_q;
} else if (operand == 1) {
batch_dims = params.batch_dims_mask;
}

for (int i = 0; i < CCV_NNC_MAX_DIM_ALLOC; ++i) {
if (batch_dims[i] == 0) {
break;
}
num_batch_dims[operand] += 1;
batch_sizes[operand] *= batch_dims[i];
}

bool dims_match_q = true;
if (num_batch_dims[0] != num_batch_dims[operand]) {
dims_match_q = false;
} else if (batch_sizes[0] != batch_sizes[operand]) {
dims_match_q = false;
} else {
for (int i = 0; i < CCV_NNC_MAX_DIM_ALLOC; ++i) {
if (params.batch_dims_q[i] != batch_dims[i]) {
dims_match_q = false;
}
}
}

if (!dims_match_q) {
CCV_NNC_MFA_PRECONDITION(batch_sizes[operand] == 1);
}
}
}
AttentionDescriptor attentionDesc;
attentionDesc.lowPrecisionInputs = (params.data_type == MTL::DataTypeHalf) ? true : false;
attentionDesc.lowPrecisionIntermediates = false;
attentionDesc.matrixDimensions[0] = hash.R;
attentionDesc.matrixDimensions[1] = hash.C;
attentionDesc.matrixDimensions[2] = hash.D;
attentionDesc.transposeState[0] = false;
attentionDesc.transposeState[1] = false;
attentionDesc.transposeState[2] = false;
attentionDesc.transposeState[3] = false;
attentionDesc.Hq = hash.Hq;
attentionDesc.batchDimension = batch_sizes[0];
attentionDesc.type = AttentionKernelType::forward;
attentionDesc.scale = hash.alpha;
if (params.batched) {
attentionDesc.batchStrides[AttentionOperand::Q] = hash.R * hash.D * hash.Hq;
attentionDesc.batchStrides[AttentionOperand::K] = hash.C * hash.D * hash.Hk;
attentionDesc.batchStrides[AttentionOperand::V] = hash.C * hash.D * hash.Hk;
attentionDesc.batchStrides[AttentionOperand::O] = hash.R * hash.D * hash.Hq;
}
simd::uint4 leadingDimensions;
leadingDimensions[0] = hash.Hq * hash.D;
leadingDimensions[1] = hash.Hk * hash.D;
leadingDimensions[2] = hash.Hk * hash.D;
leadingDimensions[3] = hash.Hq * hash.D;
attentionDesc.leadingDimensions = leadingDimensions;
auto pool = NS::AutoreleasePool::alloc()->init();
auto &shaderCache = context->v2_cache;
DeviceProperties dprops = DeviceProperties();
auto pipelineValue = shaderCache.findKernel<AttentionKernel, AttentionDescriptor, AttentionKernelDescriptor>(attentionDesc, context->device.get(), dprops);
pool->drain();
auto kernel = pipelineValue->kernel;
auto pipeline = pipelineValue->pipeline;

// Allocate a new command.
encoder->setComputePipelineState(pipeline.get());
encoder->setThreadgroupMemoryLength(kernel->threadgroupMemoryAllocation, 0);

// Bind the function arguments.
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageRead);
encoder->useResource(tensors[2], MTL::ResourceUsageRead);
auto scratch_size = sizeof(float) * hash.R * hash.Hq;
if (attentionDesc.lowPrecisionInputs) {
// Need scratch space for FP16 output.
scratch_size += sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension;
}
auto scratch = context->request_scratch(scratch_size);
if (attentionDesc.lowPrecisionInputs) {
encoder->useResource(scratch, MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
} else {
encoder->useResource(tensors[3], MTL::ResourceUsageWrite);
encoder->useResource(scratch, MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
}

encoder->setBuffer(tensors[0], tensor_offsets[0], AttentionOperand(AttentionOperand::Q).bufferIndex());
encoder->setBuffer(tensors[1], tensor_offsets[1], AttentionOperand(AttentionOperand::K).bufferIndex());
encoder->setBuffer(tensors[2], tensor_offsets[2], AttentionOperand(AttentionOperand::V).bufferIndex());
if (attentionDesc.lowPrecisionInputs) {
encoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::O).bufferIndex());
encoder->setBuffer(scratch, hash.R * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::L).bufferIndex());
} else {
encoder->setBuffer(tensors[3], tensor_offsets[3], AttentionOperand(AttentionOperand::O).bufferIndex());
encoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::L).bufferIndex());
}

// Calculate the grid size.
auto ceilDivide =
[=](int64_t target, uint16_t granularity) -> int64_t {
return (target + int64_t(granularity) - 1) / int64_t(granularity);
};
MTL::Size gridSize
(ceilDivide(int64_t(hash.R), kernel->blockDimensions[0]),
hash.Hq,
attentionDesc.batchDimension);
MTL::Size groupSize
(int64_t(kernel->threadgroupSize), 1, 1);

// Dispatch the required number of threads.
encoder->dispatchThreadgroups(gridSize, groupSize);

// Finish the command.
command_batch->finishCommand(encoder);
if (attentionDesc.lowPrecisionInputs) {
// Need to dispatch to cast.
ccv_nnc_mfa_cast_params_t cast_params = {
.original_data_type = MTL::DataTypeFloat,
.data_type = MTL::DataTypeHalf,
.length = hash.R * hash.D * hash.Hq * attentionDesc.batchDimension
};
ccv_nnc_mfa_prepare_cast(context, cast_params);
mtl_buffer_t* cast_tensors[3] = {
scratch, // gradient
tensors[3], // destination
NULL
};
size_t cast_tensor_offsets[2] = {
0,
tensor_offsets[3]
};
ccv_nnc_mfa_encode_cast(context, cast_params, command_batch, cast_tensors, cast_tensor_offsets);
}
return;
}

// Simple broadcasting rules; not yet support for NumPy broadcasting rules.
simd::ushort2 num_batch_dims(0);
Expand Down Expand Up @@ -101,7 +248,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
};
}
encoder->setBytes(matrix_offsets, batch_sizes[0] * 32, 10);
}
}
}

if (params.masked) {
Expand Down Expand Up @@ -135,7 +282,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
CCV_NNC_MFA_PRECONDITION(params.Hq > params.Hk);
CCV_NNC_MFA_PRECONDITION((params.Hq % params.Hk) == 0);
uint32_t query_offsets[params.Hq * 4];
const int h_h_k_ratio = params.Hq / params.Hk;
const int h_h_k_ratio = params.Hq / params.Hk;
for (int i = 0; i < params.Hq; i++) {
query_offsets[i * 4] = i;
query_offsets[i * 4 + 1] = i / h_h_k_ratio;
Expand Down
51 changes: 49 additions & 2 deletions lib/nnc/mfa/v2/AttentionDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,30 @@

bool AttentionDescriptor::operator==(const AttentionDescriptor& rhs) const {
return
batchDimension == rhs.batchDimension &&
Hq == rhs.Hq &&
(lowPrecisionInputs == rhs.lowPrecisionInputs) &&
(lowPrecisionIntermediates == rhs.lowPrecisionIntermediates) &&
simd_all(leadingDimensions.value_or(simd::uint4(UINT32_MAX)) == rhs.leadingDimensions.value_or(simd::uint4(UINT32_MAX))) &&
batchStrides == rhs.batchStrides &&
simd_all(matrixDimensions == rhs.matrixDimensions) &&
simd_all(transposeState == rhs.transposeState);
}

std::size_t std::hash<AttentionDescriptor>::operator()(const AttentionDescriptor& hash) const noexcept {
std::size_t seed = 0;
using namespace ccv::nnc::mfa::hash;
combine_32(seed, hash.batchDimension);
combine_32(seed, hash.Hq);
combine_32(seed, hash.matrixDimensions[0]);
combine_32(seed, hash.matrixDimensions[1]);
combine_32(seed, hash.matrixDimensions[2]);
if (hash.leadingDimensions.has_value()) {
combine_32(seed, hash.leadingDimensions.value()[0]);
combine_32(seed, hash.leadingDimensions.value()[1]);
combine_32(seed, hash.leadingDimensions.value()[2]);
combine_32(seed, hash.leadingDimensions.value()[3]);
}
combine_32(seed, pack_32(simd::uchar4 { hash.transposeState[0], hash.transposeState[1], hash.transposeState[2], hash.transposeState[3] }));
combine_32(seed, pack_32(simd::uchar4 { hash.lowPrecisionInputs, hash.lowPrecisionIntermediates, 0, 0 }));
combine_32(seed, pack_32(simd::ushort2 { hash.type.value, 0 } ));
Expand Down Expand Up @@ -86,10 +98,27 @@ AttentionKernelDescriptor AttentionDescriptor::kernelDescriptor(MTL::Device *con
return output;
};

auto createLeadingDimensions =
[=]() -> AttentionOperands<unsigned short> {
AttentionOperands<unsigned short> output;
if (leadingDimensions.has_value()) {
output[AttentionOperand::Q] = leadingDimensions.value()[0];
output[AttentionOperand::K] = leadingDimensions.value()[1];
output[AttentionOperand::V] = leadingDimensions.value()[2];
output[AttentionOperand::O] = leadingDimensions.value()[3];

output[AttentionOperand::dO] = leadingDimensions.value()[3];
output[AttentionOperand::dV] = leadingDimensions.value()[2];
output[AttentionOperand::dK] = leadingDimensions.value()[1];
output[AttentionOperand::dQ] = leadingDimensions.value()[0];
}
return output;
};

if (device->supportsFamily(MTL::GPUFamily(1009))) {
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), true, false, createRegisterPrecisions(device), createTransposeState(), type, scale);
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), true, false, createRegisterPrecisions(device), createTransposeState(), createLeadingDimensions(), type, scale);
} else {
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), false, true, createRegisterPrecisions(device), createTransposeState(), type, scale);
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), false, true, createRegisterPrecisions(device), createTransposeState(), createLeadingDimensions(), type, scale);
}
}

Expand All @@ -101,8 +130,26 @@ std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> Attention
(MTL::FunctionConstantValues::alloc()->init());
uint32_t rowDimension = matrixDimensions[0];
uint32_t columnDimension = matrixDimensions[1];
uint32_t Hq = this->Hq;
constants->setConstantValue(&rowDimension, MTL::DataTypeUInt, NS::Integer(0));
constants->setConstantValue(&columnDimension, MTL::DataTypeUInt, 1);
constants->setConstantValue(&Hq, MTL::DataTypeUInt, 2);
std::vector<AttentionOperand> operands;
switch (type.value) {
case AttentionKernelType::forward:
operands = {AttentionOperand::Q, AttentionOperand::K, AttentionOperand::V, AttentionOperand::O};
break;
case AttentionKernelType::backwardQuery:
operands = {AttentionOperand::Q, AttentionOperand::K, AttentionOperand::V, AttentionOperand::O, AttentionOperand::dO, AttentionOperand::dQ};
break;
case AttentionKernelType::backwardKeyValue:
operands = {AttentionOperand::Q, AttentionOperand::K, AttentionOperand::V, AttentionOperand::O, AttentionOperand::dO, AttentionOperand::dV, AttentionOperand::dK};
break;
}
for (const auto& operand : operands) {
uint32_t batchStride = batchStrides[operand].value_or(0);
constants->setConstantValue(&batchStride, MTL::DataTypeUInt, 3 + operand.bufferIndex());
}

NS::String* swiftName = NS::String::string("attention", NS::UTF8StringEncoding);
NS::Error* error = nil;
Expand Down
12 changes: 12 additions & 0 deletions lib/nnc/mfa/v2/AttentionDescriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ struct AttentionParameterRow {
};

struct AttentionDescriptor {
/// The number of equally sized attention per sequence that run in parallel.
uint32_t batchDimension = 1;

/// The number of query heads per sequence that run in parallel.
unsigned short Hq = 1;

/// Q, K, V, dO
bool lowPrecisionInputs;

Expand All @@ -37,6 +43,12 @@ struct AttentionDescriptor {
/// Q, K, V, O
simd::uchar4 transposeState;

/// The leading dimensions after transposed (if applied).
/// Q, K, V, O
std::optional<simd::uint4> leadingDimensions;

AttentionOperands<unsigned int> batchStrides;

AttentionKernelType type;

float scale;
Expand Down
Loading

0 comments on commit 6737683

Please sign in to comment.