Skip to content

Commit

Permalink
Add tril_ layer for lower triangular matrix operations (#3018)
Browse files Browse the repository at this point in the history
* Add tril_ layer for lower triangular matrix operations

* Improved layer consistency

* Added constant_wrapper to fix the issue of the float in the template in c++17

* Looking for a solution for c++ 14

* Refactor tril_ layer for improved flexibility and C++14 compatibility

* Updates

* Updates

* Updates

* Updates

* Updates

* Updates
  • Loading branch information
Cydral authored Sep 30, 2024
1 parent 72822fe commit 4e53f83
Show file tree
Hide file tree
Showing 4 changed files with 346 additions and 0 deletions.
126 changes: 126 additions & 0 deletions dlib/dnn/layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -4696,6 +4696,132 @@ namespace dlib

template <typename SUBNET> using transpose = add_layer<transpose_, SUBNET>;

// ----------------------------------------------------------------------------------------

struct neg_infinity_tag {};
struct zero_tag {};

template<typename T>
struct is_special_value : std::false_type {};
template<>
struct is_special_value<neg_infinity_tag> : std::true_type {};
template<>
struct is_special_value<zero_tag> : std::true_type {};

template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
class tril_
{
public:
tril_(): diag(diag_), diag_value(compute_diag_value()) {}

template <typename SUBNET>
void setup(const SUBNET& /*sub*/)
{
}

template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
auto& prev = sub.get_output();
output.set_size(prev.num_samples(), prev.k(), prev.nr(), prev.nc());

check_mask(prev);
tt::multiply(false, output, prev, binary_mask);
if (diag_value != 0.0f) tt::add(1, output, 1, output_mask);
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
{
auto& prev_grad = sub.get_gradient_input();
tt::multiply(true, prev_grad, gradient_input, binary_mask);
}

inline dpoint map_input_to_output(const dpoint& p) const { return p; }
inline dpoint map_output_to_input(const dpoint& p) const { return p; }

const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }

friend void serialize(const tril_& item, std::ostream& out)
{
serialize("tril_", out);
serialize(item.diag, out);
serialize(item.diag_value, out);
}
friend void deserialize(tril_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "tril_")
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::tril_.");
deserialize(item.diag, in);
deserialize(item.diag_value, in);
}

friend std::ostream& operator<<(std::ostream& out, const tril_& item)
{
out << "tril (diag=" << item.diag << ", diag_value=" << item.diag_value << ")";
return out;
}
friend void to_xml(const tril_& item, std::ostream& out)
{
out << "<tril diag='" << item.diag << "' diag_value='" << item.diag_value << "'/>\n";
}

private:
float compute_diag_value() const {
if (std::is_same<tag_, neg_infinity_tag>::value)
return -std::numeric_limits<float>::infinity();
else if (std::is_same<tag_, zero_tag>::value)
return 0.0f;
else
return static_cast<float>(num_) / static_cast<float>(den_);
}

void check_mask(const tensor& t)
{
if (!have_same_dimensions(binary_mask, t)) {
binary_mask.copy_size(t);
binary_mask = 1;
if (diag_value != 0.0f) {
output_mask.copy_size(t);
output_mask = 0;
}
for (long s = 0; s < output_mask.num_samples(); ++s)
{
for (long k = 0; k < output_mask.k(); ++k)
{
for (long r = 0; r < output_mask.nr(); ++r)
{
for (long c = std::max(r + diag + 1, 0L); c < output_mask.nc(); ++c)
{
if (diag_value != 0.0f) output_mask.host()[tensor_index(output_mask, s, k, r, c)] = diag_value;
binary_mask.host()[tensor_index(binary_mask, s, k, r, c)] = 0;
}
}
}
}
}
}

template <typename T>
struct always_false : std::false_type {};

resizable_tensor params; // unused
resizable_tensor binary_mask, output_mask;
long diag;
float diag_value;
};

template <typename SUBNET>
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;

template <typename SUBNET>
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;

template <long diag, long num, long den, typename SUBNET>
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;

// ----------------------------------------------------------------------------------------

}
Expand Down
156 changes: 156 additions & 0 deletions dlib/dnn/layers_abstract.h
Original file line number Diff line number Diff line change
Expand Up @@ -3711,6 +3711,162 @@ namespace dlib
template <typename SUBNET>
using transpose = add_layer<transpose_, SUBNET>;

// ----------------------------------------------------------------------------------------

struct neg_infinity_tag {};
struct zero_tag {};

template<typename T>
struct is_special_value : std::false_type {};
template<>
struct is_special_value<neg_infinity_tag> : std::true_type {};
template<>
struct is_special_value<zero_tag> : std::true_type {};

