Skip to content

Commit

Permalink
[ LORA ] Bugfix in LoRA support in FC Layer
Browse files Browse the repository at this point in the history
- In the previous code, LoRA didn't work for the case batch_size > 1.
- Tensors used in LoRA-related computation were not updated when the
batch size is upsted.
- `setBatch()` function is implemented for `FullyConnectedLayer`.
- BugFix in Lifespan of loraTmp Tensor: FORWARD_DERIV_LIFESPANE ->
FORWARD_GRAD_LIFESPAN

Self evaluation:

	Build test: [X]Passed [ ]Failed [ ]Skipped
	Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Eunju Yang <[email protected]>
  • Loading branch information
EunjuYang authored and jijoongmoon committed Sep 22, 2024
1 parent 4ae1477 commit 8104cbe
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
34 changes: 27 additions & 7 deletions nntrainer/layers/fc_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,14 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
/** set weight specifications */
// @todo : This NCHW format setting is just temporal, it needs to be set by
// global configuration

/** Bias Dimension : (1, 1, 1, unit) */
TensorDim bias_dim(
1, is_nchw ? 1 : unit, 1, is_nchw ? unit : 1,
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b0001 : 0b0100);

/** Weight Dimension : (1, 1, in_dim.width(), unit)*/
TensorDim weight_dim(
1, is_nchw ? 1 : unit, is_nchw ? in_dim.width() : 1,
is_nchw ? unit : in_dim.channel(),
Expand All @@ -115,25 +118,33 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
/** create weights for LoRA */
if (lora_rank) {

/** loraA : (in_dim.width, lora_rank) */
/** loraA Dimension : (1, 1, in_dim.width, lora_rank) */
TensorDim loraA_dim(
1, is_nchw ? 1 : lora_rank, is_nchw ? in_dim.width() : 1,
is_nchw ? lora_rank : in_dim.channel(),
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b0011 : 0b0101);

/** loraB: (lora_rank, out_dim) */
/** loraB Dimension : (1, 1, lora_rank, unit) */
TensorDim loraB_dim(
1, is_nchw ? 1 : unit, is_nchw ? lora_rank : 1,
is_nchw ? unit : lora_rank,
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b0011 : 0b0101);

/** loraTmp: (1, lora_rank) */
/** loraTmp Dimension : (B, 1, in_dim.height(), lora_rank) */
TensorDim loraTmp_dim(
1, is_nchw ? 1 : lora_rank, 1, is_nchw ? lora_rank : 1,
in_dim.batch(), is_nchw ? 1 : lora_rank, is_nchw ? in_dim.height() : 1,
is_nchw ? lora_rank : in_dim.width(),
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b1011 : 0b1101);

/** loraTmp Dimension : (B, 1, in_dim.height(), unit) */
TensorDim loraOut_dim(
in_dim.batch(), is_nchw ? 1 : unit, is_nchw ? in_dim.height() : 1,
is_nchw ? unit : in_dim.width(),
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b0001 : 0b0100);
is_nchw ? 0b1011 : 0b1101);

lora_idx[LORAParams::loraA] = context.requestWeight(
loraA_dim, Initializer::ZEROS, weight_regularizer,
Expand All @@ -145,10 +156,10 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {

lora_idx[LORAParams::loraTmp] =
context.requestTensor(loraTmp_dim, "hidden_tmp_lora", Initializer::NONE,
true, TensorLifespan::FORWARD_DERIV_LIFESPAN);
true, TensorLifespan::FORWARD_GRAD_LIFESPAN);

lora_idx[LORAParams::loraOut] =
context.requestTensor(bias_dim, "hidden_lora", Initializer::NONE, true,
context.requestTensor(loraOut_dim, "hidden_lora", Initializer::NONE, true,
TensorLifespan::FORWARD_FUNC_LIFESPAN);
}
}
Expand All @@ -164,6 +175,15 @@ void FullyConnectedLayer::setProperty(const std::vector<std::string> &values) {
LayerImpl::setProperty(remain_props);
}

void FullyConnectedLayer::setBatch(nntrainer::RunLayerContext &context,
unsigned int batch) {
if (!std::get<props::LoraRank>(fc_props).empty()) {
// update Lora Tensor's batch info.
context.updateTensor(lora_idx[LORAParams::loraTmp], batch);
context.updateTensor(lora_idx[LORAParams::loraOut], batch);
}
}

void FullyConnectedLayer::forwarding(RunLayerContext &context, bool training) {
Tensor &weight = context.getWeight(weight_idx[FCParams::weight]);
Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
Expand Down
6 changes: 6 additions & 0 deletions nntrainer/layers/fc_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ class FullyConnectedLayer : public LayerImpl {
*/
void setProperty(const std::vector<std::string> &values) override;

/**
* @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
*/
void setBatch(nntrainer::RunLayerContext &context,
unsigned int batch) override;

inline static const std::string type = "fully_connected";

private:
Expand Down

0 comments on commit 8104cbe

Please sign in to comment.