Skip to content

Commit

Permalink
Merge pull request #44 from robertknight/layer-norm-op
Browse files Browse the repository at this point in the history
Add LayerNormalization operator, Depth Anything example
  • Loading branch information
robertknight authored Jan 23, 2024
2 parents ca699cb + 7d12d21 commit 178ccdf
Show file tree
Hide file tree
Showing 13 changed files with 643 additions and 44 deletions.
4 changes: 4 additions & 0 deletions rten-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ path = "src/imagenet.rs"
name = "yolo"
path = "src/yolo.rs"

[[bin]]
name = "depth_anything"
path = "src/depth_anything.rs"

# Text
[[bin]]
name = "bert_qa"
Expand Down
1 change: 1 addition & 0 deletions rten-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ The examples have been chosen to cover common tasks and popular models.
This example works with a wide variety of models, such as ResNet, MobileNet,
ConvNeXt, ViT.
- **deeplab** - Semantic segmentation of images using [DeepLabv3](https://arxiv.org/abs/1706.05587)
- **depth_anything** - Monocular depth estimation using [Depth Anything](https://github.com/LiheYoung/Depth-Anything)
- **detr** - Object detection using [DETR](https://research.facebook.com/publications/end-to-end-object-detection-with-transformers/)
- **yolo** - Object detection using [YOLO v8](https://github.com/ultralytics/ultralytics)

Expand Down
113 changes: 113 additions & 0 deletions rten-examples/src/depth_anything.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::collections::VecDeque;
use std::error::Error;
use std::fs;

use rten::{FloatOperators, Model};
use rten_imageio::{normalize_image, read_image, write_image};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};

struct Args {
model: String,
image: String,
output: String,
}

fn parse_args() -> Result<Args, lexopt::Error> {
use lexopt::prelude::*;

let mut values = VecDeque::new();
let mut parser = lexopt::Parser::from_env();

while let Some(arg) = parser.next()? {
match arg {
Value(val) => values.push_back(val.string()?),
Long("help") => {
println!(
"Perform monocular depth estimation on an image.
Usage: {bin_name} <model> <image> [<output>]
Args:
<model> - Input Depth Anything model
<image> - Image to process
<output> - Path to save depth image to. Defaults to \"depth-map.png\".
",
bin_name = parser.bin_name().unwrap_or("deeplab")
);
std::process::exit(0);
}
_ => return Err(arg.unexpected()),
}
}

let model = values.pop_front().ok_or("missing `model` arg")?;
let image = values.pop_front().ok_or("missing `image` arg")?;
let output = values.pop_front().unwrap_or("depth-map.png".into());

let args = Args {
image,
model,
output,
};

Ok(args)
}

/// Perform monocular depth estimation using [Depth Anything][depth_anything].
///
/// The ONNX models can be obtained from
/// https://github.com/fabio-sim/Depth-Anything-ONNX. See the
/// [releases](https://github.com/fabio-sim/Depth-Anything-ONNX/releases) page
/// for pre-trained model links. The small ("vits") model is recommended for
/// CPU inference.
///
/// After downloading the model, it can be run on an image using:
///
/// ```
/// tools/convert-onnx.py depth_anything.onnx
/// cargo run --release --bin depth_anything depth_anything.rten image.jpg
/// ```
///
/// This will generate a depth map as `depth-map.png`.
///
/// [depth_anything]: <https://github.com/LiheYoung/Depth-Anything>
fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;
let model_bytes = fs::read(args.model)?;
let model = Model::load(&model_bytes)?;

let mut image: Tensor = read_image(&args.image)?.into();
let [_, orig_height, orig_width] = image.shape().try_into()?;
normalize_image(image.nd_view_mut());
image.insert_axis(0); // Add batch dim

// Input size taken from README in https://github.com/fabio-sim/Depth-Anything-ONNX.
let [input_h, input_w] = [518, 518];
let image = image.resize_image([input_h, input_w])?;

// Run model to estimate depth for each pixel.
// Generates a (batch, depth, height, width) tensor, where `depth` == 1.
let mut output: NdTensor<f32, 4> = model.run_one(image.view().into(), None)?.try_into()?;

// Normalize depth values to be in the range [0, 1].
let min = output
.reduce_min(None, false /* keep_dims */)?
.item()
.copied()
.unwrap();
let max = output
.reduce_max(None, false /* keep_dims */)?
.item()
.copied()
.unwrap();
output.apply(|x| (x - min) / (max - min));

// Resize output map back to original input size and write to file.
let resized = output.resize_image([orig_height, orig_width])?;
let resized = resized.slice::<3, _>(0);
write_image(&args.output, resized)?;

Ok(())
}
21 changes: 21 additions & 0 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ impl_default_factory!(HardSigmoid, read_hard_sigmoid_op);
impl_default_factory!(HardSwish);
impl_default_factory!(Identity);
impl_default_factory!(InstanceNormalization, read_instance_normalization_op);
impl_default_factory!(LayerNormalization, read_layer_normalization_op);
impl_default_factory!(LeakyRelu, read_leaky_relu_op);
impl_default_factory!(Less);
impl_default_factory!(LessOrEqual);
Expand Down Expand Up @@ -469,6 +470,7 @@ impl OpRegistry {
register_op!(HardSwish);
register_op!(Identity);
register_op!(InstanceNormalization);
register_op!(LayerNormalization);
register_op!(LeakyRelu);
register_op!(Less);
register_op!(LessOrEqual);
Expand Down Expand Up @@ -724,6 +726,16 @@ fn read_instance_normalization_op(node: &OperatorNode) -> ReadOpResult {
}))
}

