Skip to content

Commit

Permalink
Remove unnecessary return directives; make minor stylistic changes
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-berg committed Dec 11, 2023
1 parent ff96aed commit 3bda6d5
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 45 deletions.
29 changes: 13 additions & 16 deletions src/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl<T: TensorBase> Activation<T> {
} else {
panic!();
}
return Self {
Self {
name: name.to_owned(),
function,
derivative,
Expand All @@ -53,20 +53,20 @@ impl<T: TensorBase> Activation<T> {

/// Proxy for the actual activation function.
pub fn call(&self, tensor: &T) -> T {
return (self.function)(tensor);
(self.function)(tensor)
}

/// Proxy for the derivative of the activation function.
pub fn call_derivative(&self, tensor: &T) -> T {
return (self.derivative)(tensor);
(self.derivative)(tensor)
}
}


/// Allows `serde` to serialize `Activation` objects.
impl<T: TensorBase> Serialize for Activation<T> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
return serializer.serialize_str(&self.name);
serializer.serialize_str(&self.name)
}
}

Expand All @@ -75,7 +75,7 @@ impl<T: TensorBase> Serialize for Activation<T> {
impl<'de, T: TensorBase> Deserialize<'de> for Activation<T> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let name = String::deserialize(deserializer)?;
return Ok(Self::from_name(name.as_str()));
Ok(Self::from_name(name.as_str()))
}
}

