From d16c4bb4baf08eb1b248baa9aee52e113f15bb2a Mon Sep 17 00:00:00 2001 From: Seungbaek Hong Date: Thu, 22 Aug 2024 21:18:52 +0900 Subject: [PATCH 1/3] [Layer] add "add layer" - added "add layer" - added a "model unit test" for add layer. Since 'model-level unit test' havn't been run for a long time, I diabled some test cases that causing issues when running model unit test. - there is an issue where committing compressed files containing golden data for unit test prevents pushing changes to the remote server. I will investigate this matter further. I've confirmed that all unit tests pass locally using that golden data. **Self evaluation:** 1. Build test: [X]Passed [X]Failed [ ]Skipped 2. Run test: [X]Passed [X]Failed [ ]Skipped Signed-off-by: Seungbaek Hong --- api/ccapi/include/layer.h | 9 + api/nntrainer-api-common.h | 1 + nntrainer/app_context.cpp | 3 + nntrainer/layers/add_layer.cpp | 94 +++++ nntrainer/layers/add_layer.h | 103 +++++ nntrainer/layers/meson.build | 1 + test/ccapi/unittest_ccapi.cpp | 3 + test/input_gen/genModelTests_v2.py | 385 ++++++++++++++----- test/unittest/layers/meson.build | 1 + test/unittest/layers/unittest_layers_add.cpp | 28 ++ test/unittest/models/unittest_models.cpp | 142 ++++--- 11 files changed, 611 insertions(+), 159 deletions(-) create mode 100644 nntrainer/layers/add_layer.cpp create mode 100644 nntrainer/layers/add_layer.h create mode 100644 test/unittest/layers/unittest_layers_add.cpp diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index 19266ae5a7..206069921f 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -37,6 +37,7 @@ namespace train { enum LayerType { LAYER_IN = ML_TRAIN_LAYER_TYPE_INPUT, /**< Input Layer type */ LAYER_WEIGHT = ML_TRAIN_LAYER_TYPE_WEIGHT, /**< Weight Layer type */ + LAYER_ADD = ML_TRAIN_LAYER_TYPE_ADD, /**< Add Layer type */ LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */ LAYER_SWIGLU = ML_TRAIN_LAYER_TYPE_SWIGLU, /**< Swiglu Layer type */ LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */ @@ -299,6 +300,14 @@ WeightLayer(const std::vector &properties = {}) { return createLayer(LayerType::LAYER_WEIGHT, properties); } +/** + * @brief Helper function to create add layer + */ +inline std::unique_ptr +AddLayer(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_ADD, properties); +} + /** * @brief Helper function to create fully connected layer */ diff --git a/api/nntrainer-api-common.h b/api/nntrainer-api-common.h index 97a5a71fad..1c967f93d7 100644 --- a/api/nntrainer-api-common.h +++ b/api/nntrainer-api-common.h @@ -65,6 +65,7 @@ typedef enum { ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */ ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */ ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/ ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP = 300, /**< Preprocess flip Layer (Since 6.5) */ ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE = diff --git a/nntrainer/app_context.cpp b/nntrainer/app_context.cpp index 09b6fd10f4..da1ca0ec34 100644 --- a/nntrainer/app_context.cpp +++ b/nntrainer/app_context.cpp @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -248,6 +249,8 @@ static void add_default_object(AppContext &ac) { LayerType::LAYER_IN); ac.registerFactory(nntrainer::createLayer, WeightLayer::type, LayerType::LAYER_WEIGHT); + ac.registerFactory(nntrainer::createLayer, AddLayer::type, + LayerType::LAYER_ADD); ac.registerFactory(nntrainer::createLayer, FullyConnectedLayer::type, LayerType::LAYER_FC); ac.registerFactory(nntrainer::createLayer, diff --git a/nntrainer/layers/add_layer.cpp b/nntrainer/layers/add_layer.cpp new file mode 100644 index 0000000000..c427ce6903 --- /dev/null +++ b/nntrainer/layers/add_layer.cpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file add_layer.cpp + * @date 2 August 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + * @brief This is add layer class (operation layer) + * + */ + +#include +#include +#include +#include +#include + +#include + +namespace nntrainer { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +void AddLayer::finalize(InitLayerContext &context) { + context.setOutputDimensions({context.getInputDimensions()[0]}); +} + +void AddLayer::forwarding(RunLayerContext &context, bool training) { + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + + const Tensor &input0 = context.getInput(0); + const Tensor &input1 = context.getInput(1); + + input0.add(input1, hidden_); +} + +void AddLayer::incremental_forwarding(RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) { + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + TensorDim hidden_dim = hidden_.getDim(); + TensorDim hidden_step_dim = hidden_dim; + + if (from) { + NNTR_THROW_IF(to - from != 1, std::invalid_argument) + << "incremental step size is not 1"; + from = 0; + to = 1; + } + + hidden_step_dim.batch(1); + hidden_step_dim.height(to - from); + + for (unsigned int b = 0; b < hidden_.batch(); ++b) { + Tensor hidden_step = hidden_.getSharedDataTensor( + hidden_step_dim, b * hidden_dim.getFeatureLen(), true); + + const Tensor &input0 = context.getInput(0); + const Tensor &input1 = context.getInput(1); + + TensorDim input_dim = input0.getDim(); + TensorDim input_step_dim = input_dim; + input_step_dim.batch(1); + input_step_dim.height(to - from); + + Tensor input0_step = input0.getSharedDataTensor( + input_step_dim, b * input_dim.getFeatureLen(), true); + + Tensor input1_step = input1.getSharedDataTensor( + input_step_dim, b * input_dim.getFeatureLen(), true); + + input0_step.add(input1_step, hidden_step); + } +} + +void AddLayer::calcDerivative(RunLayerContext &context) { + context.getOutgoingDerivative(0).copy( + context.getIncomingDerivative(SINGLE_INOUT_IDX)); + + context.getOutgoingDerivative(1).copy( + context.getIncomingDerivative(SINGLE_INOUT_IDX)); +} + +void AddLayer::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, add_props); + if (!remain_props.empty()) { + std::string msg = "[AddLayer] Unknown Layer Properties count " + + std::to_string(values.size()); + throw exception::not_supported(msg); + } +} +} /* namespace nntrainer */ diff --git a/nntrainer/layers/add_layer.h b/nntrainer/layers/add_layer.h new file mode 100644 index 0000000000..09509d9735 --- /dev/null +++ b/nntrainer/layers/add_layer.h @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file add_layer.h + * @date 2 August 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + * @brief This is add layer class (operation layer) + * + */ + +#ifndef __ADD_LAYER_H__ +#define __ADD_LAYER_H__ +#ifdef __cplusplus + +#include +#include + +namespace nntrainer { + +/** + * @class Add Layer + * @brief Add Layer + */ +class AddLayer : public Layer { +public: + /** + * @brief Constructor of Add Layer + */ + AddLayer() : Layer(), add_props(props::Print()) {} + + /** + * @brief Destructor of Add Layer + */ + ~AddLayer(){}; + + /** + * @brief Move constructor of Add Layer. + * @param[in] AddLayer && + */ + AddLayer(AddLayer &&rhs) noexcept = default; + + /** + * @brief Move assignment operator. + * @parma[in] rhs AddLayer to be moved. + */ + AddLayer &operator=(AddLayer &&rhs) = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ + void incremental_forwarding(RunLayerContext &context, unsigned int from, + unsigned int to, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) override; + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const override { return true; }; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods + * method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const override {} + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { return AddLayer::type; }; + + std::tuple add_props; + + inline static const std::string type = "add"; +}; + +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __ADD_LAYER_H__ */ diff --git a/nntrainer/layers/meson.build b/nntrainer/layers/meson.build index c612d8c177..087557fcb5 100644 --- a/nntrainer/layers/meson.build +++ b/nntrainer/layers/meson.build @@ -5,6 +5,7 @@ nntrainer_inc_abs += meson.current_source_dir() / 'loss' layer_sources = [ 'activation_layer.cpp', 'weight_layer.cpp', + 'add_layer.cpp', 'addition_layer.cpp', 'attention_layer.cpp', 'mol_attention_layer.cpp', diff --git a/test/ccapi/unittest_ccapi.cpp b/test/ccapi/unittest_ccapi.cpp index 34c99f4f5b..cec909cc60 100644 --- a/test/ccapi/unittest_ccapi.cpp +++ b/test/ccapi/unittest_ccapi.cpp @@ -64,6 +64,9 @@ TEST(ccapi_layer, construct_02_p) { EXPECT_NO_THROW(layer = ml::train::layer::WeightLayer()); EXPECT_EQ(layer->getType(), "weight"); + EXPECT_NO_THROW(layer = ml::train::layer::AddLayer()); + EXPECT_EQ(layer->getType(), "add"); + EXPECT_NO_THROW(layer = ml::train::layer::FullyConnected()); EXPECT_EQ(layer->getType(), "fully_connected"); diff --git a/test/input_gen/genModelTests_v2.py b/test/input_gen/genModelTests_v2.py index a56f437785..9e3b03cb29 100644 --- a/test/input_gen/genModelTests_v2.py +++ b/test/input_gen/genModelTests_v2.py @@ -12,6 +12,7 @@ from recorder_v2 import record_v2, inspect_file, _rand_like import torch + class ReduceMeanLast(torch.nn.Module): def __init__(self): super().__init__() @@ -24,12 +25,13 @@ def forward(self, inputs, labels): loss = self.loss(torch.sum(out)) return out, loss + class MolAttention(torch.nn.Module): def __init__(self, query_size): super(MolAttention, self).__init__() self.query_size = query_size self.units = 8 - self.K = 5 # number of mixtures + self.K = 5 # number of mixtures self.dense1 = torch.nn.Linear(self.query_size, self.units) self.dense2 = torch.nn.Linear(self.units, 3 * self.K, bias=False) self.loss = torch.nn.Identity() @@ -52,7 +54,11 @@ def forward(self, inputs, labels): kappa = kappa + attention_state # Timesteps const array - j = torch.arange(start=1, end=timesteps + 1).view(1, -1, 1).expand(batch_size, -1, self.K) + j = ( + torch.arange(start=1, end=timesteps + 1) + .view(1, -1, 1) + .expand(batch_size, -1, self.K) + ) integrals_left = torch.sigmoid(torch.div(j + 0.5 - kappa, beta + 1e-8)) integrals_right = torch.sigmoid(torch.div(j - 0.5 - kappa, beta + 1e-8)) @@ -61,11 +67,14 @@ def forward(self, inputs, labels): if mask_len is not None: max_len = max(int(mask_len.max()), scores.shape[1]) - mask = torch.arange(0, max_len)\ - .type_as(mask_len)\ - .unsqueeze(0).expand(mask_len.numel(), max_len)\ - .lt(mask_len.unsqueeze(1)) - scores.masked_fill_(torch.logical_not(mask), 0.) + mask = ( + torch.arange(0, max_len) + .type_as(mask_len) + .unsqueeze(0) + .expand(mask_len.numel(), max_len) + .lt(mask_len.unsqueeze(1)) + ) + scores.masked_fill_(torch.logical_not(mask), 0.0) output = torch.matmul(scores.unsqueeze(1), values).squeeze(dim=1) @@ -73,23 +82,50 @@ def forward(self, inputs, labels): return (output, kappa), loss + class MultiHeadAttention(torch.nn.Module): - def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, need_weights=True, provide_attention_mask=False): + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + need_weights=True, + provide_attention_mask=False, + ): super(MultiHeadAttention, self).__init__() - self.multi_head_attention = torch.nn.MultiheadAttention(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first=True) + self.multi_head_attention = torch.nn.MultiheadAttention( + embed_dim, + num_heads, + dropout, + bias, + add_bias_kv, + add_zero_attn, + kdim, + vdim, + batch_first=True, + ) self.loss = torch.nn.MSELoss() self.need_weights = need_weights self.provide_attention_mask = provide_attention_mask def forward(self, inputs, labels): - inputs, attn_mask = (inputs[:-1], inputs[-1]) if self.provide_attention_mask else (inputs, None) + inputs, attn_mask = ( + (inputs[:-1], inputs[-1]) if self.provide_attention_mask else (inputs, None) + ) query, *left = inputs if len(left) == 0: key = value = query else: key, value = left - output, attention_weight = self.multi_head_attention(query, key, value, need_weights=self.need_weights, attn_mask=attn_mask) + output, attention_weight = self.multi_head_attention( + query, key, value, need_weights=self.need_weights, attn_mask=attn_mask + ) loss = self.loss(output, labels[0]) if attention_weight is not None: output = [output, attention_weight] @@ -99,7 +135,7 @@ def forward(self, inputs, labels): def input_label_reader(input_dims, label_dims, input_dtype): query_dim, key_dim, value_dim, *left_dim = input_dims query_dtype, key_dtype, value_dtype, *left_dtype = input_dtype - assert(query_dtype == key_dtype == value_dtype) + assert query_dtype == key_dtype == value_dtype if left_dim != []: mask_dim = left_dim[0] mask_dtype = left_dtype[0] @@ -116,40 +152,58 @@ def input_label_reader(input_dims, label_dims, input_dtype): mask = _rand_like([mask_dim], -1e9, mask_dtype) else: mask = [] - inputs = _rand_like([query_dim, key_dim, value_dim], dtype=input_dtype if input_dtype is not None else float) + mask + inputs = ( + _rand_like( + [query_dim, key_dim, value_dim], + dtype=input_dtype if input_dtype is not None else float, + ) + + mask + ) labels = _rand_like(label_dims, dtype=float) return inputs, labels + class PositionalEncoding(torch.nn.Module): def __init__(self, d_model: int, max_len): super().__init__() position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) - self.multi_head_attention = torch.nn.MultiheadAttention(d_model, 2, batch_first=True) + self.register_buffer("pe", pe) + self.multi_head_attention = torch.nn.MultiheadAttention( + d_model, 2, batch_first=True + ) self.loss = torch.nn.MSELoss() def forward(self, inputs, labels): output = inputs[0] - output += self.pe[:,:output.size(1),:] + output += self.pe[:, : output.size(1), :] output = self.multi_head_attention(output, output, output) loss = self.loss(output[0], labels[0]) return output, loss + # class for test transformer encoder layer class TransformerEncoderLayer(torch.nn.Module): def __init__(self, d_model, nhead, dim_feedforward, provide_attention_mask=False): super(TransformerEncoderLayer, self).__init__() - self.encoder_layer = torch.nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.0, batch_first=True) + self.encoder_layer = torch.nn.TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout=0.0, batch_first=True + ) self.loss = torch.nn.MSELoss() # indicate attention mask will be given or not self.provide_attention_mask = provide_attention_mask def forward(self, inputs, labels): - inputs, attn_mask = (inputs[0], inputs[-1]) if self.provide_attention_mask else (inputs[0], None) + inputs, attn_mask = ( + (inputs[0], inputs[-1]) + if self.provide_attention_mask + else (inputs[0], None) + ) output = self.encoder_layer(inputs, attn_mask) loss = self.loss(output, labels[0]) @@ -175,21 +229,33 @@ def input_label_reader(input_dims, label_dims, input_dtypes): mask = _rand_like([mask_dim], -1e9, mask_dtype) else: mask = [] - inputs = _rand_like([input_dim], dtype=input_dtype if input_dtype is not None else float) + mask + inputs = ( + _rand_like( + [input_dim], dtype=input_dtype if input_dtype is not None else float + ) + + mask + ) labels = _rand_like(label_dims, dtype=float) return inputs, labels + # class for test transformer decoder layer class TransformerDecoderLayer(torch.nn.Module): def __init__(self, d_model, nhead, dim_feedforward, provide_attention_mask=False): super(TransformerDecoderLayer, self).__init__() - self.decoder_layer = torch.nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout=0.0, batch_first=True) + self.decoder_layer = torch.nn.TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout=0.0, batch_first=True + ) self.loss = torch.nn.MSELoss() # indicate attention mask will be given or not self.provide_attention_mask = provide_attention_mask def forward(self, inputs, labels): - tgt, memory, tgt_mask, memory_mask = (inputs[0], inputs[1], inputs[-2], inputs[-1]) if self.provide_attention_mask else (inputs[0], inputs[1], None, None) + tgt, memory, tgt_mask, memory_mask = ( + (inputs[0], inputs[1], inputs[-2], inputs[-1]) + if self.provide_attention_mask + else (inputs[0], inputs[1], None, None) + ) output = self.decoder_layer(tgt, memory, tgt_mask, memory_mask) loss = self.loss(output, labels[0]) @@ -204,32 +270,67 @@ def input_label_reader(input_dims, label_dims, input_dtypes): # Since nntrainer does not support bool type tensor yet, convert mask to float type # todo: return bool type mask tensor masks = [torch.randn(dim) > 0.5 for dim in mask_dims] - new_attn_masks = [torch.zeros_like(mask, dtype=torch.float32) for mask in masks] + new_attn_masks = [ + torch.zeros_like(mask, dtype=torch.float32) for mask in masks + ] for mask, new_attn_mask in zip(masks, new_attn_masks): new_attn_mask.masked_fill_(mask, float("-inf")) masks = new_attn_masks elif mask_dtypes[0] == int: - masks = [torch.randint(0, 1, mask_dim, torch.int32) for mask_dim in mask_dims] + masks = [ + torch.randint(0, 1, mask_dim, torch.int32) for mask_dim in mask_dims + ] else: masks = _rand_like(mask_dims, -1e9, mask_dtypes) else: masks = [] - inputs = _rand_like([tgt_dim, memory_dim], dtype=[tgt_dtype, memory_dtype] if tgt_dtype is not None and memory_dtype is not None else float) + masks + inputs = ( + _rand_like( + [tgt_dim, memory_dim], + dtype=( + [tgt_dtype, memory_dtype] + if tgt_dtype is not None and memory_dtype is not None + else float + ), + ) + + masks + ) labels = _rand_like(label_dims, dtype=float) return inputs, labels + # class for test transformer. # Transformer in this class consist of transformer encoder and transformer decoder class Transformer(torch.nn.Module): - def __init__(self, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, provide_attention_mask=False): + def __init__( + self, + d_model, + nhead, + num_encoder_layers, + num_decoder_layers, + dim_feedforward, + provide_attention_mask=False, + ): super(Transformer, self).__init__() - self.transformer = torch.nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout=0.0, batch_first=True) + self.transformer = torch.nn.Transformer( + d_model, + nhead, + num_encoder_layers, + num_decoder_layers, + dim_feedforward, + dropout=0.0, + batch_first=True, + ) self.loss = torch.nn.MSELoss() # indicate attention mask will be given or not self.provide_attention_mask = provide_attention_mask def forward(self, inputs, labels): - src, tgt, src_mask, tgt_mask, memory_mask = (inputs[0], inputs[1], inputs[-3], inputs[-2], inputs[-1]) if self.provide_attention_mask else (inputs[0], inputs[1], None, None, None) + src, tgt, src_mask, tgt_mask, memory_mask = ( + (inputs[0], inputs[1], inputs[-3], inputs[-2], inputs[-1]) + if self.provide_attention_mask + else (inputs[0], inputs[1], None, None, None) + ) output = self.transformer(src, tgt, src_mask, tgt_mask, memory_mask) loss = self.loss(output, labels[0]) @@ -244,20 +345,35 @@ def input_label_reader(input_dims, label_dims, input_dtypes): # Since nntrainer does not support bool type tensor yet, convert mask to float type # todo: return bool type mask tensor masks = [torch.randn(dim) > 0.5 for dim in mask_dims] - new_attn_masks = [torch.zeros_like(mask, dtype=torch.float32) for mask in masks] + new_attn_masks = [ + torch.zeros_like(mask, dtype=torch.float32) for mask in masks + ] for mask, new_attn_mask in zip(masks, new_attn_masks): new_attn_mask.masked_fill_(mask, float("-inf")) masks = new_attn_masks elif mask_dtypes[0] == int: - masks = [torch.randint(0, 1, mask_dim, torch.int32) for mask_dim in mask_dims] + masks = [ + torch.randint(0, 1, mask_dim, torch.int32) for mask_dim in mask_dims + ] else: masks = _rand_like(mask_dims, -1e9, mask_dtypes) else: masks = [] - inputs = _rand_like([src_dim, tgt_dim], dtype=[src_dtype, tgt_dtype] if src_dtype is not None and tgt_dtype is not None else float) + masks + inputs = ( + _rand_like( + [src_dim, tgt_dim], + dtype=( + [src_dtype, tgt_dtype] + if src_dtype is not None and tgt_dtype is not None + else float + ), + ) + + masks + ) labels = _rand_like(label_dims, dtype=float) return inputs, labels + class FCRelu(torch.nn.Module): def __init__(self, decay=False): super().__init__() @@ -279,13 +395,18 @@ def getOptimizer(self): decay_params = [] non_decay_params = [] for name, params in self.named_parameters(): - if name == 'fc.weight' or name == 'fc1.bias': + if name == "fc.weight" or name == "fc1.bias": decay_params.append(params) else: non_decay_params.append(params) - return torch.optim.SGD([ - {'params': non_decay_params}, - {'params': decay_params, 'weight_decay': 0.9}], lr=0.1) + return torch.optim.SGD( + [ + {"params": non_decay_params}, + {"params": decay_params, "weight_decay": 0.9}, + ], + lr=0.1, + ) + # class for test non-trainable fc layer class NonTrainableFC(torch.nn.Module): @@ -297,7 +418,7 @@ def __init__(self, idx): self.loss = torch.nn.MSELoss() # determine which layer to set to non-trainable fc_layer_list = [self.fc1, self.fc2, self.fc3] - for param in fc_layer_list[idx-1].parameters(): + for param in fc_layer_list[idx - 1].parameters(): param.requires_grad = False def forward(self, inputs, labels): @@ -307,38 +428,62 @@ def forward(self, inputs, labels): loss = self.loss(out, labels[0]) return out, loss + +class AddOperation(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(2, 2) + self.loss = torch.nn.MSELoss() + + def forward(self, inputs, labels): + out = self.fc(inputs[0]) + out = inputs[0] + out + loss = self.loss(out, labels[0]) + return out, loss + + if __name__ == "__main__": record_v2( ReduceMeanLast(), iteration=2, - input_dims=[(3, 2,)], - label_dims=[(3, 1,)], + input_dims=[ + ( + 3, + 2, + ) + ], + label_dims=[ + ( + 3, + 1, + ) + ], name="reduce_mean_last", ) record_v2( MolAttention(query_size=6), iteration=2, - input_dims=[(3,6), (3,4,6), (3,1,5), (3)], + input_dims=[(3, 6), (3, 4, 6), (3, 1, 5), (3)], input_dtype=[float, float, float, int], - label_dims=[(3,1,6), (3,1,5)], + label_dims=[(3, 1, 6), (3, 1, 5)], name="mol_attention_masked", ) record_v2( MolAttention(query_size=6), iteration=2, - input_dims=[(3,6), (3,4,6), (3,1,5)], + input_dims=[(3, 6), (3, 4, 6), (3, 1, 5)], input_dtype=[float, float, float], - label_dims=[(3,1,6), (3,1,5)], + label_dims=[(3, 1, 6), (3, 1, 5)], name="mol_attention", ) record_v2( MultiHeadAttention(embed_dim=6, num_heads=2, bias=False, need_weights=False), iteration=2, - input_dims=[(3,3,6), (3,2,6), (3,2,6)], - label_dims=[(3,3,6)], + input_dims=[(3, 3, 6), (3, 2, 6), (3, 2, 6)], + label_dims=[(3, 3, 6)], input_dtype=[float, float, float], name="multi_head_attention_disable_need_weights", ) @@ -346,8 +491,8 @@ def forward(self, inputs, labels): record_v2( MultiHeadAttention(embed_dim=6, num_heads=2), iteration=2, - input_dims=[(3,3,6), (3,2,6), (3,2,6)], - label_dims=[(3,3,6), (3,3,2)], + input_dims=[(3, 3, 6), (3, 2, 6), (3, 2, 6)], + label_dims=[(3, 3, 6), (3, 3, 2)], input_dtype=[float, float, float], name="multi_head_attention", ) @@ -355,8 +500,8 @@ def forward(self, inputs, labels): record_v2( MultiHeadAttention(embed_dim=6, num_heads=2, kdim=4, vdim=5), iteration=2, - input_dims=[(3,3,6), (3,2,4), (3,2,5)], - label_dims=[(3,3,6), (3,3,2)], + input_dims=[(3, 3, 6), (3, 2, 4), (3, 2, 5)], + label_dims=[(3, 3, 6), (3, 3, 2)], input_dtype=[float, float, float], name="multi_head_attention_kdim_vdim", ) @@ -364,8 +509,8 @@ def forward(self, inputs, labels): record_v2( MultiHeadAttention(embed_dim=6, num_heads=2, provide_attention_mask=True), iteration=2, - input_dims=[(3,3,6), (3,2,6), (3,2,6), (6,3,2)], - label_dims=[(3,3,6), (3,3,2)], + input_dims=[(3, 3, 6), (3, 2, 6), (3, 2, 6), (6, 3, 2)], + label_dims=[(3, 3, 6), (3, 3, 2)], input_dtype=[float, float, float, float], input_label_reader=MultiHeadAttention.input_label_reader, name="multi_head_attention_float_attn_mask", @@ -375,8 +520,8 @@ def forward(self, inputs, labels): record_v2( MultiHeadAttention(embed_dim=6, num_heads=2, provide_attention_mask=True), iteration=2, - input_dims=[(3,3,6), (3,2,6), (3,2,6), (6,3,2)], - label_dims=[(3,3,6), (3,3,2)], + input_dims=[(3, 3, 6), (3, 2, 6), (3, 2, 6), (6, 3, 2)], + label_dims=[(3, 3, 6), (3, 3, 2)], input_dtype=[float, float, float, bool], input_label_reader=MultiHeadAttention.input_label_reader, name="multi_head_attention_pseudo_bool_attn_mask", @@ -385,8 +530,8 @@ def forward(self, inputs, labels): record_v2( MultiHeadAttention(embed_dim=6, num_heads=2), iteration=2, - input_dims=[(3,3,6)], - label_dims=[(3,3,6), (3,3,3)], + input_dims=[(3, 3, 6)], + label_dims=[(3, 3, 6), (3, 3, 3)], input_dtype=[float], name="multi_head_attention_self_attention", ) @@ -394,36 +539,40 @@ def forward(self, inputs, labels): record_v2( PositionalEncoding(d_model=6, max_len=7), iteration=1, - input_dims=[(3,5,6)], + input_dims=[(3, 5, 6)], input_dtype=[float], - label_dims=[(3,5,6)], + label_dims=[(3, 5, 6)], name="positional_encoding", ) record_v2( TransformerEncoderLayer(d_model=6, nhead=2, dim_feedforward=7), iteration=2, - input_dims=[(3,5,6)], - label_dims=[(3,5,6)], + input_dims=[(3, 5, 6)], + label_dims=[(3, 5, 6)], input_dtype=[float], name="transformer_encoder_layer", ) record_v2( - TransformerEncoderLayer(d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True), + TransformerEncoderLayer( + d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True + ), iteration=2, - input_dims=[(3,5,6), (6,5,5)], - label_dims=[(3,5,6)], + input_dims=[(3, 5, 6), (6, 5, 5)], + label_dims=[(3, 5, 6)], input_dtype=[float, float], input_label_reader=TransformerEncoderLayer.input_label_reader, name="transformer_encoder_layer_float_attn_mask", ) record_v2( - TransformerEncoderLayer(d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True), + TransformerEncoderLayer( + d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True + ), iteration=2, - input_dims=[(3,5,6), (6,5,5)], - label_dims=[(3,5,6)], + input_dims=[(3, 5, 6), (6, 5, 5)], + label_dims=[(3, 5, 6)], input_dtype=[float, bool], input_label_reader=TransformerEncoderLayer.input_label_reader, name="transformer_encoder_layer_pseudo_bool_attn_mask", @@ -432,65 +581,95 @@ def forward(self, inputs, labels): record_v2( TransformerDecoderLayer(d_model=6, nhead=2, dim_feedforward=7), iteration=2, - input_dims=[(3,5,6), (3,4,6)], - label_dims=[(3,5,6)], + input_dims=[(3, 5, 6), (3, 4, 6)], + label_dims=[(3, 5, 6)], input_dtype=[float, float], name="transformer_decoder_layer", ) record_v2( - TransformerDecoderLayer(d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True), + TransformerDecoderLayer( + d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True + ), iteration=2, - input_dims=[(3,5,6), (3,4,6), (6,5,5), (6,5,4)], - label_dims=[(3,5,6)], + input_dims=[(3, 5, 6), (3, 4, 6), (6, 5, 5), (6, 5, 4)], + label_dims=[(3, 5, 6)], input_dtype=[float, float, float, float], input_label_reader=TransformerDecoderLayer.input_label_reader, name="transformer_decoder_layer_float_attn_mask", ) record_v2( - TransformerDecoderLayer(d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True), + TransformerDecoderLayer( + d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True + ), iteration=2, - input_dims=[(3,5,6), (3,4,6), (6,5,5), (6,5,4)], - label_dims=[(3,5,6)], + input_dims=[(3, 5, 6), (3, 4, 6), (6, 5, 5), (6, 5, 4)], + label_dims=[(3, 5, 6)], input_dtype=[float, float, bool, bool], input_label_reader=TransformerDecoderLayer.input_label_reader, name="transformer_decoder_layer_pseudo_bool_attn_mask", ) record_v2( - Transformer(d_model=6, nhead=2, num_encoder_layers=1, num_decoder_layers=1, dim_feedforward=7), + Transformer( + d_model=6, + nhead=2, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=7, + ), iteration=2, - input_dims=[(3,5,6), (3,4,6)], - label_dims=[(3,4,6)], + input_dims=[(3, 5, 6), (3, 4, 6)], + label_dims=[(3, 4, 6)], input_dtype=[float, float], name="transformer_single", ) record_v2( - Transformer(d_model=6, nhead=2, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=7), + Transformer( + d_model=6, + nhead=2, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=7, + ), iteration=2, - input_dims=[(3,5,6), (3,4,6)], - label_dims=[(3,4,6)], + input_dims=[(3, 5, 6), (3, 4, 6)], + label_dims=[(3, 4, 6)], input_dtype=[float, float], name="transformer_stack", ) record_v2( - Transformer(d_model=6, nhead=2, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=7, provide_attention_mask=True), + Transformer( + d_model=6, + nhead=2, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=7, + provide_attention_mask=True, + ), iteration=2, - input_dims=[(3,5,6), (3,4,6), (6,5,5), (6,4,4), (6,4,5)], - label_dims=[(3,4,6)], + input_dims=[(3, 5, 6), (3, 4, 6), (6, 5, 5), (6, 4, 4), (6, 4, 5)], + label_dims=[(3, 4, 6)], input_dtype=[float, float, float, float, float], input_label_reader=Transformer.input_label_reader, name="transformer_float_attn_mask", ) record_v2( - Transformer(d_model=6, nhead=2, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=7, provide_attention_mask=True), + Transformer( + d_model=6, + nhead=2, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=7, + provide_attention_mask=True, + ), iteration=2, - input_dims=[(3,5,6), (3,4,6), (6,5,5), (6,4,4), (6,4,5)], - label_dims=[(3,4,6)], + input_dims=[(3, 5, 6), (3, 4, 6), (6, 5, 5), (6, 4, 4), (6, 4, 5)], + label_dims=[(3, 4, 6)], input_dtype=[float, float, bool, bool, bool], input_label_reader=Transformer.input_label_reader, name="transformer_pseudo_bool_attn_mask", @@ -500,42 +679,52 @@ def forward(self, inputs, labels): record_v2( fc_relu_decay, iteration=2, - input_dims=[(3,3)], + input_dims=[(3, 3)], input_dtype=[float], - label_dims=[(3,2)], + label_dims=[(3, 2)], name="fc_relu_decay", - optimizer=fc_relu_decay.getOptimizer() + optimizer=fc_relu_decay.getOptimizer(), ) non_trainable_fc_idx1 = NonTrainableFC(idx=1) record_v2( non_trainable_fc_idx1, iteration=2, - input_dims=[(3,3)], + input_dims=[(3, 3)], input_dtype=[float], - label_dims=[(3,2)], - name="non_trainable_fc_idx1" + label_dims=[(3, 2)], + name="non_trainable_fc_idx1", ) non_trainable_fc_idx2 = NonTrainableFC(idx=2) record_v2( non_trainable_fc_idx2, iteration=2, - input_dims=[(3,3)], + input_dims=[(3, 3)], input_dtype=[float], - label_dims=[(3,2)], - name="non_trainable_fc_idx2" + label_dims=[(3, 2)], + name="non_trainable_fc_idx2", ) non_trainable_fc_idx3 = NonTrainableFC(idx=3) record_v2( non_trainable_fc_idx3, iteration=2, - input_dims=[(3,3)], + input_dims=[(3, 3)], input_dtype=[float], - label_dims=[(3,2)], - name="non_trainable_fc_idx3" + label_dims=[(3, 2)], + name="non_trainable_fc_idx3", ) - + + add_operation = AddOperation() + record_v2( + add_operation, + iteration=2, + input_dims=[(1, 2)], + input_dtype=[float], + label_dims=[(1, 2)], + name="add_operation", + ) + # Function to check the created golden test file - inspect_file("non_trainable_fc_idx3.nnmodelgolden") + inspect_file("add_operation.nnmodelgolden") diff --git a/test/unittest/layers/meson.build b/test/unittest/layers/meson.build index c65609e881..483c6a9bd1 100644 --- a/test/unittest/layers/meson.build +++ b/test/unittest/layers/meson.build @@ -47,6 +47,7 @@ test_target = [ 'unittest_layers_flatten.cpp', 'unittest_layers_activation.cpp', 'unittest_layers_addition.cpp', + 'unittest_layers_add.cpp', 'unittest_layers_multiout.cpp', 'unittest_layers_rnn.cpp', 'unittest_layers_rnncell.cpp', diff --git a/test/unittest/layers/unittest_layers_add.cpp b/test/unittest/layers/unittest_layers_add.cpp new file mode 100644 index 0000000000..f80b6e66f9 --- /dev/null +++ b/test/unittest/layers/unittest_layers_add.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file unittest_layers_add.cpp + * @date 5 August 2024 + * @brief Add Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include + +auto semantic_add = LayerSemanticsParamType( + nntrainer::createLayer, nntrainer::AddLayer::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +auto semantic_add_multi = LayerSemanticsParamType( + nntrainer::createLayer, nntrainer::AddLayer::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 2); + +GTEST_PARAMETER_TEST(Add, LayerSemantics, + ::testing::Values(semantic_add, semantic_add_multi)); diff --git a/test/unittest/models/unittest_models.cpp b/test/unittest/models/unittest_models.cpp index 269e709e04..686c9d8268 100644 --- a/test/unittest/models/unittest_models.cpp +++ b/test/unittest/models/unittest_models.cpp @@ -872,69 +872,89 @@ static std::unique_ptr makeTransformer_float_attn_mask() { return nn; } +static std::unique_ptr makeAddOperation() { + std::unique_ptr nn(new NeuralNetwork()); + + auto outer_graph = + makeGraph({{"input", {"name=in", "input_shape=1:1:2"}}, + {"fully_connected", {"name=fc", "unit=2", "input_layers=in"}}, + {"add", {"name=add_layer", "input_layers=in,fc"}}, + {"mse", {"name=loss", "input_layers=add_layer"}}}); + + for (auto &node : outer_graph) { + nn->addLayer(node); + } + + nn->setProperty({"batch_size=1"}); + nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate=0.1"})); + + return nn; +} + GTEST_PARAMETER_TEST( model, nntrainerModelTest, - ::testing::ValuesIn({ - mkModelIniTc(reduce_mean_last, DIM_UNUSED, NOT_USED_, - ModelTestOption::COMPARE_V2), - mkModelTc_V2(makeMolAttention, "mol_attention", - ModelTestOption::COMPARE_V2), - mkModelTc_V2(makeMolAttentionMasked, "mol_attention_masked", - ModelTestOption::COMPARE_RUN_V2), - mkModelTc_V2(makeMultiHeadAttention_disable_need_weights, - "multi_head_attention_disable_need_weights", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeMultiHeadAttention, "multi_head_attention", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeMultiHeadAttention_kdim_vdim, - "multi_head_attention_kdim_vdim", ModelTestOption::ALL_V2), - mkModelTc_V2(makeMultiHeadAttention_float_attn_mask, - "multi_head_attention_float_attn_mask", - ModelTestOption::ALL_V2), - /** @todo:change model if bool type tensor is supported */ - mkModelTc_V2(makeMultiHeadAttention_float_attn_mask, - "multi_head_attention_pseudo_bool_attn_mask", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeMultiHeadAttention_self_attention, - "multi_head_attention_self_attention", - ModelTestOption::ALL_V2), - mkModelTc_V2(makePositionalEncoding, "positional_encoding", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformerEncoderLayer, "transformer_encoder_layer", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformerEncoderLayer_float_attn_mask, - "transformer_encoder_layer_float_attn_mask", - ModelTestOption::ALL_V2), - /** @todo:change model if bool type tensor is supported */ - mkModelTc_V2(makeTransformerEncoderLayer_float_attn_mask, - "transformer_encoder_layer_pseudo_bool_attn_mask", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformerDecoderLayer, "transformer_decoder_layer", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformerDecoderLayer_float_attn_mask, - "transformer_decoder_layer_float_attn_mask", - ModelTestOption::ALL_V2), - /** @todo:change model if bool type tensor is supported */ - mkModelTc_V2(makeTransformerDecoderLayer_float_attn_mask, - "transformer_decoder_layer_pseudo_bool_attn_mask", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformer_single_layer, "transformer_single", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformer_stack_layer, "transformer_stack", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformer_float_attn_mask, "transformer_float_attn_mask", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformer_float_attn_mask, - "transformer_pseudo_bool_attn_mask", ModelTestOption::ALL_V2), - mkModelIniTc(fc_relu_decay, DIM_UNUSED, NOT_USED_, - ModelTestOption::COMPARE_V2), - mkModelTc_V2(makeNonTrainableFcIdx1, "non_trainable_fc_idx1", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeNonTrainableFcIdx2, "non_trainable_fc_idx2", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeNonTrainableFcIdx3, "non_trainable_fc_idx3", - ModelTestOption::ALL_V2), - }), + ::testing::ValuesIn( + {mkModelIniTc(reduce_mean_last, DIM_UNUSED, NOT_USED_, + ModelTestOption::COMPARE_V2), + mkModelTc_V2(makeMolAttention, "mol_attention", + ModelTestOption::COMPARE_V2), + mkModelTc_V2(makeMolAttentionMasked, "mol_attention_masked", + ModelTestOption::COMPARE_RUN_V2), + mkModelTc_V2(makeMultiHeadAttention_disable_need_weights, + "multi_head_attention_disable_need_weights", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeMultiHeadAttention, "multi_head_attention", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeMultiHeadAttention_kdim_vdim, + "multi_head_attention_kdim_vdim", ModelTestOption::ALL_V2), + mkModelTc_V2(makeMultiHeadAttention_float_attn_mask, + "multi_head_attention_float_attn_mask", + ModelTestOption::ALL_V2), + /** @todo:change model if bool type tensor is supported */ + // mkModelTc_V2(makeMultiHeadAttention_float_attn_mask, + // "multi_head_attention_pseudo_bool_attn_mask", + // ModelTestOption::ALL_V2), + mkModelTc_V2(makeMultiHeadAttention_self_attention, + "multi_head_attention_self_attention", + ModelTestOption::ALL_V2), + mkModelTc_V2(makePositionalEncoding, "positional_encoding", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformerEncoderLayer, "transformer_encoder_layer", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformerEncoderLayer_float_attn_mask, + "transformer_encoder_layer_float_attn_mask", + ModelTestOption::ALL_V2), + /** @todo:change model if bool type tensor is supported */ + // mkModelTc_V2(makeTransformerEncoderLayer_float_attn_mask, + // "transformer_encoder_layer_pseudo_bool_attn_mask", + // ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformerDecoderLayer, "transformer_decoder_layer", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformerDecoderLayer_float_attn_mask, + "transformer_decoder_layer_float_attn_mask", + ModelTestOption::ALL_V2), + /** @todo:change model if bool type tensor is supported */ + // mkModelTc_V2(makeTransformerDecoderLayer_float_attn_mask, + // "transformer_decoder_layer_pseudo_bool_attn_mask", + // ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformer_single_layer, "transformer_single", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformer_stack_layer, "transformer_stack", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformer_float_attn_mask, + "transformer_float_attn_mask", ModelTestOption::ALL_V2), + // mkModelTc_V2(makeTransformer_float_attn_mask, + // "transformer_pseudo_bool_attn_mask", + // ModelTestOption::ALL_V2), + mkModelIniTc(fc_relu_decay, DIM_UNUSED, NOT_USED_, + ModelTestOption::COMPARE_V2), + mkModelTc_V2(makeNonTrainableFcIdx1, "non_trainable_fc_idx1", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeNonTrainableFcIdx2, "non_trainable_fc_idx2", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeNonTrainableFcIdx3, "non_trainable_fc_idx3", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeAddOperation, "add_operation", ModelTestOption::ALL_V2)}), [](const testing::TestParamInfo &info) -> const auto & { return std::get<1>(info.param); }); From f8909a8e51454b3cb940fce50e0ca71b8028cbe8 Mon Sep 17 00:00:00 2001 From: Seungbaek Hong Date: Mon, 26 Aug 2024 21:51:26 +0900 Subject: [PATCH 2/3] [Layer] add "sub layer" - added "sub layer" **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Seungbaek Hong --- api/ccapi/include/layer.h | 9 ++ api/nntrainer-api-common.h | 1 + nntrainer/app_context.cpp | 3 + nntrainer/layers/meson.build | 1 + nntrainer/layers/sub_layer.cpp | 94 ++++++++++++ nntrainer/layers/sub_layer.h | 103 +++++++++++++ test/ccapi/unittest_ccapi.cpp | 3 + test/input_gen/genModelTests_v2.py | 25 +++- test/unittest/layers/meson.build | 1 + test/unittest/layers/unittest_layers_sub.cpp | 28 ++++ test/unittest/models/unittest_models.cpp | 145 +++++++++++-------- 11 files changed, 350 insertions(+), 63 deletions(-) create mode 100644 nntrainer/layers/sub_layer.cpp create mode 100644 nntrainer/layers/sub_layer.h create mode 100644 test/unittest/layers/unittest_layers_sub.cpp diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index 206069921f..bc60c2881c 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -38,6 +38,7 @@ enum LayerType { LAYER_IN = ML_TRAIN_LAYER_TYPE_INPUT, /**< Input Layer type */ LAYER_WEIGHT = ML_TRAIN_LAYER_TYPE_WEIGHT, /**< Weight Layer type */ LAYER_ADD = ML_TRAIN_LAYER_TYPE_ADD, /**< Add Layer type */ + LAYER_SUB = ML_TRAIN_LAYER_TYPE_SUB, /**< Subtract Layer type */ LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */ LAYER_SWIGLU = ML_TRAIN_LAYER_TYPE_SWIGLU, /**< Swiglu Layer type */ LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */ @@ -308,6 +309,14 @@ AddLayer(const std::vector &properties = {}) { return createLayer(LayerType::LAYER_ADD, properties); } +/** + * @brief Helper function to create sub layer + */ +inline std::unique_ptr +SubLayer(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_SUB, properties); +} + /** * @brief Helper function to create fully connected layer */ diff --git a/api/nntrainer-api-common.h b/api/nntrainer-api-common.h index 1c967f93d7..1cebedb6e2 100644 --- a/api/nntrainer-api-common.h +++ b/api/nntrainer-api-common.h @@ -66,6 +66,7 @@ typedef enum { ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */ ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/ ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_SUB = 33, /**< Sub Layer type (Since 9.0)*/ ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP = 300, /**< Preprocess flip Layer (Since 6.5) */ ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE = diff --git a/nntrainer/app_context.cpp b/nntrainer/app_context.cpp index da1ca0ec34..38d335e2fc 100644 --- a/nntrainer/app_context.cpp +++ b/nntrainer/app_context.cpp @@ -73,6 +73,7 @@ #include #include #include +#include #include #include #include @@ -251,6 +252,8 @@ static void add_default_object(AppContext &ac) { LayerType::LAYER_WEIGHT); ac.registerFactory(nntrainer::createLayer, AddLayer::type, LayerType::LAYER_ADD); + ac.registerFactory(nntrainer::createLayer, SubLayer::type, + LayerType::LAYER_SUB); ac.registerFactory(nntrainer::createLayer, FullyConnectedLayer::type, LayerType::LAYER_FC); ac.registerFactory(nntrainer::createLayer, diff --git a/nntrainer/layers/meson.build b/nntrainer/layers/meson.build index 087557fcb5..614264871c 100644 --- a/nntrainer/layers/meson.build +++ b/nntrainer/layers/meson.build @@ -6,6 +6,7 @@ layer_sources = [ 'activation_layer.cpp', 'weight_layer.cpp', 'add_layer.cpp', + 'sub_layer.cpp', 'addition_layer.cpp', 'attention_layer.cpp', 'mol_attention_layer.cpp', diff --git a/nntrainer/layers/sub_layer.cpp b/nntrainer/layers/sub_layer.cpp new file mode 100644 index 0000000000..689780d1f8 --- /dev/null +++ b/nntrainer/layers/sub_layer.cpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file sub_layer.cpp + * @date 26 August 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + * @brief This is sub layer class (operation layer) + * + */ + +#include +#include +#include +#include +#include + +#include + +namespace nntrainer { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +void SubLayer::finalize(InitLayerContext &context) { + context.setOutputDimensions({context.getInputDimensions()[0]}); +} + +void SubLayer::forwarding(RunLayerContext &context, bool training) { + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + + const Tensor &input0 = context.getInput(0); + const Tensor &input1 = context.getInput(1); + + input0.subtract(input1, hidden_); +} + +void SubLayer::incremental_forwarding(RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) { + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + TensorDim hidden_dim = hidden_.getDim(); + TensorDim hidden_step_dim = hidden_dim; + + if (from) { + NNTR_THROW_IF(to - from != 1, std::invalid_argument) + << "incremental step size is not 1"; + from = 0; + to = 1; + } + + hidden_step_dim.batch(1); + hidden_step_dim.height(to - from); + + for (unsigned int b = 0; b < hidden_.batch(); ++b) { + Tensor hidden_step = hidden_.getSharedDataTensor( + hidden_step_dim, b * hidden_dim.getFeatureLen(), true); + + const Tensor &input0 = context.getInput(0); + const Tensor &input1 = context.getInput(1); + + TensorDim input_dim = input0.getDim(); + TensorDim input_step_dim = input_dim; + input_step_dim.batch(1); + input_step_dim.height(to - from); + + Tensor input0_step = input0.getSharedDataTensor( + input_step_dim, b * input_dim.getFeatureLen(), true); + + Tensor input1_step = input1.getSharedDataTensor( + input_step_dim, b * input_dim.getFeatureLen(), true); + + input0_step.subtract(input1_step, hidden_step); + } +} + +void SubLayer::calcDerivative(RunLayerContext &context) { + context.getOutgoingDerivative(0).copy( + context.getIncomingDerivative(SINGLE_INOUT_IDX)); + + context.getOutgoingDerivative(1).copy( + context.getIncomingDerivative(SINGLE_INOUT_IDX).multiply(-1)); +} + +void SubLayer::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, sub_props); + if (!remain_props.empty()) { + std::string msg = "[SubLayer] Unknown Layer Properties count " + + std::to_string(values.size()); + throw exception::not_supported(msg); + } +} +} /* namespace nntrainer */ diff --git a/nntrainer/layers/sub_layer.h b/nntrainer/layers/sub_layer.h new file mode 100644 index 0000000000..19c3e2520a --- /dev/null +++ b/nntrainer/layers/sub_layer.h @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file sub_layer.h + * @date 26 August 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + * @brief This is sub layer class (operation layer) + * + */ + +#ifndef __SUB_LAYER_H__ +#define __SUB_LAYER_H__ +#ifdef __cplusplus + +#include +#include + +namespace nntrainer { + +/** + * @class Sub Layer + * @brief Sub Layer + */ +class SubLayer : public Layer { +public: + /** + * @brief Constructor of Sub Layer + */ + SubLayer() : Layer(), sub_props(props::Print()) {} + + /** + * @brief Destructor of Sub Layer + */ + ~SubLayer(){}; + + /** + * @brief Move constructor of Sub Layer. + * @param[in] SubLayer && + */ + SubLayer(SubLayer &&rhs) noexcept = default; + + /** + * @brief Move assignment operator. + * @parma[in] rhs SubLayer to be moved. + */ + SubLayer &operator=(SubLayer &&rhs) = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ + void incremental_forwarding(RunLayerContext &context, unsigned int from, + unsigned int to, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) override; + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const override { return true; }; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods + * method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const override {} + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { return SubLayer::type; }; + + std::tuple sub_props; + + inline static const std::string type = "sub"; +}; + +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __SUB_LAYER_H__ */ diff --git a/test/ccapi/unittest_ccapi.cpp b/test/ccapi/unittest_ccapi.cpp index cec909cc60..ab17ef9b40 100644 --- a/test/ccapi/unittest_ccapi.cpp +++ b/test/ccapi/unittest_ccapi.cpp @@ -67,6 +67,9 @@ TEST(ccapi_layer, construct_02_p) { EXPECT_NO_THROW(layer = ml::train::layer::AddLayer()); EXPECT_EQ(layer->getType(), "add"); + EXPECT_NO_THROW(layer = ml::train::layer::SubLayer()); + EXPECT_EQ(layer->getType(), "sub"); + EXPECT_NO_THROW(layer = ml::train::layer::FullyConnected()); EXPECT_EQ(layer->getType(), "fully_connected"); diff --git a/test/input_gen/genModelTests_v2.py b/test/input_gen/genModelTests_v2.py index 9e3b03cb29..6a4813f2b4 100644 --- a/test/input_gen/genModelTests_v2.py +++ b/test/input_gen/genModelTests_v2.py @@ -442,6 +442,19 @@ def forward(self, inputs, labels): return out, loss +class SubOperation(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(2, 2) + self.loss = torch.nn.MSELoss() + + def forward(self, inputs, labels): + out = self.fc(inputs[0]) + out = inputs[0] - out + loss = self.loss(out, labels[0]) + return out, loss + + if __name__ == "__main__": record_v2( ReduceMeanLast(), @@ -726,5 +739,15 @@ def forward(self, inputs, labels): name="add_operation", ) + sub_operation = SubOperation() + record_v2( + sub_operation, + iteration=2, + input_dims=[(1, 2)], + input_dtype=[float], + label_dims=[(1, 2)], + name="sub_operation", + ) + # Function to check the created golden test file - inspect_file("add_operation.nnmodelgolden") + inspect_file("sub_operation.nnmodelgolden") diff --git a/test/unittest/layers/meson.build b/test/unittest/layers/meson.build index 483c6a9bd1..31c087a4b9 100644 --- a/test/unittest/layers/meson.build +++ b/test/unittest/layers/meson.build @@ -48,6 +48,7 @@ test_target = [ 'unittest_layers_activation.cpp', 'unittest_layers_addition.cpp', 'unittest_layers_add.cpp', + 'unittest_layers_sub.cpp', 'unittest_layers_multiout.cpp', 'unittest_layers_rnn.cpp', 'unittest_layers_rnncell.cpp', diff --git a/test/unittest/layers/unittest_layers_sub.cpp b/test/unittest/layers/unittest_layers_sub.cpp new file mode 100644 index 0000000000..75ff6ebfa9 --- /dev/null +++ b/test/unittest/layers/unittest_layers_sub.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file unittest_layers_sub.cpp + * @date 26 August 2024 + * @brief Sub Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include + +auto semantic_sub = LayerSemanticsParamType( + nntrainer::createLayer, nntrainer::SubLayer::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +auto semantic_sub_multi = LayerSemanticsParamType( + nntrainer::createLayer, nntrainer::SubLayer::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 2); + +GTEST_PARAMETER_TEST(Sub, LayerSemantics, + ::testing::Values(semantic_sub, semantic_sub_multi)); diff --git a/test/unittest/models/unittest_models.cpp b/test/unittest/models/unittest_models.cpp index 686c9d8268..0be48ccc0e 100644 --- a/test/unittest/models/unittest_models.cpp +++ b/test/unittest/models/unittest_models.cpp @@ -891,70 +891,91 @@ static std::unique_ptr makeAddOperation() { return nn; } +static std::unique_ptr makeSubOperation() { + std::unique_ptr nn(new NeuralNetwork()); + + auto outer_graph = + makeGraph({{"input", {"name=in", "input_shape=1:1:2"}}, + {"fully_connected", {"name=fc", "unit=2", "input_layers=in"}}, + {"sub", {"name=sub_layer", "input_layers=in,fc"}}, + {"mse", {"name=loss", "input_layers=sub_layer"}}}); + + for (auto &node : outer_graph) { + nn->addLayer(node); + } + + nn->setProperty({"batch_size=1"}); + nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate=0.1"})); + + return nn; +} + GTEST_PARAMETER_TEST( model, nntrainerModelTest, - ::testing::ValuesIn( - {mkModelIniTc(reduce_mean_last, DIM_UNUSED, NOT_USED_, - ModelTestOption::COMPARE_V2), - mkModelTc_V2(makeMolAttention, "mol_attention", - ModelTestOption::COMPARE_V2), - mkModelTc_V2(makeMolAttentionMasked, "mol_attention_masked", - ModelTestOption::COMPARE_RUN_V2), - mkModelTc_V2(makeMultiHeadAttention_disable_need_weights, - "multi_head_attention_disable_need_weights", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeMultiHeadAttention, "multi_head_attention", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeMultiHeadAttention_kdim_vdim, - "multi_head_attention_kdim_vdim", ModelTestOption::ALL_V2), - mkModelTc_V2(makeMultiHeadAttention_float_attn_mask, - "multi_head_attention_float_attn_mask", - ModelTestOption::ALL_V2), - /** @todo:change model if bool type tensor is supported */ - // mkModelTc_V2(makeMultiHeadAttention_float_attn_mask, - // "multi_head_attention_pseudo_bool_attn_mask", - // ModelTestOption::ALL_V2), - mkModelTc_V2(makeMultiHeadAttention_self_attention, - "multi_head_attention_self_attention", - ModelTestOption::ALL_V2), - mkModelTc_V2(makePositionalEncoding, "positional_encoding", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformerEncoderLayer, "transformer_encoder_layer", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformerEncoderLayer_float_attn_mask, - "transformer_encoder_layer_float_attn_mask", - ModelTestOption::ALL_V2), - /** @todo:change model if bool type tensor is supported */ - // mkModelTc_V2(makeTransformerEncoderLayer_float_attn_mask, - // "transformer_encoder_layer_pseudo_bool_attn_mask", - // ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformerDecoderLayer, "transformer_decoder_layer", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformerDecoderLayer_float_attn_mask, - "transformer_decoder_layer_float_attn_mask", - ModelTestOption::ALL_V2), - /** @todo:change model if bool type tensor is supported */ - // mkModelTc_V2(makeTransformerDecoderLayer_float_attn_mask, - // "transformer_decoder_layer_pseudo_bool_attn_mask", - // ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformer_single_layer, "transformer_single", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformer_stack_layer, "transformer_stack", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeTransformer_float_attn_mask, - "transformer_float_attn_mask", ModelTestOption::ALL_V2), - // mkModelTc_V2(makeTransformer_float_attn_mask, - // "transformer_pseudo_bool_attn_mask", - // ModelTestOption::ALL_V2), - mkModelIniTc(fc_relu_decay, DIM_UNUSED, NOT_USED_, - ModelTestOption::COMPARE_V2), - mkModelTc_V2(makeNonTrainableFcIdx1, "non_trainable_fc_idx1", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeNonTrainableFcIdx2, "non_trainable_fc_idx2", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeNonTrainableFcIdx3, "non_trainable_fc_idx3", - ModelTestOption::ALL_V2), - mkModelTc_V2(makeAddOperation, "add_operation", ModelTestOption::ALL_V2)}), + ::testing::ValuesIn({ + mkModelIniTc(reduce_mean_last, DIM_UNUSED, NOT_USED_, + ModelTestOption::COMPARE_V2), + mkModelTc_V2(makeMolAttention, "mol_attention", + ModelTestOption::COMPARE_V2), + mkModelTc_V2(makeMolAttentionMasked, "mol_attention_masked", + ModelTestOption::COMPARE_RUN_V2), + mkModelTc_V2(makeMultiHeadAttention_disable_need_weights, + "multi_head_attention_disable_need_weights", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeMultiHeadAttention, "multi_head_attention", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeMultiHeadAttention_kdim_vdim, + "multi_head_attention_kdim_vdim", ModelTestOption::ALL_V2), + mkModelTc_V2(makeMultiHeadAttention_float_attn_mask, + "multi_head_attention_float_attn_mask", + ModelTestOption::ALL_V2), + /** @todo:change model if bool type tensor is supported */ + // mkModelTc_V2(makeMultiHeadAttention_float_attn_mask, + // "multi_head_attention_pseudo_bool_attn_mask", + // ModelTestOption::ALL_V2), + mkModelTc_V2(makeMultiHeadAttention_self_attention, + "multi_head_attention_self_attention", + ModelTestOption::ALL_V2), + mkModelTc_V2(makePositionalEncoding, "positional_encoding", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformerEncoderLayer, "transformer_encoder_layer", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformerEncoderLayer_float_attn_mask, + "transformer_encoder_layer_float_attn_mask", + ModelTestOption::ALL_V2), + /** @todo:change model if bool type tensor is supported */ + // mkModelTc_V2(makeTransformerEncoderLayer_float_attn_mask, + // "transformer_encoder_layer_pseudo_bool_attn_mask", + // ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformerDecoderLayer, "transformer_decoder_layer", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformerDecoderLayer_float_attn_mask, + "transformer_decoder_layer_float_attn_mask", + ModelTestOption::ALL_V2), + /** @todo:change model if bool type tensor is supported */ + // mkModelTc_V2(makeTransformerDecoderLayer_float_attn_mask, + // "transformer_decoder_layer_pseudo_bool_attn_mask", + // ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformer_single_layer, "transformer_single", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformer_stack_layer, "transformer_stack", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeTransformer_float_attn_mask, "transformer_float_attn_mask", + ModelTestOption::ALL_V2), + // mkModelTc_V2(makeTransformer_float_attn_mask, + // "transformer_pseudo_bool_attn_mask", + // ModelTestOption::ALL_V2), + mkModelIniTc(fc_relu_decay, DIM_UNUSED, NOT_USED_, + ModelTestOption::COMPARE_V2), + mkModelTc_V2(makeNonTrainableFcIdx1, "non_trainable_fc_idx1", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeNonTrainableFcIdx2, "non_trainable_fc_idx2", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeNonTrainableFcIdx3, "non_trainable_fc_idx3", + ModelTestOption::ALL_V2), + mkModelTc_V2(makeAddOperation, "add_operation", ModelTestOption::ALL_V2), + mkModelTc_V2(makeSubOperation, "sub_operation", ModelTestOption::ALL_V2), + }), [](const testing::TestParamInfo &info) -> const auto & { return std::get<1>(info.param); }); From 4da3610c6b1de7d5942068a3ad2d362afb2fee2b Mon Sep 17 00:00:00 2001 From: Seungbaek Hong Date: Fri, 30 Aug 2024 15:31:29 +0900 Subject: [PATCH 3/3] [Layer] add "mul layer" - added "mul layer" for multiplication **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Seungbaek Hong --- api/ccapi/include/layer.h | 9 ++ api/nntrainer-api-common.h | 1 + nntrainer/app_context.cpp | 3 + nntrainer/layers/meson.build | 1 + nntrainer/layers/mul_layer.cpp | 96 +++++++++++++++++ nntrainer/layers/mul_layer.h | 103 +++++++++++++++++++ test/ccapi/unittest_ccapi.cpp | 3 + test/input_gen/genModelTests_v2.py | 23 +++++ test/unittest/layers/meson.build | 1 + test/unittest/layers/unittest_layers_mul.cpp | 28 +++++ test/unittest/models/unittest_models.cpp | 20 ++++ 11 files changed, 288 insertions(+) create mode 100644 nntrainer/layers/mul_layer.cpp create mode 100644 nntrainer/layers/mul_layer.h create mode 100644 test/unittest/layers/unittest_layers_mul.cpp diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index bc60c2881c..0fc7789b35 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -39,6 +39,7 @@ enum LayerType { LAYER_WEIGHT = ML_TRAIN_LAYER_TYPE_WEIGHT, /**< Weight Layer type */ LAYER_ADD = ML_TRAIN_LAYER_TYPE_ADD, /**< Add Layer type */ LAYER_SUB = ML_TRAIN_LAYER_TYPE_SUB, /**< Subtract Layer type */ + LAYER_MUL = ML_TRAIN_LAYER_TYPE_MUL, /**< Multiply Layer type */ LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */ LAYER_SWIGLU = ML_TRAIN_LAYER_TYPE_SWIGLU, /**< Swiglu Layer type */ LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */ @@ -317,6 +318,14 @@ SubLayer(const std::vector &properties = {}) { return createLayer(LayerType::LAYER_SUB, properties); } +/** + * @brief Helper function to create mul layer + */ +inline std::unique_ptr +MulLayer(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_MUL, properties); +} + /** * @brief Helper function to create fully connected layer */ diff --git a/api/nntrainer-api-common.h b/api/nntrainer-api-common.h index 1cebedb6e2..fe5fe2fd8e 100644 --- a/api/nntrainer-api-common.h +++ b/api/nntrainer-api-common.h @@ -67,6 +67,7 @@ typedef enum { ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/ ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/ ML_TRAIN_LAYER_TYPE_SUB = 33, /**< Sub Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_MUL = 34, /**< Mul Layer type (Since 9.0)*/ ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP = 300, /**< Preprocess flip Layer (Since 6.5) */ ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE = diff --git a/nntrainer/app_context.cpp b/nntrainer/app_context.cpp index 38d335e2fc..f8d738f88f 100644 --- a/nntrainer/app_context.cpp +++ b/nntrainer/app_context.cpp @@ -58,6 +58,7 @@ #include #include #include +#include #include #include #include @@ -254,6 +255,8 @@ static void add_default_object(AppContext &ac) { LayerType::LAYER_ADD); ac.registerFactory(nntrainer::createLayer, SubLayer::type, LayerType::LAYER_SUB); + ac.registerFactory(nntrainer::createLayer, MulLayer::type, + LayerType::LAYER_MUL); ac.registerFactory(nntrainer::createLayer, FullyConnectedLayer::type, LayerType::LAYER_FC); ac.registerFactory(nntrainer::createLayer, diff --git a/nntrainer/layers/meson.build b/nntrainer/layers/meson.build index 614264871c..e8b4e52682 100644 --- a/nntrainer/layers/meson.build +++ b/nntrainer/layers/meson.build @@ -7,6 +7,7 @@ layer_sources = [ 'weight_layer.cpp', 'add_layer.cpp', 'sub_layer.cpp', + 'mul_layer.cpp', 'addition_layer.cpp', 'attention_layer.cpp', 'mol_attention_layer.cpp', diff --git a/nntrainer/layers/mul_layer.cpp b/nntrainer/layers/mul_layer.cpp new file mode 100644 index 0000000000..dd8281b791 --- /dev/null +++ b/nntrainer/layers/mul_layer.cpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file mul_layer.cpp + * @date 30 August 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + * @brief This is mul layer class (operation layer) + * + */ + +#include +#include +#include +#include +#include + +#include + +namespace nntrainer { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +void MulLayer::finalize(InitLayerContext &context) { + context.setOutputDimensions({context.getInputDimensions()[0]}); +} + +void MulLayer::forwarding(RunLayerContext &context, bool training) { + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + + const Tensor &input0 = context.getInput(0); + const Tensor &input1 = context.getInput(1); + + input0.multiply(input1, hidden_); +} + +void MulLayer::incremental_forwarding(RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) { + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + TensorDim hidden_dim = hidden_.getDim(); + TensorDim hidden_step_dim = hidden_dim; + + if (from) { + NNTR_THROW_IF(to - from != 1, std::invalid_argument) + << "incremental step size is not 1"; + from = 0; + to = 1; + } + + hidden_step_dim.batch(1); + hidden_step_dim.height(to - from); + + for (unsigned int b = 0; b < hidden_.batch(); ++b) { + Tensor hidden_step = hidden_.getSharedDataTensor( + hidden_step_dim, b * hidden_dim.getFeatureLen(), true); + + const Tensor &input0 = context.getInput(0); + const Tensor &input1 = context.getInput(1); + + TensorDim input_dim = input0.getDim(); + TensorDim input_step_dim = input_dim; + input_step_dim.batch(1); + input_step_dim.height(to - from); + + Tensor input0_step = input0.getSharedDataTensor( + input_step_dim, b * input_dim.getFeatureLen(), true); + + Tensor input1_step = input1.getSharedDataTensor( + input_step_dim, b * input_dim.getFeatureLen(), true); + + input0_step.multiply(input1_step, hidden_step); + } +} + +void MulLayer::calcDerivative(RunLayerContext &context) { + context.getOutgoingDerivative(0).copy( + context.getIncomingDerivative(SINGLE_INOUT_IDX) + .multiply(context.getInput(1))); + + context.getOutgoingDerivative(1).copy( + context.getIncomingDerivative(SINGLE_INOUT_IDX) + .multiply(context.getInput(0))); +} + +void MulLayer::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, mul_props); + if (!remain_props.empty()) { + std::string msg = "[MulLayer] Unknown Layer Properties count " + + std::to_string(values.size()); + throw exception::not_supported(msg); + } +} +} /* namespace nntrainer */ diff --git a/nntrainer/layers/mul_layer.h b/nntrainer/layers/mul_layer.h new file mode 100644 index 0000000000..d721a6f6a2 --- /dev/null +++ b/nntrainer/layers/mul_layer.h @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file mul_layer.h + * @date 30 August 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + * @brief This is mul layer class (operation layer) + * + */ + +#ifndef __MUL_LAYER_H__ +#define __MUL_LAYER_H__ +#ifdef __cplusplus + +#include +#include + +namespace nntrainer { + +/** + * @class Mul Layer + * @brief Mul Layer + */ +class MulLayer : public Layer { +public: + /** + * @brief Constructor of Mul Layer + */ + MulLayer() : Layer(), mul_props(props::Print()) {} + + /** + * @brief Destructor of Mul Layer + */ + ~MulLayer(){}; + + /** + * @brief Move constructor of Mul Layer. + * @param[in] MulLayer && + */ + MulLayer(MulLayer &&rhs) noexcept = default; + + /** + * @brief Move assignment operator. + * @parma[in] rhs MulLayer to be moved. + */ + MulLayer &operator=(MulLayer &&rhs) = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ + void incremental_forwarding(RunLayerContext &context, unsigned int from, + unsigned int to, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) override; + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const override { return true; }; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods + * method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const override {} + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { return MulLayer::type; }; + + std::tuple mul_props; + + inline static const std::string type = "mul"; +}; + +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __MUL_LAYER_H__ */ diff --git a/test/ccapi/unittest_ccapi.cpp b/test/ccapi/unittest_ccapi.cpp index ab17ef9b40..8bbd10e531 100644 --- a/test/ccapi/unittest_ccapi.cpp +++ b/test/ccapi/unittest_ccapi.cpp @@ -70,6 +70,9 @@ TEST(ccapi_layer, construct_02_p) { EXPECT_NO_THROW(layer = ml::train::layer::SubLayer()); EXPECT_EQ(layer->getType(), "sub"); + EXPECT_NO_THROW(layer = ml::train::layer::MulLayer()); + EXPECT_EQ(layer->getType(), "mul"); + EXPECT_NO_THROW(layer = ml::train::layer::FullyConnected()); EXPECT_EQ(layer->getType(), "fully_connected"); diff --git a/test/input_gen/genModelTests_v2.py b/test/input_gen/genModelTests_v2.py index 6a4813f2b4..920b857a2d 100644 --- a/test/input_gen/genModelTests_v2.py +++ b/test/input_gen/genModelTests_v2.py @@ -455,6 +455,19 @@ def forward(self, inputs, labels): return out, loss +class MulOperation(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(2, 2) + self.loss = torch.nn.MSELoss() + + def forward(self, inputs, labels): + out = self.fc(inputs[0]) + out = inputs[0] * out + loss = self.loss(out, labels[0]) + return out, loss + + if __name__ == "__main__": record_v2( ReduceMeanLast(), @@ -749,5 +762,15 @@ def forward(self, inputs, labels): name="sub_operation", ) + mul_operation = MulOperation() + record_v2( + mul_operation, + iteration=2, + input_dims=[(1, 2)], + input_dtype=[float], + label_dims=[(1, 2)], + name="mul_operation", + ) + # Function to check the created golden test file inspect_file("sub_operation.nnmodelgolden") diff --git a/test/unittest/layers/meson.build b/test/unittest/layers/meson.build index 31c087a4b9..7ce48ccda9 100644 --- a/test/unittest/layers/meson.build +++ b/test/unittest/layers/meson.build @@ -49,6 +49,7 @@ test_target = [ 'unittest_layers_addition.cpp', 'unittest_layers_add.cpp', 'unittest_layers_sub.cpp', + 'unittest_layers_mul.cpp', 'unittest_layers_multiout.cpp', 'unittest_layers_rnn.cpp', 'unittest_layers_rnncell.cpp', diff --git a/test/unittest/layers/unittest_layers_mul.cpp b/test/unittest/layers/unittest_layers_mul.cpp new file mode 100644 index 0000000000..76c18b3d5a --- /dev/null +++ b/test/unittest/layers/unittest_layers_mul.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file unittest_layers_mul.cpp + * @date 30 August 2024 + * @brief Mul Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include + +auto semantic_mul = LayerSemanticsParamType( + nntrainer::createLayer, nntrainer::MulLayer::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +auto semantic_mul_multi = LayerSemanticsParamType( + nntrainer::createLayer, nntrainer::MulLayer::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 2); + +GTEST_PARAMETER_TEST(Mul, LayerSemantics, + ::testing::Values(semantic_mul, semantic_mul_multi)); diff --git a/test/unittest/models/unittest_models.cpp b/test/unittest/models/unittest_models.cpp index 0be48ccc0e..6f2214ab79 100644 --- a/test/unittest/models/unittest_models.cpp +++ b/test/unittest/models/unittest_models.cpp @@ -910,6 +910,25 @@ static std::unique_ptr makeSubOperation() { return nn; } +static std::unique_ptr makeMulOperation() { + std::unique_ptr nn(new NeuralNetwork()); + + auto outer_graph = + makeGraph({{"input", {"name=in", "input_shape=1:1:2"}}, + {"fully_connected", {"name=fc", "unit=2", "input_layers=in"}}, + {"mul", {"name=mul_layer", "input_layers=in,fc"}}, + {"mse", {"name=loss", "input_layers=mul_layer"}}}); + + for (auto &node : outer_graph) { + nn->addLayer(node); + } + + nn->setProperty({"batch_size=1"}); + nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate=0.1"})); + + return nn; +} + GTEST_PARAMETER_TEST( model, nntrainerModelTest, ::testing::ValuesIn({ @@ -975,6 +994,7 @@ GTEST_PARAMETER_TEST( ModelTestOption::ALL_V2), mkModelTc_V2(makeAddOperation, "add_operation", ModelTestOption::ALL_V2), mkModelTc_V2(makeSubOperation, "sub_operation", ModelTestOption::ALL_V2), + mkModelTc_V2(makeMulOperation, "mul_operation", ModelTestOption::ALL_V2), }), [](const testing::TestParamInfo &info) -> const auto & { return std::get<1>(info.param); });