fn read_layer_normalization_op(node: &OperatorNode) -> ReadOpResult {
let attrs = node
.attrs_as_layer_normalization_attrs()
.ok_or(ReadOpError::AttrError)?;
Ok(Box::new(ops::LayerNormalization {
axis: attrs.axis() as isize,
epsilon: Some(attrs.epsilon()),
}))
}

fn read_leaky_relu_op(node: &OperatorNode) -> ReadOpResult {
let attrs = node
.attrs_as_leaky_relu_attrs()
Expand Down Expand Up @@ -842,6 +854,7 @@ fn read_resize_op(node: &OperatorNode) -> ReadOpResult {
let coord_mode = match attrs.coord_mode() {
sg::CoordTransformMode::Asymmetric => CoordTransformMode::Asymmetric,
sg::CoordTransformMode::HalfPixel => CoordTransformMode::HalfPixel,
sg::CoordTransformMode::AlignCorners => CoordTransformMode::AlignCorners,
_ => CoordTransformMode::default(),
};

Expand Down Expand Up @@ -1407,6 +1420,14 @@ mod tests {
input_node, instance_norm_scale, instance_norm_bias
], { epsilon: Some(1e-5) });

let layer_norm_scale_val = tensor!([1.0]);
let layer_norm_scale = builder.add_float_constant(&layer_norm_scale_val);
let layer_norm_bias_val = tensor!([1.0]);
let layer_norm_bias = builder.add_float_constant(&layer_norm_bias_val);
add_operator!(LayerNormalization, [
input_node, layer_norm_scale, layer_norm_bias
], { axis: -1, epsilon: Some(1e-5) });

add_operator!(LeakyRelu, [input_node], { alpha: 0.01 });
add_operator!(Less, [input_node, input_node]);
add_operator!(LessOrEqual, [input_node, input_node]);
Expand Down
18 changes: 14 additions & 4 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use crate::graph::Dimension;
use crate::ops::{
ArgMax, ArgMin, AveragePool, BatchNormalization, BoxOrder, Cast, Concat, ConstantOfShape, Conv,
ConvTranspose, CoordTransformMode, DataType, Flatten, Gather, GatherElements, Gemm,
HardSigmoid, InstanceNormalization, LeakyRelu, LogSoftmax, MaxPool, Mod, NearestMode,
NonMaxSuppression, OneHot, Padding, ReduceMax, ReduceMean, ReduceMin, ReduceProd, ReduceSum,
Reshape, Resize, ResizeMode, Scalar, ScatterElements, ScatterReduction, Softmax, Split, TopK,
Transpose, Trilu,
HardSigmoid, InstanceNormalization, LayerNormalization, LeakyRelu, LogSoftmax, MaxPool, Mod,
NearestMode, NonMaxSuppression, OneHot, Padding, ReduceMax, ReduceMean, ReduceMin, ReduceProd,
ReduceSum, Reshape, Resize, ResizeMode, Scalar, ScatterElements, ScatterReduction, Softmax,
Split, TopK, Transpose, Trilu,
};
use crate::schema_generated as sg;

Expand Down Expand Up @@ -52,6 +52,7 @@ pub enum OpType {
HardSwish,
Identity,
InstanceNormalization(InstanceNormalization),
LayerNormalization(LayerNormalization),
LeakyRelu(LeakyRelu),
Less,
LessOrEqual,
Expand Down Expand Up @@ -472,6 +473,14 @@ impl<'a> ModelBuilder<'a> {
epsilon: args.epsilon.unwrap_or(1e-5)
}
),
OpType::LayerNormalization(args) => op_with_attrs!(
LayerNormalization,
LayerNormalizationAttrs,
sg::LayerNormalizationAttrsArgs {
axis: args.axis as i32,
epsilon: args.epsilon.unwrap_or(1e-5)
}
),
OpType::LeakyRelu(args) => op_with_attrs!(
LeakyRelu,
LeakyReluAttrs,
Expand Down Expand Up @@ -565,6 +574,7 @@ impl<'a> ModelBuilder<'a> {
let coord_mode = match args.coord_mode {
CoordTransformMode::Asymmetric => sg::CoordTransformMode::Asymmetric,
CoordTransformMode::HalfPixel => sg::CoordTransformMode::HalfPixel,
CoordTransformMode::AlignCorners => sg::CoordTransformMode::AlignCorners,
};
let nearest_mode = match args.nearest_mode {
NearestMode::Ceil => sg::NearestMode::Ceil,
Expand Down
4 changes: 2 additions & 2 deletions src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ pub use layout::{
pub use matmul::{gemm_op, matmul, Gemm, MatMul};
pub use non_max_suppression::{non_max_suppression, BoxOrder, NonMaxSuppression};
pub use norm::{
batch_norm, batch_norm_in_place, instance_normalization, log_softmax, softmax,
BatchNormalization, InstanceNormalization, LogSoftmax, Softmax,
batch_norm, batch_norm_in_place, instance_normalization, layer_normalization, log_softmax,
softmax, BatchNormalization, InstanceNormalization, LayerNormalization, LogSoftmax, Softmax,
};
pub use pad::{pad, Pad};
pub use pooling::{
Expand Down
109 changes: 108 additions & 1 deletion src/ops/norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use rayon::prelude::*;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, Tensor, TensorView};
use rten_vecmath::vec_softmax_in_place;
use smallvec::SmallVec;

use crate::ops::{add, mul, reduce_mean, sub};
use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Output};
use crate::slice_reductions::{slice_max, slice_sum};
use crate::{check_dims, static_dims};
Expand Down Expand Up @@ -228,6 +230,79 @@ impl Operator for InstanceNormalization {
}
}

pub fn layer_normalization(
input: TensorView,
scale: TensorView,
bias: Option<TensorView>,
axis: isize,
epsilon: Option<f32>,
) -> Result<Tensor, OpError> {
if !scale.can_broadcast_to(input.shape()) {
return Err(OpError::IncompatibleInputShapes(
"`scale` cannot be broadcast to input shape",
));
}
if let Some(bias) = bias.as_ref() {
if !bias.can_broadcast_to(input.shape()) {
return Err(OpError::IncompatibleInputShapes(
"`bias` cannot be broadcast to input shape",
));
}
}

let epsilon = epsilon.unwrap_or(1e-5);
let resolved_axis = resolve_axis(input.ndim(), axis)?;
let normalized_axes: SmallVec<[i32; 5]> = (resolved_axis..input.ndim())
.map(|axis| axis as i32)
.collect();

// First step: standardize input elements to have unit mean and variance.
let mean = reduce_mean(
input.view(),
Some(normalized_axes.as_slice()),
true, /* keep_dims */
)?;
let d = sub(input, mean.view())?;
let dd = mul(d.view(), d.view())?;
let var = reduce_mean(
dd.view(),
Some(normalized_axes.as_slice()),
true, /* keep_dims */
)?;
let inverse_std_dev = var.map(|x| 1. / (x + epsilon).sqrt());
let normalized = mul(d.view(), inverse_std_dev.view())?;

// Second step: Shift and scale input.
let normalized_scaled = mul(normalized.view(), scale)?;
let output = if let Some(bias) = bias {
add(normalized_scaled.view(), bias)?
} else {
normalized_scaled
};

Ok(output)
}

#[derive(Debug)]
pub struct LayerNormalization {
pub axis: isize,
pub epsilon: Option<f32>,
}

impl Operator for LayerNormalization {
fn name(&self) -> &str {
"LayerNormalization"
}

fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
let input = inputs.require_as(0)?;
let scale = inputs.require_as(1)?;
let bias = inputs.get_as(2)?;

layer_normalization(input.view(), scale, bias, self.axis, self.epsilon).into_op_result()
}
}

pub fn log_softmax(input: TensorView, axis: isize) -> Result<Tensor, OpError> {
let mut output = input.to_tensor();
log_softmax_in_place(&mut output, axis)?;
Expand Down Expand Up @@ -375,7 +450,8 @@ mod tests {

use crate::ops::tests::expect_eq_1e4;
use crate::ops::{
batch_norm, batch_norm_in_place, instance_normalization, log_softmax, softmax,
batch_norm, batch_norm_in_place, instance_normalization, layer_normalization, log_softmax,
softmax,
};

#[test]
Expand Down Expand Up @@ -460,6 +536,37 @@ mod tests {
Ok(())
}

#[test]
fn test_layer_normalization() -> Result<(), Box<dyn Error>> {
// Sample values generated using `torch.rand`.
let input = tensor!((1, 5, 2); [
0.9562, 0.0572, 0.4366, 0.5655, 0.2017,
0.0230, 0.7941, 0.1554, 0.3226, 0.120
]);
let scale = tensor!([0.0751, 0.6952]);
let bias = tensor!([0.9993, 0.7632]);

let result = layer_normalization(
input.view(),
scale.view(),
Some(bias.view()),
-1, /* axis */
None, /* epsilon */
)
.unwrap();

let expected = Tensor::from([[
[1.0744, 0.0680],
[0.9243, 1.4576],
[1.0744, 0.0684],
[1.0744, 0.0680],
[1.0744, 0.0683],
]]);
expect_eq_1e4(&result, &expected)?;

Ok(())
}

#[test]
fn test_log_softmax() -> Result<(), Box<dyn Error>> {
// 1D input
Expand Down
Loading

0 comments on commit 178ccdf

Please sign in to comment.