Skip to content

Commit

Permalink
Add flags that you can supply to GEMM in CNNP models.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jul 2, 2024
1 parent bd40f0b commit 885188a
Show file tree
Hide file tree
Showing 17 changed files with 140 additions and 131 deletions.
8 changes: 4 additions & 4 deletions bin/nnc/cifar-10.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ ccv_cnnp_model_t* _cifar_10_resnet56(void)
output = ccv_cnnp_model_apply(identity, MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_average_pool(DIM_ALLOC(0, 0), ccv_nnc_no_hint, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_flatten(0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_dense(10, 0, 1, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_dense(10, 0, 0, 1, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_softmax(0), MODEL_IO_LIST(output));
return ccv_cnnp_model_new(MODEL_IO_LIST(input), MODEL_IO_LIST(output), 1, 0);
}
Expand Down Expand Up @@ -123,7 +123,7 @@ ccv_cnnp_model_t* _cifar_10_dawn(void)
layer3,
ccv_cnnp_max_pool(DIM_ALLOC(0, 0), ccv_nnc_no_hint, 0),
ccv_cnnp_flatten(0),
ccv_cnnp_dense(10, 0, 1, 0)), 1, 0);
ccv_cnnp_dense(10, 0, 0, 1, 0)), 1, 0);
}

