Skip to content

Commit

Permalink
Merge pull request #13 from theovincent/lunar_lander
Browse files Browse the repository at this point in the history
Lunar lander
  • Loading branch information
theovincent authored Jan 6, 2023
2 parents 6a3650f + 4d577c2 commit c7fcfa2
Show file tree
Hide file tree
Showing 72 changed files with 3,424 additions and 3,787 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ dmypy.json
**/*.pdf

# Avoid pushing experiments output
experiments/bicycle/figures
experiments/car_on_hill/figures
experiments/lunar_lander/figures
experiments/lunar_lander/transfer_from_ias_cluster.sh
**/*.out

# Save optimal value functions of car on hill
Expand Down
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ In the folder where the code is, create a Python virtual environment, activate i
```bash
python3 -m venv env
source env/bin/activate
pip install -e .[cpu]
pip install -e .
```

### With Docker
Expand Down Expand Up @@ -176,23 +176,28 @@ In the folder where the code is, create a Python virtual environment, activate i
```bash
python3 -m venv env_gpu
source env_gpu/bin/activate
pip install -e .[gpu]
pip install -e .
```

If jax does not recognize the gpu, you may need to run
```bash
pip install -U jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -U jax[cuda11_cudnn82]==0.3.22 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
Taken from https://github.com/google/jax/discussions/10323.


## Using in a cluster
## Using a cluster
Download miniconda on the server host to get Python 3.8:
```Bash
wget https://repo.anaconda.com/miniconda/Miniconda3-py38_4.12.0-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
```
Upgrade pip and install virtualenv
Install cuda packages with:
```Bash
conda install -c conda-forge cudatoolkit-dev
```
do not forget to set the environment variable *LD_LIBRARY_PATH* correctly.
Finally, upgrade pip and install virtualenv
```Bash
python3 -m pip install --user --upgrade pip
python3 -m pip install --user virtualenv
Expand Down
577 changes: 0 additions & 577 deletions experiments/bicycle/FQI.ipynb

This file was deleted.

112 changes: 112 additions & 0 deletions experiments/bicycle/FQI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import sys
import argparse
import json
import numpy as np
import jax
from tqdm import tqdm


def run_cli(argvs=sys.argv[1:]):
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

parser = argparse.ArgumentParser("Train FQI on Bicycle.")
parser.add_argument(
"-e",
"--experiment_name",
help="Experiment name.",
type=str,
required=True,
)
parser.add_argument(
"-s",
"--seed",
help="Seed of the training.",
type=int,
required=True,
)
parser.add_argument(
"-b",
"--max_bellman_iterations",
help="Maximum number of Bellman iteration.",
type=int,
required=True,
)
args = parser.parse_args(argvs)
print(f"{args.experiment_name}:")
print(f"Training FQI on Bicycle with {args.max_bellman_iterations} Bellman iterations and seed {args.seed}...")
p = json.load(open(f"experiments/bicycle/figures/{args.experiment_name}/parameters.json")) # p for parameters

from experiments.bicycle.utils import define_environment
from pbo.sample_collection.replay_buffer import ReplayBuffer
from pbo.sample_collection.dataloader import SampleDataLoader
from pbo.networks.learnable_q import FullyConnectedQ
from pbo.utils.params import save_params

key = jax.random.PRNGKey(args.seed)
shuffle_key, q_network_key, _ = jax.random.split(
key, 3
) # 3 keys are generated to be coherent with the other trainings

env = define_environment(jax.random.PRNGKey(p["env_seed"]), p["gamma"])

replay_buffer = ReplayBuffer(p["n_samples"])
replay_buffer.load(f"experiments/bicycle/figures/{args.experiment_name}/replay_buffer.npz")
data_loader_samples = SampleDataLoader(replay_buffer, p["batch_size_samples"], shuffle_key)

