Skip to content

Commit

Permalink
Optimize registration of default activations
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-berg committed Dec 29, 2023
1 parent 7d6180d commit 4895317
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 59 deletions.
89 changes: 49 additions & 40 deletions src/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,27 @@ impl<T: TensorBase + 'static> Activation<T> {
///
/// When called for the first time, the activation registry is initialized first.
fn get_activation_registry() -> &'static RwLock<ActivationRegistry<T>> {
get_or_init!(|| RwLock::new(ActivationRegistry::<T>::new()))
get_or_init!(|| {
let registry_lock = RwLock::new(ActivationRegistry::<T>::new());
Self::register_common(&registry_lock);
registry_lock
})
}

/// Registers reference implementations of some common activation functions.
///
/// # Arguments
/// - `registry_lock` - Reference to the activation registry wrapped in a [`RwLock`]
fn register_common(registry_lock: &RwLock<ActivationRegistry<T>>) {
let mut registry = registry_lock.write().unwrap();
let _ = registry.insert(
"sigmoid".to_owned(),
Self::new("sigmoid".to_owned(), sigmoid, sigmoid_prime),
);
let _ = registry.insert(
"relu".to_owned(),
Self::new("relu".to_owned(), relu, relu_prime),
);
}

/// Creates and registers a new [`Activation`] with the specified parameters.
Expand All @@ -88,19 +108,20 @@ impl<T: TensorBase + 'static> Activation<T> {
pub fn register<S: Into<String>>(name: S, function: TFunc<T>, derivative: TFunc<T>) -> Option<Self> {
let name: String = name.into();
let registry_lock = Self::get_activation_registry();
registry_lock.write().unwrap().insert(
name.clone(),
Activation::new(name, function, derivative),
)
registry_lock.write().unwrap()
.insert(name.clone(), Activation::new(name, function, derivative))
}

/// Retrieves a clone of a previously registered [`Activation`] by name.
///
/// Unless custom instances were added before via [`Activation::register`], reference implementations
/// for the following activation functions will be available by default:
/// Reference implementations for the following activation functions are available by default:
/// - `sigmoid`
/// - `relu`
///
/// A custom instance registered via [`Activation::register`] under one of those names will
/// replace the default implementation.
///
///
/// # Arguments
/// - `name` - The name/key of the activation to return.
///
Expand All @@ -109,33 +130,9 @@ impl<T: TensorBase + 'static> Activation<T> {
/// registered under that name.
pub fn from_name<S: Into<String>>(name: S) -> Option<Self> {
let registry_lock = Self::get_activation_registry();
// The registry should only be empty the first time this method is called.
// In that case, fill it with the known/common activation functions.
if registry_lock.read().unwrap().is_empty() {
Self::register_common(registry_lock);
}
registry_lock.read().unwrap().get(&name.into()).and_then(|activation| Some(activation.clone()))
}

/// Registers reference implementations of some common activation functions.
///
/// Activation functions that will be available by name after calling this method:
/// - `sigmoid`
/// - `relu`
///
/// # Arguments
/// - `registry_lock` - Reference to the activation registry wrapped in a [`RwLock`]
pub fn register_common(registry_lock: &RwLock<ActivationRegistry<T>>) {
// TODO: Consider using `HashMap::try_insert` instead to avoid overwriting any custom
// implementations of common activation functions.
registry_lock.write().unwrap().insert(
"sigmoid".to_owned(),
Activation::<T>::new("sigmoid".to_owned(), sigmoid, sigmoid_prime),
);
registry_lock.write().unwrap().insert(
"relu".to_owned(),
Activation::<T>::new("relu".to_owned(), relu, relu_prime),
);
registry_lock.read().unwrap()
.get(&name.into())
.and_then(|activation| Some(activation.clone()))
}
}

Expand Down Expand Up @@ -237,27 +234,39 @@ mod tests {
fn test_register_and_from_name() {
use ndarray::Array2;

type NDActivation = Activation<Array2<f64>>;
type NDActivation = Activation<Array2<f32>>;

let name = "foo".to_owned();

let option = NDActivation::from_name("foo");
assert_eq!(option, None);

// Register under that name for the first time.
let option = NDActivation::register(name.clone(), relu, relu_prime);
assert_eq!(option, None);

// Get from registry by name.
let option = NDActivation::from_name("foo");
let activation_relu = Activation { name: name.clone(), function: relu, derivative: relu_prime };
assert_eq!(option, Some(activation_relu));
assert_eq!(option, Some(Activation { name: name.clone(), function: relu, derivative: relu_prime }));

// Register different one under the same name. Should return the previous one.
let option = NDActivation::register(name.clone(), sigmoid, sigmoid_prime);
let activation_sigmoid = Activation { name: name.clone(), function: sigmoid, derivative: sigmoid_prime };
assert_eq!(option, Some(activation_relu));
assert_eq!(option, Some(Activation { name: name.clone(), function: relu, derivative: relu_prime }));

// Get the new one from the registry by name.
let option = NDActivation::from_name("foo");
assert_eq!(option, Some(activation_sigmoid));
assert_eq!(option, Some(Activation { name: name.clone(), function: sigmoid, derivative: sigmoid_prime }));

// Get default `sigmoid` from the registry.
let option = NDActivation::from_name("sigmoid");
assert_eq!(option, Some(Activation { name: "sigmoid".to_owned(), function: sigmoid, derivative: sigmoid_prime }));

// Replace it with a different `Activation` instance.
fn identity<T: TensorBase>(t: &T) -> T { t.clone() }
let option = NDActivation::register("sigmoid", identity, identity);
assert_eq!(option, Some(Activation { name: "sigmoid".to_owned(), function: sigmoid, derivative: sigmoid_prime }));
let option = NDActivation::from_name("sigmoid");
assert_eq!(option, Some(Activation { name: "sigmoid".to_owned(), function: identity, derivative: identity }));
}

#[test]
Expand Down
18 changes: 0 additions & 18 deletions src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,5 @@ mod tests {

#[test]
fn test() {
let input = array![
[1.],
[1.],
];
let layer = Layer{
weights: array![
[1., 0.],
[0., 1.],
],
biases: array![
[-1.],
[-1.],
],
activation: Activation::from_name("sigmoid"),
};
let (z, a) = layer.feed_forward(&input);
println!("{}", z);
println!("{}", a);
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Rust library for creating, training and evolving neural networks.

#![feature(iter_array_chunks, trait_alias)]
#![feature(iter_array_chunks, map_try_insert, trait_alias)]

pub mod activation;
pub mod component;
Expand Down

0 comments on commit 4895317

Please sign in to comment.