Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(rjy): add mamujoco env and related configs #153

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
1 change: 1 addition & 0 deletions zoo/multiagent_mujoco/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .multiagent_mujoco_lightzero_env import MAMujocoEnvLZ
98 changes: 98 additions & 0 deletions zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
from typing import Union

import gym
import numpy as np
from ding.envs import BaseEnvTimestep
from ding.envs.common import save_frames_as_gif
from ding.torch_utils import to_ndarray
from ding.utils import ENV_REGISTRY
from dizoo.multiagent_mujoco.envs.multi_mujoco_env import MujocoEnv,MujocoMulti


@ENV_REGISTRY.register('mujoco_lightzero')
nighood marked this conversation as resolved.
Show resolved Hide resolved
class MAMujocoEnvLZ(MujocoEnv):
"""
Overview:
The modified Multi-agentMuJoCo environment with continuous action space for LightZero's algorithms.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/opendilab/LightZero/blob/main/zoo/box2d/lunarlander/envs/lunarlander_env.py 类似这里增加详细清晰的注释,可以参考https://aicarrier.feishu.cn/wiki/N4bqwLRO5iyQcAkb4HCcflbgnpR 这里的提示词用gpt4优化,然后手动矫正。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR的description里面增加这个PR的简要描述

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类似这里https://github.com/opendilab/LightZero/blob/main/zoo/box2d/lunarlander/envs/lunarlander_env.py#L30增加存储MP4和gif回复的功能。

原来DI-engine这里似乎还没replay,等把其他改完我再测试一下


config = dict(
stop_value=int(1e6),
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
)

def __init__(self, cfg: dict) -> None:
super().__init__(cfg)
self._cfg = cfg
# We use env_name to indicate the env_id in LightZero.
self._cfg.env_id = self._cfg.env_name
self._init_flag = False

def reset(self) -> np.ndarray:
if not self._init_flag:
self._env = MujocoMulti(env_args=self._cfg)
self._init_flag = True

if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._env.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
self._env.seed(self._seed)

obs = self._env.reset()
obs = to_ndarray(obs)
self._eval_episode_return = 0.
self.env_info = self._env.get_env_info()

self._num_agents = self.env_info['n_agents']
self._agents = [i for i in range(self._num_agents)]
self._observation_space = gym.spaces.Dict(
{
'agent_state': gym.spaces.Box(
low=float("-inf"), high=float("inf"), shape=obs['agent_state'].shape, dtype=np.float32
),
'global_state': gym.spaces.Box(
low=float("-inf"), high=float("inf"), shape=obs['global_state'].shape, dtype=np.float32
),
}
)
self._action_space = gym.spaces.Dict({agent: self._env.action_space[agent] for agent in self._agents})
single_agent_obs_space = self._env.action_space[self._agents[0]]
if isinstance(single_agent_obs_space, gym.spaces.Box):
self._action_dim = single_agent_obs_space.shape
elif isinstance(single_agent_obs_space, gym.spaces.Discrete):
self._action_dim = (single_agent_obs_space.n, )
else:
raise Exception('Only support `Box` or `Discrte` obs space for single agent.')
self._reward_space = gym.spaces.Dict(
{
agent: gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
for agent in self._agents
}
)

action_mask = None
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}

return obs

def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
action = to_ndarray(action)
obs, rew, done, info = self._env.step(action)
self._eval_episode_return += rew
if done:
info['eval_episode_return'] = self._eval_episode_return

obs = to_ndarray(obs)
rew = to_ndarray([rew]).astype(np.float32)

action_mask = None
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}

return BaseEnvTimestep(obs, rew, done, info)

def __repr__(self) -> str:
return "LightZero MAMujoco Env({})".format(self._cfg.env_name)

40 changes: 40 additions & 0 deletions zoo/multiagent_mujoco/envs/test_multiagent_mujoco_lightzero_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from time import time
import pytest
import numpy as np
from easydict import EasyDict
from zoo.multiagent_mujoco.envs import MAMujocoEnvLZ


@pytest.mark.envtest
@pytest.mark.parametrize(
'cfg', [
EasyDict({
'env_name': 'mujoco_lightzero',
'scenario': 'Ant-v2',
'agent_conf': "2x4d",
'agent_obsk': 2,
'add_agent_id': False,
'episode_limit': 1000,
},)
]
)

class TestMAMujocoEnvLZ:
def test_naive(self, cfg):
env = MAMujocoEnvLZ(cfg)
env.seed(314)
assert env._seed == 314
obs = env.reset()
assert isinstance(obs, dict)
for i in range(10):
random_action = env.random_action()
timestep = env.step(random_action[0])
print(timestep)
assert isinstance(timestep.obs, dict)
assert isinstance(timestep.done, bool)
assert timestep.obs['observation']['global_state'].shape == (2, 111)
assert timestep.obs['observation']['agent_state'].shape == (2, 54)
assert timestep.reward.shape == (1, )
assert isinstance(timestep, tuple)
print(env.observation_space, env.action_space, env.reward_space)
env.close()