q = FullyConnectedQ(
state_dim=4,
action_dim=2,
actions_on_max=env.actions_on_max,
gamma=p["gamma"],
network_key=q_network_key,
layers_dimension=p["layers_dimension"],
zero_initializer=True,
learning_rate={
"first": p["starting_lr_fqi"],
"last": p["ending_lr_fqi"],
"duration": p["fitting_steps_fqi"] * replay_buffer.len // p["batch_size_samples"],
},
)

l2_losses = np.ones((args.max_bellman_iterations, p["fitting_steps_fqi"])) * np.nan
iterated_params = {}
iterated_params["0"] = q.params

for bellman_iteration in tqdm(range(1, args.max_bellman_iterations + 1)):
q.reset_optimizer()
params_target = q.params
best_loss = float("inf")
patience = 0

for step in range(p["fitting_steps_fqi"]):
cumulative_l2_loss = 0

data_loader_samples.shuffle()
for batch_samples in data_loader_samples:
q.params, q.optimizer_state, l2_loss = q.learn_on_batch(
q.params, params_target, q.optimizer_state, batch_samples
)
cumulative_l2_loss += l2_loss

l2_losses[bellman_iteration - 1, step] = cumulative_l2_loss
if cumulative_l2_loss < best_loss:
patience = 0
best_loss = cumulative_l2_loss
else:
patience += 1

if patience > p["patience"]:
break

iterated_params[f"{bellman_iteration}"] = q.params

save_params(
f"experiments/bicycle/figures/{args.experiment_name}/FQI/{args.max_bellman_iterations}_P_{args.seed}",
iterated_params,
)
np.save(
f"experiments/bicycle/figures/{args.experiment_name}/FQI/{args.max_bellman_iterations}_L_{args.seed}.npy",
l2_losses,
)
91 changes: 91 additions & 0 deletions experiments/bicycle/FQI_evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import sys
import argparse
import multiprocessing
import json
import jax
import jax.numpy as jnp
import numpy as np


def run_cli(argvs=sys.argv[1:]):
with jax.default_device(jax.devices("cpu")[0]):
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

parser = argparse.ArgumentParser("Evaluate a FQI on Bicycle.")
parser.add_argument(
"-e",
"--experiment_name",
help="Experiment name.",
type=str,
required=True,
)
parser.add_argument(
"-s",
"--seed",
help="Seed of the training.",
type=int,
required=True,
)
parser.add_argument(
"-b",
"--max_bellman_iterations",
help="Maximum number of Bellman iteration.",
type=int,
required=True,
)
args = parser.parse_args(argvs)
print(f"{args.experiment_name}:")
print(
f"Evaluating FQI on Bicycle with {args.max_bellman_iterations} Bellman iterations and seed {args.seed} ..."
)
p = json.load(open(f"experiments/bicycle/figures/{args.experiment_name}/parameters.json")) # p for parameters

from experiments.bicycle.utils import define_environment
from pbo.networks.learnable_q import FullyConnectedQ
from pbo.utils.params import load_params

key = jax.random.PRNGKey(args.seed)
_, q_network_key, _ = jax.random.split(key, 3)

env = define_environment(jax.random.PRNGKey(p["env_seed"]), p["gamma"])

q = FullyConnectedQ(
state_dim=4,
action_dim=2,
actions_on_max=env.actions_on_max,
gamma=p["gamma"],
network_key=q_network_key,
layers_dimension=p["layers_dimension"],
zero_initializer=True,
)
iterated_params = load_params(
f"experiments/bicycle/figures/{args.experiment_name}/FQI/{args.max_bellman_iterations}_P_{args.seed}"
)

def evaluate(iteration: int, metrics_list: list, q: FullyConnectedQ, q_weights: jnp.ndarray, horizon: int):
metrics_list[iteration] = env.evaluate(q, q.to_params(q_weights), horizon, p["n_simulations"])

manager = multiprocessing.Manager()
iterated_metrics = manager.list(list(np.zeros((args.max_bellman_iterations + 1, p["n_simulations"], 2))))

