From 3bda6d5a486c5f0d2fcc1d449f48cf92a84cfcb4 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Mon, 11 Dec 2023 17:07:08 +0100 Subject: [PATCH] Remove unnecessary `return` directives; make minor stylistic changes --- src/activation.rs | 29 +++++++++++++---------------- src/cost_function.rs | 24 +++++++++--------------- src/individual.rs | 15 +++++++-------- src/layer.rs | 4 ++-- src/population.rs | 4 ++-- src/species.rs | 4 ++-- 6 files changed, 35 insertions(+), 45 deletions(-) diff --git a/src/activation.rs b/src/activation.rs index 9308135..3871c69 100644 --- a/src/activation.rs +++ b/src/activation.rs @@ -44,7 +44,7 @@ impl Activation { } else { panic!(); } - return Self { + Self { name: name.to_owned(), function, derivative, @@ -53,12 +53,12 @@ impl Activation { /// Proxy for the actual activation function. pub fn call(&self, tensor: &T) -> T { - return (self.function)(tensor); + (self.function)(tensor) } /// Proxy for the derivative of the activation function. pub fn call_derivative(&self, tensor: &T) -> T { - return (self.derivative)(tensor); + (self.derivative)(tensor) } } @@ -66,7 +66,7 @@ impl Activation { /// Allows `serde` to serialize `Activation` objects. impl Serialize for Activation { fn serialize(&self, serializer: S) -> Result { - return serializer.serialize_str(&self.name); + serializer.serialize_str(&self.name) } } @@ -75,7 +75,7 @@ impl Serialize for Activation { impl<'de, T: TensorBase> Deserialize<'de> for Activation { fn deserialize>(deserializer: D) -> Result { let name = String::deserialize(deserializer)?; - return Ok(Self::from_name(name.as_str())); + Ok(Self::from_name(name.as_str())) } } @@ -84,7 +84,7 @@ impl<'de, T: TensorBase> Deserialize<'de> for Activation { /// /// Takes a tensor as input and returns a new tensor. pub fn sigmoid(tensor: &T) -> T { - return tensor.map(sigmoid_component); + tensor.map(sigmoid_component) } @@ -99,7 +99,7 @@ pub fn sigmoid_inplace(tensor: &mut T) { /// Sigmoid function for a scalar/number. pub fn sigmoid_component(number: C) -> C { let one = C::from_usize(1).unwrap(); - return one / (one + (-number).exp()); + one / (one + (-number).exp()) } @@ -107,14 +107,14 @@ pub fn sigmoid_component(number: C) -> C { /// /// Takes a tensor as input and returns a new tensor. pub fn sigmoid_prime(tensor: &T) -> T { - return tensor.map(sigmoid_prime_component); + tensor.map(sigmoid_prime_component) } /// Derivative of the sigmoid function for a scalar/number. pub fn sigmoid_prime_component(number: C) -> C { let one = C::from_usize(1).unwrap(); - return sigmoid_component(number) * (one - sigmoid_component(number)); + sigmoid_component(number) * (one - sigmoid_component(number)) } @@ -122,17 +122,14 @@ pub fn sigmoid_prime_component(number: C) -> C { /// /// Takes a tensor as input and returns a new tensor. pub fn relu(tensor: &T) -> T { - return tensor.map(relu_component); + tensor.map(relu_component) } /// Rectified Linear Unit (RELU) activation function for a scalar/number. pub fn relu_component(number: C) -> C { let zero = C::from_usize(0).unwrap(); - if number < zero { - return zero; - } - return number; + if number < zero { zero } else { number } } @@ -140,7 +137,7 @@ pub fn relu_component(number: C) -> C { /// /// Takes a tensor as input and returns a new tensor. pub fn relu_prime(tensor: &T) -> T { - return tensor.map(relu_prime_component); + tensor.map(relu_prime_component) } @@ -148,7 +145,7 @@ pub fn relu_prime(tensor: &T) -> T { pub fn relu_prime_component(number: C) -> C { let zero = C::from_usize(0).unwrap(); let one = C::from_usize(1).unwrap(); - return if number < zero { zero } else { one }; + if number < zero { zero } else { one } } #[cfg(test)] diff --git a/src/cost_function.rs b/src/cost_function.rs index 85f1c3a..f805b9d 100644 --- a/src/cost_function.rs +++ b/src/cost_function.rs @@ -27,7 +27,7 @@ pub struct CostFunction { impl CostFunction { /// Basic constructor to manually define all fields. pub fn new(name: &str, function: TFunc, derivative: TFuncPrime) -> Self { - return Self{name: name.to_owned(), function, derivative} + Self { name: name.to_owned(), function, derivative } } /// Convenience constructor for known/available cost functions. @@ -42,7 +42,7 @@ impl CostFunction { } else { panic!(); } - return CostFunction { + CostFunction { name: name.to_owned(), function, derivative, @@ -51,19 +51,19 @@ impl CostFunction { /// Proxy for the actual cost function. pub fn call(&self, output: &T, desired_output: &T) -> f32 { - return (self.function)(output, desired_output); + (self.function)(output, desired_output) } /// Proxy for the derivative of the cost function. pub fn call_derivative(&self, output: &T, desired_output: &T) -> T { - return (self.derivative)(output, desired_output); + (self.derivative)(output, desired_output) } } impl Default for CostFunction { fn default() -> Self { - return Self::from_name("quadratic"); + Self::from_name("quadratic") } } @@ -71,7 +71,7 @@ impl Default for CostFunction { /// Allows `serde` to serialize `CostFunction` objects. impl Serialize for CostFunction { fn serialize(&self, serializer: S) -> Result { - return serializer.serialize_str(&self.name); + serializer.serialize_str(&self.name) } } @@ -80,7 +80,7 @@ impl Serialize for CostFunction { impl<'de, T: TensorOp> Deserialize<'de> for CostFunction { fn deserialize>(deserializer: D) -> Result { let name = String::deserialize(deserializer)?; - return Ok(Self::from_name(name.as_str())); + Ok(Self::from_name(name.as_str())) } } @@ -97,10 +97,7 @@ impl<'de, T: TensorOp> Deserialize<'de> for CostFunction { /// /// # Returns /// The cost as 32 bit float. -pub fn quadratic( - output: &T, - desired_output: &T, -) -> f32 { +pub fn quadratic(output: &T, desired_output: &T) -> f32 { (desired_output - output).norm().to_f32().unwrap() / 2. } @@ -116,9 +113,6 @@ pub fn quadratic( /// /// # Returns /// Another tensor. -pub fn quadratic_prime( - output: &T, - desired_output: &T, -) -> T { +pub fn quadratic_prime(output: &T, desired_output: &T) -> T { output - desired_output } diff --git a/src/individual.rs b/src/individual.rs index f2dc4e0..f790ace 100644 --- a/src/individual.rs +++ b/src/individual.rs @@ -53,10 +53,7 @@ impl Individual { /// # Returns /// New `Individual` with the given layers pub fn new(layers: Vec>, cost_function: CostFunction) -> Self { - return Individual { - layers, - cost_function, - } + Self { layers, cost_function } } /// Load an individual from a json file @@ -67,7 +64,7 @@ impl Individual { /// # Returns /// A new `Individual` instance or a `LoadError` pub fn from_file>(path: P) -> Result { - return Ok(serde_json::from_str(read_to_string(path)?.as_str())?); + Ok(serde_json::from_str(read_to_string(path)?.as_str())?) } /// Performs a full forward pass for a given input and returns the network's output. @@ -80,6 +77,7 @@ impl Individual { pub fn forward_pass(&self, input: &T) -> T { let mut _weighted_input: T; let mut output: T = input.clone(); + // TODO: Consider replacing this loop with `Iterator::fold`. for layer in (self.layers).iter() { (_weighted_input, output) = layer.feed_forward(&output); } @@ -102,6 +100,7 @@ impl Individual { let mut activation: T = input.clone(); let mut weighted_input: T; activations.push(activation.clone()); + // TODO: Consider replacing this loop with `Iterator::map` and `collect`. for layer in &self.layers { (weighted_input, activation) = layer.feed_forward(&activation); weighted_inputs.push(weighted_input); @@ -158,7 +157,7 @@ impl Individual { nabla_weights.push(delta.dot(previous_activation.transpose())); nabla_biases.push(delta.sum_axis(1)); } - return (nabla_weights, nabla_biases); + (nabla_weights, nabla_biases) } /// Updates the weights and biases of the individual, @@ -225,7 +224,7 @@ impl Individual { /// The error value pub fn calculate_error(&self, input: &T, desired_output: &T,) -> f32 { let output = self.forward_pass(input); - return self.cost_function.call(&output, desired_output); + self.cost_function.call(&output, desired_output) } } @@ -273,6 +272,6 @@ mod tests { individual_file.write_all(individual_json.as_bytes())?; let individual_loaded = Individual::from_file(individual_file.path())?; assert_eq!(individual_expected, individual_loaded); - return Ok(()); + Ok(()) } } diff --git a/src/layer.rs b/src/layer.rs index a1797db..3d1f7d2 100644 --- a/src/layer.rs +++ b/src/layer.rs @@ -14,13 +14,13 @@ pub struct Layer { impl Layer { pub fn new(weights: T, biases: T, activation: Activation) -> Self { - return Self{weights, biases, activation} + Self { weights, biases, activation } } pub fn feed_forward(&self, input: &T) -> (T, T) { let weighted_input = &self.weights.dot(input) + &self.biases; let activation = self.activation.call(&weighted_input); - return (weighted_input, activation); + (weighted_input, activation) } } diff --git a/src/population.rs b/src/population.rs index 814b2b9..f729f62 100644 --- a/src/population.rs +++ b/src/population.rs @@ -24,12 +24,12 @@ impl Population { procreation_function: ProcreateFunc, determine_key_function: DetermineSpeciesKey, ) -> Self { - return Population { + Self { species: HashMap::new(), kill_weak_and_select_parents: selection_function, procreate_pair: procreation_function, determine_species_key: determine_key_function, - }; + } } // mutates sets of individuals within one species diff --git a/src/species.rs b/src/species.rs index a186949..a492887 100644 --- a/src/species.rs +++ b/src/species.rs @@ -9,7 +9,7 @@ pub struct Species { impl Species { pub fn new() -> Self { - return Self { + Self { individuals: Vec::new() } } @@ -22,6 +22,6 @@ impl Species { impl Default for Species { fn default() -> Self { - return Species::new(); + Species::new() } }