ccv_cnnp_model_t* _cifar_10_alexnet(void)
Expand All @@ -142,10 +142,10 @@ ccv_cnnp_model_t* _cifar_10_alexnet(void)
ccv_cnnp_relu(0),
ccv_cnnp_average_pool(DIM_ALLOC(3, 3), HINT((2, 2), (0, 0)), 0),
ccv_cnnp_flatten(0),
ccv_cnnp_dense(256, 0, 1, 0),
ccv_cnnp_dense(256, 0, 0, 1, 0),
ccv_cnnp_batch_norm(0.9, 1e-4, 1, 0),
ccv_cnnp_relu(0),
ccv_cnnp_dense(10, 0, 1, 0),
ccv_cnnp_dense(10, 0, 0, 1, 0),
ccv_cnnp_softmax(0)
), 1, 0);
}
Expand Down
2 changes: 1 addition & 1 deletion bin/nnc/coco.c
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ ccv_array_t* _imagenet_resnet50_v1d(void)
output = ccv_cnnp_model_apply(conv_4, MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_average_pool(DIM_ALLOC(0, 0), ccv_nnc_no_hint, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_flatten(0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_dense(1000, 0, 1, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_dense(1000, 0, 0, 1, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_softmax(0), MODEL_IO_LIST(output));
ccv_cnnp_model_t* const resnet50_v1d = ccv_cnnp_model_new(MODEL_IO_LIST(input), MODEL_IO_LIST(output), 1, 0);
ccv_array_push(backbones, &resnet50_v1d);
Expand Down
8 changes: 4 additions & 4 deletions bin/nnc/imagenet.c
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ ccv_cnnp_model_t* _imagenet_resnet50_v1d(void)
output = ccv_cnnp_model_apply(_resnet_block_layer_new(512, 4, 2, 3), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_average_pool(DIM_ALLOC(0, 0), ccv_nnc_no_hint, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_flatten(0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_dense(1000, 0, 1, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_dense(1000, 0, 0, 1, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_softmax(0), MODEL_IO_LIST(output));
return ccv_cnnp_model_new(MODEL_IO_LIST(input), MODEL_IO_LIST(output), 1, 0);
}
Expand All @@ -116,7 +116,7 @@ ccv_cnnp_model_t* _imagenet_resnet101_v1d(void)
output = ccv_cnnp_model_apply(_resnet_block_layer_new(512, 4, 2, 3), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_average_pool(DIM_ALLOC(0, 0), ccv_nnc_no_hint, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_flatten(0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_dense(1000, 0, 1, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_dense(1000, 0, 0, 1, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_softmax(0), MODEL_IO_LIST(output));
return ccv_cnnp_model_new(MODEL_IO_LIST(input), MODEL_IO_LIST(output), 1, 0);
}
Expand Down Expand Up @@ -150,7 +150,7 @@ ccv_cnnp_model_t* _imagenet_vgg13(void)
ccv_cnnp_relu(0),
ccv_cnnp_average_pool(DIM_ALLOC(0, 0), ccv_nnc_no_hint, 0),
ccv_cnnp_flatten(0),
ccv_cnnp_dense(1000, 0, 1, 0),
ccv_cnnp_dense(1000, 0, 0, 1, 0),
ccv_cnnp_softmax(0)
), 1, 0);
return vgg13;
Expand Down Expand Up @@ -243,7 +243,7 @@ ccv_cnnp_model_t* _efficientnet_b0(void)
output = ccv_cnnp_model_apply(ccv_cnnp_flatten(0), MODEL_IO_LIST(output));
if (dropout > 0)
output = ccv_cnnp_model_apply(ccv_cnnp_dropout(dropout, 0, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_dense(1000, 0, 1, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_dense(1000, 0, 0, 1, 0), MODEL_IO_LIST(output));
output = ccv_cnnp_model_apply(ccv_cnnp_softmax(0), MODEL_IO_LIST(output));
return ccv_cnnp_model_new(MODEL_IO_LIST(input), MODEL_IO_LIST(output), 1, 0);
}
Expand Down
18 changes: 9 additions & 9 deletions bin/nnc/imdb.c
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ static ccv_cnnp_model_t* _self_attention_new(const int k, const int h, const int
const ccv_cnnp_model_io_t x = ccv_cnnp_input();
ccv_cnnp_model_io_t mask = ccv_cnnp_input();
ccv_cnnp_model_io_t multiheads = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * t, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(x));
ccv_cnnp_model_t* const tokeys = ccv_cnnp_dense(k * h, 1, 1, "tokeys");
ccv_cnnp_model_t* const toqueries = ccv_cnnp_dense(k * h, 1, 1, "toqueries");
ccv_cnnp_model_t* const tovalues = ccv_cnnp_dense(k * h, 1, 1, "tovalues");
ccv_cnnp_model_t* const tokeys = ccv_cnnp_dense(k * h, 1, 0, 1, "tokeys");
ccv_cnnp_model_t* const toqueries = ccv_cnnp_dense(k * h, 1, 0, 1, "toqueries");
ccv_cnnp_model_t* const tovalues = ccv_cnnp_dense(k * h, 1, 0, 1, "tovalues");
ccv_cnnp_model_io_t keys = ccv_cnnp_model_apply(tokeys, MODEL_IO_LIST(multiheads));
ccv_cnnp_model_io_t queries = ccv_cnnp_model_apply(toqueries, MODEL_IO_LIST(multiheads));
ccv_cnnp_model_io_t values = ccv_cnnp_model_apply(tovalues, MODEL_IO_LIST(multiheads));
Expand All @@ -113,7 +113,7 @@ static ccv_cnnp_model_t* _self_attention_new(const int k, const int h, const int
keys = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * h, t, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(keys));
queries = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * h, t, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(queries));
values = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * h, t, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(values));
ccv_cnnp_model_io_t dot = ccv_cnnp_model_apply(ccv_cnnp_matmul(NO_TRANSPOSE, TRANSPOSE(1, 2), 0), MODEL_IO_LIST(queries, keys));
ccv_cnnp_model_io_t dot = ccv_cnnp_model_apply(ccv_cnnp_matmul(NO_TRANSPOSE, TRANSPOSE(1, 2), 0, 0), MODEL_IO_LIST(queries, keys));
const float scale = 1. / sqrt(k);
dot = ccv_cnnp_model_apply(ccv_cnnp_scalar_mul(scale, 0), MODEL_IO_LIST(dot));
dot = ccv_cnnp_model_apply(ccv_cnnp_masked_fill(0, -1e9, 0), MODEL_IO_LIST(dot, mask));
Expand All @@ -122,11 +122,11 @@ static ccv_cnnp_model_t* _self_attention_new(const int k, const int h, const int
if (dropout > 0)
dot = ccv_cnnp_model_apply(ccv_cnnp_dropout(dropout, 0, 0), MODEL_IO_LIST(dot));
dot = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * h, t, t), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(dot));
ccv_cnnp_model_io_t out = ccv_cnnp_model_apply(ccv_cnnp_matmul(NO_TRANSPOSE, NO_TRANSPOSE, 0), MODEL_IO_LIST(dot, values));
ccv_cnnp_model_io_t out = ccv_cnnp_model_apply(ccv_cnnp_matmul(NO_TRANSPOSE, NO_TRANSPOSE, 0, 0), MODEL_IO_LIST(dot, values));
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(h, b, t, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_transpose(0, 2, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * t, h * k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
ccv_cnnp_model_t* const unifyheads = ccv_cnnp_dense(k, 0, 1, "unifyheads");
ccv_cnnp_model_t* const unifyheads = ccv_cnnp_dense(k, 0, 0, 1, "unifyheads");
out = ccv_cnnp_model_apply(unifyheads, MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(t, b, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
return ccv_cnnp_model_new(MODEL_IO_LIST(x, mask), MODEL_IO_LIST(out), 1, "self-attention");
Expand All @@ -145,9 +145,9 @@ static ccv_cnnp_model_t* _transformer_block_new(const int k, const int h, const
else
out = first;
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * t, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(ff, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(ff, 0, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_relu(0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(k, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(k, 0, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(t, b, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_sum(0), MODEL_IO_LIST(first, out));
out = ccv_cnnp_model_apply(ccv_cnnp_layer_norm(1e-5, DIM_ALLOC(2), 1, 1, 1, 0), MODEL_IO_LIST(out));
Expand All @@ -170,7 +170,7 @@ static ccv_cnnp_model_t* _classifier_transformer_new(const int layers, const int
out = ccv_cnnp_model_apply(ccv_cnnp_average_pool(DIM_ALLOC(0, 0), ccv_nnc_no_hint, 0), MODEL_IO_LIST(out));
// Last layer, get it to 1.
out = ccv_cnnp_model_apply(ccv_cnnp_flatten(0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(1, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(1, 0, 0, 1, 0), MODEL_IO_LIST(out));
return ccv_cnnp_model_new(MODEL_IO_LIST(x, mask), MODEL_IO_LIST(out), 1, "classifier");
}

Expand Down
2 changes: 1 addition & 1 deletion bin/nnc/imdb_lstm.c
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ static ccv_cnnp_model_t* _classifier_lstm_new(const int batch_size, const int ba
out = ccv_cnnp_model_apply(ccv_cnnp_index_select(0), MODEL_IO_LIST(out, index));
// Last layer, get it to 1.
out = ccv_cnnp_model_apply(ccv_cnnp_flatten(0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(1, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(1, 0, 0, 1, 0), MODEL_IO_LIST(out));
return ccv_cnnp_model_new(MODEL_IO_LIST(x, mask, index), MODEL_IO_LIST(out), 1, "classifier");
}

Expand Down
22 changes: 11 additions & 11 deletions bin/nnc/iwslt.c
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ static ccv_array_t* _array_from_disk_new(const char* src_file, const char* tgt_f
static ccv_cnnp_model_t* _multihead_attention_new(const int k, const int h, const int b, const int t, const float dropout, const int has_m)
{
const ccv_cnnp_model_io_t x = ccv_cnnp_input();
ccv_cnnp_model_t* const tokeys = ccv_cnnp_dense(k * h, 1, 1, 0);
ccv_cnnp_model_t* const toqueries = ccv_cnnp_dense(k * h, 1, 1, 0);
ccv_cnnp_model_t* const tovalues = ccv_cnnp_dense(k * h, 1, 1, 0);
ccv_cnnp_model_t* const tokeys = ccv_cnnp_dense(k * h, 1, 0, 1, 0);
ccv_cnnp_model_t* const toqueries = ccv_cnnp_dense(k * h, 1, 0, 1, 0);
ccv_cnnp_model_t* const tovalues = ccv_cnnp_dense(k * h, 1, 0, 1, 0);
ccv_cnnp_model_io_t queries = ccv_cnnp_model_apply(toqueries, MODEL_IO_LIST(x));
ccv_cnnp_model_io_t m = has_m ? ccv_cnnp_input() : 0;
ccv_cnnp_model_io_t mask = ccv_cnnp_input();
Expand All @@ -155,7 +155,7 @@ static ccv_cnnp_model_t* _multihead_attention_new(const int k, const int h, cons
keys = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * h, t, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(keys));
queries = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * h, t, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(queries));
values = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * h, t, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(values));
ccv_cnnp_model_io_t dot = ccv_cnnp_model_apply(ccv_cnnp_matmul(NO_TRANSPOSE, TRANSPOSE(1, 2), 0), MODEL_IO_LIST(queries, keys));
ccv_cnnp_model_io_t dot = ccv_cnnp_model_apply(ccv_cnnp_matmul(NO_TRANSPOSE, TRANSPOSE(1, 2), 0, 0), MODEL_IO_LIST(queries, keys));
const float scale = 1. / sqrt(k);
dot = ccv_cnnp_model_apply(ccv_cnnp_scalar_mul(scale, 0), MODEL_IO_LIST(dot));
dot = ccv_cnnp_model_apply(ccv_cnnp_masked_fill(0, -1e9, 0), MODEL_IO_LIST(dot, mask));
Expand All @@ -164,11 +164,11 @@ static ccv_cnnp_model_t* _multihead_attention_new(const int k, const int h, cons
if (dropout > 0)
dot = ccv_cnnp_model_apply(ccv_cnnp_dropout(dropout, 0, 0), MODEL_IO_LIST(dot));
dot = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * h, t, t), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(dot));
ccv_cnnp_model_io_t out = ccv_cnnp_model_apply(ccv_cnnp_matmul(NO_TRANSPOSE, NO_TRANSPOSE, 0), MODEL_IO_LIST(dot, values));
ccv_cnnp_model_io_t out = ccv_cnnp_model_apply(ccv_cnnp_matmul(NO_TRANSPOSE, NO_TRANSPOSE, 0, 0), MODEL_IO_LIST(dot, values));
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(h, b, t, k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_transpose(0, 2, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * t, h * k), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
ccv_cnnp_model_t* const unifyheads = ccv_cnnp_dense(k * h, 0, 1, 0);
ccv_cnnp_model_t* const unifyheads = ccv_cnnp_dense(k * h, 0, 0, 1, 0);
out = ccv_cnnp_model_apply(unifyheads, MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(t, b, k * h), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
if (m)
Expand All @@ -190,9 +190,9 @@ static ccv_cnnp_model_t* _encoder_block_new(const int k, const int h, const int
out = ccv_cnnp_model_apply(ccv_cnnp_dropout(dropout, 0, 0), MODEL_IO_LIST(out));
// feed-forward
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * t, k * h), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(ff, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(ff, 0, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_relu(0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(k * h, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(k * h, 0, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(t, b, k * h), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_layer_norm(1e-5, DIM_ALLOC(2), 1, 1, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_sum(0), MODEL_IO_LIST(first, out));
Expand Down Expand Up @@ -223,9 +223,9 @@ static ccv_cnnp_model_t* _decoder_block_new(const int k, const int h, const int
out = ccv_cnnp_model_apply(ccv_cnnp_dropout(dropout, 0, 0), MODEL_IO_LIST(out));
// feed-forward
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * t, k * h), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(ff, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(ff, 0, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_relu(0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(k * h, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(k * h, 0, 0, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(t, b, k * h), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_layer_norm(1e-5, DIM_ALLOC(2), 1, 1, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_sum(0), MODEL_IO_LIST(first, out));
Expand All @@ -250,7 +250,7 @@ ccv_cnnp_model_t* _encoder_decoder_new(const int tgt_vocab_size, const int layer
decoder_out = ccv_cnnp_model_apply(_decoder_block_new(k, h, b, t, ff, dropout), MODEL_IO_LIST(decoder_out, encoder_out, src_mask, tgt_mask));
ccv_cnnp_model_io_t out = ccv_cnnp_model_apply(ccv_cnnp_transpose(0, 1, 0), MODEL_IO_LIST(decoder_out)); // t, b, d -> b, t, d
out = ccv_cnnp_model_apply(ccv_cnnp_reshape(0, DIM_ALLOC(b * t, k * h), DIM_ALLOC(), DIM_ALLOC(), 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(tgt_vocab_size, 1, 1, 0), MODEL_IO_LIST(out));
out = ccv_cnnp_model_apply(ccv_cnnp_dense(tgt_vocab_size, 1, 0, 1, 0), MODEL_IO_LIST(out));
return ccv_cnnp_model_new(MODEL_IO_LIST(src, tgt, src_mask, tgt_mask), MODEL_IO_LIST(out), 1, 0);
}

Expand Down
Loading

0 comments on commit 885188a

Please sign in to comment.