diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 200b91441..cb7a32e7f 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -3,8 +3,10 @@ from .eval_muzero_with_gym_env import eval_muzero_with_gym_env from .train_alphazero import train_alphazero from .train_muzero import train_muzero +from .train_muzero_reanalyze import train_muzero_reanalyze from .train_muzero_with_gym_env import train_muzero_with_gym_env from .train_muzero_with_gym_env import train_muzero_with_gym_env from .train_muzero_with_reward_model import train_muzero_with_reward_model from .train_rezero import train_rezero from .train_unizero import train_unizero +from .train_unizero_reanalyze import train_unizero_reanalyze \ No newline at end of file diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index 0f8e4d165..0f7110d20 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -18,7 +18,7 @@ from lzero.policy.random_policy import LightZeroRandomPolicy from lzero.worker import MuZeroCollector as Collector from lzero.worker import MuZeroEvaluator as Evaluator -from .utils import random_collect, initialize_zeros_batch +from .utils import random_collect def train_muzero( diff --git a/lzero/entry/train_muzero_reanalyze.py b/lzero/entry/train_muzero_reanalyze.py new file mode 100644 index 000000000..b6596b112 --- /dev/null +++ b/lzero/entry/train_muzero_reanalyze.py @@ -0,0 +1,256 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from .utils import random_collect + +timer = EasyTimer() + + +def train_muzero_reanalyze( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel Muzero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'sampled_muzero', 'gumbel_muzero', 'stochastic_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'" + + if create_cfg.policy.type in ['muzero', 'muzero_context', 'muzero_rnn_full_obs']: + from lzero.mcts import MuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_muzero': + from lzero.mcts import SampledMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'stochastic_muzero': + from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + if cfg.policy.eval_offline: + cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # load pretrained model + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if cfg.policy.update_per_collect is not None: + update_per_collect = cfg.policy.update_per_collect + + # The purpose of collecting random data before training: + # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely. + # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms. + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + if cfg.policy.eval_offline: + eval_train_iter_list = [] + eval_train_envstep_list = [] + + # Evaluate the random agent + # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + if cfg.policy.eval_offline: + eval_train_iter_list.append(learner.train_iter) + eval_train_envstep_list.append(collector.envstep) + else: + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + if cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. + collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # Periodically reanalyze buffer + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch + if train_epoch % (1//cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): + with timer: + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + # Learn policy from collected data. + for i in range(update_per_collect): + + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + train_epoch += 1 + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + if cfg.policy.eval_offline: + logging.info(f'eval offline beginning...') + ckpt_dirname = './{}/ckpt'.format(learner.exp_name) + # Evaluate the performance of the pretrained model. + for train_iter, collector_envstep in zip(eval_train_iter_list, eval_train_envstep_list): + ckpt_name = 'iteration_{}.pth.tar'.format(train_iter) + ckpt_path = os.path.join(ckpt_dirname, ckpt_name) + # load the ckpt of pretrained model + policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device)) + stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep) + logging.info( + f'eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}') + logging.info(f'eval offline finished!') + break + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index 56b910077..06caeb112 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -17,8 +17,8 @@ from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature from lzero.policy.random_policy import LightZeroRandomPolicy -from lzero.worker import MuZeroCollector as Collector from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroCollector as Collector from .utils import random_collect diff --git a/lzero/entry/train_unizero_reanalyze.py b/lzero/entry/train_unizero_reanalyze.py new file mode 100644 index 000000000..aa3fc6aae --- /dev/null +++ b/lzero/entry/train_unizero_reanalyze.py @@ -0,0 +1,218 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter +from torch.utils.tensorboard import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from .utils import random_collect + +timer = EasyTimer() + +def train_unizero_reanalyze( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models. + UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, + particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + + # Ensure the specified policy type is supported + assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'" + + # Import the correct GameBuffer class based on the policy type + game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} + + GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]), + game_buffer_classes[create_cfg.policy.type]) + + # Set device based on CUDA availability + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # Compile the configuration + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + # collector_env.seed(cfg.seed, dynamic_seed=False) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load pretrained model if specified + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # Create worker components: learner, collector, evaluator, replay buffer, commander + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # MCTS+RL algorithms related core code + policy_config = cfg.policy + replay_buffer = GameBuffer(policy_config) + collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name, + policy_config=policy_config) + evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode, + tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config) + + # Learner's before_run hook + learner.call_hook('before_run') + + # Collect random data before training + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + batch_size = policy._cfg.batch_size + + # TODO: for visualize + # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + + while True: + # Log buffer memory usage + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + # Set temperature for visit count distributions + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # Default epsilon value + } + + # Configure epsilon for epsilon-greedy exploration + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # Evaluate policy performance + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect new data + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Determine updates per collection + update_per_collect = cfg.policy.update_per_collect + if update_per_collect is None: + collected_transitions_num = sum(len(game_segment) for game_segment in new_data[0]) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + + # Update replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # Periodically reanalyze buffer + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch + if train_epoch % (1//cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): + with timer: + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + # Train the policy if sufficient data is available + if collector.envstep > cfg.policy.train_start_after_envsteps: + if cfg.policy.sample_type == 'episode': + data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size + else: + data_sufficient = replay_buffer.get_num_of_transitions() > batch_size + if not data_sufficient: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....' + ) + continue + + for i in range(update_per_collect): + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + + train_data = replay_buffer.sample(batch_size, policy) + if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0: + # Clear caches and precompute positional embedding matrices + policy.recompute_pos_emb_diff_and_clear_cache() # TODO + + train_data.append({'train_which_component': 'transformer'}) + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # Check stopping criteria + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + learner.call_hook('after_run') + return policy diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index f9fbde1f8..bad188912 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -125,9 +125,6 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: probs /= probs.sum() # sample according to transition index - # TODO(pu): replace=True - # print(f"num transitions is {num_of_transitions}") - # print(f"length of probs is {len(probs)}") batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) if self._cfg.reanalyze_outdated is True: @@ -146,13 +143,65 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: game_segment = self.game_segment_buffer[game_segment_idx] game_segment_list.append(game_segment) + # pos_in_game_segment_list.append(pos_in_game_segment) + # TODO: check + if pos_in_game_segment > self._cfg.game_segment_length - self._cfg.num_unroll_steps: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps + 1, 1).item() pos_in_game_segment_list.append(pos_in_game_segment) + make_time = [time.time() for _ in range(len(batch_index_list))] orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) return orig_data + def _sample_orig_reanalyze_batch_data(self, batch_size: int) -> Tuple: + """ + Overview: + sample orig_data that contains: + game_segment_list: a list of game segments + pos_in_game_segment_list: transition index in game (relative index) + batch_index_list: the index of start transition of sampled minibatch in replay buffer + weights_list: the weight concerning the priority + make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) + Arguments: + - batch_size (:obj:`int`): batch size + - beta: float the parameter in PER for calculating the priority + """ + assert self._beta > 0 + train_sample_num = (self.get_num_of_transitions()//self._cfg.num_unroll_steps) + + valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) + base_decay_rate = 5 + # decay rate becomes smaller as the number of samples increases + decay_rate = base_decay_rate / valid_sample_num + # Generate exponentially decaying weights (only for the first 3/4 of the samples) + weights = np.exp(-decay_rate * np.arange(valid_sample_num)) + # Normalize the weights to a probability distribution + probabilities = weights / np.sum(weights) + batch_index_list = np.random.choice(valid_sample_num, batch_size, replace=False, p=probabilities) + + if self._cfg.reanalyze_outdated is True: + # NOTE: used in reanalyze part + batch_index_list.sort() + + game_segment_list = [] + pos_in_game_segment_list = [] + + for idx in batch_index_list: + game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx*self._cfg.num_unroll_steps] + game_segment_idx -= self.base_idx + game_segment = self.game_segment_buffer[game_segment_idx] + + game_segment_list.append(game_segment) + pos_in_game_segment_list.append(pos_in_game_segment) + + + make_time = [time.time() for _ in range(len(batch_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) + return orig_data + def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple: """ Overview: diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 5c956ad14..7a07e1df9 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -72,6 +72,52 @@ def reset_runtime_metrics(self): self.sample_times = 0 self.active_root_num = 0 + def reanalyze_buffer( + self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] + ) -> List[Any]: + """ + Overview: + sample data from ``GameBuffer`` and prepare the current and target batch for training. + Arguments: + - batch_size (:obj:`int`): batch size. + - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. + Returns: + - train_data (:obj:`List`): List of train data, including current_batch and target_batch. + """ + policy._target_model.to(self._cfg.device) + policy._target_model.eval() + self.policy = policy + # obtain the current_batch and prepare target context + policy_re_context = self._make_batch_for_reanalyze(batch_size) + # target policy + self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) + + def _make_batch_for_reanalyze(self, batch_size: int) -> Tuple[Any]: + """ + Overview: + first sample orig_data through ``_sample_orig_data()``, + then prepare the context of a batch: + reward_value_context: the context of reanalyzed value targets + policy_re_context: the context of reanalyzed policy targets + policy_non_re_context: the context of non-reanalyzed policy targets + current_batch: the inputs of batch + Arguments: + - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. + Returns: + - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch + """ + # obtain the batch context from replay buffer + orig_data = self._sample_orig_reanalyze_batch_data(batch_size) + game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data + batch_size = len(batch_index_list) + # obtain the context of reanalyzed policy targets + policy_re_context = self._prepare_policy_reanalyzed_context( + batch_index_list, game_segment_list, + pos_in_game_segment_list + ) + self.reanalyze_num = batch_size + return policy_re_context + def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -236,7 +282,7 @@ def _prepare_reward_value_context( action_mask_segment, to_play_segment = [], [] td_steps_list = [] - for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): + for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): game_segment_len = len(game_segment) game_segment_lens.append(game_segment_len) @@ -504,8 +550,8 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A target_values.append(value_list[value_index]) target_rewards.append(reward_list[current_index]) else: - target_values.append(np.array([0.])) - target_rewards.append(np.array([0.])) + target_values.append(np.array(0.)) + target_rewards.append(np.array(0.)) value_index += 1 batch_rewards.append(target_rewards) @@ -513,8 +559,6 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_rewards = np.asarray(batch_rewards) batch_target_values = np.asarray(batch_target_values) - batch_rewards = np.squeeze(batch_rewards, axis=-1) - batch_target_values = np.squeeze(batch_target_values, axis=-1) return batch_rewards, batch_target_values diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index fe57bebf0..9787bb21c 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -4,8 +4,7 @@ import torch from ding.utils import BUFFER_REGISTRY -from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree -from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree +from lzero.mcts.tree_search.mcts_ctree import UniZeroMCTSCtree as MCTSCtree from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer_muzero import MuZeroGameBuffer @@ -74,7 +73,7 @@ def sample( # target reward, target value batch_rewards, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model, current_batch[1] # current_batch[1] is action_batch + reward_value_context, policy._target_model, current_batch[1] # current_batch[1] is batch_action ) # target policy batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, @@ -167,7 +166,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy """ - reanalyze_num = int(batch_size * reanalyze_ratio) + reanalyze_num = max(int(batch_size * reanalyze_ratio), 1) if reanalyze_ratio > 0 else 0 + # print(f'reanalyze_ratio: {reanalyze_ratio}, reanalyze_num: {reanalyze_num}') self.reanalyze_num = reanalyze_num # reanalyzed policy if reanalyze_num > 0: @@ -192,6 +192,95 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: context = reward_value_context, policy_re_context, policy_non_re_context, current_batch return context + def reanalyze_buffer( + self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] + ) -> List[Any]: + """ + Overview: + sample data from ``GameBuffer`` and prepare the current and target batch for training. + Arguments: + - batch_size (:obj:`int`): batch size. + - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. + Returns: + - train_data (:obj:`List`): List of train data, including current_batch and target_batch. + """ + policy._target_model.to(self._cfg.device) + policy._target_model.eval() + + # obtain the current_batch and prepare target context + policy_re_context, current_batch = self._make_batch_for_reanalyze(batch_size) + # target policy + self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, current_batch[1]) + + def _make_batch_for_reanalyze(self, batch_size: int) -> Tuple[Any]: + """ + Overview: + first sample orig_data through ``_sample_orig_data()``, + then prepare the context of a batch: + reward_value_context: the context of reanalyzed value targets + policy_re_context: the context of reanalyzed policy targets + policy_non_re_context: the context of non-reanalyzed policy targets + current_batch: the inputs of batch + Arguments: + - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. + Returns: + - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch + """ + # obtain the batch context from replay buffer + if self.sample_type == 'transition': + orig_data = self._sample_orig_reanalyze_batch_data(batch_size) + # elif self.sample_type == 'episode': # TODO + # orig_data = self._sample_orig_data_episode(batch_size) + game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data + batch_size = len(batch_index_list) + obs_list, action_list, mask_list = [], [], [] + # prepare the inputs of a batch + for i in range(batch_size): + game = game_segment_list[i] + pos_in_game_segment = pos_in_game_segment_list[i] + + actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid + mask_tmp = [1. for i in range(len(actions_tmp))] + mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + + # pad random action + actions_tmp += [ + np.random.randint(0, game.action_space_size) + for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) + ] + + # obtain the input observations + # pad if length of obs in game_segment is less than stack+num_unroll_steps + # e.g. stack+num_unroll_steps = 4+5 + obs_list.append( + game_segment_list[i].get_unroll_obs( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + ) + ) + action_list.append(actions_tmp) + mask_list.append(mask_tmp) + + # formalize the input observations + obs_list = prepare_observation(obs_list, self._cfg.model.model_type) + + # formalize the inputs of a batch + current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] + for i in range(len(current_batch)): + current_batch[i] = np.asarray(current_batch[i]) + + # reanalyzed policy + # obtain the context of reanalyzed policy targets + policy_re_context = self._prepare_policy_reanalyzed_context( + batch_index_list, game_segment_list, + pos_in_game_segment_list + ) + + context = policy_re_context, current_batch + self.reanalyze_num = batch_size + return context + def _prepare_policy_reanalyzed_context( self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] ) -> List[Any]: @@ -245,7 +334,7 @@ def _prepare_policy_reanalyzed_context( ] return policy_re_context - def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, action_batch) -> np.ndarray: + def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, batch_action) -> np.ndarray: """ Overview: prepare policy targets from the reanalyzed context of policies @@ -280,16 +369,19 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + # NOTE: TODO + model.world_model.reanalyze_phase = True + with torch.no_grad(): policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) network_output = [] - m_obs = torch.from_numpy(policy_obs_list).to(self._cfg.device) + batch_obs = torch.from_numpy(policy_obs_list).to(self._cfg.device) # =============== NOTE: The key difference with MuZero ================= # calculate the target value - # action_batch.shape (32, 10) - # m_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352 - m_output = model.initial_inference(m_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + # batch_action.shape (32, 10) + # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352 + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num # ======================================================================= if not model.training: @@ -366,9 +458,12 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: batch_target_policies_re = np.array(batch_target_policies_re) + # NOTE: TODO + model.world_model.reanalyze_phase = False + return batch_target_policies_re - def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch) -> Tuple[ + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, batch_action) -> Tuple[ Any, Any]: """ Overview: @@ -384,33 +479,17 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A to_play_segment = reward_value_context # noqa # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) transition_batch_size = len(value_obs_list) - game_segment_batch_size = len(pos_in_game_segment_list) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) - ] - # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) - ] - else: - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] batch_target_values, batch_rewards = [], [] with torch.no_grad(): value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) network_output = [] - m_obs = torch.from_numpy(value_obs_list).to(self._cfg.device) + batch_obs = torch.from_numpy(value_obs_list).to(self._cfg.device) # =============== NOTE: The key difference with MuZero ================= # calculate the target value - # m_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 - m_output = model.initial_inference(m_obs, action_batch) + # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 + m_output = model.initial_inference(batch_obs, batch_action) # ====================================================================== if not model.training: @@ -422,44 +501,16 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A m_output.policy_logits ] ) + network_output.append(m_output) - # concat the output slices after model inference - if self._cfg.use_root_value: - # use the root values from MCTS, as in EfficientZero - # the root values have limited improvement but require much more GPU actors; - _, reward_pool, policy_logits_pool, latent_state_roots = concat_output( - network_output, data_type='muzero' - ) - reward_pool = reward_pool.squeeze().tolist() - policy_logits_pool = policy_logits_pool.tolist() - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(transition_batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) - else: - # python mcts_tree - roots = MCTSPtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) - - roots_values = roots.get_values() - value_list = np.array(roots_values) - else: - # use the predicted values - value_list = concat_output_value(network_output) + # use the predicted values + value_numpy = concat_output_value(network_output) # get last state value if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: # TODO(pu): for board_games, very important, to check - value_list = value_list.reshape(-1) * np.array( + value_numpy = value_numpy.reshape(-1) * np.array( [ self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % 2 == 0 else -self._cfg.discount_factor ** @@ -468,12 +519,12 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A ] ) else: - value_list = value_list.reshape(-1) * ( + value_numpy = value_numpy.reshape(-1) * ( np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list ) - value_list = value_list * np.array(value_mask) - value_list = value_list.tolist() + value_numpy= value_numpy * np.array(value_mask) + value_list = value_numpy.tolist() horizon_id, value_index = 0, 0 for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, @@ -499,8 +550,8 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A target_values.append(value_list[value_index]) target_rewards.append(reward_list[current_index]) else: - target_values.append(np.array([0.])) - target_rewards.append(np.array([0.])) + target_values.append(np.array(0.)) + target_rewards.append(np.array(0.)) value_index += 1 batch_rewards.append(target_rewards) @@ -508,7 +559,5 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_rewards = np.asarray(batch_rewards) batch_target_values = np.asarray(batch_target_values) - batch_rewards = np.squeeze(batch_rewards, axis=-1) - batch_target_values = np.squeeze(batch_target_values, axis=-1) return batch_rewards, batch_target_values diff --git a/lzero/model/common.py b/lzero/model/common.py index 30c860b6c..6f38e2f44 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -341,7 +341,12 @@ def __init__( self.activation = activation self.embedding_dim = embedding_dim - self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) + + if self.observation_shape[1] == 64: + self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) + + elif self.observation_shape[1] == 96: + self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) self.sim_norm = SimNorm(simnorm_dim=group_size) @@ -365,7 +370,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Important: Transform the output feature plane to the latent state. # For example, for an Atari feature plane of shape (64, 8, 8), # flattening results in a size of 4096, which is then transformed to 768. - x = self.last_linear(x.reshape(-1, 64 * 8 * 8)) + x = self.last_linear(x.view(x.size(0), -1)) + x = x.view(-1, self.embedding_dim) # NOTE: very important for training stability. diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index df50088ce..b92b138f1 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -98,24 +98,17 @@ def __init__( embedding_dim=world_model_cfg.embed_dim, group_size=world_model_cfg.group_size, ) - # TODO: we should change the output_shape to the real observation shape - self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64)) # ====== for analysis ====== if world_model_cfg.analysis_sim_norm: self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - - self.tokenizer = Tokenizer(encoder=self.representation_network, - decoder_network=self.decoder_network, with_lpips=True,) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') - print('==' * 20) print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') print('==' * 20) elif world_model_cfg.obs_type == 'image_memory': self.representation_network = LatentEncoderForMemoryEnv( diff --git a/lzero/model/unizero_world_models/kv_caching.py b/lzero/model/unizero_world_models/kv_caching.py index 6f5afb9a8..ac83b6113 100644 --- a/lzero/model/unizero_world_models/kv_caching.py +++ b/lzero/model/unizero_world_models/kv_caching.py @@ -62,17 +62,17 @@ def get(self) -> torch.Tensor: """ return self._cache[:, :, :self._size, :] - def update(self, x: torch.Tensor) -> None: + def update(self, x: torch.Tensor, tokens: int) -> None: """ Overview: Update the cache with new values. Arguments: - x (:obj:`torch.Tensor`): The new values to update the cache with. """ - assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)]) - assert self._size + x.size(2) <= self._cache.shape[2] # TODO - self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + x.size(2)) - self._size += x.size(2) + # assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)]) + # assert self._size + tokens <= self._cache.shape[2] # TODO + self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + tokens) + self._size += tokens class KVCache: @@ -136,8 +136,8 @@ def update(self, k: torch.Tensor, v: torch.Tensor): - k (:obj:`torch.Tensor`): The new values to update the key cache with. - v (:obj:`torch.Tensor`): The new values to update the value cache with. """ - self._k_cache.update(k) - self._v_cache.update(v) + self._k_cache.update(k, k.size(2)) + self._v_cache.update(v, v.size(2)) class KeysValues: @@ -203,23 +203,6 @@ def prune(self, mask: np.ndarray) -> None: for kv_cache in self._keys_values: kv_cache.prune(mask) - def to_device(self, device: str): - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - Not used in the current implementation. - - Arguments: - - self._keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - device = torch.device(device if torch.cuda.is_available() else 'cpu') - for kv_cache in self._keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return self._keys_values - class AssignWithoutInplaceCheck(torch.autograd.Function): """ diff --git a/lzero/model/unizero_world_models/lpips.py b/lzero/model/unizero_world_models/lpips.py index 0ebc7bfc4..c6ee6426c 100644 --- a/lzero/model/unizero_world_models/lpips.py +++ b/lzero/model/unizero_world_models/lpips.py @@ -21,16 +21,16 @@ def __init__(self, use_dropout: bool = True): self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] # vg16 features # Comment out the following line if you don't need perceptual loss - self.net = vgg16(pretrained=True, requires_grad=False) + # self.net = vgg16(pretrained=True, requires_grad=False) self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) # Comment out the following line if you don't need perceptual loss - self.load_from_pretrained() - for param in self.parameters(): - param.requires_grad = False + # self.load_from_pretrained() + # for param in self.parameters(): + # param.requires_grad = False def load_from_pretrained(self) -> None: ckpt = get_ckpt_path(name="vgg_lpips", root=Path.home() / ".cache/iris/tokenizer_pretrained_vgg") # Download VGG if necessary diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 714bc13d6..62536c892 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -214,8 +214,8 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) if kv_cache is not None: - kv_cache.update(k, v) - k, v = kv_cache.get() + kv_cache.update(k, v) # time occupancy 21% + k, v = kv_cache.get() # time occupancy 5% att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index bb00e8947..fbca4c242 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -1,13 +1,123 @@ import hashlib +import xxhash from dataclasses import dataclass import numpy as np import torch import torch.nn as nn - +import time from .kv_caching import KeysValues +def custom_copy_kv_cache_to_dict_speed(src_kv: KeysValues, dst_dict: dict, cache_key: str, reuse_cache: bool = True) -> None: + """ + Overview: + Efficiently copy the contents of a KeysValues object to a new entry in a dictionary. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object to copy from. + - dst_dict (:obj:`dict`): The destination dictionary to copy to. + - cache_key (:obj:`str`): The key for the new entry in the destination dictionary. + - reuse_cache (:obj:`bool`, optional): Whether to reuse the existing cache if the cache_key already exists. + If True, the existing cache will not be overwritten. + If False, the cache will be overwritten every time. + Default: True. + """ + if reuse_cache and cache_key in dst_dict: + print(f"Cache key '{cache_key}' already exists in the destination dictionary. Reusing the existing cache.") + print(f"Dictionary size: {len(dst_dict)}") + return + + start_time = time.time() + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + dst_kv = KeysValues( + src_kv_shape[0], # n + src_kv_shape[1], # num_heads + src_kv_shape[2], # max_tokens + src_kv_shape[3] * src_kv_shape[1], # embed_dim + len(src_kv._keys_values), # num_layers + src_kv._keys_values[0]._k_cache._cache.device, # device + ) + shape_time = time.time() - start_time + + start_time = time.time() + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + copy_time = time.time() - start_time + + dst_dict[cache_key] = dst_kv + + print(f"Shape initialization time: {shape_time:.6f} seconds") + print(f"Cache copy time: {copy_time:.6f} seconds") + print(f"Total time: {shape_time + copy_time:.6f} seconds") + + # print(f"Cache key '{cache_key}' has been copied to the destination dictionary.") + # print(f"Dictionary size: {len(dst_dict)}") + + +def custom_copy_kv_cache_to_dict(src_kv: KeysValues, dst_dict: dict, cache_key: str, reuse_cache: bool = True) -> None: + """ + Overview: + Efficiently copy the contents of a KeysValues object to a new entry in a dictionary. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object to copy from. + - dst_dict (:obj:`dict`): The destination dictionary to copy to. + - cache_key (:obj:`str`): The key for the new entry in the destination dictionary. + - reuse_cache (:obj:`bool`, optional): Whether to reuse the existing cache if the cache_key already exists. + If True, the existing cache will not be overwritten. + If False, the cache will be overwritten every time. + Default: True. + """ + if reuse_cache and cache_key in dst_dict: + print(f"Cache key '{cache_key}' already exists in the destination dictionary. Reusing the existing cache.") + print(f"Dictionary size: {len(dst_dict)}") + return + + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + dst_kv = KeysValues( + src_kv_shape[0], # n + src_kv_shape[1], # num_heads + src_kv_shape[2], # max_tokens + src_kv_shape[3] * src_kv_shape[1], # embed_dim + len(src_kv._keys_values), # num_layers + src_kv._keys_values[0]._k_cache._cache.device, # device + ) + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + dst_dict[cache_key] = dst_kv + + +def custom_copy_kv_cache(src_kv: KeysValues) -> KeysValues: + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + dst_kv = KeysValues( + src_kv_shape[0], # n + src_kv_shape[1], # num_heads + src_kv_shape[2], # max_tokens + src_kv_shape[3] * src_kv_shape[1], # embed_dim + len(src_kv), # num_layers + src_kv._keys_values[0]._k_cache._cache.device, # device + ) + + # with torch.no_grad(): + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + return dst_kv + + def to_device_for_kvcache(keys_values: KeysValues, device: str) -> KeysValues: """ Transfer all KVCache objects within the KeysValues object to a certain device. @@ -18,11 +128,13 @@ def to_device_for_kvcache(keys_values: KeysValues, device: str) -> KeysValues: Returns: - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. """ - device = torch.device(device if torch.cuda.is_available() else 'cpu') + target_device = torch.device(device) for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) + if kv_cache._k_cache._cache.device != target_device: + kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(target_device) + if kv_cache._v_cache._cache.device != target_device: + kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(target_device) return keys_values @@ -67,23 +179,17 @@ def calculate_cuda_memory_gb(past_keys_values_cache, num_layers: int): total_memory_gb = total_memory_bytes / (1024 ** 3) return total_memory_gb - -def quantize_state(state, num_buckets=100): +def hash_state(state): """ - Quantize the state vector. + Hash the state vector. Arguments: - state: The state vector to be quantized. - num_buckets: The number of quantization buckets. + state: The state vector to be hashed. Returns: - The hash value of the quantized state vector. + The hash value of the state vector. """ - # Use np.digitize to map each dimension value of the state vector into num_buckets - quantized_state = np.digitize(state, bins=np.linspace(0, 1, num=num_buckets)) - # Use a more stable hash function - quantized_state_bytes = quantized_state.tobytes() - hash_object = hashlib.sha256(quantized_state_bytes) - return hash_object.hexdigest() + # Use xxhash for faster hashing + return xxhash.xxh64(state).hexdigest() @dataclass class WorldModelOutput: diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index 42ba540cf..9d6c6f286 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -1,5 +1,3 @@ -import collections -import copy import logging from typing import Any, Tuple from typing import Optional @@ -10,15 +8,16 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange +from torch.distributions import Categorical, Independent, Normal from lzero.model.common import SimNorm from lzero.model.utils import cal_dormant_ratio +from .kv_caching import KeysValues from .slicer import Head, PolicyHeadCont from .tokenizer import Tokenizer from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights, to_device_for_kvcache -from .utils import WorldModelOutput, quantize_state -from torch.distributions import Categorical, Independent, Normal +from .utils import LossWithIntermediateLosses, init_weights +from .utils import WorldModelOutput, hash_state logging.getLogger().setLevel(logging.DEBUG) @@ -61,6 +60,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: # Initialize patterns for block masks self._initialize_patterns() + self.hidden_size = config.embed_dim // config.num_heads + # Position embedding self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) self.precompute_pos_emb_diff_kv() @@ -110,6 +111,138 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: # Initialize keys and values for transformer self._initialize_transformer_keys_values() + # TODO: check + self.latent_recon_loss = torch.tensor(0., device=self.device) + self.perceptual_loss = torch.tensor(0., device=self.device) + + # TODO: check + # for self.kv_cache_recurrent_infer + # If needed, recurrent_infer should store the results of the one MCTS search. + self.shared_pool_size = int(50*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size + self.shared_pool_index = 0 + + # for self.kv_cache_init_infer + # In contrast, init_infer only needs to retain the results of the most recent step. + self.shared_pool_size_init = int(2*self.env_num) + self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] + + # for self.kv_cache_wm + self.shared_pool_size_wm = int(self.env_num) + self.shared_pool_wm = [None] * self.shared_pool_size_wm + self.shared_pool_index_wm = 0 + + self.reanalyze_phase = False + + #@profile + def custom_copy_kv_cache_to_shared_init_envs(self, src_kv: KeysValues, env_id) -> int: + """ + Overview: + Efficiently copy the contents of a KeysValues object to the shared pool. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object to copy from. + Returns: + - index (:obj:`int`): The index of the copied KeysValues object in the shared pool. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] is None: + self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] = KeysValues( + src_kv_shape[0], # n + src_kv_shape[1], # num_heads + src_kv_shape[2], # max_tokens + src_kv_shape[3] * src_kv_shape[1], # embed_dim + len(src_kv), # num_layers + src_kv._keys_values[0]._k_cache._cache.device, # device + ) + + dst_kv = self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + index = self.shared_pool_index_init_envs[env_id] + self.shared_pool_index_init_envs[env_id] = (self.shared_pool_index_init_envs[env_id] + 1) % self.shared_pool_size_init + + return index + + #@profile + def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: + """ + Overview: + Efficiently copy the contents of a KeysValues object to the shared pool. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object to copy from. + Returns: + - index (:obj:`int`): The index of the copied KeysValues object in the shared pool. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_wm[self.shared_pool_index_wm] is None: + self.shared_pool_wm[self.shared_pool_index_wm] = KeysValues( + src_kv_shape[0], # n + src_kv_shape[1], # num_heads + src_kv_shape[2], # max_tokens + src_kv_shape[3] * src_kv_shape[1], # embed_dim + len(src_kv), # num_layers + src_kv._keys_values[0]._k_cache._cache.device, # device + ) + + dst_kv = self.shared_pool_wm[self.shared_pool_index_wm] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + self.shared_pool_index_wm = (self.shared_pool_index_wm + 1) % self.shared_pool_size_wm + + return dst_kv + + #@profile + def custom_copy_kv_cache_to_shared_recur(self, src_kv: KeysValues) -> int: + """ + Overview: + Efficiently copy the contents of a KeysValues object to the shared pool. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object to copy from. + Returns: + - index (:obj:`int`): The index of the copied KeysValues object in the shared pool. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_recur_infer[self.shared_pool_index] is None: + self.shared_pool_recur_infer[self.shared_pool_index] = KeysValues( + src_kv_shape[0], # n + src_kv_shape[1], # num_heads + src_kv_shape[2], # max_tokens + src_kv_shape[3] * src_kv_shape[1], # embed_dim + len(src_kv), # num_layers + src_kv._keys_values[0]._k_cache._cache.device, # device + ) + + dst_kv = self.shared_pool_recur_infer[self.shared_pool_index] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + index = self.shared_pool_index + self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size + + return index + + def _initialize_config_parameters(self) -> None: """Initialize configuration parameters.""" self.policy_entropy_weight = self.config.policy_entropy_weight @@ -195,9 +328,14 @@ def _initialize_last_layer(self) -> None: def _initialize_cache_structures(self) -> None: """Initialize cache structures for past keys and values.""" - self.past_kv_cache_recurrent_infer = collections.OrderedDict() - self.past_kv_cache_init_infer = collections.OrderedDict() - self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)] + # self.past_kv_cache_init_infer = collections.OrderedDict() + # self.past_kv_cache_recurrent_infer = collections.OrderedDict() + # self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)] + # TODO: check + from collections import defaultdict + self.past_kv_cache_recurrent_infer = defaultdict(dict) + self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] + self.keys_values_wm_list = [] self.keys_values_wm_size_list = [] @@ -221,6 +359,8 @@ def _initialize_transformer_keys_values(self) -> None: """Initialize keys and values for the transformer.""" self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + self.keys_values_wm_single_env_tmp = self.transformer.generate_empty_keys_values(n=1, + max_tokens=self.context_length) self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.context_length) @@ -261,6 +401,7 @@ def precompute_pos_emb_diff_kv(self): self.pos_emb_diff_k.append(layer_pos_emb_diff_k) self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + #@profile def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: """ Helper function to get positional embedding for a given layer and attention type. @@ -282,6 +423,7 @@ def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads ).transpose(1, 2).detach() + #@profile def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], past_keys_values: Optional[torch.Tensor] = None, kvcache_independent: bool = False, is_init_infer: bool = True, @@ -353,6 +495,7 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu # logits_ends is None return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + #@profile def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths): """ @@ -381,6 +524,7 @@ def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_in valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) return embeddings + position_embeddings + #@profile def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -417,6 +561,7 @@ def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_step return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + #@profile def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -446,6 +591,7 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + #@profile def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths): """ Pass sequences through the transformer. @@ -466,8 +612,9 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va else: return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + #@profile @torch.no_grad() - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: + def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: """ Reset the model state based on initial observations and actions. @@ -478,50 +625,51 @@ def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> to """ # Extract observations, actions, and current observations from the dictionary. if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - current_obs = obs_act_dict['current_obs'] + batch_obs = obs_act_dict['obs'] # obs_act_dict['obs'] is at timestep t + batch_action = obs_act_dict['action'] # obs_act_dict['action'] is at timestep t + batch_current_obs = obs_act_dict['current_obs'] # obs_act_dict['current_obs'] is at timestep t+1 # Encode observations to latent embeddings. - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations) + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs) - if current_obs is not None: + if batch_current_obs is not None: # ================ Collect and Evaluation Phase ================ # Encode current observations to latent embeddings - current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(current_obs) + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs) # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") self.latent_state = current_obs_embeddings - outputs_wm = self.refresh_kvs_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action, + outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action, current_obs_embeddings) else: # ================ calculate the target value in Train phase ================ self.latent_state = obs_embeddings - outputs_wm = self.refresh_kvs_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action, None) + outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action, None) return outputs_wm, self.latent_state + #@profile @torch.no_grad() - def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, - buffer_action=None, + def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTensor, + batch_action=None, current_obs_embeddings=None) -> torch.FloatTensor: """ Refresh key-value pairs with the initial latent state for inference. Arguments: - - latent_state (:obj:`torch.LongTensor`): The latent state embeddings. - - buffer_action (optional): Actions taken. + - last_obs_embeddings (:obj:`torch.LongTensor`): The latent state embeddings. + - batch_action (optional): Actions taken. - current_obs_embeddings (optional): Current observation embeddings. Returns: - torch.FloatTensor: The outputs from the world model. """ - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: + n, num_observations_tokens, _ = last_obs_embeddings.shape + if n <= self.env_num and current_obs_embeddings is not None: # ================ Collect and Evaluation Phase ================ if current_obs_embeddings is not None: if self.continuous_action_space: - first_step_flag = not isinstance(buffer_action[0], np.ndarray) + first_step_flag = not isinstance(batch_action[0], np.ndarray) else: - first_step_flag = max(buffer_action) == -1 + first_step_flag = max(batch_action) == -1 if first_step_flag: # First step in an episode self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], @@ -533,26 +681,31 @@ def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: tor # Copy and store keys_values_wm for a single environment self.update_cache_context(current_obs_embeddings, is_init_infer=True) else: - # Assume latest_state is the new latent_state, containing information from ready_env_num environments + # current_obs_embeddings is the new latent_state, containing information from ready_env_num environments ready_env_num = current_obs_embeddings.shape[0] self.keys_values_wm_list = [] self.keys_values_wm_size_list = [] for i in range(ready_env_num): # Retrieve latent state for a single environment - state_single_env = latent_state[i] - quantized_state = state_single_env.detach().cpu().numpy() - # Compute hash value using quantized state - cache_key = quantize_state(quantized_state) + # NOTE: len(last_obs_embeddings) may smaller than len(current_obs_embeddings), because some environments may have done + + state_single_env = last_obs_embeddings[i] + # Compute hash value using latent state for a single environment + cache_key = hash_state(state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor + # Retrieve cached value - matched_value = self.past_kv_cache_init_infer_envs[i].get(cache_key) + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None self.root_total_query_cnt += 1 if matched_value is not None: # If a matching value is found, add it to the list self.root_hit_cnt += 1 - # deepcopy is needed because forward modifies matched_value in place - self.keys_values_wm_list.append( - copy.deepcopy(to_device_for_kvcache(matched_value, self.device))) + # NOTE: deepcopy is needed because forward modifies matched_value in place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) self.keys_values_wm_size_list.append(matched_value.size) else: # Reset using zero values @@ -567,14 +720,14 @@ def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: tor # Input self.keys_values_wm_list, output self.keys_values_wm self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True) - buffer_action = buffer_action[:ready_env_num] + batch_action = batch_action[:ready_env_num] # # only for debug # if ready_env_num < self.env_num: # print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}') if self.continuous_action_space: - act_tokens = torch.from_numpy(np.array(buffer_action)).to(latent_state.device).unsqueeze(1) + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(1) else: - act_tokens = torch.from_numpy(np.array(buffer_action)).to(latent_state.device).unsqueeze(-1) + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(-1) outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, is_init_infer=True) @@ -584,26 +737,25 @@ def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: tor # Copy and store keys_values_wm for a single environment self.update_cache_context(current_obs_embeddings, is_init_infer=True) - # elif n > self.env_num and buffer_action is not None and current_obs_embeddings is None: - elif buffer_action is not None and current_obs_embeddings is None: + elif batch_action is not None and current_obs_embeddings is None: # ================ calculate the target value in Train phase ================ # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, + last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - latent_state = latent_state[:, :-1, :] - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) + last_obs_embeddings = last_obs_embeddings[:, :-1, :] + batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) if self.continuous_action_space: - act_tokens = buffer_action + act_tokens = batch_action else: - act_tokens = rearrange(buffer_action, 'b l -> b l 1') + act_tokens = rearrange(batch_action, 'b l -> b l 1') # select the last timestep for each sample # This will select the last column while keeping the dimensions unchanged, and the target policy/value in the final step itself is not used. last_steps_act = act_tokens[:, -1:, :] act_tokens = torch.cat((act_tokens, last_steps_act), dim=1) - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (latent_state, act_tokens)}) + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}) # select the last timestep for each sample last_steps_value = outputs_wm.logits_value[:, -1:, :] @@ -619,6 +771,7 @@ def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: tor return outputs_wm + #@profile @torch.no_grad() def forward_initial_inference(self, obs_act_dict): """ @@ -630,12 +783,13 @@ def forward_initial_inference(self, obs_act_dict): - tuple: A tuple containing output sequence, latent state, logits rewards, logits policy, and logits value. """ # UniZero has context in the root node - outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) + outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict) self.past_kv_cache_recurrent_infer.clear() return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile @torch.no_grad() def forward_recurrent_inference(self, state_action_history, simulation_index=0, latent_state_index_in_search_path=[]): @@ -678,7 +832,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) - # Trim and pad kv_cache + # Trim and pad kv_cache: modify self.keys_values_wm in-place self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) self.keys_values_wm_size_list_current = self.keys_values_wm_size_list @@ -718,9 +872,10 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, latent_state_index_in_search_path=latent_state_index_in_search_path ) - return ( - outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + + #@profile def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: """ Adjusts the key-value cache for each environment to ensure they all have the same size. @@ -773,6 +928,7 @@ def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: return self.keys_values_wm_size_list + #@profile def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, latent_state_index_in_search_path=[], valid_context_lengths=None): """ @@ -790,9 +946,7 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde return for i in range(latent_state.size(0)): # ============ Iterate over each environment ============ - state_single_env = latent_state[i] - quantized_state = state_single_env.detach().cpu().numpy() - cache_key = quantize_state(quantized_state) + cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) # latent_state[i] is torch.Tensor context_length = self.context_length if not is_init_infer: @@ -867,8 +1021,7 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde for layer in range(self.num_layers): # ============ Apply trimming and padding to each layer of kv_cache ============ - if self.keys_values_wm._keys_values[ - layer]._k_cache._size < context_length - 1: # Keep only the last self.context_length-1 timesteps of context + if self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1: # Keep only the last self.context_length-1 timesteps of context self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = \ self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze( 0) # Shape torch.Size([2, 100, 512]) @@ -909,13 +1062,15 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde if is_init_infer: # Store the latest key-value cache for initial inference - self.past_kv_cache_init_infer_envs[i][cache_key] = copy.deepcopy( - to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index else: # Store the latest key-value cache for recurrent inference - self.past_kv_cache_recurrent_infer[cache_key] = copy.deepcopy( - to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + #@profile def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, simulation_index: int = 0) -> list: """ @@ -934,21 +1089,29 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, """ for i in range(ready_env_num): self.total_query_count += 1 - state_single_env = latent_state[i] # Get the latent state for a single environment - cache_key = quantize_state(state_single_env) # Compute the hash value using the quantized state + state_single_env = latent_state[i] # latent_state[i] is np.array + cache_key = hash_state(state_single_env) - # Try to retrieve the cached value from past_kv_cache_init_infer_envs - matched_value = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if self.reanalyze_phase: + # TODO: check + matched_value = None + else: + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None - # If not found, try to retrieve from past_kv_cache_recurrent_infer - if matched_value is None: - matched_value = self.past_kv_cache_recurrent_infer.get(cache_key) + # If not found, try to retrieve from past_kv_cache_recurrent_infer + if matched_value is None: + matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] if matched_value is not None: # If a matching cache is found, add it to the lists self.hit_count += 1 # Perform a deep copy because the transformer's forward pass might modify matched_value in-place - self.keys_values_wm_list.append(copy.deepcopy(to_device_for_kvcache(matched_value, self.device))) + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) self.keys_values_wm_size_list.append(matched_value.size) else: # If no matching cache is found, generate a new one using zero reset @@ -964,6 +1127,8 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, return self.keys_values_wm_size_list + + #@profile def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, **kwargs: Any) -> LossWithIntermediateLosses: # Encode observations into latent state representations @@ -984,7 +1149,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), percentage=self.dormant_threshold) - self.past_kv_cache_init_infer.clear() + # self.past_kv_cache_init_infer.clear() self.past_kv_cache_recurrent_infer.clear() self.keys_values_wm_list.clear() torch.cuda.empty_cache() @@ -996,7 +1161,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar if self.obs_type == 'image': # Reconstruct observations from latent state representations - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) # ========== for visualization ========== # Uncomment the lines below for visual analysis @@ -1011,10 +1176,9 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Calculate reconstruction loss and perceptual loss # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 - latent_recon_loss = torch.tensor(0., device=batch['observations'].device, - dtype=batch['observations'].dtype) - perceptual_loss = torch.tensor(0., device=batch['observations'].device, - dtype=batch['observations'].dtype) + + latent_recon_loss = self.latent_recon_loss + perceptual_loss = self.perceptual_loss elif self.obs_type == 'vector': perceptual_loss = torch.tensor(0., device=batch['observations'].device, @@ -1063,7 +1227,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar dormant_ratio_world_model = cal_dormant_ratio(self, { 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, percentage=self.dormant_threshold) - self.past_kv_cache_init_infer.clear() + # self.past_kv_cache_init_infer.clear() self.past_kv_cache_recurrent_infer.clear() self.keys_values_wm_list.clear() torch.cuda.empty_cache() @@ -1088,7 +1252,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations']) # Compute labels for observations, rewards, and ends - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(target_obs_embeddings, + labels_observations, labels_rewards, _ = self.compute_labels_world_model(target_obs_embeddings, batch['rewards'], batch['ends'], batch['mask_padding']) @@ -1143,6 +1307,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + # ==== TODO: calculate the new priorities for each transition. ==== + # value_priority = L1Loss(reduction='none')(labels_value.squeeze(-1), outputs['logits_value'][:, 0]) + # value_priority = value_priority.data.cpu().numpy() + 1e-6 + # Compute timesteps timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) # Compute discount coefficients for each timestep @@ -1188,12 +1356,12 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar discounted_perceptual_loss = perceptual_loss # Calculate overall discounted loss - discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).mean() - discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).mean() - discounted_loss_value = (loss_value.view(-1, batch['actions'].shape[1]) * discounts).mean() - discounted_loss_policy = (loss_policy.view(-1, batch['actions'].shape[1]) * discounts).mean() - discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).mean() - discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).mean() + discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() + discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_value = (loss_value.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_policy = (loss_policy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() if self.continuous_action_space: return LossWithIntermediateLosses( @@ -1237,6 +1405,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar latent_state_l2_norms=latent_state_l2_norms, ) + #@profile def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[ torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -1301,6 +1470,7 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[ return policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma + #@profile def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): # Assume outputs is an object with logits attributes like 'rewards', 'policy', and 'value'. # labels is a target tensor for comparison. batch is a dictionary with a mask indicating valid timesteps. @@ -1327,6 +1497,7 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): return loss + #@profile def compute_policy_entropy_loss(self, logits, mask): # Compute entropy of the policy probs = torch.softmax(logits, dim=1) @@ -1336,9 +1507,10 @@ def compute_policy_entropy_loss(self, logits, mask): entropy_loss = (entropy * mask) return entropy_loss + #@profile def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag + # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag mask_fill = torch.logical_not(mask_padding) # Prepare observation labels @@ -1349,10 +1521,13 @@ def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torc labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) # Fill the masked areas of ends - labels_ends = ends.masked_fill(mask_fill, -100) + # labels_endgs = ends.masked_fill(mask_fill, -100) + + # return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + return labels_observations, labels_rewards.view(-1, self.support_size), None - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + #@profile def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute labels for value and policy predictions. """ @@ -1375,7 +1550,7 @@ def clear_caches(self): """ Clears the caches of the world model. """ - self.past_kv_cache_init_infer.clear() + # self.past_kv_cache_init_infer.clear() for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: kv_cache_dict_env.clear() self.past_kv_cache_recurrent_infer.clear() diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 94e007edd..cd479d5f9 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -154,8 +154,10 @@ class MuZeroPolicy(Policy): momentum=0.9, # (float) The maximum constraint value of gradient norm clipping. grad_clip_value=10, - # (int) The number of episodes in each collecting stage. + # (int) The number of episodes in each collecting stage when use muzero_collector. n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, # (int) the number of simulations in MCTS. num_simulations=50, # (float) Discount factor (gamma) for returns. diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index a8f318c8b..1a64ee5ae 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -182,7 +182,7 @@ class UniZeroPolicy(MuZeroPolicy): replay_ratio=0.25, # (int) Minibatch size for one gradient descent. batch_size=256, - # (str) Optimizer for training policy network. ['SGD', 'Adam'] + # (str) Optimizer for training policy network. optim_type='AdamW', # (float) Learning rate for training policy network. Initial lr for manually decay schedule. learning_rate=0.0001, @@ -198,8 +198,10 @@ class UniZeroPolicy(MuZeroPolicy): momentum=0.9, # (float) The maximum constraint value of gradient norm clipping. grad_clip_value=5, - # (int) The number of episodes in each collecting stage. + # (int) The number of episodes in each collecting stage when use muzero_collector. n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, # (int) the number of simulations in MCTS. num_simulations=50, # (float) Discount factor (gamma) for returns. @@ -214,8 +216,6 @@ class UniZeroPolicy(MuZeroPolicy): value_loss_weight=0.25, # (float) The weight of policy loss. policy_loss_weight=1, - # (float) The weight of policy entropy loss. - policy_entropy_loss_weight=0, # (float) The weight of ssl (self-supervised learning) loss. ssl_loss_weight=0, # (bool) Whether to use piecewise constant learning rate decay. @@ -354,7 +354,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in if self._cfg.model.frame_stack_num == 4: obs_batch, obs_target_batch = prepare_obs_stack4_for_unizero(obs_batch_ori, self._cfg) else: - obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) # TODO: optimize # Apply augmentations if needed if self._cfg.use_augmentation: @@ -368,12 +368,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in data_list = [mask_batch, target_reward, target_value, target_policy, weights] mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, self._cfg.device) - target_reward = target_reward.view(self._cfg.batch_size, -1) target_value = target_value.view(self._cfg.batch_size, -1) - assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) - # Transform rewards and values to their scaled forms transformed_target_reward = scalar_transform(target_reward) transformed_target_value = scalar_transform(target_value) @@ -404,7 +401,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Extract valid target policy data and compute entropy valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) - average_target_policy_entropy = target_policy_entropy.mean().item() + average_target_policy_entropy = target_policy_entropy.mean() # Update world model losses = self._learn_model.world_model.compute_loss( @@ -448,10 +445,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() self._target_model.encoder_hook.clear_data() - if self._cfg.multi_gpu: - self.sync_gradients(self._learn_model) total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value) + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) self._optimizer_world_model.step() if self._cfg.lr_piecewise_constant_decay: @@ -492,24 +489,24 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'collect_epsilon': self._collect_epsilon, 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], 'weighted_total_loss': weighted_total_loss.item(), - 'obs_loss': obs_loss, - 'latent_recon_loss': latent_recon_loss, - 'perceptual_loss': perceptual_loss, - 'policy_loss': policy_loss, - 'orig_policy_loss': orig_policy_loss, - 'policy_entropy': policy_entropy, - 'target_policy_entropy': average_target_policy_entropy, - 'reward_loss': reward_loss, - 'value_loss': value_loss, - 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO + 'obs_loss': obs_loss.item(), + 'latent_recon_loss': latent_recon_loss.item(), + 'perceptual_loss': perceptual_loss.item(), + 'policy_loss': policy_loss.item(), + 'orig_policy_loss': orig_policy_loss.item(), + 'policy_entropy': policy_entropy.item(), + 'target_policy_entropy': average_target_policy_entropy.item(), + 'reward_loss': reward_loss.item(), + 'value_loss': value_loss.item(), + # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO 'target_reward': target_reward.mean().item(), 'target_value': target_value.mean().item(), 'transformed_target_reward': transformed_target_reward.mean().item(), 'transformed_target_value': transformed_target_value.mean().item(), 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), - 'analysis/dormant_ratio_encoder': dormant_ratio_encoder, - 'analysis/dormant_ratio_world_model': dormant_ratio_world_model, - 'analysis/latent_state_l2_norms': latent_state_l2_norms, + 'analysis/dormant_ratio_encoder': dormant_ratio_encoder.item(), + 'analysis/dormant_ratio_world_model': dormant_ratio_world_model.item(), + 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), 'analysis/l2_norm_before': self.l2_norm_before, 'analysis/l2_norm_after': self.l2_norm_after, 'analysis/grad_norm_before': self.grad_norm_before, @@ -660,6 +657,12 @@ def _forward_collect( self.last_batch_obs = data self.last_batch_action = batch_action + # ========= TODO: for muzero_segment_collector now ========= + if active_collect_env_num < self.collector_env_num: + print('='*20) + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True) + return output def _init_eval(self) -> None: @@ -798,7 +801,6 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in # Clear various caches in the collect model's world model world_model = self._collect_model.world_model - world_model.past_kv_cache_init_infer.clear() for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: kv_cache_dict_env.clear() world_model.past_kv_cache_recurrent_infer.clear() @@ -843,7 +845,6 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ # Clear various caches in the eval model's world model world_model = self._eval_model.world_model - world_model.past_kv_cache_init_infer.clear() for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: kv_cache_dict_env.clear() world_model.past_kv_cache_recurrent_infer.clear() diff --git a/lzero/worker/__init__.py b/lzero/worker/__init__.py index b74e1e745..ece5213be 100644 --- a/lzero/worker/__init__.py +++ b/lzero/worker/__init__.py @@ -1,4 +1,5 @@ from .alphazero_collector import AlphaZeroCollector from .alphazero_evaluator import AlphaZeroEvaluator from .muzero_collector import MuZeroCollector +from .muzero_segment_collector import MuZeroSegmentCollector from .muzero_evaluator import MuZeroEvaluator diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 9933f816e..ff9863817 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -333,11 +333,10 @@ def collect(self, collected_episode = 0 collected_step = 0 env_nums = self._env_num + retry_waiting_time = 0.05 # initializations init_obs = self._env.ready_obs - - retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to # len(self._env.ready_obs), especially in tictactoe env. @@ -369,7 +368,6 @@ def collect(self, [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], maxlen=self.policy_config.model.frame_stack_num ) - game_segments[env_id].reset(observation_window_stack[env_id]) dones = np.array([False for _ in range(env_nums)]) @@ -400,9 +398,20 @@ def collect(self, with self._timer: # Get current ready env obs. obs = self._env.ready_obs - new_available_env_id = set(obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) + + ready_env_id = set(obs.keys()) + while len(obs.keys()) != self._env_num: + # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + time.sleep(retry_waiting_time) + self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) + ) + obs = self._env.ready_obs + ready_env_id = set(obs.keys()) stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) @@ -638,7 +647,7 @@ def collect(self, init_obs = self._env.ready_obs retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to # len(self._env.ready_obs), especially in tictactoe env. self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) @@ -692,6 +701,15 @@ def collect(self, self._reset_stat(env_id) ready_env_id.remove(env_id) + # ===== NOTE: if one episode done not return ======= + # create new GameSegment + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + game_segments[env_id].reset(observation_window_stack[env_id]) + if collected_episode >= n_episode: # [data, meta_data] return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py new file mode 100644 index 000000000..2d4bb05ef --- /dev/null +++ b/lzero/worker/muzero_segment_collector.py @@ -0,0 +1,766 @@ +import time +from collections import deque, namedtuple +from typing import Optional, Any, List + +import numpy as np +import torch +from ding.envs import BaseEnvManager +from ding.torch_utils import to_ndarray +from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size, \ + allreduce_data +from ding.worker.collector.base_serial_collector import ISerialCollector +from torch.nn import L1Loss + +from lzero.mcts.buffer.game_segment import GameSegment +from lzero.mcts.utils import prepare_observation + + +@SERIAL_COLLECTOR_REGISTRY.register('segment_muzero') +class MuZeroSegmentCollector(ISerialCollector): + """ + Overview: + The Collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel MuZero. + It manages the data collection process for training these algorithms using a serial mechanism. + Interfaces: + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``, + ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close`` + Properties: + ``envstep`` + """ + + # TO be compatible with ISerialCollector + config = dict() + + def __init__( + self, + collect_print_freq: int = 100, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'collector', + policy_config: 'policy_config' = None, # noqa + ) -> None: + """ + Overview: + Initialize the MuZeroCollector with the given parameters. + Arguments: + - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. + - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. + - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. + - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. + - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. + - instance_name (:obj:`str`): Unique identifier for this collector instance. + - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + """ + self._exp_name = exp_name + self._instance_name = instance_name + self._collect_print_freq = collect_print_freq + self._timer = EasyTimer() + self._end_flag = False + + self._rank = get_rank() + self._world_size = get_world_size() + if self._rank == 0: + if tb_logger is not None: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), + name=self._instance_name, + need_tb=False + ) + self._tb_logger = tb_logger + else: + self._logger, self._tb_logger = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + ) + else: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + ) + self._tb_logger = None + + self.policy_config = policy_config + self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy + + self.reset(policy, env) + + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset or replace the environment managed by this collector. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. + Arguments: + - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided. + """ + if _env is not None: + self._env = _env + self._env.launch() + self._env_num = self._env.env_num + else: + self._env.reset() + + def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: + """ + Overview: + Reset or replace the policy used by this collector. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. + Arguments: + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + """ + assert hasattr(self, '_env'), "please set env first" + if _policy is not None: + self._policy = _policy + + self._default_num_segments = _policy.get_attribute('cfg').get('num_segments', None) + self._logger.debug( + 'Set default num_segments mode(num_segments({}), env_num({}))'.format(self._default_num_segments, self._env_num) + ) + self._policy.reset() + + def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset the collector with the given policy and/or environment. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. + Arguments: + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ + env_manager(BaseEnvManager) + """ + if _env is not None: + self.reset_env(_env) + if _policy is not None: + self.reset_policy(_policy) + + self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + + # 在此处初始化action_mask_dict, to_play_dict和chance_dict,确保它们包含所有env_id的值 + self.action_mask_dict = {i: None for i in range(self._env_num)} + self.to_play_dict = {i: None for i in range(self._env_num)} + if self.policy_config.use_ture_chance_label_in_chance_encoder: + self.chance_dict = {i: None for i in range(self._env_num)} + + self.dones = np.array([False for _ in range(self._env_num)]) + self.last_game_segments = [None for _ in range(self._env_num)] + self.last_game_priorities = [None for _ in range(self._env_num)] + + self._episode_info = [] + self._total_envstep_count = 0 + self._total_episode_count = 0 + self._total_duration = 0 + self._last_train_iter = 0 + self._end_flag = False + + # A game_segment_pool implementation based on the deque structure. + self.game_segment_pool = deque(maxlen=int(1e6)) + self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + + def _reset_stat(self, env_id: int) -> None: + """ + Overview: + Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \ + and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ + to get more messages. + Arguments: + - env_id (:obj:`int`): the id where we need to reset the collector's state + """ + self._env_info[env_id] = {'time': 0., 'step': 0} + + @property + def envstep(self) -> int: + """ + Overview: + Get the total number of environment steps collected. + Returns: + - envstep (:obj:`int`): Total number of environment steps collected. + """ + return self._total_envstep_count + + def close(self) -> None: + """ + Overview: + Close the collector. If end_flag is False, close the environment, flush the tb_logger \ + and close the tb_logger. + """ + if self._end_flag: + return + self._end_flag = True + self._env.close() + if self._tb_logger: + self._tb_logger.flush() + self._tb_logger.close() + + def __del__(self) -> None: + """ + Overview: + Execute the close command and close the collector. __del__ is automatically called to \ + destroy the collector instance when the collector finishes its work + """ + self.close() + + # ============================================================== + # MCTS+RL related core code + # ============================================================== + def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray: + """ + Overview: + Compute the priorities for transitions based on prediction and search value discrepancies. + Arguments: + - i (:obj:`int`): Index of the values in the list to compute the priority for. + - pred_values_lst (:obj:`List[float]`): List of predicted values. + - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS. + Returns: + - priorities (:obj:`np.ndarray`): Array of computed priorities. + """ + if self.policy_config.use_priority: + # Calculate priorities. The priorities are the L1 losses between the predicted + # values and the search values. We use 'none' as the reduction parameter, which + # means the loss is calculated for each element individually, instead of being summed or averaged. + # A small constant (1e-6) is added to the results to avoid zero priorities. This + # is done because zero priorities could potentially cause issues in some scenarios. + pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device + ).float().view(-1) + priorities = L1Loss(reduction='none' + )(pred_values, + search_values).detach().cpu().numpy() + 1e-6 + else: + # priorities is None -> use the max priority for all newly collected data + priorities = None + + return priorities + + def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], + last_game_priorities: List[np.ndarray], + game_segments: List[GameSegment], done: np.ndarray) -> None: + """ + Overview: + Save the game segment to the pool if the current game is finished, padding it if necessary. + Arguments: + - i (:obj:`int`): Index of the current game segment. + - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved. + - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments. + - game_segments (:obj:`List[GameSegment]`): List of the current game segments. + - done (:obj:`np.ndarray`): Array indicating whether each game is done. + Note: + (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + """ + # pad over last segment trajectory + beg_index = self.policy_config.model.frame_stack_num + end_index = beg_index + self.policy_config.num_unroll_steps + + # the start obs is init zero obs, so we take the + # [ : +] obs as the pad obs + # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] + pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps] + # EfficientZero original repo bug: + # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] + + beg_index = 0 + # self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + end_index = beg_index + self.unroll_plus_td_steps - 1 + + pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_lst = game_segments[i].chance_segment[beg_index:end_index] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps + + pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] + + if self.policy_config.gumbel_algo: + pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] + + # pad over and save + if self.policy_config.gumbel_algo: + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_improved_policy=pad_improved_policy_prob) + else: + if self.policy_config.use_ture_chance_label_in_chance_encoder: + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, + next_chances=chance_lst) + else: + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) + """ + Note: + game_segment element shape: + obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 + rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 + action: game_segment_length -> 20 + root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 + child_visits: game_segment_length + num_unroll_steps -> 20 +5 + to_play: game_segment_length -> 20 + action_mask: game_segment_length -> 20 + """ + + last_game_segments[i].game_segment_to_array() + + # put the game segment into the pool + self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) + + # reset last game_segments + last_game_segments[i] = None + last_game_priorities[i] = None + + + return None + + def collect(self, + num_segments: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[dict] = None, + collect_with_pure_policy: bool = False) -> List[Any]: + """ + Overview: + Collect `num_segments` segments of data with policy_kwargs, trained for `train_iter` iterations. + Arguments: + - num_segments (:obj:`Optional[int]`): Number of segments to collect. + - train_iter (:obj:`int`): Number of training iterations completed so far. + - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. + - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS. + Returns: + - return_data (:obj:`List[Any]`): Collected data in the form of a list. + """ + if num_segments is None: + if self._default_num_segments is None: + raise RuntimeError("Please specify collect num_segments") + else: + num_segments = self._default_num_segments + assert num_segments == self._env_num, "Please make sure num_segments == env_num{}/{}".format(num_segments, self._env_num) + + if policy_kwargs is None: + policy_kwargs = {} + temperature = policy_kwargs['temperature'] + epsilon = policy_kwargs['epsilon'] + + collected_episode = 0 + collected_step = 0 + env_nums = self._env_num + + # initializations + init_obs = self._env.ready_obs + + retry_waiting_time = 0.05 + while len(init_obs.keys()) != self._env_num: + # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + time.sleep(retry_waiting_time) + self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) + ) + init_obs = self._env.ready_obs + + for env_id in range(env_nums): + if env_id in init_obs.keys(): + self.action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) + self.to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + self.chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) + + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(env_nums) + ] + # stacked observation windows in reset stage for init game_segments + observation_window_stack = [[] for _ in range(env_nums)] + for env_id in range(env_nums): + observation_window_stack[env_id] = deque( + [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + + game_segments[env_id].reset(observation_window_stack[env_id]) + + # for priorities in self-play + search_values_lst = [[] for _ in range(env_nums)] + pred_values_lst = [[] for _ in range(env_nums)] + if self.policy_config.gumbel_algo: + improved_policy_lst = [[] for _ in range(env_nums)] + + # some logs + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) + if self.policy_config.gumbel_algo: + completed_value_lst = np.zeros(env_nums) + self_play_moves = 0. + self_play_episodes = 0. + self_play_moves_max = 0 + self_play_visit_entropy = [] + total_transitions = 0 + + if collect_with_pure_policy: + temp_visit_list = [0.0 for i in range(self._env.action_space.n)] + + while True: + with self._timer: + # Get current ready env obs. + obs = self._env.ready_obs + ready_env_id = set(obs.keys()) + if len(ready_env_id) < self._env_num: + print(f'ready_env_id: {ready_env_id}') + + while len(obs.keys()) != self._env_num: + # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + time.sleep(retry_waiting_time) + self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) + ) + obs = self._env.ready_obs + ready_env_id = set(obs.keys()) + + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + + self.action_mask_dict_tmp = {env_id: self.action_mask_dict[env_id] for env_id in ready_env_id} + self.to_play_dict_tmp = {env_id: self.to_play_dict[env_id] for env_id in ready_env_id} + + action_mask = [self.action_mask_dict_tmp[env_id] for env_id in ready_env_id] + to_play = [self.to_play_dict_tmp[env_id] for env_id in ready_env_id] + if self.policy_config.use_ture_chance_label_in_chance_encoder: + self.chance_dict_tmp = {env_id: self.chance_dict[env_id] for env_id in ready_env_id} + + stack_obs = to_ndarray(stack_obs) + # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96] + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device) + + # ============================================================== + # Key policy forward step + # ============================================================== + # print(f'ready_env_id:{ready_env_id}') + + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + + # Extract relevant policy outputs + actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} + value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} + pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} + + if self.policy_config.sampled_algo: + root_sampled_actions_dict_with_env_id = { + k: v['root_sampled_actions'] for k, v in policy_output.items() + } + + if not collect_with_pure_policy: + distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in + policy_output.items()} + visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in + policy_output.items()} + + if self.policy_config.gumbel_algo: + improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in + policy_output.items()} + completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} + + # Initialize dictionaries to store results + actions = {} + value_dict = {} + pred_value_dict = {} + + if not collect_with_pure_policy: + distributions_dict = {} + visit_entropy_dict = {} + + if self.policy_config.sampled_algo: + root_sampled_actions_dict = {} + + if self.policy_config.gumbel_algo: + improved_policy_dict = {} + completed_value_dict = {} + + # Populate the result dictionaries + for env_id in ready_env_id: + actions[env_id] = actions_with_env_id.pop(env_id) + value_dict[env_id] = value_dict_with_env_id.pop(env_id) + pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) + + if not collect_with_pure_policy: + distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) + + if self.policy_config.sampled_algo: + root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) + + visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) + + if self.policy_config.gumbel_algo: + improved_policy_dict[env_id] = improved_policy_dict_with_env_id.pop(env_id) + completed_value_dict[env_id] = completed_value_with_env_id.pop(env_id) + + # ============================================================== + # Interact with the environment + # ============================================================== + timesteps = self._env.step(actions) + + interaction_duration = self._timer.value / len(timesteps) + + for env_id, timestep in timesteps.items(): + with self._timer: + if timestep.info.get('abnormal', False): + # If there is an abnormal timestep, reset all the related variables(including this env). + # suppose there is no reset param, reset this env + self._env.reset({env_id: None}) + self._policy.reset([env_id]) + self._reset_stat(env_id) + self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info)) + continue + obs, reward, done, info = timestep.obs, timestep.reward, timestep.done, timestep.info + + if collect_with_pure_policy: + game_segments[env_id].store_search_stats(temp_visit_list, 0) + else: + if self.policy_config.sampled_algo: + game_segments[env_id].store_search_stats( + distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] + ) + elif self.policy_config.gumbel_algo: + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], + improved_policy=improved_policy_dict[env_id]) + else: + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + + # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} + # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` + if self.policy_config.use_ture_chance_label_in_chance_encoder: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id], + self.to_play_dict_tmp[env_id], self.chance_dict_tmp[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id], + self.to_play_dict_tmp[env_id] + ) + + # NOTE: the position of code snippet is very important. + # the obs['action_mask'] and obs['to_play'] are corresponding to the next action + self.action_mask_dict_tmp[env_id] = to_ndarray(obs['action_mask']) + self.to_play_dict_tmp[env_id] = to_ndarray(obs['to_play']) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + self.chance_dict_tmp[env_id] = to_ndarray(obs['chance']) + + if self.policy_config.ignore_done: + self.dones[env_id] = False + else: + self.dones[env_id] = done + + if not collect_with_pure_policy: + visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + if self.policy_config.gumbel_algo: + completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) + + eps_steps_lst[env_id] += 1 + if self._policy.get_attribute('cfg').type == 'unizero': + # ============ only for UniZero now ============ + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + + total_transitions += 1 + + if self.policy_config.use_priority: + pred_values_lst[env_id].append(pred_value_dict[env_id]) + search_values_lst[env_id].append(value_dict[env_id]) + if self.policy_config.gumbel_algo and not collect_with_pure_policy: + improved_policy_lst[env_id].append(improved_policy_dict[env_id]) + + # append the newest obs + observation_window_stack[env_id].append(to_ndarray(obs['observation'])) + + # ============================================================== + # we will save a game segment if it is the end of the game or the next game segment is finished. + # ============================================================== + + # if game segment is full, we will save the last game segment + if game_segments[env_id].is_full(): + # pad over last segment trajectory + if self.last_game_segments[env_id] is not None: + # TODO(pu): return the one game segment + self.pad_and_save_last_trajectory( + env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones + ) + + # calculate priority + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + if self.policy_config.gumbel_algo and not collect_with_pure_policy: + improved_policy_lst[env_id] = [] + + # the current game_segments become last_game_segment + self.last_game_segments[env_id] = game_segments[env_id] + self.last_game_priorities[env_id] = priorities + + # create new GameSegment + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + game_segments[env_id].reset(observation_window_stack[env_id]) + + self._env_info[env_id]['step'] += 1 + collected_step += 1 + + self._env_info[env_id]['time'] += self._timer.value + interaction_duration + # =========== NOTE: =========== + if timestep.done: + print(f'========env {env_id} done!========') + self._total_episode_count += 1 + + reward = timestep.info['eval_episode_return'] + info = { + 'reward': reward, + 'time': self._env_info[env_id]['time'], + 'step': self._env_info[env_id]['step'], + } + if not collect_with_pure_policy: + info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] + if self.policy_config.gumbel_algo: + info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] + + collected_episode += 1 + self._episode_info.append(info) + + # ============================================================== + # if it is the end of the game, we will save the game segment + # ============================================================== + + # NOTE: put the penultimate game segment in one episode into the trajectory_pool + # pad over 2th last game_segment using the last game_segment + if self.last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory( + env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones + ) + + # store current segment trajectory + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + + # NOTE: put the last game segment in one episode into the trajectory_pool + game_segments[env_id].game_segment_to_array() + + # assert len(game_segments[env_id]) == len(priorities) + # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null + if len(game_segments[env_id].reward_segment) != 0: + self.game_segment_pool.append((game_segments[env_id], priorities, self.dones[env_id])) + + # log + self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) + if not collect_with_pure_policy: + self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) + self_play_moves += eps_steps_lst[env_id] + self_play_episodes += 1 + + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + eps_steps_lst[env_id] = 0 + visit_entropies_lst[env_id] = 0 + + # Env reset is done by env_manager automatically + # NOTE: ============ reset the policy for the env_id. Default reset_init_data=True. ================ + self._policy.reset([env_id]) + self._reset_stat(env_id) + ready_env_id.remove(env_id) + + # ===== NOTE: if one episode done not return ======= + # create new GameSegment + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + game_segments[env_id].reset(observation_window_stack[env_id]) + + # NOTE: must after the for loop to make sure all env_id's data are collected + if len(self.game_segment_pool) >= self._default_num_segments: + print(f'collect {len(self.game_segment_pool)} segments now!') + + # [data, meta_data] + return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ + { + 'priorities': self.game_segment_pool[i][1], + 'done': self.game_segment_pool[i][2], + 'unroll_plus_td_steps': self.unroll_plus_td_steps + } for i in range(len(self.game_segment_pool)) + ] + self.game_segment_pool.clear() + break + + collected_duration = sum([d['time'] for d in self._episode_info]) + # reduce data when enables DDP + if self._world_size > 1: + collected_step = allreduce_data(collected_step, 'sum') + collected_episode = allreduce_data(collected_episode, 'sum') + collected_duration = allreduce_data(collected_duration, 'sum') + self._total_envstep_count += collected_step + self._total_episode_count += collected_episode + self._total_duration += collected_duration + + # log + self._output_log(train_iter) + return return_data + + def _output_log(self, train_iter: int) -> None: + """ + Overview: + Log the collector's data and output the log information. + Arguments: + - train_iter (:obj:`int`): Current training iteration number for logging context. + """ + if self._rank != 0: + return + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: + self._last_train_iter = train_iter + episode_count = len(self._episode_info) + envstep_count = sum([d['step'] for d in self._episode_info]) + duration = sum([d['time'] for d in self._episode_info]) + episode_reward = [d['reward'] for d in self._episode_info] + if not self.collect_with_pure_policy: + visit_entropy = [d['visit_entropy'] for d in self._episode_info] + else: + visit_entropy = [0.0] + if self.policy_config.gumbel_algo: + completed_value = [d['completed_value'] for d in self._episode_info] + self._total_duration += duration + info = { + 'episode_count': episode_count, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / episode_count, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_episode_per_sec': episode_count / duration, + 'collect_time': duration, + 'reward_mean': np.mean(episode_reward), + 'reward_std': np.std(episode_reward), + 'reward_max': np.max(episode_reward), + 'reward_min': np.min(episode_reward), + 'total_envstep_count': self._total_envstep_count, + 'total_episode_count': self._total_episode_count, + 'total_duration': self._total_duration, + 'visit_entropy': np.mean(visit_entropy), + } + if self.policy_config.gumbel_algo: + info['completed_value'] = np.mean(completed_value) + self._episode_info.clear() + self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + for k, v in info.items(): + if k in ['each_reward']: + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if k in ['total_envstep_count']: + continue + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index fa79c2e96..2cad7cc15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ pycolab pytest pooltool-billiards>=0.3.1 line_profiler +xxhash diff --git a/zoo/atari/config/atari_muzero_multigpu_ddp_config.py b/zoo/atari/config/atari_muzero_multigpu_ddp_config.py index e33c293b9..320d32522 100644 --- a/zoo/atari/config/atari_muzero_multigpu_ddp_config.py +++ b/zoo/atari/config/atari_muzero_multigpu_ddp_config.py @@ -34,6 +34,7 @@ manager=dict(shared_memory=False, ), ), policy=dict( + model_path=None, model=dict( observation_shape=(4, 96, 96), frame_stack_num=4, @@ -99,7 +100,7 @@ Overview: This script should be executed with GPUs. Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/atari/config/atari_muzero_multigpu_ddp_config.py + python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_muzero_multigpu_ddp_config.py """ from ding.utils import DDPContext from lzero.entry import train_muzero diff --git a/zoo/atari/config/atari_muzero_reanalyze_config.py b/zoo/atari/config/atari_muzero_reanalyze_config.py new file mode 100644 index 000000000..b3e7c099b --- /dev/null +++ b/zoo/atari/config/atari_muzero_reanalyze_config.py @@ -0,0 +1,130 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + +def main(env_id, seed): + action_space_size = atari_env_action_space_map[env_id] + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 8 + num_segments = 8 + game_segment_length = 20 + evaluator_env_num = 3 + num_simulations = 50 + update_per_collect = None + replay_ratio = 0.25 + num_unroll_steps = 5 + batch_size = 256 + max_env_step = int(2e5) + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq = 1/10 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=1 + + # =========== for debug =========== + # collector_env_num = 2 + # num_segments = 2 + # evaluator_env_num = 2 + # num_simulations = 2 + # update_per_collect = 2 + # batch_size = 2 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + atari_muzero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: debug + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + analysis_sim_norm=False, + cal_dormant_ratio=False, + model=dict( + observation_shape=(4, 96, 96), + image_channel=1, + frame_stack_num=4, + gray_scale=True, + action_space_size=action_space_size, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + use_sim_norm=True, # NOTE + use_sim_norm_kl_loss=False, + model_type='conv' + ), + cuda=True, + env_type='not_board_games', + num_segments=num_segments, + train_start_after_envsteps=2000, + game_segment_length=game_segment_length, + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=replay_ratio, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_simulations=num_simulations, + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + atari_muzero_config = EasyDict(atari_muzero_config) + main_config = atari_muzero_config + + atari_muzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + ) + atari_muzero_create_config = EasyDict(atari_muzero_create_config) + create_config = atari_muzero_create_config + + main_config.exp_name = f'data_muzero_reanalyze/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_seed{seed}' + # ============ use muzero_segment_collector instead of muzero_collector ============= + from lzero.entry import train_muzero_reanalyze + train_muzero_reanalyze([main_config, create_config], seed=seed, max_env_step=max_env_step) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process different environments and seeds.') + parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + main(args.env, args.seed) \ No newline at end of file diff --git a/zoo/atari/config/atari_rezero_mz_config.py b/zoo/atari/config/atari_rezero_mz_config.py index b3b42afc1..fbb900a7d 100644 --- a/zoo/atari/config/atari_rezero_mz_config.py +++ b/zoo/atari/config/atari_rezero_mz_config.py @@ -19,6 +19,17 @@ reuse_search = True collect_with_pure_policy = True buffer_reanalyze_freq = 1 + +# ====== only for debug ===== +# collector_env_num = 8 +# num_segments = 8 +# evaluator_env_num = 2 +# num_simulations = 5 +# max_env_step = int(2e5) +# reanalyze_ratio = 0.1 +# batch_size = 64 +# num_unroll_steps = 10 +# replay_ratio = 0.01 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -33,6 +44,9 @@ evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), + # # TODO: only for debug + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), ), policy=dict( model=dict( diff --git a/zoo/atari/config/atari_unizero_config.py b/zoo/atari/config/atari_unizero_config.py index 1c549010f..d29c2f40b 100644 --- a/zoo/atari/config/atari_unizero_config.py +++ b/zoo/atari/config/atari_unizero_config.py @@ -1,106 +1,120 @@ from easydict import EasyDict from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map -env_id = 'PongNoFrameskip-v4' # You can specify any Atari game here -action_space_size = atari_env_action_space_map[env_id] -# ============================================================== -# begin of the most frequently changed config specified by the user -# ============================================================== -update_per_collect = None -replay_ratio = 0.25 -collector_env_num = 8 -n_episode = 8 -evaluator_env_num = 3 -num_simulations = 50 -max_env_step = int(5e5) -reanalyze_ratio = 0. -batch_size = 64 -num_unroll_steps = 10 -infer_context_length = 4 +def main(env_id='PongNoFrameskip-v4', seed=0): + action_space_size = atari_env_action_space_map[env_id] -# ====== only for debug ===== -# collector_env_num = 2 -# n_episode = 2 -# evaluator_env_num = 2 -# num_simulations = 5 -# max_env_step = int(5e5) -# reanalyze_ratio = 0. -# batch_size = 2 -# num_unroll_steps = 10 -# ============================================================== -# end of the most frequently changed config specified by the user -# ============================================================== + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + replay_ratio = 1 + collector_env_num = 8 + game_segment_length = 20 + evaluator_env_num = 5 + num_simulations = 50 + max_env_step = int(5e5) + batch_size = 64 + num_unroll_steps = 10 + infer_context_length = 4 + num_layers = 4 -atari_unizero_config = dict( - env=dict( - stop_value=int(1e6), - env_id=env_id, - observation_shape=(3, 64, 64), - gray_scale=False, - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_evaluator_episode=evaluator_env_num, - manager=dict(shared_memory=False, ), - # TODO: only for debug - # collect_max_episode_steps=int(50), - # eval_max_episode_steps=int(50), - ), - policy=dict( - model=dict( - observation_shape=(3, 64, 64), - action_space_size=action_space_size, - world_model_cfg=dict( - max_blocks=num_unroll_steps, - max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action - context_length=2 * infer_context_length, - device='cuda', - # device='cpu', + # ====== only for debug ===== + # collector_env_num = 8 + # evaluator_env_num = 2 + # num_simulations = 5 + # max_env_step = int(2e5) + # reanalyze_ratio = 0.1 + # batch_size = 64 + # num_unroll_steps = 10 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + atari_unizero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 96, 96), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: only for debug + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + model=dict( + observation_shape=(3, 96, 96), action_space_size=action_space_size, - num_layers=4, - num_heads=8, - embed_dim=768, - obs_type='image', - env_num=max(collector_env_num, evaluator_env_num), + world_model_cfg=dict( + policy_entropy_weight=0, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + ), ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + use_augmentation=False, + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(5e4), + use_priority=False, + num_unroll_steps=num_unroll_steps, + update_per_collect=None, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + learning_rate=0.0001, + num_simulations=num_simulations, + train_start_after_envsteps=2000, + game_segment_length=game_segment_length, + grad_clip_value=20, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + ) + atari_unizero_config = EasyDict(atari_unizero_config) + main_config = atari_unizero_config + + atari_unizero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], ), - # (str) The path of the pretrained model. If None, the model will be initialized by the default model. - model_path=None, - num_unroll_steps=num_unroll_steps, - update_per_collect=update_per_collect, - replay_ratio=replay_ratio, - batch_size=batch_size, - optim_type='AdamW', - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - n_episode=n_episode, - replay_buffer_size=int(1e6), - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - ), -) -atari_unizero_config = EasyDict(atari_unizero_config) -main_config = atari_unizero_config + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + atari_unizero_create_config = EasyDict(atari_unizero_create_config) + create_config = atari_unizero_create_config + + main_config.exp_name = f'data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + from lzero.entry import train_unizero + train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) -atari_unizero_create_config = dict( - env=dict( - type='atari_lightzero', - import_names=['zoo.atari.envs.atari_lightzero_env'], - ), - env_manager=dict(type='subprocess'), - policy=dict( - type='unizero', - import_names=['lzero.policy.unizero'], - ), -) -atari_unizero_create_config = EasyDict(atari_unizero_create_config) -create_config = atari_unizero_create_config if __name__ == "__main__": - # Define a list of seeds for multiple runs - seeds = [0] # You can add more seed values here - for seed in seeds: - # Update exp_name to include the current seed - main_config.exp_name = f'data_unizero/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}' - from lzero.entry import train_unizero - train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + import argparse + parser = argparse.ArgumentParser(description='Process some environment.') + parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + main(args.env, args.seed) + diff --git a/zoo/atari/config/atari_unizero_multigpu_ddp_config.py b/zoo/atari/config/atari_unizero_multigpu_ddp_config.py new file mode 100644 index 000000000..82f64f141 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multigpu_ddp_config.py @@ -0,0 +1,116 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + +env_id = 'PongNoFrameskip-v4' # You can specify any Atari game here +action_space_size = atari_env_action_space_map[env_id] + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +gpu_num = 2 +update_per_collect = None +replay_ratio = 0.25 +collector_env_num = 8 +num_segments = int(8*gpu_num) +n_episode = int(8*gpu_num) +evaluator_env_num = 3 +num_simulations = 50 +max_env_step = int(2e5) +batch_size = 64 +num_unroll_steps = 10 +infer_context_length = 4 +seed = 0 + +# ====== only for debug ===== +# num_simulations = 2 +# max_env_step = int(2e5) +# batch_size = 2 +# num_unroll_steps = 10 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_unizero_config = dict( + exp_name = f'data_unizero/{env_id[:-14]}/{env_id[:-14]}_stack1_unizero_ddp_{gpu_num}gpu_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}', + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: only for debug + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + world_model_cfg=dict( + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + # device='cpu', + action_space_size=action_space_size, + num_layers=2, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + multi_gpu=True, + num_unroll_steps=num_unroll_steps, + update_per_collect=update_per_collect, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + num_simulations=num_simulations, + num_segments=num_segments, + n_episode=n_episode, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +atari_unizero_config = EasyDict(atari_unizero_config) +main_config = atari_unizero_config + +atari_unizero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), +) +atari_unizero_create_config = EasyDict(atari_unizero_create_config) +create_config = atari_unizero_create_config + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_unizero_multigpu_ddp_config.py + torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_multigpu_ddp_config.py + + """ + from ding.utils import DDPContext + from lzero.entry import train_unizero + from lzero.config.utils import lz_to_ddp_config + with DDPContext(): + main_config = lz_to_ddp_config(main_config) + # TODO: first test muzero_collector + train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) diff --git a/zoo/atari/config/atari_unizero_reanalyze_config.py b/zoo/atari/config/atari_unizero_reanalyze_config.py new file mode 100644 index 000000000..981351dbb --- /dev/null +++ b/zoo/atari/config/atari_unizero_reanalyze_config.py @@ -0,0 +1,136 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + + +def main(env_id='PongNoFrameskip-v4', seed=0): + action_space_size = atari_env_action_space_map[env_id] + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + replay_ratio = 1 + collector_env_num = 8 + num_segments = 8 + game_segment_length = 20 + evaluator_env_num = 5 + num_simulations = 50 + max_env_step = int(5e5) + batch_size = 64 + num_unroll_steps = 10 + infer_context_length = 4 + num_layers = 4 + + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq = 1/10 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=1 + + # ====== only for debug ===== + # collector_env_num = 2 + # num_segments = 2 + # evaluator_env_num = 2 + # num_simulations = 2 + # update_per_collect = 2 + # batch_size = 2 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + atari_unizero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 96, 96), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: only for debug + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + model=dict( + observation_shape=(3, 96, 96), + action_space_size=action_space_size, + world_model_cfg=dict( + policy_entropy_weight=0, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + use_augmentation=False, + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(5e4), + use_priority=False, + num_unroll_steps=num_unroll_steps, + update_per_collect=None, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + learning_rate=0.0001, + num_simulations=num_simulations, + num_segments=num_segments, + train_start_after_envsteps=2000, + game_segment_length=game_segment_length, + grad_clip_value=20, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + atari_unizero_config = EasyDict(atari_unizero_config) + main_config = atari_unizero_config + + atari_unizero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + atari_unizero_create_config = EasyDict(atari_unizero_create_config) + create_config = atari_unizero_create_config + + main_config.exp_name = f'data_unizero_reanalyze/{env_id[:-14]}/{env_id[:-14]}_uz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + # ============ use muzero_segment_collector instead of muzero_collector ============= + from lzero.entry import train_unizero_reanalyze + train_unizero_reanalyze([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process different environments and seeds.') + parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + main(args.env, args.seed) + diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 84288feb5..67a27c5f0 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -132,7 +132,10 @@ def reset(self) -> dict: self.obs = to_ndarray(obs) self._eval_episode_return = 0. + self.timestep = 0 + obs = self.observe() + return obs def step(self, action: int) -> BaseEnvTimestep: @@ -148,6 +151,8 @@ def step(self, action: int) -> BaseEnvTimestep: self.obs = to_ndarray(obs) self.reward = np.array(reward).astype(np.float32) self._eval_episode_return += self.reward + self.timestep += 1 + # print(f'self.timestep: {self.timestep}') observation = self.observe() if done: info['eval_episode_return'] = self._eval_episode_return @@ -169,7 +174,7 @@ def observe(self) -> dict: observation = np.transpose(observation, (2, 0, 1)) action_mask = np.ones(self._action_space.n, 'int8') - return {'observation': observation, 'action_mask': action_mask, 'to_play': -1} + return {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'timestep': self.timestep} @property def legal_actions(self):