Expand All @@ -84,7 +84,7 @@ impl<'de, T: TensorBase> Deserialize<'de> for Activation<T> {
///
/// Takes a tensor as input and returns a new tensor.
pub fn sigmoid<T: TensorBase>(tensor: &T) -> T {
return tensor.map(sigmoid_component);
tensor.map(sigmoid_component)
}


Expand All @@ -99,56 +99,53 @@ pub fn sigmoid_inplace<T: TensorBase>(tensor: &mut T) {
/// Sigmoid function for a scalar/number.
pub fn sigmoid_component<C: TensorComponent>(number: C) -> C {
let one = C::from_usize(1).unwrap();
return one / (one + (-number).exp());
one / (one + (-number).exp())
}


/// Reference implementation of the derivative of the sigmoid activation function.
///
/// Takes a tensor as input and returns a new tensor.
pub fn sigmoid_prime<T: TensorBase>(tensor: &T) -> T {
return tensor.map(sigmoid_prime_component);
tensor.map(sigmoid_prime_component)
}


/// Derivative of the sigmoid function for a scalar/number.
pub fn sigmoid_prime_component<C: TensorComponent>(number: C) -> C {
let one = C::from_usize(1).unwrap();
return sigmoid_component(number) * (one - sigmoid_component(number));
sigmoid_component(number) * (one - sigmoid_component(number))
}


/// Reference implementation of the Rectified Linear Unit (RELU) activation function.
///
/// Takes a tensor as input and returns a new tensor.
pub fn relu<T: TensorBase>(tensor: &T) -> T {
return tensor.map(relu_component);
tensor.map(relu_component)
}


/// Rectified Linear Unit (RELU) activation function for a scalar/number.
pub fn relu_component<C: TensorComponent>(number: C) -> C {
let zero = C::from_usize(0).unwrap();
if number < zero {
return zero;
}
return number;
if number < zero { zero } else { number }
}


/// Reference implementation of the derivative of the Rectified Linear Unit (RELU) activation function.
///
/// Takes a tensor as input and returns a new tensor.
pub fn relu_prime<T: TensorBase>(tensor: &T) -> T {
return tensor.map(relu_prime_component);
tensor.map(relu_prime_component)
}


/// Derivative of the Rectified Linear Unit (RELU) function for a scalar/number.
pub fn relu_prime_component<C: TensorComponent>(number: C) -> C {
let zero = C::from_usize(0).unwrap();
let one = C::from_usize(1).unwrap();
return if number < zero { zero } else { one };
if number < zero { zero } else { one }
}

#[cfg(test)]
Expand Down
24 changes: 9 additions & 15 deletions src/cost_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct CostFunction<T: TensorBase> {
impl<T: TensorBase> CostFunction<T> {
/// Basic constructor to manually define all fields.
pub fn new(name: &str, function: TFunc<T>, derivative: TFuncPrime<T>) -> Self {
return Self{name: name.to_owned(), function, derivative}
Self { name: name.to_owned(), function, derivative }
}

/// Convenience constructor for known/available cost functions.
Expand All @@ -42,7 +42,7 @@ impl<T: TensorBase> CostFunction<T> {
} else {
panic!();
}
return CostFunction {
CostFunction {
name: name.to_owned(),
function,
derivative,
Expand All @@ -51,27 +51,27 @@ impl<T: TensorBase> CostFunction<T> {

/// Proxy for the actual cost function.
pub fn call(&self, output: &T, desired_output: &T) -> f32 {
return (self.function)(output, desired_output);
(self.function)(output, desired_output)
}

/// Proxy for the derivative of the cost function.
pub fn call_derivative(&self, output: &T, desired_output: &T) -> T {
return (self.derivative)(output, desired_output);
(self.derivative)(output, desired_output)
}
}


impl<T: TensorOp> Default for CostFunction<T> {
fn default() -> Self {
return Self::from_name("quadratic");
Self::from_name("quadratic")
}
}


/// Allows `serde` to serialize `CostFunction` objects.
impl<T: TensorBase> Serialize for CostFunction<T> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
return serializer.serialize_str(&self.name);
serializer.serialize_str(&self.name)
}
}

Expand All @@ -80,7 +80,7 @@ impl<T: TensorBase> Serialize for CostFunction<T> {
impl<'de, T: TensorOp> Deserialize<'de> for CostFunction<T> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let name = String::deserialize(deserializer)?;
return Ok(Self::from_name(name.as_str()));
Ok(Self::from_name(name.as_str()))
}
}

Expand All @@ -97,10 +97,7 @@ impl<'de, T: TensorOp> Deserialize<'de> for CostFunction<T> {
///
/// # Returns
/// The cost as 32 bit float.
pub fn quadratic<T: TensorOp>(
output: &T,
desired_output: &T,
) -> f32 {
pub fn quadratic<T: TensorOp>(output: &T, desired_output: &T) -> f32 {
(desired_output - output).norm().to_f32().unwrap() / 2.
}

Expand All @@ -116,9 +113,6 @@ pub fn quadratic<T: TensorOp>(
///
/// # Returns
/// Another tensor.
pub fn quadratic_prime<T: TensorOp>(
output: &T,
desired_output: &T,
) -> T {
pub fn quadratic_prime<T: TensorOp>(output: &T, desired_output: &T) -> T {
output - desired_output
}
15 changes: 7 additions & 8 deletions src/individual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ impl<T: Tensor> Individual<T> {
/// # Returns
/// New `Individual` with the given layers
pub fn new(layers: Vec<Layer<T>>, cost_function: CostFunction<T>) -> Self {
return Individual {
layers,
cost_function,
}
Self { layers, cost_function }
}

/// Load an individual from a json file
Expand All @@ -67,7 +64,7 @@ impl<T: Tensor> Individual<T> {
/// # Returns
/// A new `Individual` instance or a `LoadError`
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, LoadError> {
return Ok(serde_json::from_str(read_to_string(path)?.as_str())?);
Ok(serde_json::from_str(read_to_string(path)?.as_str())?)
}

/// Performs a full forward pass for a given input and returns the network's output.
Expand All @@ -80,6 +77,7 @@ impl<T: Tensor> Individual<T> {
pub fn forward_pass(&self, input: &T) -> T {
let mut _weighted_input: T;
let mut output: T = input.clone();
// TODO: Consider replacing this loop with `Iterator::fold`.
for layer in (self.layers).iter() {
(_weighted_input, output) = layer.feed_forward(&output);
}
Expand All @@ -102,6 +100,7 @@ impl<T: Tensor> Individual<T> {
let mut activation: T = input.clone();
let mut weighted_input: T;
activations.push(activation.clone());
// TODO: Consider replacing this loop with `Iterator::map` and `collect`.
for layer in &self.layers {
(weighted_input, activation) = layer.feed_forward(&activation);
weighted_inputs.push(weighted_input);
Expand Down Expand Up @@ -158,7 +157,7 @@ impl<T: Tensor> Individual<T> {
nabla_weights.push(delta.dot(previous_activation.transpose()));
nabla_biases.push(delta.sum_axis(1));
}
return (nabla_weights, nabla_biases);
(nabla_weights, nabla_biases)
}

/// Updates the weights and biases of the individual,
Expand Down Expand Up @@ -225,7 +224,7 @@ impl<T: Tensor> Individual<T> {
/// The error value
pub fn calculate_error(&self, input: &T, desired_output: &T,) -> f32 {
let output = self.forward_pass(input);
return self.cost_function.call(&output, desired_output);
self.cost_function.call(&output, desired_output)
}
}

Expand Down Expand Up @@ -273,6 +272,6 @@ mod tests {
individual_file.write_all(individual_json.as_bytes())?;
let individual_loaded = Individual::from_file(individual_file.path())?;
assert_eq!(individual_expected, individual_loaded);
return Ok(());
Ok(())
}
}
4 changes: 2 additions & 2 deletions src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ pub struct Layer<T: Tensor> {

impl<T: Tensor> Layer<T> {
pub fn new(weights: T, biases: T, activation: Activation<T>) -> Self {
return Self{weights, biases, activation}
Self { weights, biases, activation }
}

pub fn feed_forward(&self, input: &T) -> (T, T) {
let weighted_input = &self.weights.dot(input) + &self.biases;
let activation = self.activation.call(&weighted_input);
return (weighted_input, activation);
(weighted_input, activation)
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/population.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ impl<T: Tensor> Population<T> {
procreation_function: ProcreateFunc<T>,
determine_key_function: DetermineSpeciesKey<T>,
) -> Self {
return Population {
Self {
species: HashMap::new(),
kill_weak_and_select_parents: selection_function,
procreate_pair: procreation_function,
determine_species_key: determine_key_function,
};
}
}

// mutates sets of individuals within one species
Expand Down
4 changes: 2 additions & 2 deletions src/species.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub struct Species<T: Tensor> {

impl<T: Tensor> Species<T> {
pub fn new() -> Self {
return Self {
Self {
individuals: Vec::new()
}
}
Expand All @@ -22,6 +22,6 @@ impl<T: Tensor> Species<T> {

impl<T: Tensor> Default for Species<T> {
fn default() -> Self {
return Species::new();
Species::new()
}
}

0 comments on commit 3bda6d5

Please sign in to comment.