Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Commit

Permalink
applying filters for a csv loaded dataset + some bug-fixes in data lo…
Browse files Browse the repository at this point in the history
…ading (#319)
  • Loading branch information
Gal Leibovich authored May 28, 2019
1 parent 6319387 commit 4c996e1
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 22 deletions.
45 changes: 32 additions & 13 deletions rl_coach/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,6 @@ def __init__(self, agent_parameters: AgentParameters, parent: Union['LevelManage
if self.ap.memory.memory_backend_params.run_type != 'trainer':
self.memory.set_memory_backend(self.memory_backend)

if agent_parameters.memory.load_memory_from_file_path:
if isinstance(agent_parameters.memory.load_memory_from_file_path, PickledReplayBuffer):
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
self.memory.load_pickled(agent_parameters.memory.load_memory_from_file_path.filepath)
elif isinstance(agent_parameters.memory.load_memory_from_file_path, CsvDataset):
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
self.memory.load_csv(agent_parameters.memory.load_memory_from_file_path)
else:
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
.format(agent_parameters.memory.load_memory_from_file_path))

if self.shared_memory and self.is_chief:
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)

Expand Down Expand Up @@ -262,6 +249,38 @@ def set_session(self, sess) -> None:
self.output_filter.set_session(sess)
self.pre_network_filter.set_session(sess)
[network.set_session(sess) for network in self.networks.values()]
self.initialize_session_dependent_components()

def initialize_session_dependent_components(self):
"""
Initialize components which require a session as part of their initialization.
:return: None
"""

# Loading a memory from a CSV file, requires an input filter to filter through the data.
# The filter needs a session before it can be used.
if self.ap.memory.load_memory_from_file_path:
self.load_memory_from_file()

def load_memory_from_file(self):
"""
Load memory transitions from a file.
:return: None
"""

if isinstance(self.ap.memory.load_memory_from_file_path, PickledReplayBuffer):
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
.format(self.ap.memory.load_memory_from_file_path.filepath))
self.memory.load_pickled(self.ap.memory.load_memory_from_file_path.filepath)
elif isinstance(self.ap.memory.load_memory_from_file_path, CsvDataset):
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
.format(self.ap.memory.load_memory_from_file_path.filepath))
self.memory.load_csv(self.ap.memory.load_memory_from_file_path, self.input_filter)
else:
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
.format(self.ap.memory.load_memory_from_file_path))

def register_signal(self, signal_name: str, dump_one_value_per_episode: bool=True,
dump_one_value_per_step: bool=False) -> Signal:
Expand Down
19 changes: 13 additions & 6 deletions rl_coach/filters/observation/observation_stacking_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace
from rl_coach.spaces import ObservationSpace, VectorObservationSpace


class LazyStack(object):
Expand Down Expand Up @@ -63,6 +63,7 @@ def __init__(self, stack_size: int, stacking_axis: int=-1):
self.stack_size = stack_size
self.stacking_axis = stacking_axis
self.stack = []
self.input_observation_space = None

if stack_size <= 0:
raise ValueError("The stack shape must be a positive number")
Expand All @@ -86,22 +87,28 @@ def validate_input_observation_space(self, input_observation_space: ObservationS
raise ValueError("The stacking axis is larger than the number of dimensions in the observation space")

def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:

if len(self.stack) == 0:
self.stack = deque([observation] * self.stack_size, maxlen=self.stack_size)
else:
if update_internal_state:
self.stack.append(observation)
observation = LazyStack(self.stack, self.stacking_axis)

if isinstance(self.input_observation_space, VectorObservationSpace):
# when stacking vectors, we cannot avoid copying the memory as we're flattening it all
observation = np.array(observation).flatten()

return observation

def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
if self.stacking_axis == -1:
input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0)
if isinstance(input_observation_space, VectorObservationSpace):
self.input_observation_space = input_observation_space = VectorObservationSpace(input_observation_space.shape * self.stack_size)
else:
input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis,
values=[self.stack_size], axis=0)
if self.stacking_axis == -1:
input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0)
else:
input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis,
values=[self.stack_size], axis=0)
return input_observation_space

def reset(self) -> None:
Expand Down
20 changes: 17 additions & 3 deletions rl_coach/memories/episodic/episodic_experience_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import random

from rl_coach.core_types import Transition, Episode
from rl_coach.filters.filter import InputFilter
from rl_coach.logger import screen
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
from rl_coach.utils import ReaderWriterLock, ProgressBar
Expand Down Expand Up @@ -408,11 +409,12 @@ def mean_reward(self) -> np.ndarray:
self.reader_writer_lock.release_writing()
return mean

def load_csv(self, csv_dataset: CsvDataset) -> None:
def load_csv(self, csv_dataset: CsvDataset, input_filter: InputFilter) -> None:
"""
Restore the replay buffer contents from a csv file.
The csv file is assumed to include a list of transitions.
:param csv_dataset: A construct which holds the dataset parameters
:param input_filter: A filter used to filter the CSV data before feeding it to the memory.
"""
self.assert_not_frozen()

Expand All @@ -429,18 +431,30 @@ def load_csv(self, csv_dataset: CsvDataset) -> None:
for e_id in episode_ids:
progress_bar.update(e_id)
df_episode_transitions = df[df['episode_id'] == e_id]
input_filter.reset()

if len(df_episode_transitions) < 2:
# we have to have at least 2 rows in each episode for creating a transition
continue

episode = Episode()
transitions = []
for (_, current_transition), (_, next_transition) in zip(df_episode_transitions[:-1].iterrows(),
df_episode_transitions[1:].iterrows()):
state = np.array([current_transition[col] for col in state_columns])
next_state = np.array([next_transition[col] for col in state_columns])

episode.insert(
transitions.append(
Transition(state={'observation': state},
action=current_transition['action'], reward=current_transition['reward'],
next_state={'observation': next_state}, game_over=False,
info={'all_action_probabilities':
ast.literal_eval(current_transition['all_action_probabilities'])}))
ast.literal_eval(current_transition['all_action_probabilities'])}),
)

transitions = input_filter.filter(transitions, deep_copy=False)
for t in transitions:
episode.insert(t)

# Set the last transition to end the episode
if csv_dataset.is_episodic:
Expand Down

0 comments on commit 4c996e1

Please sign in to comment.