processes = []
for iteration in range(args.max_bellman_iterations + 1):
processes.append(
multiprocessing.Process(
target=evaluate,
args=(iteration, iterated_metrics, q, q.to_weights(iterated_params[f"{iteration}"]), p["horizon"]),
)
)

for process in processes:
process.start()

for process in processes:
process.join()

np.save(
f"experiments/bicycle/figures/{args.experiment_name}/FQI/{args.max_bellman_iterations}_M_{args.seed}.npy",
iterated_metrics,
)
100 changes: 100 additions & 0 deletions experiments/bicycle/IFQI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import sys
import argparse
import json
import numpy as np
import jax
from tqdm import tqdm


def run_cli(argvs=sys.argv[1:]):
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

parser = argparse.ArgumentParser("Train IFQI on Bicycle.")
parser.add_argument(
"-e",
"--experiment_name",
help="Experiment name.",
type=str,
required=True,
)
parser.add_argument(
"-s",
"--seed",
help="Seed of the training.",
type=int,
required=True,
)
parser.add_argument(
"-b",
"--max_bellman_iterations",
help="Maximum number of Bellman iteration.",
type=int,
required=True,
)
args = parser.parse_args(argvs)
print(f"{args.experiment_name}:")
print(f"Training IFQI on Bicycle with {args.max_bellman_iterations} Bellman iterations and seed {args.seed}...")
p = json.load(open(f"experiments/bicycle/figures/{args.experiment_name}/parameters.json")) # p for parameters

from experiments.bicycle.utils import define_environment
from pbo.sample_collection.replay_buffer import ReplayBuffer
from pbo.sample_collection.dataloader import SampleDataLoader
from pbo.networks.learnable_multi_head_q import FullyConnectedMultiHeadQ
from pbo.utils.params import save_params

key = jax.random.PRNGKey(args.seed)
shuffle_key, q_network_key, _ = jax.random.split(
key, 3
) # 3 keys are generated to be coherent with the other trainings

env = define_environment(jax.random.PRNGKey(p["env_seed"]), p["gamma"])

replay_buffer = ReplayBuffer(p["n_samples"])
replay_buffer.load(f"experiments/bicycle/figures/{args.experiment_name}/replay_buffer.npz")
data_loader_samples = SampleDataLoader(replay_buffer, p["batch_size_samples"], shuffle_key)

q = FullyConnectedMultiHeadQ(
n_heads=args.max_bellman_iterations + 1,
state_dim=4,
action_dim=2,
actions_on_max=env.actions_on_max,
gamma=p["gamma"],
network_key=q_network_key,
layers_dimension=p["layers_dimension"],
zero_initializer=True,
learning_rate={
"first": p["starting_lr_ifqi"],
"last": p["ending_lr_ifqi"],
"duration": p["training_steps_ifqi"]
* p["fitting_steps_ifqi"]
* replay_buffer.len
// p["batch_size_samples"],
},
)
l2_losses = np.ones((p["training_steps_ifqi"], p["fitting_steps_ifqi"])) * np.nan

for training_step in tqdm(range(p["training_steps_ifqi"])):
params_target = q.params

for fitting_step in range(p["fitting_steps_ifqi"]):
cumulative_l2_loss = 0

data_loader_samples.shuffle()
for batch_samples in data_loader_samples:
q.params, q.optimizer_state, l2_loss = q.learn_on_batch(
q.params, params_target, q.optimizer_state, batch_samples
)
cumulative_l2_loss += l2_loss

l2_losses[training_step, fitting_step] = cumulative_l2_loss

save_params(
f"experiments/bicycle/figures/{args.experiment_name}/IFQI/{args.max_bellman_iterations}_P_{args.seed}",
q.params,
)
np.save(
f"experiments/bicycle/figures/{args.experiment_name}/IFQI/{args.max_bellman_iterations}_L_{args.seed}.npy",
l2_losses,
)
Loading

0 comments on commit c7fcfa2

Please sign in to comment.