Skip to content

Commit

Permalink
Add tril_ layer for lower triangular matrix operations
Browse files Browse the repository at this point in the history
  • Loading branch information
Cydral authored Sep 23, 2024
1 parent 72822fe commit 354e5d0
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 0 deletions.
106 changes: 106 additions & 0 deletions dlib/dnn/layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -4696,6 +4696,112 @@ namespace dlib

template <typename SUBNET> using transpose = add_layer<transpose_, SUBNET>;

// ----------------------------------------------------------------------------------------

template <typename T, T val>
struct float_constant {
static constexpr T value = val;
};

template <long diag_, typename diag_value_>
class tril_
{
public:
tril_(): diag(diag_), diag_value(diag_value_::value) {}

template <typename SUBNET>
void setup(const SUBNET& sub) {
initialize_mask(sub.get_output());
}

template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
auto& prev = sub.get_output();
output.set_size(prev.num_samples(), prev.k(), prev.nr(), prev.nc());
tt::multiply(false, output, prev, binary_mask);
if (diag_value != 0.0f) tt::add(1, output, 1, output_mask);
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
{
auto& prev_grad = sub.get_gradient_input();
tt::multiply(true, prev_grad, gradient_input, binary_mask);
}

inline dpoint map_input_to_output(const dpoint& p) const { return p; }
inline dpoint map_output_to_input(const dpoint& p) const { return p; }

const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }

friend void serialize(const tril_& item, std::ostream& out)
{
serialize("tril_", out);
serialize(item.diag, out);
serialize(item.diag_value, out);
}
friend void deserialize(tril_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "tril_")
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::tril_.");
deserialize(item.diag, in);
deserialize(item.diag_value, in);
}

friend std::ostream& operator<<(std::ostream& out, const tril_& item)
{
out << "tril (diag=" << item.diag << ", diag_value=" << item.diag_value << ")";
return out;
}
friend void to_xml(const tril_& item, std::ostream& out)
{
out << "<tril diag='" << item.diag << "' diag_value='" << item.diag_value << "'/>\n";
}

private:
void initialize_mask(const tensor& t)
{
if (!have_same_dimensions(output_mask, t)) {
output_mask.copy_size(t);
binary_mask.copy_size(output_mask);
output_mask = 0;
binary_mask = 1;
for (long s = 0; s < output_mask.num_samples(); ++s)
{
for (long k = 0; k < output_mask.k(); ++k)
{
for (long r = 0; r < output_mask.nr(); ++r)
{
for (long c = std::max(r + diag + 1, 0L); c < output_mask.nc(); ++c)
{
output_mask.host()[tensor_index(output_mask, s, k, r, c)] = diag_value;
binary_mask.host()[tensor_index(binary_mask, s, k, r, c)] = 0;
}
}
}
}
}
}

resizable_tensor params; // unused
resizable_tensor output_mask;
resizable_tensor binary_mask;
long diag;
float diag_value;
};

template <typename SUBNET>
using tril = add_layer<tril_<0, float_constant<float, 0.0f>>, SUBNET>;

template <typename SUBNET>
using tril_mask = add_layer<tril_<0, float_constant<float, -std::numeric_limits<float>::infinity()>>, SUBNET>;

template <long diag, typename diag_value_type, typename SUBNET>
using tril_diag = add_layer<tril_<diag, diag_value_type>, SUBNET>;

// ----------------------------------------------------------------------------------------

}
Expand Down
120 changes: 120 additions & 0 deletions dlib/dnn/layers_abstract.h
Original file line number Diff line number Diff line change
Expand Up @@ -3711,6 +3711,126 @@ namespace dlib
template <typename SUBNET>
using transpose = add_layer<transpose_, SUBNET>;

// ----------------------------------------------------------------------------------------

template <typename T, T val>
struct float_constant {
static constexpr T value = val;
};

