diff --git a/stable_learning_control/algos/pytorch/common/helpers.py b/stable_learning_control/algos/pytorch/common/helpers.py index 5129c22e..3f9d0c06 100644 --- a/stable_learning_control/algos/pytorch/common/helpers.py +++ b/stable_learning_control/algos/pytorch/common/helpers.py @@ -126,24 +126,17 @@ def rescale(data, min_bound, max_bound): the desired range. Returns: - torch.Tensor: Array which has it values scaled between the min and max - boundaries. + Union[Torch.Tensor, numpy.ndarray]: Array which has it values scaled between + the min and max boundaries. """ - data = torch.tensor(data) if not isinstance(data, torch.Tensor) else data - min_bound = ( - torch.tensor(min_bound, device=data.device) - if not isinstance(min_bound, torch.Tensor) - else min_bound.to(data.device) - ) - max_bound = ( - torch.tensor(max_bound, device=data.device) - if not isinstance(max_bound, torch.Tensor) - else max_bound.to(data.device) - ) + was_numpy = isinstance(data, np.ndarray) + data = torch.as_tensor(data) + min_bound = torch.as_tensor(min_bound, device=data.device) + max_bound = torch.as_tensor(max_bound, device=data.device) # Return rescaled data in the same format as the input data. data_rescaled = (data + 1.0) * (max_bound - min_bound) / 2 + min_bound - return data_rescaled.astype(data.dtype) if isinstance(data, np.ndarray) else data + return data_rescaled.cpu().numpy() if was_numpy else data_rescaled def np_to_torch(input_object, dtype=None, device=None):