diff --git a/src/activation.rs b/src/activation.rs index acc2fb4..05133fe 100644 --- a/src/activation.rs +++ b/src/activation.rs @@ -65,7 +65,27 @@ impl Activation { /// /// When called for the first time, the activation registry is initialized first. fn get_activation_registry() -> &'static RwLock> { - get_or_init!(|| RwLock::new(ActivationRegistry::::new())) + get_or_init!(|| { + let registry_lock = RwLock::new(ActivationRegistry::::new()); + Self::register_common(®istry_lock); + registry_lock + }) + } + + /// Registers reference implementations of some common activation functions. + /// + /// # Arguments + /// - `registry_lock` - Reference to the activation registry wrapped in a [`RwLock`] + fn register_common(registry_lock: &RwLock>) { + let mut registry = registry_lock.write().unwrap(); + let _ = registry.insert( + "sigmoid".to_owned(), + Self::new("sigmoid".to_owned(), sigmoid, sigmoid_prime), + ); + let _ = registry.insert( + "relu".to_owned(), + Self::new("relu".to_owned(), relu, relu_prime), + ); } /// Creates and registers a new [`Activation`] with the specified parameters. @@ -88,19 +108,20 @@ impl Activation { pub fn register>(name: S, function: TFunc, derivative: TFunc) -> Option { let name: String = name.into(); let registry_lock = Self::get_activation_registry(); - registry_lock.write().unwrap().insert( - name.clone(), - Activation::new(name, function, derivative), - ) + registry_lock.write().unwrap() + .insert(name.clone(), Activation::new(name, function, derivative)) } /// Retrieves a clone of a previously registered [`Activation`] by name. /// - /// Unless custom instances were added before via [`Activation::register`], reference implementations - /// for the following activation functions will be available by default: + /// Reference implementations for the following activation functions are available by default: /// - `sigmoid` /// - `relu` /// + /// A custom instance registered via [`Activation::register`] under one of those names will + /// replace the default implementation. + /// + /// /// # Arguments /// - `name` - The name/key of the activation to return. /// @@ -109,33 +130,9 @@ impl Activation { /// registered under that name. pub fn from_name>(name: S) -> Option { let registry_lock = Self::get_activation_registry(); - // The registry should only be empty the first time this method is called. - // In that case, fill it with the known/common activation functions. - if registry_lock.read().unwrap().is_empty() { - Self::register_common(registry_lock); - } - registry_lock.read().unwrap().get(&name.into()).and_then(|activation| Some(activation.clone())) - } - - /// Registers reference implementations of some common activation functions. - /// - /// Activation functions that will be available by name after calling this method: - /// - `sigmoid` - /// - `relu` - /// - /// # Arguments - /// - `registry_lock` - Reference to the activation registry wrapped in a [`RwLock`] - pub fn register_common(registry_lock: &RwLock>) { - // TODO: Consider using `HashMap::try_insert` instead to avoid overwriting any custom - // implementations of common activation functions. - registry_lock.write().unwrap().insert( - "sigmoid".to_owned(), - Activation::::new("sigmoid".to_owned(), sigmoid, sigmoid_prime), - ); - registry_lock.write().unwrap().insert( - "relu".to_owned(), - Activation::::new("relu".to_owned(), relu, relu_prime), - ); + registry_lock.read().unwrap() + .get(&name.into()) + .and_then(|activation| Some(activation.clone())) } } @@ -237,27 +234,39 @@ mod tests { fn test_register_and_from_name() { use ndarray::Array2; - type NDActivation = Activation>; + type NDActivation = Activation>; let name = "foo".to_owned(); + let option = NDActivation::from_name("foo"); + assert_eq!(option, None); + // Register under that name for the first time. let option = NDActivation::register(name.clone(), relu, relu_prime); assert_eq!(option, None); // Get from registry by name. let option = NDActivation::from_name("foo"); - let activation_relu = Activation { name: name.clone(), function: relu, derivative: relu_prime }; - assert_eq!(option, Some(activation_relu)); + assert_eq!(option, Some(Activation { name: name.clone(), function: relu, derivative: relu_prime })); // Register different one under the same name. Should return the previous one. let option = NDActivation::register(name.clone(), sigmoid, sigmoid_prime); - let activation_sigmoid = Activation { name: name.clone(), function: sigmoid, derivative: sigmoid_prime }; - assert_eq!(option, Some(activation_relu)); + assert_eq!(option, Some(Activation { name: name.clone(), function: relu, derivative: relu_prime })); // Get the new one from the registry by name. let option = NDActivation::from_name("foo"); - assert_eq!(option, Some(activation_sigmoid)); + assert_eq!(option, Some(Activation { name: name.clone(), function: sigmoid, derivative: sigmoid_prime })); + + // Get default `sigmoid` from the registry. + let option = NDActivation::from_name("sigmoid"); + assert_eq!(option, Some(Activation { name: "sigmoid".to_owned(), function: sigmoid, derivative: sigmoid_prime })); + + // Replace it with a different `Activation` instance. + fn identity(t: &T) -> T { t.clone() } + let option = NDActivation::register("sigmoid", identity, identity); + assert_eq!(option, Some(Activation { name: "sigmoid".to_owned(), function: sigmoid, derivative: sigmoid_prime })); + let option = NDActivation::from_name("sigmoid"); + assert_eq!(option, Some(Activation { name: "sigmoid".to_owned(), function: identity, derivative: identity })); } #[test] diff --git a/src/layer.rs b/src/layer.rs index 99a3229..4da9e61 100644 --- a/src/layer.rs +++ b/src/layer.rs @@ -35,23 +35,5 @@ mod tests { #[test] fn test() { - let input = array![ - [1.], - [1.], - ]; - let layer = Layer{ - weights: array![ - [1., 0.], - [0., 1.], - ], - biases: array![ - [-1.], - [-1.], - ], - activation: Activation::from_name("sigmoid"), - }; - let (z, a) = layer.feed_forward(&input); - println!("{}", z); - println!("{}", a); } } diff --git a/src/lib.rs b/src/lib.rs index b4ad86d..05eefe8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ //! Rust library for creating, training and evolving neural networks. -#![feature(iter_array_chunks, trait_alias)] +#![feature(iter_array_chunks, map_try_insert, trait_alias)] pub mod activation; pub mod component;