From 6d316be7eb0c3b0eb0ed5e553251c3fa55421224 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Wed, 7 Feb 2024 20:59:53 +0100 Subject: [PATCH] fix(pytorch): Resolve critical action rescaling bug This commit addresses a critical bug related to action rescaling in PyTorch, which was preventing the agent from training effectively in specific environments. --- .../algos/pytorch/common/helpers.py | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/stable_learning_control/algos/pytorch/common/helpers.py b/stable_learning_control/algos/pytorch/common/helpers.py index 5129c22e..3f9d0c06 100644 --- a/stable_learning_control/algos/pytorch/common/helpers.py +++ b/stable_learning_control/algos/pytorch/common/helpers.py @@ -126,24 +126,17 @@ def rescale(data, min_bound, max_bound): the desired range. Returns: - torch.Tensor: Array which has it values scaled between the min and max - boundaries. + Union[Torch.Tensor, numpy.ndarray]: Array which has it values scaled between + the min and max boundaries. """ - data = torch.tensor(data) if not isinstance(data, torch.Tensor) else data - min_bound = ( - torch.tensor(min_bound, device=data.device) - if not isinstance(min_bound, torch.Tensor) - else min_bound.to(data.device) - ) - max_bound = ( - torch.tensor(max_bound, device=data.device) - if not isinstance(max_bound, torch.Tensor) - else max_bound.to(data.device) - ) + was_numpy = isinstance(data, np.ndarray) + data = torch.as_tensor(data) + min_bound = torch.as_tensor(min_bound, device=data.device) + max_bound = torch.as_tensor(max_bound, device=data.device) # Return rescaled data in the same format as the input data. data_rescaled = (data + 1.0) * (max_bound - min_bound) / 2 + min_bound - return data_rescaled.astype(data.dtype) if isinstance(data, np.ndarray) else data + return data_rescaled.cpu().numpy() if was_numpy else data_rescaled def np_to_torch(input_object, dtype=None, device=None):