template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
class tril_
{
/*!
TEMPLATE PARAMETERS
- diag_: A long integer specifying the diagonal offset.
- tag_: A type tag specifying special values or void for numeric values.
- num_: Numerator for numeric diagonal value (default is 0, only used if tag_ is void).
- den_: Denominator for numeric diagonal value (default is 1, only used if tag_ is void).
REQUIREMENTS
- diag_ must be an integer.
- tag_ must be either neg_infinity_tag, zero_tag, or void.
- If tag_ is void, num_ and den_ are used to compute the diagonal value.
- If tag_ is neg_infinity_tag or zero_tag, num_ and den_ are ignored.
WHAT THIS OBJECT REPRESENTS
This object implements a layer in a deep neural network that applies a lower triangular mask to
its input tensor. The mask is defined such that all elements above the specified diagonal are set
to a given value. The diagonal offset and the mask value are determined by the template parameters.
DIAGONAL VALUE DETERMINATION
- If tag_ is neg_infinity_tag: diagonal value is set to negative infinity.
- If tag_ is zero_tag: diagonal value is set to zero.
- If tag_ is void: diagonal value is set to num_ / den_ as a float.
DIAGONAL OFFSET
The diag_ parameter determines the diagonal above which elements are masked:
- diag_ = 0: main diagonal
- diag_ > 0: diag_ steps above the main diagonal
- diag_ < 0: |diag_| steps below the main diagonal
EXAMPLE USAGE
// Create a layer that masks all elements above the main diagonal with -inf
tril_<0, neg_infinity_tag> layer1;
// Create a layer that masks all elements above the main diagonal with 0
tril_<0, zero_tag> layer2;
// Create a layer that masks all elements above the main diagonal with 0.5
tril_<0, void, 1, 2> layer3;
// Create a layer that masks all elements 5 positions above the main diagonal with -inf
tril_<5, neg_infinity_tag> layer4;
// Create a layer that masks all elements 3 positions below the main diagonal with 0.25
tril_<-3, void, 1, 4> layer5;
SERIALIZATION SUPPORT
This object supports serialization and deserialization via the serialize() and deserialize() functions.
!*/

public:
tril_() = default;
/*!
ensures
- This object is properly initialized.
!*/

template <typename SUBNET>
void setup(const SUBNET& sub);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Initializes the mask based on the dimensions of the input tensor from sub.
!*/

template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Applies the lower triangular mask to the input tensor from sub and stores the result in output.
!*/

template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Computes the gradient of the loss with respect to the input tensor and stores it in sub.
!*/

inline dpoint map_input_to_output(const dpoint& p) const;
/*!
ensures
- Maps a point from the input tensor to the corresponding point in the output tensor.
!*/

inline dpoint map_output_to_input(const dpoint& p) const;
/*!
ensures
- Maps a point from the output tensor to the corresponding point in the input tensor.
!*/

const tensor& get_layer_params() const;
/*!
ensures
- Returns the parameters of this layer.
!*/

tensor& get_layer_params();
/*!
ensures
- Returns the parameters of this layer.
!*/

friend void serialize(const tril_& item, std::ostream& out);
/*!
ensures
- Serializes the state of this object to the given output stream.
!*/

friend void deserialize(tril_& item, std::istream& in);
/*!
ensures
- Deserializes the state of this object from the given input stream.
!*/

friend std::ostream& operator<<(std::ostream& out, const tril_& item);
/*!
ensures
- Prints a human-readable representation of this object to the given output stream.
!*/

friend void to_xml(const tril_& item, std::ostream& out);
/*!
ensures
- Serializes the state of this object to XML format and writes it to the given output stream.
!*/
};

template <typename SUBNET>
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;

template <typename SUBNET>
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;

template <long diag, long num, long den, typename SUBNET>
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;

// ----------------------------------------------------------------------------------------

}
Expand Down
16 changes: 16 additions & 0 deletions dlib/dnn/visitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,22 @@ namespace dlib
update(i);
}

template <long diag, typename tag, long num, long den, typename U, typename E>
void operator()(size_t i, const add_layer<tril_<diag, tag, num, den>, U, E>&)
{
start_node(i, "tril");
out << " | {diag|{" << diag << "}}";
out << " | {diag_value|{";

if (std::is_same<tag, neg_infinity_tag>::value) out << "-inf";
else if (std::is_same<tag, zero_tag>::value) out << "0";
else out << static_cast<float>(num) / static_cast<float>(den);

out << "}}";
end_node();
update(i);
}

template <typename T, typename U, typename E>
void operator()(size_t i, const add_layer<T, U, E>&)
{
Expand Down
48 changes: 48 additions & 0 deletions dlib/test/dnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,12 @@ namespace
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
tril_<-5, void, 1, 2> l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
extract_<0,2,2,2> l;
Expand Down Expand Up @@ -4447,6 +4453,47 @@ namespace
}
}

// ----------------------------------------------------------------------------------------

void test_tril()
{
print_spinner();
using net_type = tag1<tril_mask<tag2<input<matrix<float>>>>>;
net_type net;

// Input tensor
dlib::rand rnd;
const int nr = 2, nc = 3;
constexpr int n_samples = 3, k = 1;
std::vector<matrix<float>> x(n_samples);
matrix<float> xtmp(nr, nc);
for (int ii = 0; ii < n_samples; ++ii) {
for (int jj = 0; jj < nr; ++jj)
for (int kk = 0; kk < nc; ++kk)
xtmp(jj, kk) = rnd.get_random_gaussian();
x[ii] = xtmp;
}

// Convert input matrix to tensor
resizable_tensor input_tensor;
net.to_tensor(&x[0], &x[0] + n_samples, input_tensor);
net.forward(input_tensor);

// Expected output tensor (manually set for comparison)
resizable_tensor expected_output;
expected_output.copy_size(input_tensor);
tt::copy_tensor(false, expected_output, 0, input_tensor, 0, input_tensor.k());
for (int ii = 0; ii < n_samples; ++ii) {
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = -std::numeric_limits<float>::infinity();
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = -std::numeric_limits<float>::infinity();
expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = -std::numeric_limits<float>::infinity();
}

// Compare output tensor with expected output
auto& net_output = layer<tag1>(net).get_output();
DLIB_TEST(max(abs(mat(net_output) - mat(expected_output))) < 1e-5);
}

// ----------------------------------------------------------------------------------------

class dnn_tester : public tester
Expand Down Expand Up @@ -4527,6 +4574,7 @@ namespace
test_layer_normalize();
test_rms_normalize();
test_transpose();
test_tril();
test_basic_tensor_ops();
test_layers();
test_visit_functions();
Expand Down

0 comments on commit 4e53f83

Please sign in to comment.