diff --git a/deepbots/supervisor/controllers/robot_supervisor.py b/deepbots/supervisor/controllers/robot_supervisor.py index 50675ca..4c97b8e 100644 --- a/deepbots/supervisor/controllers/robot_supervisor.py +++ b/deepbots/supervisor/controllers/robot_supervisor.py @@ -1,5 +1,5 @@ from warnings import warn, simplefilter -from deepbots.supervisor.controllers.supervisor_env import SupervisorEnv +from deepbots.supervisor.controllers.supervisor_env import SupervisorEnv, SupervisorGoalEnv from controller import Supervisor @@ -27,6 +27,7 @@ class RobotSupervisor(SupervisorEnv): action, e.g. motor speeds. Note that apply_action() is called during step(). """ + def __init__(self, timestep=None): super(RobotSupervisor, self).__init__() @@ -99,3 +100,63 @@ def apply_action(self, action): :param action: list, containing action data """ raise NotImplementedError + + +class RobotGoalSupervisor(SupervisorGoalEnv, RobotSupervisor): + """ + The RobotGoalSupervisor class is just like RobotSupervisor, but it + uses compute_reward from gym.GoalEnv. + + step(): + (similar to use_step() of RobotSupervisor) + This method steps the controller. + Note that the gym-inherited compute_reward method is used here. + """ + + def __init__(self, timestep=None): + super(RobotGoalSupervisor, self).__init__() + + if timestep is None: + self.timestep = int(self.getBasicTimeStep()) + else: + self.timestep = timestep + + def step(self, action): + """ + The basic step method that steps the controller, + calls the method that applies the action on the robot + and returns the (observations, reward, done, info) object. + + For RobotGoalSupervisor, the gym-inherited compute_reward + method is used. This method must be implemented by the + user, according to gym.GoalEnv, using achieved_goal and + desired_goal. + + :param action: Whatever the use-case uses as an action, e.g. + an integer representing discrete actions + :type action: Defined by the implementation + :param achieved_goal: the goal that was achieved during execution + :type achieved_goal: object + :param desired_goal: the desired goal that we asked the agent to + attempt to achieve + :type desired_goal: object + :param info: an info dictionary with additional information + :type info: object + :return: tuple, (observations, reward, done, info) as provided by the + corresponding methods as implemented for the use-case + """ + if super(Supervisor, self).step(self.timestep) == -1: + exit() + + self.apply_action(action) + obs = self.get_observations() + info = self.get_info() + + return ( + obs, + self.compute_reward(obs["achieved_goal"], + obs["desired_goal"], + info), + self.is_done(), + info, + ) diff --git a/deepbots/supervisor/controllers/supervisor_env.py b/deepbots/supervisor/controllers/supervisor_env.py index 78a58f0..4ed1f02 100644 --- a/deepbots/supervisor/controllers/supervisor_env.py +++ b/deepbots/supervisor/controllers/supervisor_env.py @@ -115,3 +115,35 @@ def get_info(self): information on each step, e.g. for debugging purposes. """ raise NotImplementedError + + +class SupervisorGoalEnv(gym.GoalEnv, SupervisorEnv): + """ + This class is just like SupervisorEnv, but it imposes gym.GoalEnv. + + Refer to gym.GoalEnv documentation on how to implement a custom + gym.GoalEnv for additional functionality. + """ + + def reset(self): + """ + Used to reset the world to an initial state and enforce that each + SupervisorGoalEnv uses a Goal-compatible observation space. + + Default, problem-agnostic, implementation of reset method, + using Webots-provided methods. + + *Note that this works properly only with Webots versions >R2020b + and must be overridden with a custom reset method when using + earlier versions. It is backwards compatible due to the fact + that the new reset method gets overridden by whatever the user + has previously implemented, so an old supervisor can be migrated + easily to use this class. + + :return: default observation provided by get_default_observation() + """ + super().reset() + self.simulationReset() + self.simulationResetPhysics() + super(Supervisor, self).step(int(self.getBasicTimeStep())) + return self.get_default_observation()