template <long diag_, typename diag_value_>
class tril_
{
/*!
REQUIREMENTS ON diag_ and diag_value_
- diag_ must be a non-negative integer.
- diag_value_ must be a type that has a static constexpr member `value` of type float.
WHAT THIS OBJECT REPRESENTS
This object implements a layer in a deep neural network that applies a lower triangular mask to
its input tensor. The mask is defined such that all elements above the specified diagonal are set
to a given value (diag_value_::value). The diagonal is specified by the diag_ parameter.
EXAMPLE USAGE
tril_<0, float_constant<float, -std::numeric_limits<float>::infinity()>> layer;
// This creates a layer that masks all elements above the main diagonal with -inf.
SERIALIZATION SUPPORT
This object supports serialization and deserialization via the serialize() and deserialize() functions.
!*/

public:
tril_() = default;
/*!
ensures
- This object is properly initialized.
!*/

template <typename SUBNET>
void setup(const SUBNET& sub);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Initializes the mask based on the dimensions of the input tensor from sub.
!*/

template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Applies the lower triangular mask to the input tensor from sub and stores the result in output.
!*/

template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Computes the gradient of the loss with respect to the input tensor and stores it in sub.
!*/

inline dpoint map_input_to_output(const dpoint& p) const;
/*!
ensures
- Maps a point from the input tensor to the corresponding point in the output tensor.
!*/

inline dpoint map_output_to_input(const dpoint& p) const;
/*!
ensures
- Maps a point from the output tensor to the corresponding point in the input tensor.
!*/

const tensor& get_layer_params() const;
/*!
ensures
- Returns the parameters of this layer.
!*/

tensor& get_layer_params();
/*!
ensures
- Returns the parameters of this layer.
!*/

friend void serialize(const tril_& item, std::ostream& out);
/*!
ensures
- Serializes the state of this object to the given output stream.
!*/

friend void deserialize(tril_& item, std::istream& in);
/*!
ensures
- Deserializes the state of this object from the given input stream.
!*/

friend std::ostream& operator<<(std::ostream& out, const tril_& item);
/*!
ensures
- Prints a human-readable representation of this object to the given output stream.
!*/

friend void to_xml(const tril_& item, std::ostream& out);
/*!
ensures
- Serializes the state of this object to XML format and writes it to the given output stream.
!*/
};

template <typename SUBNET>
using tril = add_layer<tril_<0, float_constant<float, 0.0f>>, SUBNET>;

template <typename SUBNET>
using tril_mask = add_layer<tril_<0, float_constant<float, -std::numeric_limits<float>::infinity()>>, SUBNET>;

template <long diag, typename diag_value_type, typename SUBNET>
using tril_diag = add_layer<tril_<diag, diag_value_type>, SUBNET>;

// ----------------------------------------------------------------------------------------

}
Expand Down
10 changes: 10 additions & 0 deletions dlib/dnn/visitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,16 @@ namespace dlib
update(i);
}

template <long diag, typename diag_value_type, typename U, typename E>
void operator()(size_t i, const add_layer<tril_<diag, diag_value_type>, U, E>&)
{
start_node(i, "tril");
out << " | {diag|{" << diag << "}}";
out << " | {diag_value|{" << diag_value_type::value << "}}";
end_node();
update(i);
}

template <typename T, typename U, typename E>
void operator()(size_t i, const add_layer<T, U, E>&)
{
Expand Down
50 changes: 50 additions & 0 deletions dlib/test/dnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,13 @@ namespace
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
using specific_float = float_constant<float, -std::numeric_limits<float>::infinity()>;
tril_<-5, specific_float> l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
extract_<0,2,2,2> l;
Expand Down Expand Up @@ -4447,6 +4454,49 @@ namespace
}
}

// ----------------------------------------------------------------------------------------

void test_tril()
{
print_spinner();

using NEG_INF = float_constant<float, -std::numeric_limits<float>::infinity()>;
using net_type = tag1<tril_diag<0, NEG_INF, tag2<input<matrix<float>>>>>;
net_type net;

// Input tensor
dlib::rand rnd;
const int nr = 2, nc = 3;
constexpr int n_samples = 3, k = 1;
std::vector<matrix<float>> x(n_samples);
matrix<float> xtmp(nr, nc);
for (int ii = 0; ii < n_samples; ++ii) {
for (int jj = 0; jj < nr; ++jj)
for (int kk = 0; kk < nc; ++kk)
xtmp(jj, kk) = rnd.get_random_gaussian();
x[ii] = xtmp;
}

// Convert input matrix to tensor
resizable_tensor input_tensor;
net.to_tensor(&x[0], &x[0] + n_samples, input_tensor);
net.forward(input_tensor);

// Expected output tensor (manually set for comparison)
resizable_tensor expected_output;
expected_output.copy_size(input_tensor);
tt::copy_tensor(false, expected_output, 0, input_tensor, 0, input_tensor.k());
for (int ii = 0; ii < n_samples; ++ii) {
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = -std::numeric_limits<float>::infinity();
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = -std::numeric_limits<float>::infinity();
expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = -std::numeric_limits<float>::infinity();
}

// Compare output tensor with expected output
auto& net_output = layer<tag1>(net).get_output();
DLIB_TEST_MSG(max(abs(mat(net_output) - mat(expected_output))) < 1e-5);
}

// ----------------------------------------------------------------------------------------

class dnn_tester : public tester
Expand Down

0 comments on commit 354e5d0

Please sign in to comment.