diff --git a/docker/docker_manager.py b/docker/docker_manager.py index 96a7d746..116e2d9c 100644 --- a/docker/docker_manager.py +++ b/docker/docker_manager.py @@ -302,6 +302,7 @@ def ping(self) -> bool: return self._get_client().ping() except Exception: print('Docker server is not responding!') + print(f'Client is: {self._get_client()}') return False # PRIVATE METHODS @@ -320,8 +321,10 @@ def _get_client(cls) -> docker.DockerClient: if cls._client is not None: return cls._client try: + print('Trying to get a new client from docker...') cls._client = docker.from_env() - except docker.errors.DockerException: + except docker.errors.DockerException as e: + print(f'Got a DockerException: {e}') cls._client = None return cls._client diff --git a/environment.yml b/environment.yml index 790e5268..0d98dce3 100644 --- a/environment.yml +++ b/environment.yml @@ -14,5 +14,6 @@ dependencies: - coverage>=5.5 - coverage-badge>=1.1.0 - pytest=6.2.4 + - lxml>=4.8.0 - pytest-randomly=3.11.0 - pytest-xdist=2.5.0 diff --git a/recommerce/configuration/common_rules.py b/recommerce/configuration/common_rules.py new file mode 100644 index 00000000..4c53c902 --- /dev/null +++ b/recommerce/configuration/common_rules.py @@ -0,0 +1,14 @@ +def greater_zero_rule(field_name: str): + return (lambda x: x > 0, f'{field_name} should be positive') + + +def non_negative_rule(field_name: str): + return (lambda x: x >= 0, f'{field_name} should be non-negative') + + +def between_zero_one_rule(field_name: str): + return (lambda x: x >= 0 and x <= 1, f'{field_name} should be between 0 (included) and 1 (included)') + + +def greater_zero_even_rule(field_name: str): + return (lambda x: x > 0 and x % 2 == 0, f'{field_name} should be even and positive') diff --git a/recommerce/configuration/config_validation.py b/recommerce/configuration/config_validation.py index 53dbbed0..ab80e1c3 100644 --- a/recommerce/configuration/config_validation.py +++ b/recommerce/configuration/config_validation.py @@ -1,161 +1,108 @@ # This file contains logic used by the webserver to validate configuration files - from recommerce.configuration.environment_config import EnvironmentConfig from recommerce.configuration.hyperparameter_config import HyperparameterConfigValidator +from recommerce.configuration.utils import get_class -def validate_config(config: dict, config_is_final: bool) -> tuple: +def validate_config(config: dict) -> tuple: """ Validates a given config dictionary either uploaded by the user or entered into the form before starting a container. Args: config (dict): The config to validate. - config_is_final (bool): Whether or not the config must contain all required keys. Returns: tuple: success: A status (True) and the split hyperparameter_config and environment_config dictionaries as a tuple. failure: A status (False) and the errormessage as a string. """ try: - # first check if the environment and hyperparameter parts are already split up + # first check if we got a complete config from the webserver + # in which case we will have two keys on the top-level + # this can either be an uploaded complete config, or a config sent when pressing the launch/check button if 'environment' in config and 'hyperparameter' in config: assert len(config) == 2, 'Your config should not contain keys other than "environment" and "hyperparameter"' hyperparameter_config = config['hyperparameter'] environment_config = config['environment'] - elif 'environment' in config or 'hyperparameter' in config: - raise AssertionError('If your config contains one of "environment" or "hyperparameter" it must also contain the other') - else: - # try to split the config. If any keys are unknown, an AssertionError will be thrown - hyperparameter_config, environment_config = split_mixed_config(config) - # then validate that all given values have the correct types - check_config_types(hyperparameter_config, environment_config, config_is_final) - if 'rl' in hyperparameter_config: - HyperparameterConfigValidator.check_rl_ranges(hyperparameter_config['rl'], config_is_final) - if 'sim_market' in hyperparameter_config: - HyperparameterConfigValidator.check_sim_market_ranges(hyperparameter_config['sim_market'], config_is_final) + market_class = get_class(environment_config['marketplace']) + agent_class = get_class(environment_config['agents'][0]['agent_class']) + + HyperparameterConfigValidator.validate_config(hyperparameter_config['sim_market'], market_class) + HyperparameterConfigValidator.validate_config(hyperparameter_config['rl'], agent_class) + EnvironmentConfig.check_types(environment_config, environment_config['task'], False, True) + + return True, ({'rl': hyperparameter_config['rl']}, {'sim_market': hyperparameter_config['sim_market']}, + {'environment': environment_config}) + # if the two keys are not present, the config MUST be one of environment, rl, or market + # this is only the case when uploading a config + else: + config_type = find_config_type(config) + + # we can only validate types for the environment_config, as we do not know the agent/market class for the rl/market configs + if config_type == 'environment': + # validate that all given values have the correct types + task = config['task'] if 'task' in config else 'None' + EnvironmentConfig.check_types(config, task, False, False) + + # the webserver needs another format for the config + config.pop('config_type') + if config_type == 'rl': + return True, ({'rl': config}, None, None) + elif config_type == 'sim_market': + return True, ({'sim_market': config}, None, None) + else: + return True, ({'environment': config}, None, None) - return True, (hyperparameter_config, environment_config) except Exception as error: return False, str(error) -def validate_sub_keys(config_class: HyperparameterConfigValidator or EnvironmentConfig, config: dict, top_level_keys: dict) -> None: +def find_config_type(config: dict) -> str: """ - Utility function that validates if a given config contains only allowed keys. - Can be used recursively for dictionaries within dictionaries. - "Unisex": Works for both HyperparameterConfigValidator and EnvironmentConfig. + Extract the config type from the config dictionary. Config type is defined by the "config_type" key, which must always be present. Args: - config_class (HyperparameterConfigValidator or EnvironmentConfig): The config class from which to get the required fields. - config (dict): The config given by the user. - top_level_keys (dict): The keys of the current level. Their values indicate if there is another dictionary expected for that key. + config (dict): The config to check. Raises: - AssertionError: If the given config contains a key that is invalid. - """ - for key, _ in config.items(): - # we need to separately check agents, since it is a list of dictionaries - if key == 'agents': - assert isinstance(config['agents'], list), f'The "agents" key must have a value of type list, but was {type(config["agents"])}' - for agent in config['agents']: - assert isinstance(agent, dict), f'All agents must be of type dict, but this one was {type(agent)}' - assert all(agent_key in {'name', 'agent_class', 'argument'} for agent_key in agent.keys()), \ - f'An invalid key for agents was provided: {agent.keys()}' - # the key is key of a dictionary in the config - elif top_level_keys[key]: - assert isinstance(config[key], dict), f'The value of this key must be of type dict: {key}, but was {type(config[key])}' - # these are the valid keys that sub-key can have as keys in the dictionary - key_fields = config_class.get_required_fields(key) - # check that only valid keys were given by the user - for sub_key, _ in config[key].items(): - assert sub_key in key_fields.keys(), \ - f'The key "{sub_key}" should not exist within a {config_class.__name__} config (was checked at sub-key "{key}")' - # if there is an additional layer of dictionaries, check it recursively - validate_sub_keys(config_class, config[key], key_fields) - - -def split_mixed_config(config: dict) -> tuple: - """ - Utility function that splits a potentially mixed config of hyperparameters and environment-variables - into two dictionaries for the two configurations. - - Args: - config (dict): The potentially mixed configuration. + AssertionError: If the config_type key has an invalid value or is missing. Returns: - dict: The hyperparameter_config - dict: The environment_config - - Raises: - AssertionError: If the user provides a key that should not exist. + str: The config type. """ - top_level_hyperparameter = HyperparameterConfigValidator.get_required_fields('top-dict') - top_level_environment = EnvironmentConfig.get_required_fields('top-dict') - - hyperparameter_config = {} - environment_config = {} - - for key, value in config.items(): - if key in top_level_hyperparameter.keys(): - hyperparameter_config[key] = value - elif key in top_level_environment.keys(): - environment_config[key] = value + try: + if config['config_type'] in ['rl', 'sim_market', 'environment']: + return config['config_type'] else: - raise AssertionError(f'Your config contains an invalid key: {key}') - - validate_sub_keys(HyperparameterConfigValidator, hyperparameter_config, top_level_hyperparameter) - validate_sub_keys(EnvironmentConfig, environment_config, top_level_environment) - - return hyperparameter_config, environment_config - - -def check_config_types(hyperparameter_config: dict, environment_config: dict, must_contain: bool = False) -> None: - """ - Utility function that checks (incomplete) config dictionaries for their correct types. - - Args: - hyperparameter_config (dict): The config containing hyperparameter_config-keys. - environment_config (dict): The config containing environment_config-keys. - must_contain (bool): Whether or not the configuration should contain all required keys. - - Raises: - AssertionError: If one of the values has the wrong type. - """ - # check types for hyperparameter_config - # @NikkelM Why was this here? - # HyperparameterConfigValidator.check_types(hyperparameter_config, 'top-dict', must_contain) - if 'rl' in hyperparameter_config: - HyperparameterConfigValidator.check_types(hyperparameter_config['rl'], 'rl', must_contain) - if 'sim_market' in hyperparameter_config: - HyperparameterConfigValidator.check_types(hyperparameter_config['sim_market'], 'sim_market', must_contain) - - # check types for environment_config - task = environment_config['task'] if must_contain else 'None' - EnvironmentConfig.check_types(environment_config, task, False, must_contain) + raise AssertionError(f'the "config_type" key must be one of "rl", "sim_market", "environment" but was {config["config_type"]}') + except KeyError as e: + raise AssertionError(f"your config is missing the 'config_type' key, must be one of 'rl', 'sim_market', 'environment': {config}") from e if __name__ == '__main__': # pragma: no cover - test_config = { - 'rl': { - 'batch_size': 32, - 'replay_size': 100000, - 'learning_rate': 1e-6, - 'sync_target_frames': 1000, - 'replay_start_size': 10000, - 'epsilon_decay_last_frame': 75000, - 'epsilon_start': 1.0, - 'epsilon_final': 0.1 - }, - 'sim_market': { - 'max_storage': 100, - 'episode_length': 50, - 'max_price': 10, - 'max_quality': 50, - 'production_price': 3, - 'storage_cost_per_product': 0.1 - }, + test_config_rl = { + 'config_type': 'rl', + 'batch_size': 32, + 'replay_size': 100000, + 'learning_rate': 1e-6, + 'sync_target_frames': 1000, + 'replay_start_size': 10000, + 'epsilon_decay_last_frame': 75000, + 'epsilon_start': 1.0, + 'epsilon_final': 0.1 + } + test_config_market = { + 'config_type': 'sim_market', + 'max_storage': 100, + 'episode_length': 50, + 'max_price': 10, + 'max_quality': 50, + 'production_price': 3, + 'storage_cost_per_product': 0.1 + } + test_config_environment = { + 'config_type': 'environment', 'episodes': 5, 'agents': [ { @@ -170,5 +117,7 @@ def check_config_types(hyperparameter_config: dict, environment_config: dict, mu } ] } - hyper, env = split_mixed_config(test_config) - check_config_types(hyper, env) + print('Testing config validation...') + print(validate_config(test_config_rl)) + print(validate_config(test_config_market)) + print(validate_config(test_config_environment)) diff --git a/recommerce/configuration/environment_config.py b/recommerce/configuration/environment_config.py index a6c535f6..0214654f 100644 --- a/recommerce/configuration/environment_config.py +++ b/recommerce/configuration/environment_config.py @@ -13,8 +13,6 @@ from recommerce.market.linear.linear_vendors import LinearAgent from recommerce.market.sim_market import SimMarket from recommerce.market.vendors import FixedPriceAgent -from recommerce.rl.actorcritic.actorcritic_agent import ActorCriticAgent -from recommerce.rl.q_learning.q_learning_agent import QLearningAgent from recommerce.rl.reinforcement_learning_agent import ReinforcementLearningAgent @@ -159,7 +157,7 @@ def _parse_and_set_agents(self, agent_list: list, needs_modelfile: bool) -> None agent['agent_class'] = get_class(agent['agent_class']) # This if-else contains the parsing logic for the different types of arguments agents can have, e.g. modelfiles or fixed-price-lists - if needs_modelfile and issubclass(agent['agent_class'], (QLearningAgent, ActorCriticAgent)): + if needs_modelfile and issubclass(agent['agent_class'], ReinforcementLearningAgent): assert isinstance(agent['argument'], str), \ f'The "argument" field of this agent ({agent["name"]}) must be a string but was ({type(agent["argument"])})' assert agent['argument'].endswith('.dat') or agent['argument'].endswith('.zip'), \ @@ -291,7 +289,7 @@ def _validate_config(self, config: dict) -> None: self.agent = [] for current_agent in passed_agents: # with modelfile - if issubclass(current_agent['agent_class'], (QLearningAgent, ActorCriticAgent)): + if issubclass(current_agent['agent_class'], ReinforcementLearningAgent): self.agent.append((current_agent['agent_class'], [current_agent['argument'], current_agent['name']])) # without modelfile else: @@ -312,8 +310,6 @@ class ExampleprinterEnvironmentConfig(EnvironmentConfig): """ def _validate_config(self, config: dict) -> None: super(ExampleprinterEnvironmentConfig, self)._validate_config(config, single_agent=False, needs_modelfile=True) - # Since we only have one agent, we extract it from the provided list - self.agent = self.agent[0] def _get_task(self) -> str: return 'exampleprinter' diff --git a/recommerce/configuration/hyperparameter_config.py b/recommerce/configuration/hyperparameter_config.py index f0397870..c0758632 100644 --- a/recommerce/configuration/hyperparameter_config.py +++ b/recommerce/configuration/hyperparameter_config.py @@ -1,60 +1,16 @@ -#!/usr/bin/env python3 - -# helper import json import os from attrdict import AttrDict +from recommerce.configuration.json_configurable import JSONConfigurable from recommerce.configuration.path_manager import PathManager +from recommerce.configuration.utils import get_class +from recommerce.market.sim_market import SimMarket +from recommerce.market.vendors import Agent class HyperparameterConfigValidator(): - - @classmethod - def get_required_fields(cls, dict_key) -> dict: - """ - Utility function that returns all of the keys required for a hyperparameter_config.json at the given level. - The value of any given key indicates whether or not it is the key of a dictionary within the config (i.e. they are a level themselves). - - Args: - dict_key (str): The key for which the required fields are needed. 'top-dict' for getting the keys of the first level. - 'top-dict', 'rl' or 'sim_market'. - - Returns: - dict: The required keys for the config at the given level, together with a boolean indicating of they are the key - of another level. - - Raises: - AssertionError: If the given level is invalid. - """ - if dict_key == 'top-dict': - return {'rl': True, 'sim_market': True} - elif dict_key == 'rl': - return { - 'gamma': False, - 'batch_size': False, - 'replay_size': False, - 'learning_rate': False, - 'sync_target_frames': False, - 'replay_start_size': False, - 'epsilon_decay_last_frame': False, - 'epsilon_start': False, - 'epsilon_final': False - } - elif dict_key == 'sim_market': - return { - 'max_storage': False, - 'episode_length': False, - 'max_price': False, - 'max_quality': False, - 'number_of_customers': False, - 'production_price': False, - 'storage_cost_per_product': False - } - else: - raise AssertionError(f'The given level does not exist in a hyperparameter-config: {dict_key}') - def __str__(self) -> str: """ This overwrites the internal function that get called when you call `print(class_instance)`. @@ -67,197 +23,93 @@ def __str__(self) -> str: return f'{self.__class__.__name__}: {self.__dict__}' @classmethod - def validate_config(self, config: AttrDict) -> None: + def validate_config(cls, config: dict, checked_class: SimMarket or Agent) -> None: """ - Validate the given config dictionary and set the instance variables. + Validate the given config dictionary. Args: config (dict): The config to validate and take the values from. + checked_class (SimMarket or Agent): The relevant class for which the fields are to be checked. """ - if 'sim_market' in config: - self._check_config_sim_market_completeness(config['sim_market']) - self.check_types(config['sim_market'], 'sim_market') - self.check_sim_market_ranges(config['sim_market']) - - if 'rl' in config: - self._check_config_rl_completeness(config['rl']) - self.check_types(config['rl'], 'rl') - self.check_rl_ranges(config['rl']) - - # TODO: replace 'rl' option with 7 different rl branches - - @classmethod - def _check_config_rl_completeness(cls, config: dict) -> None: - """ - Check if the passed config dictionary contains all rl values. - - Args: - config (dict): The dictionary to be checked. - """ - assert 'gamma' in config, 'your config_rl is missing gamma' - assert 'batch_size' in config, 'your config_rl is missing batch_size' - assert 'replay_size' in config, 'your config_rl is missing replay_size' - assert 'learning_rate' in config, 'your config_rl is missing learning_rate' - assert 'sync_target_frames' in config, 'your config_rl is missing sync_target_frames' - assert 'replay_start_size' in config, 'your config_rl is missing replay_start_size' - assert 'epsilon_decay_last_frame' in config, 'your config_rl is missing epsilon_decay_last_frame' - assert 'epsilon_start' in config, 'your config_rl is missing epsilon_start' - assert 'epsilon_final' in config, 'your config_rl is missing epsilon_final' + demanded_fields = [field for field, _, _ in checked_class.get_configurable_fields()] + cls._check_only_valid_keys(config, demanded_fields) + cls._check_types(config, checked_class.get_configurable_fields()) + cls._check_rules(config, checked_class.get_configurable_fields()) @classmethod - def _check_config_sim_market_completeness(cls, config: dict) -> None: + def _check_only_valid_keys(cls, config: dict, demanded_fields: list) -> None: """ - Check if the passed config dictionary contains all sim_market values. + Checks if only valid keys were provided. Args: - config (dict): The dictionary to be checked. - """ - assert 'max_storage' in config, 'your config is missing max_storage' - assert 'episode_length' in config, 'your config is missing episode_length' - assert 'max_price' in config, 'your config is missing max_price' - assert 'max_quality' in config, 'your config is missing max_quality' - assert 'number_of_customers' in config, 'your config is missing number_of_customers' - assert 'production_price' in config, 'your config is missing production_price' - assert 'storage_cost_per_product' in config, 'your config is missing storage_cost_per_product' + config (dict): The config which should contain all values in demanded_fields. + demanded_fields (list): The list containing all values that should be contained in config. + """ + config_keys = set(config.keys()) + # the config_type key is completely optional as it is only used for webserver validation, so we don't prevent people from adding it + if 'config_type' in config_keys: + config_keys.remove('config_type') + demanded_keys = set(demanded_fields) + + if config_keys != demanded_keys: + missing_keys = demanded_keys.difference(config_keys) + redundant_keys = config_keys.difference(demanded_keys) + if missing_keys: + assert False, f'your config is missing {missing_keys}' + if redundant_keys: + assert False, f'your config provides {redundant_keys} which was not demanded' @classmethod - def check_types(cls, config: dict, key: str, must_contain: bool = True) -> None: - """ - Check if all given variables have the correct types. - If must_contain is True, all keys must exist, else non-existing keys will be skipped. - - Args: - config (dict): The config to check. - key (str): The key for which to check the values. 'top-dict', 'rl' or 'sim_market'. - must_contain (bool, optional): Whether or not all variables must be present in the config. Defaults to True. - - Raises: - KeyError: If the dictionary is missing a key but should contain all keys. - """ - """ - deprecated - if key == 'top-dict': - types_dict = { - 'rl': dict, - 'sim_market': dict - } - """ - if key == 'rl': - types_dict = { - 'gamma': (int, float), - 'batch_size': int, - 'replay_size': int, - 'learning_rate': (int, float), - 'sync_target_frames': int, - 'replay_start_size': int, - 'epsilon_decay_last_frame': int, - 'epsilon_start': (int, float), - 'epsilon_final': (int, float) - } - elif key == 'sim_market': - types_dict = { - 'max_storage': int, - 'episode_length': int, - 'max_price': int, - 'max_quality': int, - 'number_of_customers': int, - 'production_price': int, - 'storage_cost_per_product': float - } - else: - raise AssertionError(f'Your config contains an invalid key: {key}') - - for key, value in types_dict.items(): + def _check_types(cls, config: dict, configurable_fields: list, must_contain: bool = True) -> None: + for field_name, type, _ in configurable_fields: try: - assert isinstance(config[key], value), f'{key} must be a {value} but was {type(config[key])}' + assert isinstance(config[field_name], type), f'{field_name} must be a {type} but was {type(config[field_name])}' except KeyError as error: if must_contain: - raise KeyError(f'Your config is missing the following required key: {key}') from error + raise KeyError(f'Your config is missing the following required key: {field_name}') from error @classmethod - def check_rl_ranges(cls, config: dict, must_contain: bool = True) -> None: - """ - Check if all rl variables are within their (pre-defined) ranges. - - Args: - config (dict): The config for which to check the variables. - must_contain (bool, optional): Whether or not all variables must be present in the config. Defaults to True. - """ - if must_contain or 'gamma' in config: - assert config['gamma'] >= 0 and config['gamma'] < 1, 'gamma should be between 0 (included) and 1 (excluded)' - if must_contain or 'batch_size' in config: - assert config['batch_size'] > 0, 'batch_size should be greater than 0' - if must_contain or 'replay_size' in config: - assert config['replay_size'] > 0, 'replay_size should be greater than 0' - if must_contain or 'learning_rate' in config: - assert config['learning_rate'] > 0 and config['learning_rate'] < 1, 'learning_rate should be between 0 and 1 (excluded)' - if must_contain or 'sync_target_frames' in config: - assert config['sync_target_frames'] > 0, 'sync_target_frames should be greater than 0' - if must_contain or 'replay_start_size' in config: - assert config['replay_start_size'] > 0, 'replay_start_size should be greater than 0' - if must_contain or 'epsilon_decay_last_frame' in config: - assert config['epsilon_decay_last_frame'] >= 0, 'epsilon_decay_last_frame should not be negative' - if must_contain or 'epsilon_start' in config: - assert config['epsilon_start'] > 0 and config['epsilon_start'] <= 1, 'epsilon_start should be between 0 and 1 (excluded)' - if must_contain or 'epsilon_final' in config: - assert config['epsilon_final'] > 0 and config['epsilon_final'] <= 1, 'epsilon_final should be between 0 and 1 (excluded)' - if must_contain or ('epsilon_start' in config and 'epsilon_final' in config): - assert config['epsilon_start'] > config['epsilon_final'], 'epsilon_start should be greater than epsilon_final' - - @classmethod - def check_sim_market_ranges(cls, config: dict, must_contain: bool = True) -> None: - """ - Check if all sim_market variables are within their (pre-defined) ranges. - - Args: - config (dict): The config for which to check the variables. - must_contain (bool, optional): Whether or not all variables must be present in the config. Defaults to True. - """ - if must_contain or 'max_storage' in config: - assert config['max_storage'] >= 0, 'max_storage must be positive' - if must_contain or 'number_of_customers' in config: - assert config['number_of_customers'] > 0 and config['number_of_customers'] % 2 == 0, 'number_of_customers should be even and positive' - if must_contain or 'production_price' in config: - assert config['production_price'] <= config['max_price'] and config['production_price'] >= 0, \ - 'production_price needs to be smaller than max_price and >=0' - if must_contain or 'max_quality' in config: - assert config['max_quality'] > 0, 'max_quality should be positive' - if must_contain or 'max_price' in config: - assert config['max_price'] > 0, 'max_price should be positive' - if must_contain or 'episode_length' in config: - assert config['episode_length'] > 0, 'episode_length should be positive' - if must_contain or 'storage_cost_per_product' in config: - assert config['storage_cost_per_product'] >= 0, 'storage_cost_per_product should be non-negative' + def _check_rules(cls, config: dict, configurable_fields: list, must_contain: bool = True) -> None: + for field_name, _, rule in configurable_fields: + if rule is not None: + assert callable(rule) + check_method, error_string = rule(field_name) + try: + assert check_method(config[field_name]), error_string + except KeyError as error: + if must_contain: + raise KeyError(f'Your config is missing the following required key: {field_name}') from error class HyperparameterConfigLoader(): - @classmethod - def flat_and_convert_to_attrdict(cls, config): - config_flatten = {} - if 'rl' in config: - config_flatten.update(config['rl']) - if 'sim_market' in config: - config_flatten.update(config['sim_market']) - assert config_flatten is not {} - return AttrDict(config_flatten) @classmethod - def load(cls, filename: str) -> AttrDict: + def load(cls, filename: str, checked_class: SimMarket or Agent) -> AttrDict: """ - Load the configuration json file from the `configuration_files` folder, validate all keys and retruning an AttrDict instance - without top level keys. + Load the market configuration json file from the `configuration_files` folder, validate all keys and return an AttrDict instance. + This can only be done after the relevant `environment_config` has been loaded, if both are needed, as the checked_class needs to be known. Args: filename (str): The name of the json file containing the configuration values. Must be located in the `configuration_files` directory in the user's datapath folder. + checked_class (SimMarket or Agent): The relevant class for which the fields are to be checked. Returns: AttrDict: An Arribute Dict containing the hyperparameters. """ - filename += '.json' + # In case the class is still in string format, extract it + if issubclass(checked_class, str): + checked_class = get_class(checked_class) + + assert issubclass(checked_class, (SimMarket, Agent)), f'the provided checked_class must be a subclass of SimMarket \ + if the config is a market_config or of Agent if it is an rl_config: {checked_class}' + assert issubclass(checked_class, JSONConfigurable), f'the provided checked_class must be a subclass of JSONConfigurable: {checked_class}' + + if not filename.endswith('.json'): + filename += '.json' path = os.path.join(PathManager.user_path, 'configuration_files', filename) with open(path) as config_file: - config = json.load(config_file) - HyperparameterConfigValidator.validate_config(config) - config_attr_dict = cls.flat_and_convert_to_attrdict(config) - return config_attr_dict + hyperparameter_config = json.load(config_file) + + HyperparameterConfigValidator.validate_config(config=hyperparameter_config, checked_class=checked_class) + return AttrDict(hyperparameter_config) diff --git a/recommerce/configuration/json_configurable.py b/recommerce/configuration/json_configurable.py new file mode 100644 index 00000000..c84285f8 --- /dev/null +++ b/recommerce/configuration/json_configurable.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod + + +class JSONConfigurable(ABC): + @staticmethod + @abstractmethod + def get_configurable_fields(): + raise NotImplementedError diff --git a/recommerce/default_data/configuration_files/environment_config_agent_monitoring.json b/recommerce/default_data/configuration_files/environment_config_agent_monitoring.json index 32c72f05..ac16cd57 100644 --- a/recommerce/default_data/configuration_files/environment_config_agent_monitoring.json +++ b/recommerce/default_data/configuration_files/environment_config_agent_monitoring.json @@ -1,4 +1,5 @@ { + "config_type": "environment", "task": "agent_monitoring", "enable_live_draw": false, "episodes": 500, diff --git a/recommerce/default_data/configuration_files/environment_config_exampleprinter.json b/recommerce/default_data/configuration_files/environment_config_exampleprinter.json index 35c805e1..248885a6 100644 --- a/recommerce/default_data/configuration_files/environment_config_exampleprinter.json +++ b/recommerce/default_data/configuration_files/environment_config_exampleprinter.json @@ -1,4 +1,5 @@ { + "config_type": "environment", "task": "exampleprinter", "marketplace": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceDuopoly", "agents": [ diff --git a/recommerce/default_data/configuration_files/environment_config_training.json b/recommerce/default_data/configuration_files/environment_config_training.json index ae6cb936..26a250a0 100644 --- a/recommerce/default_data/configuration_files/environment_config_training.json +++ b/recommerce/default_data/configuration_files/environment_config_training.json @@ -1,4 +1,5 @@ { + "config_type": "environment", "task": "training", "marketplace": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly", "agents": [ diff --git a/recommerce/default_data/configuration_files/hyperparameter_config.json b/recommerce/default_data/configuration_files/hyperparameter_config.json deleted file mode 100644 index 5fe8956f..00000000 --- a/recommerce/default_data/configuration_files/hyperparameter_config.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "rl": { - "gamma" : 0.99, - "batch_size" : 32, - "replay_size" : 100000, - "learning_rate" : 1e-6, - "sync_target_frames" : 1000, - "replay_start_size" : 10000, - "epsilon_decay_last_frame" : 75000, - "epsilon_start" : 1.0, - "epsilon_final" : 0.1 - }, - "sim_market": { - "max_storage": 100, - "episode_length": 50, - "max_price": 10, - "max_quality": 50, - "number_of_customers": 20, - "production_price": 3, - "storage_cost_per_product": 0.1 - } -} diff --git a/recommerce/default_data/configuration_files/market_config.json b/recommerce/default_data/configuration_files/market_config.json index 00638718..7f8236be 100644 --- a/recommerce/default_data/configuration_files/market_config.json +++ b/recommerce/default_data/configuration_files/market_config.json @@ -1,11 +1,10 @@ { - "sim_market": { - "max_storage": 100, - "episode_length": 50, - "max_price": 10, - "max_quality": 50, - "number_of_customers": 20, - "production_price": 3, - "storage_cost_per_product": 0.1 - } + "config_type": "sim_market", + "max_storage": 100, + "episode_length": 50, + "max_price": 10, + "max_quality": 50, + "number_of_customers": 20, + "production_price": 3, + "storage_cost_per_product": 0.1 } diff --git a/recommerce/default_data/configuration_files/q_learning_config.json b/recommerce/default_data/configuration_files/q_learning_config.json new file mode 100644 index 00000000..6388e885 --- /dev/null +++ b/recommerce/default_data/configuration_files/q_learning_config.json @@ -0,0 +1,12 @@ +{ + "config_type": "rl", + "gamma": 0.99, + "batch_size": 32, + "replay_size": 100000, + "learning_rate": 1e-6, + "sync_target_frames": 1000, + "replay_start_size": 10000, + "epsilon_decay_last_frame": 75000, + "epsilon_start": 1.0, + "epsilon_final": 0.1 +} diff --git a/recommerce/default_data/configuration_files/rl_config.json b/recommerce/default_data/configuration_files/rl_config.json deleted file mode 100644 index 48216830..00000000 --- a/recommerce/default_data/configuration_files/rl_config.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "rl": { - "gamma" : 0.99, - "batch_size" : 32, - "replay_size" : 100000, - "learning_rate" : 1e-6, - "sync_target_frames" : 1000, - "replay_start_size" : 10000, - "epsilon_decay_last_frame" : 75000, - "epsilon_start" : 1.0, - "epsilon_final" : 0.1 - } -} diff --git a/recommerce/market/circular/circular_sim_market.py b/recommerce/market/circular/circular_sim_market.py index e9c7fa47..09007caa 100644 --- a/recommerce/market/circular/circular_sim_market.py +++ b/recommerce/market/circular/circular_sim_market.py @@ -6,6 +6,7 @@ import recommerce.configuration.utils as ut import recommerce.market.circular.circular_vendors as circular_vendors import recommerce.market.owner as owner +from recommerce.configuration.common_rules import greater_zero_even_rule, greater_zero_rule, non_negative_rule from recommerce.market.circular.circular_customers import CustomerCircular from recommerce.market.customer import Customer from recommerce.market.owner import Owner @@ -18,6 +19,19 @@ def get_competitor_classes() -> list: import recommerce.market.circular.circular_vendors as c_vendors return sorted(ut.filtered_class_str_from_dir('recommerce.market.circular.circular_vendors', dir(c_vendors), '.*CE.*Agent.*')) + @staticmethod + def get_configurable_fields() -> list: + # TODO: reduce this list to only the required fields (remove max_quality) + return [ + ('max_storage', int, greater_zero_rule), + ('episode_length', int, greater_zero_rule), + ('max_price', int, greater_zero_rule), + ('max_quality', int, greater_zero_rule), + ('number_of_customers', int, greater_zero_even_rule), + ('production_price', int, non_negative_rule), + ('storage_cost_per_product', (int, float), non_negative_rule), + ] + def _setup_action_observation_space(self, support_continuous_action_space: bool) -> None: # cell 0: number of products in the used storage, cell 1: number of products in circulation self.max_storage = self.config.max_storage diff --git a/recommerce/market/linear/linear_sim_market.py b/recommerce/market/linear/linear_sim_market.py index 5987e716..d81189f8 100644 --- a/recommerce/market/linear/linear_sim_market.py +++ b/recommerce/market/linear/linear_sim_market.py @@ -4,6 +4,7 @@ import numpy as np import recommerce.configuration.utils as ut +from recommerce.configuration.common_rules import greater_zero_even_rule, greater_zero_rule, non_negative_rule from recommerce.market.customer import Customer from recommerce.market.linear.linear_customers import CustomerLinear from recommerce.market.linear.linear_vendors import Just2PlayersLEAgent, LERandomAgent, LinearRatio1LEAgent @@ -16,6 +17,19 @@ def get_competitor_classes() -> list: import recommerce.market.linear.linear_vendors as l_vendors return sorted(ut.filtered_class_str_from_dir('recommerce.market.linear.linear_vendors', dir(l_vendors), '.*LE.*Agent.*')) + @staticmethod + def get_configurable_fields() -> list: + # TODO: reduce this list to only the required fields + return [ + ('max_storage', int, greater_zero_rule), + ('episode_length', int, greater_zero_rule), + ('max_price', int, greater_zero_rule), + ('max_quality', int, greater_zero_rule), + ('number_of_customers', int, greater_zero_even_rule), + ('production_price', int, non_negative_rule), + ('storage_cost_per_product', (int, float), non_negative_rule), + ] + def _setup_action_observation_space(self, support_continuous_action_space: bool) -> None: """ The observation array has the following format: diff --git a/recommerce/market/sim_market.py b/recommerce/market/sim_market.py index d3063b07..cd8a8ea9 100644 --- a/recommerce/market/sim_market.py +++ b/recommerce/market/sim_market.py @@ -1,10 +1,11 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Tuple import gym import numpy as np from attrdict import AttrDict +from recommerce.configuration.json_configurable import JSONConfigurable from recommerce.configuration.utils import filtered_class_str_from_dir # An offer is a market state that contains all prices and qualities @@ -15,7 +16,7 @@ # Third: vendor's actions from the former round which needs to be saved and influence the other's decision e.g. prices -class SimMarket(gym.Env, ABC): +class SimMarket(gym.Env, JSONConfigurable): """ The superclass to all market environments. Abstract class that cannot be instantiated. @@ -341,3 +342,14 @@ def _ensure_output_dict_has(self, name, init_for_all_vendors=None) -> None: self._output_dict[name] = 0 else: self._output_dict[name] = dict(zip([f'vendor_{i}' for i in range(self._number_of_vendors)], init_for_all_vendors)) + + @abstractmethod + def get_configurable_fields() -> list: + """ + Return a list of keys that can be used to configure this marketplace using a `market_config.json`. + Also contains key types and validation logic. + + Returns: + list: The list of (key, type, validation). + """ + raise NotImplementedError diff --git a/recommerce/monitoring/agent_monitoring/am_configuration.py b/recommerce/monitoring/agent_monitoring/am_configuration.py index b0d417c0..67689614 100644 --- a/recommerce/monitoring/agent_monitoring/am_configuration.py +++ b/recommerce/monitoring/agent_monitoring/am_configuration.py @@ -29,8 +29,8 @@ def __init__(self) -> None: self.plot_interval = 50 self.marketplace = circular_market.CircularEconomyMonopoly default_agent = FixedPriceCEAgent - self.config_market: AttrDict = HyperparameterConfigLoader.load('market_config') - self.config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') + self.config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) + self.config_rl: AttrDict = HyperparameterConfigLoader.load('q_learning_config', QLearningAgent) self.agents = [default_agent(config_market=self.config_market)] self.agent_colors = [(0.0, 0.0, 1.0, 1.0)] self.folder_path = os.path.abspath(os.path.join(PathManager.results_path, 'monitoring', 'plots_' + time.strftime('%b%d_%H-%M-%S'))) @@ -88,7 +88,6 @@ def _update_agents(self, agents) -> None: 'If the market is linear, the agent must be linear too!' self.agents = [] - # Instantiate all agents. If they are not rule-based, use the marketplace parameters accordingly agents_with_config = [(current_agent[0], [self.config_market] + current_agent[1]) for current_agent in agents] @@ -102,7 +101,8 @@ def _update_agents(self, agents) -> None: assert (1 <= len(current_agent[1]) <= 3), 'the argument list for a RL-agent must have length between 0 and 2' assert all(isinstance(argument, str) for argument in current_agent[1][1:]), 'the arguments for a RL-agent must be of type str' - # Stablebaselines ends in .zip - we don't + # Stablebaselines ends in .zip - so if you use it, you need to specify a modelfile name + # For many others, it can be omitted since we use a default format agent_modelfile = f'{type(self.marketplace).__name__}_{current_agent[0].__name__}.dat' agent_name = 'q_learning' if issubclass(current_agent[0], QLearningAgent) else 'actor_critic' # no arguments @@ -121,7 +121,7 @@ def _update_agents(self, agents) -> None: agent_name = current_agent[1][1] # both arguments, first must be the modelfile, second the name elif len(current_agent[1]) == 3: - assert current_agent[1][1].endswith('.dat'), \ + assert current_agent[1][1].endswith('.dat') or current_agent[1][1].endswith('.zip'), \ 'if two arguments as well as a config are provided, ' + \ f'the first extra one must be the modelfile. Arg1: {current_agent[1][1]}, Arg2: {current_agent[1][2]}' agent_modelfile = current_agent[1][1] diff --git a/recommerce/monitoring/agent_monitoring/am_monitoring.py b/recommerce/monitoring/agent_monitoring/am_monitoring.py index abd612be..da3471b9 100644 --- a/recommerce/monitoring/agent_monitoring/am_monitoring.py +++ b/recommerce/monitoring/agent_monitoring/am_monitoring.py @@ -47,9 +47,8 @@ def run_marketplace(self) -> list: Returns: list: A list with a list of rewards for each agent """ - config_market = HyperparameterConfigLoader.load('market_config') # initialize the watcher list with a list for each agent - watchers = [Watcher(config_market=config_market) for _ in range(len(self.configurator.agents))] + watchers = [Watcher(config_market=self.configurator.marketplace.config) for _ in range(len(self.configurator.agents))] for episode in trange(1, self.configurator.episodes + 1, unit=' episodes', leave=False): # reset the state & marketplace once to be used by all agents @@ -98,7 +97,7 @@ def main(): # pragma: no cover """ monitor = Monitor() config_environment_am: AgentMonitoringEnvironmentConfig = EnvironmentConfigLoader.load('environment_config_agent_monitoring') - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', config_environment_am.marketplace) monitor.configurator.setup_monitoring( enable_live_draw=config_environment_am.enable_live_draw, episodes=config_environment_am.episodes, diff --git a/recommerce/monitoring/exampleprinter.py b/recommerce/monitoring/exampleprinter.py index 4300d23a..7567d1a2 100644 --- a/recommerce/monitoring/exampleprinter.py +++ b/recommerce/monitoring/exampleprinter.py @@ -98,24 +98,25 @@ def main(): # pragma: no cover """ Defines what is performed when the `agent_monitoring` command is chosen in `main.py`. """ - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') - config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') + config_environment: ExampleprinterEnvironmentConfig = EnvironmentConfigLoader.load('environment_config_exampleprinter') + + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', config_environment.marketplace) + config_rl: AttrDict = HyperparameterConfigLoader.load('q_learning_config', config_environment.agent[0]['agent_class']) printer = ExamplePrinter(config_market=config_market) - config_environment: ExampleprinterEnvironmentConfig = EnvironmentConfigLoader.load('environment_config_exampleprinter') # TODO: Theoretically, the name of the agent is saved in config_environment['name'], but we don't use it yet. - marketplace = config_environment.marketplace(config=config_market) + marketplace = config_environment.marketplace(config=config_market, competitors=config_environment.agent[1:]) # QLearningAgents need more initialization - if issubclass(config_environment.agent['agent_class'], QLearningAgent): + if issubclass(config_environment.agent[0]['agent_class'], QLearningAgent): printer.setup_exampleprinter(marketplace=marketplace, - agent=config_environment.agent['agent_class']( + agent=config_environment.agent[0]['agent_class']( config_market=config_market, config_rl=config_rl, marketplace=marketplace, - load_path=os.path.abspath(os.path.join(PathManager.data_path, config_environment.agent['argument'])))) + load_path=os.path.abspath(os.path.join(PathManager.data_path, config_environment.agent[0]['argument'])))) else: - printer.setup_exampleprinter(marketplace=marketplace, agent=config_environment.agent['agent_class']()) + printer.setup_exampleprinter(marketplace=marketplace, agent=config_environment.agent[0]['agent_class']()) print(f'The final profit was: {printer.run_example()}') diff --git a/recommerce/monitoring/policyanalyzer.py b/recommerce/monitoring/policyanalyzer.py index c352e302..7581cec3 100644 --- a/recommerce/monitoring/policyanalyzer.py +++ b/recommerce/monitoring/policyanalyzer.py @@ -4,7 +4,9 @@ import numpy as np import recommerce.configuration.utils as ut +from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader from recommerce.configuration.path_manager import PathManager +from recommerce.market.circular.circular_sim_market import CircularEconomyRebuyPriceMonopoly from recommerce.market.circular.circular_vendors import RuleBasedCERebuyAgentCompetitive from recommerce.rl.actorcritic.actorcritic_agent import ContinuousActorCriticAgent @@ -117,7 +119,8 @@ def analyze_policy(self, base_input, analyzed_features, title='add a title here' if __name__ == '__main__': - pa = PolicyAnalyzer(RuleBasedCERebuyAgentCompetitive(), 'default_configuration') + config_market = HyperparameterConfigLoader.load('market_config', CircularEconomyRebuyPriceMonopoly) + pa = PolicyAnalyzer(RuleBasedCERebuyAgentCompetitive(config_market=config_market), 'default_configuration') one_competitor_examples = [ ('rule based own refurbished price', 0), ('rule based own new price', 1), diff --git a/recommerce/rl/actorcritic/actorcritic_agent.py b/recommerce/rl/actorcritic/actorcritic_agent.py index d59ff2b6..6702c638 100644 --- a/recommerce/rl/actorcritic/actorcritic_agent.py +++ b/recommerce/rl/actorcritic/actorcritic_agent.py @@ -6,6 +6,7 @@ import recommerce.configuration.utils as ut import recommerce.rl.model as model +from recommerce.configuration.common_rules import between_zero_one_rule, greater_zero_rule from recommerce.market.circular.circular_vendors import CircularAgent from recommerce.market.linear.linear_vendors import LinearAgent from recommerce.market.sim_market import SimMarket @@ -150,6 +151,14 @@ def agent_output_to_market_form(self, action) -> None: # pragma: no cover """ raise NotImplementedError('This method is abstract. Use a subclass') + @staticmethod + def get_configurable_fields() -> list: + return [ + ('gamma', float, between_zero_one_rule), + ('sync_target_frames', int, greater_zero_rule), + ('testvalue2', float, greater_zero_rule) + ] + class DiscreteActorCriticAgent(ActorCriticAgent, LinearAgent, CircularAgent): """ diff --git a/recommerce/rl/q_learning/q_learning_agent.py b/recommerce/rl/q_learning/q_learning_agent.py index 4fd3aaf9..79512686 100644 --- a/recommerce/rl/q_learning/q_learning_agent.py +++ b/recommerce/rl/q_learning/q_learning_agent.py @@ -6,6 +6,7 @@ from attrdict import AttrDict import recommerce.rl.model as model +from recommerce.configuration.common_rules import between_zero_one_rule, greater_zero_rule from recommerce.market.circular.circular_vendors import CircularAgent from recommerce.market.linear.linear_vendors import LinearAgent from recommerce.market.sim_market import SimMarket @@ -125,3 +126,17 @@ def save(self, model_path: str) -> None: """ assert model_path.endswith('.dat'), f'the modelname must end in ".dat": {model_path}' torch.save(self.net.state_dict(), model_path) + + @staticmethod + def get_configurable_fields() -> list: + return [ + ('gamma', float, between_zero_one_rule), + ('batch_size', int, greater_zero_rule), + ('replay_size', int, greater_zero_rule), + ('learning_rate', float, greater_zero_rule), + ('sync_target_frames', int, greater_zero_rule), + ('replay_start_size', int, greater_zero_rule), + ('epsilon_decay_last_frame', int, greater_zero_rule), + ('epsilon_start', float, between_zero_one_rule), + ('epsilon_final', float, between_zero_one_rule), + ] diff --git a/recommerce/rl/reinforcement_learning_agent.py b/recommerce/rl/reinforcement_learning_agent.py index ffd497fa..111aa242 100644 --- a/recommerce/rl/reinforcement_learning_agent.py +++ b/recommerce/rl/reinforcement_learning_agent.py @@ -1,13 +1,14 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod import torch from attrdict import AttrDict +from recommerce.configuration.json_configurable import JSONConfigurable from recommerce.market.sim_market import SimMarket from recommerce.market.vendors import Agent -class ReinforcementLearningAgent(Agent, ABC): +class ReinforcementLearningAgent(Agent, JSONConfigurable): @abstractmethod def __init__( self, diff --git a/recommerce/rl/rl_vs_rl_training.py b/recommerce/rl/rl_vs_rl_training.py index bb0b320d..c631dd91 100644 --- a/recommerce/rl/rl_vs_rl_training.py +++ b/recommerce/rl/rl_vs_rl_training.py @@ -6,7 +6,8 @@ from recommerce.configuration.path_manager import PathManager from recommerce.market.circular.circular_sim_market import CircularEconomyRebuyPriceDuopoly -from recommerce.rl.stable_baselines.stable_baselines_model import StableBaselinesPPO, StableBaselinesSAC +from recommerce.rl.stable_baselines.sb_ppo import StableBaselinesPPO +from recommerce.rl.stable_baselines.sb_sac import StableBaselinesSAC def train_rl_vs_rl(config_market: AttrDict, config_rl: AttrDict, num_switches: int = 30, num_steps_per_switch: int = 25000): diff --git a/recommerce/rl/self_play.py b/recommerce/rl/self_play.py index 76dd9af7..d2c75bbe 100644 --- a/recommerce/rl/self_play.py +++ b/recommerce/rl/self_play.py @@ -1,7 +1,8 @@ from attrdict import AttrDict from recommerce.market.circular.circular_sim_market import CircularEconomyRebuyPriceDuopoly -from recommerce.rl.stable_baselines.stable_baselines_model import StableBaselinesAgent, StableBaselinesPPO +from recommerce.rl.stable_baselines.sb_ppo import StableBaselinesPPO +from recommerce.rl.stable_baselines.stable_baselines_model import StableBaselinesAgent def train_self_play( diff --git a/recommerce/rl/stable_baselines/sb_a2c.py b/recommerce/rl/stable_baselines/sb_a2c.py new file mode 100644 index 00000000..6b08aa1b --- /dev/null +++ b/recommerce/rl/stable_baselines/sb_a2c.py @@ -0,0 +1,24 @@ +from stable_baselines3 import A2C + +from recommerce.configuration.common_rules import between_zero_one_rule, greater_zero_rule +from recommerce.rl.stable_baselines.stable_baselines_model import StableBaselinesAgent + + +class StableBaselinesA2C(StableBaselinesAgent): + """ + This a stable baseline agent using A2C. + """ + name = 'Stable_Baselines_A2C' + + def _initialize_model(self, marketplace): + self.model = A2C('MlpPolicy', marketplace, verbose=False, tensorboard_log=self.tensorboard_log) + + def _load(self, load_path): + self.model = A2C.load(load_path, tensorboard_log=self.tensorboard_log) + + @staticmethod + def get_configurable_fields() -> list: + return [ + ('testvalue1', float, between_zero_one_rule), + ('a2cvalue', float, greater_zero_rule) + ] diff --git a/recommerce/rl/stable_baselines/sb_ddpg.py b/recommerce/rl/stable_baselines/sb_ddpg.py new file mode 100644 index 00000000..87351d0d --- /dev/null +++ b/recommerce/rl/stable_baselines/sb_ddpg.py @@ -0,0 +1,28 @@ +import numpy as np +from stable_baselines3 import DDPG +from stable_baselines3.common.noise import NormalActionNoise + +from recommerce.configuration.common_rules import between_zero_one_rule, greater_zero_rule +from recommerce.rl.stable_baselines.stable_baselines_model import StableBaselinesAgent + + +class StableBaselinesDDPG(StableBaselinesAgent): + """ + This a stable baseline agent using Deep Deterministic Policy Gradient (DDPG) algorithm. + """ + name = 'Stable_Baselines_DDPG' + + def _initialize_model(self, marketplace): + n_actions = marketplace.get_actions_dimension() + action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=1 * np.ones(n_actions)) + self.model = DDPG('MlpPolicy', marketplace, action_noise=action_noise, verbose=False, tensorboard_log=self.tensorboard_log) + + def _load(self, load_path): + self.model = DDPG.load(load_path, tensorboard_log=self.tensorboard_log) + + @staticmethod + def get_configurable_fields() -> list: + return [ + ('testvalue1', float, between_zero_one_rule), + ('ddpgvalue', float, greater_zero_rule) + ] diff --git a/recommerce/rl/stable_baselines/sb_ppo.py b/recommerce/rl/stable_baselines/sb_ppo.py new file mode 100644 index 00000000..e86cf354 --- /dev/null +++ b/recommerce/rl/stable_baselines/sb_ppo.py @@ -0,0 +1,24 @@ +from stable_baselines3 import PPO + +from recommerce.configuration.common_rules import between_zero_one_rule, greater_zero_rule +from recommerce.rl.stable_baselines.stable_baselines_model import StableBaselinesAgent + + +class StableBaselinesPPO(StableBaselinesAgent): + """ + This a stable baseline agent using Proximal Policy Optimization algorithm (PPO). + """ + name = 'Stable_Baselines_PPO' + + def _initialize_model(self, marketplace): + self.model = PPO('MlpPolicy', marketplace, verbose=False, tensorboard_log=self.tensorboard_log) + + def _load(self, load_path): + self.model = PPO.load(load_path, tensorboard_log=self.tensorboard_log) + + @staticmethod + def get_configurable_fields() -> list: + return [ + ('testvalue1', float, between_zero_one_rule), + ('ppovalue', float, greater_zero_rule) + ] diff --git a/recommerce/rl/stable_baselines/sb_sac.py b/recommerce/rl/stable_baselines/sb_sac.py new file mode 100644 index 00000000..c2f863fd --- /dev/null +++ b/recommerce/rl/stable_baselines/sb_sac.py @@ -0,0 +1,24 @@ +from stable_baselines3 import SAC + +from recommerce.configuration.common_rules import between_zero_one_rule, greater_zero_rule +from recommerce.rl.stable_baselines.stable_baselines_model import StableBaselinesAgent + + +class StableBaselinesSAC(StableBaselinesAgent): + """ + This a stable baseline agent using Soft Actor Critic (SAC). + """ + name = 'Stable_Baselines_SAC' + + def _initialize_model(self, marketplace): + self.model = SAC('MlpPolicy', marketplace, verbose=False, tensorboard_log=self.tensorboard_log) + + def _load(self, load_path): + self.model = SAC.load(load_path, tensorboard_log=self.tensorboard_log) + + @staticmethod + def get_configurable_fields() -> list: + return [ + ('testvalue1', float, between_zero_one_rule), + ('sacvalue', float, greater_zero_rule) + ] diff --git a/recommerce/rl/stable_baselines/sb_td3.py b/recommerce/rl/stable_baselines/sb_td3.py new file mode 100644 index 00000000..f826a812 --- /dev/null +++ b/recommerce/rl/stable_baselines/sb_td3.py @@ -0,0 +1,28 @@ +import numpy as np +from stable_baselines3 import TD3 +from stable_baselines3.common.noise import NormalActionNoise + +from recommerce.configuration.common_rules import between_zero_one_rule, greater_zero_rule +from recommerce.rl.stable_baselines.stable_baselines_model import StableBaselinesAgent + + +class StableBaselinesTD3(StableBaselinesAgent): + """ + This a stable baseline agent using TD3 which is a direct successor of DDPG. + """ + name = 'Stable_Baselines_TD3' + + def _initialize_model(self, marketplace): + n_actions = marketplace.get_actions_dimension() + action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=1 * np.ones(n_actions)) + self.model = TD3('MlpPolicy', marketplace, action_noise=action_noise, verbose=False, tensorboard_log=self.tensorboard_log) + + def _load(self, load_path): + self.model = TD3.load(load_path, tensorboard_log=self.tensorboard_log) + + @staticmethod + def get_configurable_fields() -> list: + return [ + ('testvalue1', float, between_zero_one_rule), + ('td3value', float, greater_zero_rule) + ] diff --git a/recommerce/rl/stable_baselines/stable_baselines_model.py b/recommerce/rl/stable_baselines/stable_baselines_model.py index de1dcc25..04726a52 100644 --- a/recommerce/rl/stable_baselines/stable_baselines_model.py +++ b/recommerce/rl/stable_baselines/stable_baselines_model.py @@ -3,9 +3,8 @@ import numpy as np from attrdict import AttrDict -from stable_baselines3 import A2C, DDPG, PPO, SAC, TD3 -from stable_baselines3.common.noise import NormalActionNoise +from recommerce.configuration.common_rules import greater_zero_rule from recommerce.configuration.path_manager import PathManager from recommerce.market.circular.circular_vendors import CircularAgent from recommerce.market.linear.linear_vendors import LinearAgent @@ -61,71 +60,8 @@ def train_agent(self, training_steps=100000, iteration_length=500, analyze_after self.model.learn(training_steps, callback=callback) return callback.watcher.all_dicts - -class StableBaselinesDDPG(StableBaselinesAgent): - """ - This a stable baseline agent using Deep Deterministic Policy Gradient (DDPG) algorithm. - """ - name = 'Stable_Baselines_DDPG' - - def _initialize_model(self, marketplace): - n_actions = marketplace.get_actions_dimension() - action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=1 * np.ones(n_actions)) - self.model = DDPG('MlpPolicy', marketplace, action_noise=action_noise, verbose=False, tensorboard_log=self.tensorboard_log) - - def _load(self, load_path): - self.model = DDPG.load(load_path, tensorboard_log=self.tensorboard_log) - - -class StableBaselinesTD3(StableBaselinesAgent): - """ - This a stable baseline agent using TD3 which is a direct successor of DDPG. - """ - name = 'Stable_Baselines_TD3' - - def _initialize_model(self, marketplace): - n_actions = marketplace.get_actions_dimension() - action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=1 * np.ones(n_actions)) - self.model = TD3('MlpPolicy', marketplace, action_noise=action_noise, verbose=False, tensorboard_log=self.tensorboard_log) - - def _load(self, load_path): - self.model = TD3.load(load_path, tensorboard_log=self.tensorboard_log) - - -class StableBaselinesA2C(StableBaselinesAgent): - """ - This a stable baseline agent using A2C. - """ - name = 'Stable_Baselines_A2C' - - def _initialize_model(self, marketplace): - self.model = A2C('MlpPolicy', marketplace, verbose=False, tensorboard_log=self.tensorboard_log) - - def _load(self, load_path): - self.model = A2C.load(load_path, tensorboard_log=self.tensorboard_log) - - -class StableBaselinesPPO(StableBaselinesAgent): - """ - This a stable baseline agent using Proximal Policy Optimization algorithm (PPO). - """ - name = 'Stable_Baselines_PPO' - - def _initialize_model(self, marketplace): - self.model = PPO('MlpPolicy', marketplace, verbose=False, tensorboard_log=self.tensorboard_log) - - def _load(self, load_path): - self.model = PPO.load(load_path, tensorboard_log=self.tensorboard_log) - - -class StableBaselinesSAC(StableBaselinesAgent): - """ - This a stable baseline agent using Soft Actor Critic (SAC). - """ - name = 'Stable_Baselines_SAC' - - def _initialize_model(self, marketplace): - self.model = SAC('MlpPolicy', marketplace, verbose=False, tensorboard_log=self.tensorboard_log) - - def _load(self, load_path): - self.model = SAC.load(load_path, tensorboard_log=self.tensorboard_log) + @staticmethod + def get_configurable_fields() -> list: + return [ + ('stable_baseline_test', float, greater_zero_rule) + ] diff --git a/recommerce/rl/training_scenario.py b/recommerce/rl/training_scenario.py index 609f9f5f..4604d49e 100644 --- a/recommerce/rl/training_scenario.py +++ b/recommerce/rl/training_scenario.py @@ -8,7 +8,6 @@ import recommerce.rl.q_learning.q_learning_agent as q_learning_agent import recommerce.rl.rl_vs_rl_training as rl_vs_rl_training import recommerce.rl.self_play as self_play -import recommerce.rl.stable_baselines.stable_baselines_model as sbmodel from recommerce.configuration.environment_config import EnvironmentConfigLoader, TrainingEnvironmentConfig from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader from recommerce.configuration.path_manager import PathManager @@ -17,23 +16,33 @@ from recommerce.market.vendors import FixedPriceAgent from recommerce.rl.actorcritic.actorcritic_training import ActorCriticTrainer from recommerce.rl.q_learning.q_learning_training import QLearningTrainer +from recommerce.rl.stable_baselines.sb_ppo import StableBaselinesPPO +from recommerce.rl.stable_baselines.sb_sac import StableBaselinesSAC print('successfully imported torch: cuda?', torch.cuda.is_available()) def run_training_session( - config_market: AttrDict = HyperparameterConfigLoader.load('market_config'), - config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config'), marketplace=circular_market.CircularEconomyRebuyPriceDuopoly, agent=q_learning_agent.QLearningAgent, + config_market: AttrDict = None, + config_rl: AttrDict = None, competitors: list = None) -> None: """ - Run a training session with the passed marketplace and QLearningAgent. + Run a training session with the passed marketplace and Agent. + Args: marketplace (SimMarket subclass): What marketplace to run the training session on. agent (QLearningAgent subclass): What kind of QLearningAgent to train. + config_market (AttrDict, optional): The config to be used for the marketplace. Defaults to loading the `market_config`. + config_rl (AttrDict, optional): The config to be used for the agent. Defaults to loading the `q_learning_config`. competitors (list | None, optional): If set, which competitors should be used instead of the default ones. """ + if config_market is None: + config_market = HyperparameterConfigLoader.load('market_config', marketplace) + if config_rl is None: + config_rl = HyperparameterConfigLoader.load('q_learning_config', agent) + assert issubclass(marketplace, sim_market.SimMarket), \ f'the type of the passed marketplace must be a subclass of SimMarket: {marketplace}' assert issubclass(agent, (q_learning_agent.QLearningAgent, actorcritic_agent.ActorCriticAgent)), \ @@ -79,45 +88,51 @@ def train_q_learning_circular_economy_rebuy(): """ run_training_session( marketplace=circular_market.CircularEconomyRebuyPriceDuopoly, - agent=q_learning_agent.QLearningAgent) + agent=q_learning_agent.QLearningAgent,) def train_continuous_a2c_circular_economy_rebuy(): """ Train an ActorCriticAgent on a Circular Economy Market with Rebuy Prices and one competitor. """ + used_agent = actorcritic_agent.ContinuousActorCriticAgentFixedOneStd run_training_session( marketplace=circular_market.CircularEconomyRebuyPriceDuopoly, - agent=actorcritic_agent.ContinuousActorCriticAgentFixedOneStd) + agent=used_agent, + config_rl=HyperparameterConfigLoader.load('actor_critic_config', used_agent)) def train_stable_baselines_ppo(): - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') - config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') - sbmodel.StableBaselinesPPO( + used_marketplace = circular_market.CircularEconomyRebuyPriceDuopoly + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', used_marketplace) + config_rl: AttrDict = HyperparameterConfigLoader.load('sb_ppo_config', StableBaselinesPPO) + StableBaselinesPPO( config_market=config_market, config_rl=config_rl, - marketplace=circular_market.CircularEconomyRebuyPriceDuopoly(config_market, True)).train_agent() + marketplace=used_marketplace(config_market, True)).train_agent() def train_stable_baselines_sac(): - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') - config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') - sbmodel.StableBaselinesSAC( + used_marketplace = circular_market.CircularEconomyRebuyPriceDuopoly + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', used_marketplace) + config_rl: AttrDict = HyperparameterConfigLoader.load('sb_sac_config', StableBaselinesSAC) + StableBaselinesSAC( config_market=config_market, config_rl=config_rl, - marketplace=circular_market.CircularEconomyRebuyPriceDuopoly(config_market, True)).train_agent() + marketplace=used_marketplace(config_market, True)).train_agent() def train_rl_vs_rl(): - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') - config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') + # marketplace is currently hardcoded in train_rl_vs_rl + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceDuopoly) + config_rl: AttrDict = HyperparameterConfigLoader.load('sb_ppo_config', StableBaselinesPPO) rl_vs_rl_training.train_rl_vs_rl(config_market, config_rl) def train_self_play(): - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') - config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') + # marketplace is currently hardcoded in train_self_play + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceDuopoly) + config_rl: AttrDict = HyperparameterConfigLoader.load('sb_ppo_config', StableBaselinesPPO) self_play.train_self_play(config_market, config_rl) @@ -125,21 +140,22 @@ def train_from_config(): """ Use the `environment_config_training.json` file to decide on the training parameters. """ - config: TrainingEnvironmentConfig = EnvironmentConfigLoader.load('environment_config_training') - config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') - # TODO: Theoretically, the name of the agent is saved in config['name'], but we don't use it yet. + config_environment: TrainingEnvironmentConfig = EnvironmentConfigLoader.load('environment_config_training') + config_rl: AttrDict = HyperparameterConfigLoader.load('q_learning_config', config_environment.agent[0]['agent_class']) + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', config_environment.marketplace) + competitor_list = [] - for competitor in config.agent[1:]: + for competitor in config_environment.agent[1:]: if issubclass(competitor['agent_class'], FixedPriceAgent): competitor_list.append( - competitor['agent_class'](config_market=config, fixed_price=competitor['argument'], name=competitor['name'])) + competitor['agent_class'](config_market=config_market, fixed_price=competitor['argument'], name=competitor['name'])) else: - competitor_list.append(competitor['agent_class'](config_market=config, name=competitor['name'])) + competitor_list.append(competitor['agent_class'](config_market=config_market, name=competitor['name'])) run_training_session( config_rl=config_rl, - marketplace=config.marketplace, - agent=config.agent[0]['agent_class'], + marketplace=config_environment.marketplace, + agent=config_environment.agent[0]['agent_class'], competitors=competitor_list) diff --git a/setup.cfg b/setup.cfg index db6dcc61..ad9d6b77 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,7 @@ install_requires = tensorboard>=2.7.0 tqdm>=4.63.0 stable-baselines3[extra]>=1.5.0 + # names is a webserver dependency, but unfortunately not available through conda names>=0.3.0 scipy>=1.8.0 attrdict>=2.0.1 @@ -57,6 +58,9 @@ exclude = webserver/alpha_business_app/migrations/* per-file-ignores = recommerce/monitoring/performance.py:F401 webserver/alpha_business_app/tests/test_prefill.py:E501 + webserver/alpha_business_app/config_parser.py:F401 + webserver/alpha_business_app/models/config.py:F401 + webserver/alpha_business_app/models/hyperparameter_config.py:F401 [coverage:run] omit = diff --git a/tests/test_actorcritic_agent.py b/tests/test_actorcritic_agent.py index 8b045ccb..208180f0 100644 --- a/tests/test_actorcritic_agent.py +++ b/tests/test_actorcritic_agent.py @@ -7,9 +7,10 @@ import recommerce.market.linear.linear_sim_market as linear_market import recommerce.rl.actorcritic.actorcritic_agent as actorcritic_agent from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader +from recommerce.rl.actorcritic.actorcritic_agent import ContinuousActorCriticAgentFixedOneStd -config_market: AttrDict = HyperparameterConfigLoader.load('market_config') -config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') +config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) +config_rl: AttrDict = HyperparameterConfigLoader.load('actor_critic_config', ContinuousActorCriticAgentFixedOneStd) abstract_agent_classes_testcases = [ actorcritic_agent.ActorCriticAgent, diff --git a/tests/test_actorcritic_training.py b/tests/test_actorcritic_training.py index 2934e92f..09110ab6 100644 --- a/tests/test_actorcritic_training.py +++ b/tests/test_actorcritic_training.py @@ -30,8 +30,8 @@ @pytest.mark.slow @pytest.mark.parametrize('market_class, agent_class, verbose', test_scenarios) def test_training_configurations(market_class, agent_class, verbose): - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') - config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) + config_rl: AttrDict = HyperparameterConfigLoader.load('actor_critic_config', actorcritic_agent.ContinuousActorCriticAgentFixedOneStd) config_rl.batch_size = 8 ActorCriticTrainer(market_class, agent_class, config_market, config_rl).train_agent( verbose=verbose, diff --git a/tests/test_agent_monitoring/test_am_configuration.py b/tests/test_agent_monitoring/test_am_configuration.py index 72fa8f1b..c1350eea 100644 --- a/tests/test_agent_monitoring/test_am_configuration.py +++ b/tests/test_agent_monitoring/test_am_configuration.py @@ -3,13 +3,13 @@ from unittest.mock import patch import pytest -import utils_tests as ut_t +from attrdict import AttrDict import recommerce.market.circular.circular_sim_market as circular_market import recommerce.market.linear.linear_sim_market as linear_market import recommerce.monitoring.agent_monitoring.am_monitoring as monitoring import recommerce.rl.actorcritic.actorcritic_agent as actorcritic_agent -from recommerce.configuration.hyperparameter_config import HyperparameterConfigValidator +from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader from recommerce.configuration.path_manager import PathManager from recommerce.market.circular.circular_vendors import FixedPriceCEAgent, FixedPriceCERebuyAgent, HumanPlayerCERebuy, RuleBasedCEAgent from recommerce.market.linear.linear_vendors import FixedPriceLEAgent @@ -17,7 +17,7 @@ monitor = monitoring.Monitor() -config_hyperparameter: HyperparameterConfigValidator = ut_t.mock_config_hyperparameter() +config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) # setup before each test @@ -31,7 +31,7 @@ def setup_function(function): plot_interval=10, marketplace=circular_market.CircularEconomyMonopoly, agents=[(FixedPriceCERebuyAgent, [])], - config_market=config_hyperparameter, + config_market=config_market, subfolder_name=f'test_plots_{function.__name__}') @@ -66,7 +66,7 @@ def test_get_modelfile_path(): @pytest.mark.parametrize('agents, expected_message', incorrect_update_agents_RL_testcases) def test_incorrect_update_agents_RL(agents, expected_message): with pytest.raises(AssertionError) as assertion_message: - monitor.configurator.setup_monitoring(agents=agents, config_market=config_hyperparameter) + monitor.configurator.setup_monitoring(agents=agents, config_market=config_market) assert expected_message in str(assertion_message.value) @@ -81,7 +81,7 @@ def test_incorrect_update_agents_RL(agents, expected_message): @pytest.mark.parametrize('agents', correct_update_agents_RL_testcases) def test_correct_update_agents_RL(agents): - monitor.configurator.setup_monitoring(agents=agents, config_market=config_hyperparameter) + monitor.configurator.setup_monitoring(agents=agents, config_market=config_market) def test_correct_setup_monitoring(): @@ -92,7 +92,7 @@ def test_correct_setup_monitoring(): marketplace=circular_market.CircularEconomyMonopoly, agents=[(HumanPlayerCERebuy, ['reptiloid']), (QLearningAgent, ['CircularEconomyMonopoly_QLearningAgent.dat', 'q_learner'])], - config_market=config_hyperparameter, + config_market=config_market, subfolder_name='subfoldername') assert monitor.configurator.enable_live_draw is False assert 10 == monitor.configurator.episodes @@ -116,11 +116,11 @@ def test_correct_setup_monitoring(): @pytest.mark.parametrize('agents', setting_multiple_agents_testcases) def test_setting_multiple_agents(agents): - monitor.configurator.setup_monitoring(agents=agents, config_market=config_hyperparameter) + monitor.configurator.setup_monitoring(agents=agents, config_market=config_market) def test_setting_market_not_agents(): - monitor.configurator.setup_monitoring(marketplace=circular_market.CircularEconomyMonopoly, config_market=config_hyperparameter) + monitor.configurator.setup_monitoring(marketplace=circular_market.CircularEconomyMonopoly, config_market=config_market) correct_setup_monitoring_testcases = [ @@ -164,7 +164,7 @@ def test_correct_setup_monitoring_parametrized(parameters): plot_interval=dict['plot_interval'], marketplace=dict['marketplace'], agents=dict['agents'], - config_market=config_hyperparameter, + config_market=config_market, subfolder_name=dict['subfolder_name'] ) @@ -241,16 +241,16 @@ def test_incorrect_setup_monitoring(parameters, expected_message): plot_interval=dict['plot_interval'], marketplace=dict['marketplace'], agents=dict['agents'], - config_market=config_hyperparameter, + config_market=config_market, subfolder_name=dict['subfolder_name'] ) assert expected_message in str(assertion_message.value) incorrect_setup_monitoring_type_errors_testcases = [ - {'marketplace': linear_market.LinearEconomyDuopoly(config=config_hyperparameter)}, - {'agents': [(linear_market.LinearEconomyDuopoly(config=config_hyperparameter), [])]}, - {'agents': [(RuleBasedCEAgent(config_market=config_hyperparameter), [])]} + {'marketplace': linear_market.LinearEconomyDuopoly(config=config_market)}, + {'agents': [(linear_market.LinearEconomyDuopoly(config=config_market), [])]}, + {'agents': [(RuleBasedCEAgent(config_market=config_market), [])]} ] @@ -275,7 +275,7 @@ def test_incorrect_setup_monitoring_type_errors(parameters): plot_interval=dict['plot_interval'], marketplace=dict['marketplace'], agents=dict['agents'], - config_market=config_hyperparameter, + config_market=config_market, subfolder_name=dict['subfolder_name'] ) @@ -288,14 +288,14 @@ def test_incorrect_setup_monitoring_type_errors(parameters): @pytest.mark.parametrize('agents', print_configuration_testcases) def test_print_configuration(agents): - monitor.configurator.setup_monitoring(agents=agents, config_market=config_hyperparameter) + monitor.configurator.setup_monitoring(agents=agents, config_market=config_market) monitor.configurator.print_configuration() @pytest.mark.parametrize('agents', print_configuration_testcases) def test_print_configuration_ratio(agents): - monitor.configurator.setup_monitoring(config_market=config_hyperparameter, episodes=51, plot_interval=1, agents=agents) + monitor.configurator.setup_monitoring(config_market=config_market, episodes=51, plot_interval=1, agents=agents) with patch('recommerce.monitoring.agent_monitoring.am_configuration.input', create=True) as mocked_input: mocked_input.side_effect = ['n'] diff --git a/tests/test_agent_monitoring/test_am_evaluation.py b/tests/test_agent_monitoring/test_am_evaluation.py index cefd3f23..a0bbb077 100644 --- a/tests/test_agent_monitoring/test_am_evaluation.py +++ b/tests/test_agent_monitoring/test_am_evaluation.py @@ -3,18 +3,18 @@ import numpy as np import pytest -import utils_tests as ut_t from attrdict import AttrDict import recommerce.market.circular.circular_sim_market as circular_market import recommerce.monitoring.agent_monitoring.am_monitoring as monitoring +from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader from recommerce.configuration.path_manager import PathManager from recommerce.market.circular.circular_vendors import FixedPriceCEAgent, RuleBasedCEAgent from recommerce.rl.q_learning.q_learning_agent import QLearningAgent monitor = monitoring.Monitor() -config_hyperparameter: AttrDict = ut_t.mock_config_hyperparameter() +config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) # setup before each test @@ -28,7 +28,7 @@ def setup_function(function): plot_interval=10, marketplace=circular_market.CircularEconomyMonopoly, agents=[(QLearningAgent, [os.path.join(PathManager.data_path, 'CircularEconomyMonopoly_QLearningAgent.dat')])], - config_market=config_hyperparameter, + config_market=config_market, subfolder_name=f'test_plots_{function.__name__}') @@ -50,7 +50,7 @@ def test_evaluate_session(agents, rewards): patch('recommerce.monitoring.agent_monitoring.am_configuration.os.makedirs'), \ patch('recommerce.monitoring.agent_monitoring.am_configuration.os.path.exists') as exists_mock: exists_mock.return_value = True - monitor.configurator.setup_monitoring(episodes=4, plot_interval=1, agents=agents, config_market=config_hyperparameter) + monitor.configurator.setup_monitoring(episodes=4, plot_interval=1, agents=agents, config_market=config_market) monitor.evaluator.evaluate_session(rewards) @@ -79,7 +79,7 @@ def test_rewards_array_size(): @pytest.mark.parametrize('agents, rewards, plot_bins, agent_color, lower_upper_range', create_histogram_statistics_plots_testcases) def test_create_histogram(agents, rewards, plot_bins, agent_color, lower_upper_range): - monitor.configurator.setup_monitoring(enable_live_draw=True, agents=agents, config_market=config_hyperparameter) + monitor.configurator.setup_monitoring(enable_live_draw=True, agents=agents, config_market=config_market) name_list = [agent.name for agent in monitor.configurator.agents] with patch('recommerce.monitoring.agent_monitoring.am_evaluation.plt.clf'), \ patch('recommerce.monitoring.agent_monitoring.am_evaluation.plt.xlabel'), \ @@ -101,7 +101,7 @@ def test_create_histogram(agents, rewards, plot_bins, agent_color, lower_upper_r def test_create_histogram_without_saving_to_directory(): - monitor.configurator.setup_monitoring(enable_live_draw=False, agents=[(RuleBasedCEAgent, [])], config_market=config_hyperparameter) + monitor.configurator.setup_monitoring(enable_live_draw=False, agents=[(RuleBasedCEAgent, [])], config_market=config_market) with patch('recommerce.monitoring.agent_monitoring.am_evaluation.plt.clf'), \ patch('recommerce.monitoring.agent_monitoring.am_evaluation.plt.xlabel'), \ patch('recommerce.monitoring.agent_monitoring.am_evaluation.plt.title'), \ @@ -120,7 +120,7 @@ def test_create_histogram_without_saving_to_directory(): @pytest.mark.parametrize('agents, rewards, plot_bins, agent_color, lower_upper_range', create_histogram_statistics_plots_testcases) def test_create_statistics_plots(agents, rewards, plot_bins, agent_color, lower_upper_range): - monitor.configurator.setup_monitoring(agents=agents, episodes=len(rewards[0]), plot_interval=1, config_market=config_hyperparameter) + monitor.configurator.setup_monitoring(agents=agents, episodes=len(rewards[0]), plot_interval=1, config_market=config_market) with patch('recommerce.monitoring.agent_monitoring.am_evaluation.plt'), \ patch('recommerce.monitoring.agent_monitoring.am_configuration.os.makedirs'), \ patch('recommerce.monitoring.agent_monitoring.am_configuration.os.path.exists') as exists_mock: @@ -138,14 +138,14 @@ def test_create_statistics_plots(agents, rewards, plot_bins, agent_color, lower_ @pytest.mark.parametrize('x_values, y_values, plot_type, expected_message', incorrect_create_line_plot_testcases) def test_incorrect_create_line_plot(x_values, y_values, plot_type, expected_message): - monitor.configurator.setup_monitoring(episodes=4, plot_interval=2, config_market=config_hyperparameter) + monitor.configurator.setup_monitoring(episodes=4, plot_interval=2, config_market=config_market) with pytest.raises(AssertionError) as assertion_message: monitor.evaluator._create_line_plot(x_values, y_values, 'test_plot', plot_type) assert expected_message in str(assertion_message.value) def test_incorrect_create_line_plot_runtime_errors(): - monitor.configurator.setup_monitoring(episodes=4, plot_interval=2, config_market=config_hyperparameter) + monitor.configurator.setup_monitoring(episodes=4, plot_interval=2, config_market=config_market) with pytest.raises(RuntimeError) as assertion_message: monitor.evaluator._create_line_plot([1, 2], [[1, 3]], 'test_plot', 'Unknown_metric_type') assert 'this metric_type is unknown: Unknown_metric_type' in str(assertion_message.value) diff --git a/tests/test_agent_monitoring/test_am_monitoring.py b/tests/test_agent_monitoring/test_am_monitoring.py index a4d83b2f..4db3c2b6 100644 --- a/tests/test_agent_monitoring/test_am_monitoring.py +++ b/tests/test_agent_monitoring/test_am_monitoring.py @@ -13,7 +13,7 @@ monitor = monitoring.Monitor() -config_market: AttrDict = HyperparameterConfigLoader.load('market_config') +config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) # setup before each test diff --git a/tests/test_config_validation.py b/tests/test_config_validation.py index 541fb95a..599507ca 100644 --- a/tests/test_config_validation.py +++ b/tests/test_config_validation.py @@ -1,387 +1,64 @@ +import os + import pytest import utils_tests as ut_t import recommerce.configuration.config_validation as config_validation -from recommerce.configuration.environment_config import EnvironmentConfig -from recommerce.configuration.hyperparameter_config import HyperparameterConfigValidator - -########## -# Tests with already combined configs (== hyperparameter and/or environment key on the top-level) -########## -validate_config_valid_combined_final_testcases = [ - ut_t.create_combined_mock_dict(), - ut_t.create_combined_mock_dict(hyperparameter=ut_t.create_hyperparameter_mock_dict(rl=ut_t.create_hyperparameter_mock_dict_rl(gamma=0.5))), - ut_t.create_combined_mock_dict(hyperparameter=ut_t.create_hyperparameter_mock_dict( - sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(max_price=25))), - ut_t.create_combined_mock_dict(environment=ut_t.create_environment_mock_dict(task='exampleprinter')), - ut_t.create_combined_mock_dict(environment=ut_t.create_environment_mock_dict(agents=[ - { - 'name': 'Test_agent', - 'agent_class': 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent', - 'argument': '' - }, - { - 'name': 'Test_agent2', - 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', - 'argument': '' - } - ])), -] - - -@pytest.mark.parametrize('config', validate_config_valid_combined_final_testcases) -def test_validate_config_valid_combined_final(config): - # If the config is valid, the first member of the tuple returned will be True - validate_status, validate_data = config_validation.validate_config(config, True) - assert validate_status, validate_data - assert isinstance(validate_data, tuple) - assert 'rl' in validate_data[0] - assert 'sim_market' in validate_data[0] - assert 'gamma' in validate_data[0]['rl'] - assert 'max_price' in validate_data[0]['sim_market'] - assert 'task' in validate_data[1] - assert 'agents' in validate_data[1] - - -# These testcases do not cover everything, nor should they, there are simply too many combinations -validate_config_valid_combined_not_final_testcases = [ - ut_t.create_combined_mock_dict( - hyperparameter=ut_t.remove_key('rl', ut_t.create_hyperparameter_mock_dict())), - ut_t.create_combined_mock_dict( - hyperparameter=ut_t.create_hyperparameter_mock_dict( - rl=ut_t.remove_key('learning_rate', ut_t.create_hyperparameter_mock_dict_rl(gamma=0.5)))), - ut_t.create_combined_mock_dict( - hyperparameter=ut_t.create_hyperparameter_mock_dict( - rl=ut_t.remove_key('epsilon_start', ut_t.remove_key('learning_rate', ut_t.create_hyperparameter_mock_dict_rl())))), - ut_t.create_combined_mock_dict(environment=ut_t.remove_key('task', ut_t.create_environment_mock_dict())), - ut_t.create_combined_mock_dict(environment=ut_t.remove_key('agents', ut_t.remove_key('task', ut_t.create_environment_mock_dict()))), -] + validate_config_valid_combined_final_testcases - - -@pytest.mark.parametrize('config', validate_config_valid_combined_not_final_testcases) -def test_validate_config_valid_combined_not_final(config): - # If the config is valid, the first member of the returned tuple will be True - validate_status, validate_data = config_validation.validate_config(config, False) - assert validate_status, validate_data - - -validate_config_one_top_key_missing_testcases = [ - (ut_t.create_combined_mock_dict(hyperparameter=None), True), - (ut_t.create_combined_mock_dict(environment=None), True), - (ut_t.create_combined_mock_dict(hyperparameter=None), False), - (ut_t.create_combined_mock_dict(environment=None), False) -] +from recommerce.configuration.path_manager import PathManager - -@pytest.mark.parametrize('config, is_final', validate_config_one_top_key_missing_testcases) -def test_validate_config_one_top_key_missing(config, is_final): - validate_status, validate_data = config_validation.validate_config(config, is_final) - assert not validate_status, validate_data - assert 'If your config contains one of "environment" or "hyperparameter" it must also contain the other' == validate_data - - -validate_config_too_many_keys_testcases = [ - True, - False -] - - -@pytest.mark.parametrize('is_final', validate_config_too_many_keys_testcases) -def test_validate_config_too_many_keys(is_final): - test_config = ut_t.create_combined_mock_dict() - test_config['additional_key'] = "this should'nt be allowed" - validate_status, validate_data = config_validation.validate_config(test_config, is_final) - assert not validate_status, validate_data - assert 'Your config should not contain keys other than "environment" and "hyperparameter"' == validate_data -########## -# End of tests with already combined configs (== hyperparameter and/or environment key on the top-level) -########## - - -########## -# Tests without the already split top-level (config keys are mixed and need to be matched) -########## -# These are singular dicts that will get combined for the actual testcases -validate_config_valid_not_final_dicts = [ - { - 'rl': { - 'gamma': 0.5, - 'epsilon_start': 0.9 - } - }, - { - 'sim_market': { - 'max_price': 40 - } - }, - { - 'task': 'training' - }, - { - 'marketplace': 'recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly' - }, - { - 'agents': [ - { - 'name': 'Rule_Based Agent', - 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', - 'argument': '' - }, - { - 'name': 'CE Rebuy Agent (QLearning)', - 'agent_class': 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent', - 'argument': 'CircularEconomyRebuyPriceMonopoly_QLearningAgent.dat' - } - ] - }, - { - 'agents': [ - { - 'name': 'Rule_Based Agent', - 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', - 'argument': '' - } - ] - } -] - - -# get all combinations of the dicts defined above to mix and match as much as possible -mixed_configs = [ - {**dict1, **dict2} for dict1 in validate_config_valid_not_final_dicts for dict2 in validate_config_valid_not_final_dicts -] +env_config_file = os.path.join(PathManager.user_path, 'configuration_files', 'environment_config_training.json') +market_config_file = os.path.join(PathManager.user_path, 'configuration_files', 'market_config.json') +rl_config_file = os.path.join(PathManager.user_path, 'configuration_files', 'q_learning_config.json') -@pytest.mark.parametrize('config', mixed_configs) -def test_validate_config_valid_not_final(config): - validate_status, validate_data = config_validation.validate_config(config, False) - assert validate_status, f'Test failed with error: {validate_data} on config: {config}' +config_environment = ut_t.load_json(env_config_file) +config_market = ut_t.load_json(market_config_file) +config_rl = ut_t.load_json(rl_config_file) +config_environment['config_type'] = 'environment' +config_market['config_type'] = 'sim_market' +config_rl['config_type'] = 'rl' -validate_config_valid_final_testcases = [ - {**ut_t.create_hyperparameter_mock_dict(), **ut_t.create_environment_mock_dict()}, - {**ut_t.create_hyperparameter_mock_dict(rl=ut_t.create_hyperparameter_mock_dict_rl(gamma=0.2)), **ut_t.create_environment_mock_dict()}, - {**ut_t.create_hyperparameter_mock_dict(), **ut_t.create_environment_mock_dict(episodes=20)} -] - - -@pytest.mark.parametrize('config', validate_config_valid_final_testcases) -def test_validate_config_valid_final(config): - validate_status, validate_data = config_validation.validate_config(config, True) - assert validate_status, f'Test failed with error: {validate_data} on config: {config}' - assert 'rl' in validate_data[0] - assert 'sim_market' in validate_data[0] - assert 'agents' in validate_data[1] - - -@pytest.mark.parametrize('config', mixed_configs) -def test_split_mixed_config_valid(config): - config_validation.split_mixed_config(config) - - -split_mixed_config_invalid_testcases = [ - { - 'invalid_key': 2 - }, - { - 'rl': { - 'gamma': 0.5 - }, - 'invalid_key': 2 - }, - { - 'agents': [ - { - 'name': 'test', - 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', - 'argument': '' - } - ], - 'invalid_key': 2 - } -] - - -@pytest.mark.parametrize('config', split_mixed_config_invalid_testcases) -def test_split_mixed_config_invalid(config): - with pytest.raises(AssertionError) as error_message: - config_validation.split_mixed_config(config) - assert 'Your config contains an invalid key:' in str(error_message.value) - - -validate_sub_keys_invalid_keys_hyperparameter_testcases = [ - { - 'rl': { - 'gamma': 0.5, - 'invalid_key': 2 - } - }, - { - 'sim_market': { - 'max_price': 50, - 'invalid_key': 2 - } - }, - { - 'rl': { - 'gamma': 0.5, - 'invalid_key': 2 - }, - 'sim_market': { - 'max_price': 50, - 'invalid_key': 2 - } - }, - { - 'rl': { - 'gamma': 0.5 - }, - 'sim_market': { - 'max_price': 50, - 'invalid_key': 2 - } - } -] - - -@pytest.mark.parametrize('config', validate_sub_keys_invalid_keys_hyperparameter_testcases) -def test_validate_sub_keys_invalid_keys_hyperparameter(config): - with pytest.raises(AssertionError) as error_message: - top_level_keys = HyperparameterConfigValidator.get_required_fields('top-dict') - config_validation.validate_sub_keys(HyperparameterConfigValidator, config, top_level_keys) - assert 'The key "invalid_key" should not exist within a HyperparameterConfigValidator config' in str(error_message.value) - - -validate_sub_keys_agents_invalid_keys_testcases = [ - { - 'task': 'training', - 'agents': [ - { - 'name': 'name', - 'invalid_key': 2 - } - ] - }, - { - 'agents': [ - { - 'name': '', - 'argument': '', - 'invalid_key': 2 - } - ] - }, - { - 'agents': [ - { - 'argument': '' - }, - { - 'name': '', - 'agent_class': '', - 'argument': '', - 'invalid_key': 2 - } - ] - } -] - - -@pytest.mark.parametrize('config', validate_sub_keys_agents_invalid_keys_testcases) -def test_validate_sub_keys_agents_invalid_keys(config): - with pytest.raises(AssertionError) as error_message: - top_level_keys = EnvironmentConfig.get_required_fields('top-dict') - config_validation.validate_sub_keys(EnvironmentConfig, config, top_level_keys) - assert 'An invalid key for agents was provided:' in str(error_message.value) - - -validate_sub_keys_agents_wrong_type_testcases = [ - { - 'agents': 2 - }, - { - 'agents': 'string' - }, - { - 'agents': 2.0 - }, - { - 'agents': {} - } -] +def setup_function(function): + print('***SETUP***') + global config_environment + global config_market + global config_rl -@pytest.mark.parametrize('config', validate_sub_keys_agents_wrong_type_testcases) -def test_validate_sub_keys_agents_wrong_type(config): - with pytest.raises(AssertionError) as error_message: - top_level_keys = EnvironmentConfig.get_required_fields('top-dict') - config_validation.validate_sub_keys(EnvironmentConfig, config, top_level_keys) - assert 'The "agents" key must have a value of type list, but was' in str(error_message.value) + config_environment['config_type'] = 'environment' + config_market['config_type'] = 'sim_market' + config_rl['config_type'] = 'rl' -validate_sub_keys_agents_wrong_type_testcases = [ - { - 'agents': [ - 2 - ] - }, - { - 'agents': [ - 'string' - ] - }, - { - 'agents': [ - 2.0 - ] - }, - { - 'agents': [ - [] - ] - } +test_valid_config_validation_complete_testcases = [ + config_environment, + config_market, + config_rl ] -@pytest.mark.parametrize('config', validate_sub_keys_agents_wrong_type_testcases) -def test_validate_sub_keys_agents_wrong_subtype(config): - with pytest.raises(AssertionError) as error_message: - top_level_keys = EnvironmentConfig.get_required_fields('top-dict') - config_validation.validate_sub_keys(EnvironmentConfig, config, top_level_keys) - assert 'All agents must be of type dict, but this one was' in str(error_message.value) +@pytest.mark.parametrize('config', test_valid_config_validation_complete_testcases) +def test_valid_config_validation_complete(config): + config_type = config['config_type'] + success, result = config_validation.validate_config(config) + assert success, result + assert result == ({config_type: config}, None, None) -validate_sub_keys_wrong_type_hyperparameter_testcases = [ - { - 'rl': [] - }, - { - 'sim_market': [] - }, - { - 'rl': 2 - }, - { - 'sim_market': 2 - }, - { - 'rl': 'string' - }, - { - 'sim_market': 'string' - }, - { - 'rl': 2.0 - }, - { - 'sim_market': 2.0 - }, +test_valid_config_validation_incomplete_testcases = [ + (config_environment, 'agents'), + (config_market, 'max_price'), + (config_rl, 'learning_rate') ] -@pytest.mark.parametrize('config', validate_sub_keys_wrong_type_hyperparameter_testcases) -def test_validate_sub_keys_wrong_type_hyperparameter(config): - with pytest.raises(AssertionError) as error_message: - top_level_keys = HyperparameterConfigValidator.get_required_fields('top-dict') - config_validation.validate_sub_keys(HyperparameterConfigValidator, config, top_level_keys) - assert 'The value of this key must be of type dict:' in str(error_message.value) +@pytest.mark.parametrize('config, removed_key', test_valid_config_validation_incomplete_testcases) +def test_valid_config_validation_incomplete(config, removed_key): + # Hacky, thx pytest! + tested_config = config.copy() + tested_config = ut_t.remove_key(removed_key, tested_config) + config_type = tested_config['config_type'] + success, result = config_validation.validate_config(tested_config) + assert success + assert result == ({config_type: tested_config}, None, None) diff --git a/tests/test_customers.py b/tests/test_customers.py index 5e81a672..660b8d68 100644 --- a/tests/test_customers.py +++ b/tests/test_customers.py @@ -1,16 +1,16 @@ import numpy as np import pytest -import utils_tests as ut_t from attrdict import AttrDict import recommerce.market.circular.circular_sim_market as circular_market import recommerce.market.customer as customer import recommerce.market.linear.linear_sim_market as linear_market +from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader from recommerce.market.circular.circular_customers import CustomerCircular from recommerce.market.linear.linear_customers import CustomerLinear from recommerce.market.sim_market import SimMarket -config_hyperparameter: AttrDict = ut_t.mock_config_hyperparameter() +config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) # Test the Customer parent class, i.e. make sure it cannot be used @@ -59,7 +59,7 @@ def test_generate_purchase_probabilities_from_offer(customer, common_state, vend def test_customer_action_range(customer, market): offers = random_offer(market) probability_distribution = customer.generate_purchase_probabilities_from_offer(customer, *offers) - assert len(probability_distribution) == market(config=config_hyperparameter)._get_number_of_vendors() * \ + assert len(probability_distribution) == market(config=config_market)._get_number_of_vendors() * \ (1 if issubclass(market, linear_market.LinearEconomy) else 2) + 1 @@ -116,7 +116,7 @@ def random_offer(marketplace: SimMarket): Args: marketplace (SimMarket): The marketplace for which offers should be generated. """ - marketplace = marketplace(config=config_hyperparameter) + marketplace = marketplace(config=config_market) marketplace.reset() marketplace.vendor_actions[0] = marketplace.action_space.sample() return marketplace._get_common_state_array(), marketplace.vendor_specific_state, marketplace.vendor_actions diff --git a/tests/test_data/configuration_files/actor_critic_config.json b/tests/test_data/configuration_files/actor_critic_config.json new file mode 100644 index 00000000..66a0b97d --- /dev/null +++ b/tests/test_data/configuration_files/actor_critic_config.json @@ -0,0 +1,5 @@ +{ + "gamma": 0.99, + "sync_target_frames": 35, + "testvalue2": 15.0 +} diff --git a/tests/test_data/configuration_files/hyperparameter_config.json b/tests/test_data/configuration_files/hyperparameter_config.json deleted file mode 100644 index 0cdc7b97..00000000 --- a/tests/test_data/configuration_files/hyperparameter_config.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "rl": { - "gamma" : 0.99, - "batch_size" : 32, - "replay_size" : 500, - "learning_rate" : 1e-6, - "sync_target_frames" : 10, - "replay_start_size" : 100, - "epsilon_decay_last_frame" : 400, - "epsilon_start" : 1.0, - "epsilon_final" : 0.1 - }, - "sim_market": { - "max_storage": 100, - "episode_length": 25, - "max_price": 10, - "max_quality": 50, - "number_of_customers": 10, - "production_price": 3, - "storage_cost_per_product": 0.1 - } -} diff --git a/tests/test_data/configuration_files/market_config.json b/tests/test_data/configuration_files/market_config.json index 00638718..bdf37b24 100644 --- a/tests/test_data/configuration_files/market_config.json +++ b/tests/test_data/configuration_files/market_config.json @@ -1,11 +1,9 @@ { - "sim_market": { - "max_storage": 100, - "episode_length": 50, - "max_price": 10, - "max_quality": 50, - "number_of_customers": 20, - "production_price": 3, - "storage_cost_per_product": 0.1 - } + "max_storage": 100, + "episode_length": 50, + "max_price": 10, + "max_quality": 50, + "number_of_customers": 20, + "production_price": 3, + "storage_cost_per_product": 0.1 } diff --git a/tests/test_data/configuration_files/q_learning_config.json b/tests/test_data/configuration_files/q_learning_config.json new file mode 100644 index 00000000..45736dd0 --- /dev/null +++ b/tests/test_data/configuration_files/q_learning_config.json @@ -0,0 +1,11 @@ +{ + "gamma": 0.99, + "batch_size": 8, + "replay_size": 350, + "learning_rate": 1e-6, + "sync_target_frames": 35, + "replay_start_size": 20, + "epsilon_decay_last_frame": 400, + "epsilon_start": 1.0, + "epsilon_final": 0.1 +} diff --git a/tests/test_data/configuration_files/rl_config.json b/tests/test_data/configuration_files/rl_config.json deleted file mode 100644 index e50abc29..00000000 --- a/tests/test_data/configuration_files/rl_config.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "rl": { - "gamma" : 0.99, - "batch_size" : 8, - "replay_size" : 350, - "learning_rate" : 1e-6, - "sync_target_frames" : 35, - "replay_start_size" : 20, - "epsilon_decay_last_frame" : 400, - "epsilon_start" : 1.0, - "epsilon_final" : 0.1 - } -} diff --git a/tests/test_data/configuration_files/sb_a2c_config.json b/tests/test_data/configuration_files/sb_a2c_config.json new file mode 100644 index 00000000..49201571 --- /dev/null +++ b/tests/test_data/configuration_files/sb_a2c_config.json @@ -0,0 +1,4 @@ +{ + "testvalue1": 0.99, + "a2cvalue": 15.0 +} diff --git a/tests/test_data/configuration_files/sb_ddpg_config.json b/tests/test_data/configuration_files/sb_ddpg_config.json new file mode 100644 index 00000000..f855a5dc --- /dev/null +++ b/tests/test_data/configuration_files/sb_ddpg_config.json @@ -0,0 +1,4 @@ +{ + "testvalue1": 0.99, + "ddpgvalue": 15.0 +} diff --git a/tests/test_data/configuration_files/sb_ppo_config.json b/tests/test_data/configuration_files/sb_ppo_config.json new file mode 100644 index 00000000..af3e26ba --- /dev/null +++ b/tests/test_data/configuration_files/sb_ppo_config.json @@ -0,0 +1,4 @@ +{ + "testvalue1": 0.99, + "ppovalue": 15.0 +} diff --git a/tests/test_data/configuration_files/sb_sac_config.json b/tests/test_data/configuration_files/sb_sac_config.json new file mode 100644 index 00000000..f5d6bc05 --- /dev/null +++ b/tests/test_data/configuration_files/sb_sac_config.json @@ -0,0 +1,4 @@ +{ + "testvalue1": 0.99, + "sacvalue": 15.0 +} diff --git a/tests/test_data/configuration_files/sb_td3_config.json b/tests/test_data/configuration_files/sb_td3_config.json new file mode 100644 index 00000000..d797d0be --- /dev/null +++ b/tests/test_data/configuration_files/sb_td3_config.json @@ -0,0 +1,4 @@ +{ + "testvalue1": 0.99, + "td3value": 15.0 +} diff --git a/tests/test_exampleprinter.py b/tests/test_exampleprinter.py index 29ba4037..d09a0a65 100644 --- a/tests/test_exampleprinter.py +++ b/tests/test_exampleprinter.py @@ -16,8 +16,9 @@ # The load path for the agent modelfiles parameters_path = os.path.join('tests', 'test_data') -config_market: AttrDict = HyperparameterConfigLoader.load('market_config') -config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') +config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) +config_q_learning: AttrDict = HyperparameterConfigLoader.load('q_learning_config', QLearningAgent) +config_actor_critic: AttrDict = HyperparameterConfigLoader.load('actor_critic_config', ContinuousActorCriticAgentFixedOneStd) def test_setup_exampleprinter(): @@ -60,22 +61,23 @@ def test_full_episode_rule_based(marketplace, agent): full_episode_testcases_rl_agent = [ - (linear_market.LinearEconomyDuopoly(config=config_market), QLearningAgent, 'LinearEconomyDuopoly_QLearningAgent.dat'), + (linear_market.LinearEconomyDuopoly(config=config_market), QLearningAgent, + 'LinearEconomyDuopoly_QLearningAgent.dat', config_q_learning), (circular_market.CircularEconomyMonopoly(config=config_market), QLearningAgent, - 'CircularEconomyMonopoly_QLearningAgent.dat'), + 'CircularEconomyMonopoly_QLearningAgent.dat', config_q_learning), (circular_market.CircularEconomyRebuyPriceMonopoly(config=config_market), QLearningAgent, - 'CircularEconomyRebuyPriceMonopoly_QLearningAgent.dat'), + 'CircularEconomyRebuyPriceMonopoly_QLearningAgent.dat', config_q_learning), (circular_market.CircularEconomyRebuyPriceDuopoly(config=config_market), QLearningAgent, - 'CircularEconomyRebuyPriceDuopoly_QLearningAgent.dat'), + 'CircularEconomyRebuyPriceDuopoly_QLearningAgent.dat', config_q_learning), (circular_market.CircularEconomyRebuyPriceDuopoly(config=config_market), ContinuousActorCriticAgentFixedOneStd, - 'actor_parametersCircularEconomyRebuyPriceDuopoly_ContinuousActorCriticAgentFixedOneStd.dat'), + 'actor_parametersCircularEconomyRebuyPriceDuopoly_ContinuousActorCriticAgentFixedOneStd.dat', config_actor_critic), (circular_market.CircularEconomyRebuyPriceDuopoly(config=config_market), DiscreteActorCriticAgent, - 'actor_parametersCircularEconomyRebuyPriceDuopoly_DiscreteACACircularEconomyRebuy.dat') + 'actor_parametersCircularEconomyRebuyPriceDuopoly_DiscreteACACircularEconomyRebuy.dat', config_actor_critic) ] -@pytest.mark.parametrize('marketplace, agent_class, parameters_file', full_episode_testcases_rl_agent) -def test_full_episode_rl_agents(marketplace, agent_class, parameters_file): +@pytest.mark.parametrize('marketplace, agent_class, parameters_file, config_rl', full_episode_testcases_rl_agent) +def test_full_episode_rl_agents(marketplace, agent_class, parameters_file, config_rl): agent = agent_class( marketplace=marketplace, config_market=config_market, diff --git a/tests/test_hyperparameter_config_rl.py b/tests/test_hyperparameter_config_rl.py index c1e32dba..b280e2ec 100644 --- a/tests/test_hyperparameter_config_rl.py +++ b/tests/test_hyperparameter_config_rl.py @@ -1,123 +1,70 @@ import json -from importlib import reload +import os from unittest.mock import mock_open, patch import pytest import utils_tests as ut_t import recommerce.configuration.hyperparameter_config as hyperparameter_config +from recommerce.configuration.path_manager import PathManager +from recommerce.rl.q_learning.q_learning_agent import QLearningAgent +q_learning_config_file = os.path.join(PathManager.user_path, 'configuration_files', 'q_learning_config.json') -def teardown_module(module): - print('***TEARDOWN***') - reload(hyperparameter_config) - -###### -# General tests for the HyperParameter parent class -##### -get_required_fields_valid_testcases = [ - ('top-dict', {'rl': True, 'sim_market': True}), - ('rl', { - 'gamma': False, - 'batch_size': False, - 'replay_size': False, - 'learning_rate': False, - 'sync_target_frames': False, - 'replay_start_size': False, - 'epsilon_decay_last_frame': False, - 'epsilon_start': False, - 'epsilon_final': False - }), - ('sim_market', { - 'max_storage': False, - 'episode_length': False, - 'max_price': False, - 'max_quality': False, - 'number_of_customers': False, - 'production_price': False, - 'storage_cost_per_product': False - }) -] - - -@pytest.mark.parametrize('level, expected_dict', get_required_fields_valid_testcases) -def test_get_required_fields_valid(level, expected_dict): - fields = hyperparameter_config.HyperparameterConfigValidator.get_required_fields(level) - assert fields == expected_dict - - -def test_get_required_fields_invalid(): - with pytest.raises(AssertionError) as error_message: - hyperparameter_config.HyperparameterConfigValidator.get_required_fields('wrong_key') - assert 'The given level does not exist in a hyperparameter-config: wrong_key' in str(error_message.value) -###### -# End general tests -##### - - -# mock format taken from: -# https://stackoverflow.com/questions/1289894/how-do-i-mock-an-open-used-in-a-with-statement-using-the-mock-framework-in-pyth # Test that checks if the config.json is read correctly def test_reading_file_values(): - mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict()) - with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: - ut_t.check_mock_file(mock_file, mock_json) - config = hyperparameter_config.HyperparameterConfigLoader.load('hyperparameter_config') - - assert config.gamma == 0.99 - assert config.batch_size == 32 - assert config.replay_size == 500 - assert config.learning_rate == 1e-6 - assert config.sync_target_frames == 10 - assert config.replay_start_size == 100 - assert config.epsilon_decay_last_frame == 400 - assert config.epsilon_start == 1.0 - assert config.epsilon_final == 0.1 - - # Test a second time with other values to ensure that the values are read correctly - mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict(rl=ut_t.create_hyperparameter_mock_dict_rl(learning_rate=1e-4))) - with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: - ut_t.check_mock_file(mock_file, mock_json) - - config = hyperparameter_config.HyperparameterConfigLoader.load('hyperparameter_config') + config = hyperparameter_config.HyperparameterConfigLoader.load('q_learning_config', QLearningAgent) - assert config.learning_rate == 1e-4 + assert config.gamma == 0.99 + assert config.batch_size == 8 + assert config.replay_size == 350 + assert config.learning_rate == 1e-6 + assert config.sync_target_frames == 35 + assert config.replay_start_size == 20 + assert config.epsilon_decay_last_frame == 400 + assert config.epsilon_start == 1.0 + assert config.epsilon_final == 0.1 # The following variables are input mock-json strings for the test_invalid_values test # These tests have invalid values in their input file, the import should throw a specific error message -learning_rate_larger_one = (ut_t.create_hyperparameter_mock_dict_rl(learning_rate=1.5), - 'learning_rate should be between 0 and 1 (excluded)') -negative_learning_rate = (ut_t.create_hyperparameter_mock_dict_rl(learning_rate=0), 'learning_rate should be between 0 and 1 (excluded)') -large_gamma = (ut_t.create_hyperparameter_mock_dict_rl(gamma=1.0), 'gamma should be between 0 (included) and 1 (excluded)') -negative_gamma = ((ut_t.create_hyperparameter_mock_dict_rl(gamma=-1.0), 'gamma should be between 0 (included) and 1 (excluded)')) -negative_batch_size = (ut_t.create_hyperparameter_mock_dict_rl(batch_size=-5), 'batch_size should be greater than 0') -negative_replay_size = (ut_t.create_hyperparameter_mock_dict_rl(replay_size=-5), - 'replay_size should be greater than 0') -negative_sync_target_frames = (ut_t.create_hyperparameter_mock_dict_rl(sync_target_frames=-5), - 'sync_target_frames should be greater than 0') -negative_replay_start_size = (ut_t.create_hyperparameter_mock_dict_rl(replay_start_size=-5), 'replay_start_size should be greater than 0') -negative_epsilon_decay_last_frame = (ut_t.create_hyperparameter_mock_dict_rl(epsilon_decay_last_frame=-5), - 'epsilon_decay_last_frame should not be negative') +negative_learning_rate = (ut_t.replace_field_in_dict(ut_t.load_json(q_learning_config_file), 'learning_rate', 0.0), + 'learning_rate should be positive') +large_gamma = (ut_t.replace_field_in_dict(ut_t.load_json(q_learning_config_file), 'gamma', 1.1), + 'gamma should be between 0 (included) and 1 (included)') +negative_gamma = ((ut_t.replace_field_in_dict(ut_t.load_json(q_learning_config_file), 'gamma', -1.0), + 'gamma should be between 0 (included) and 1 (included)')) +negative_batch_size = (ut_t.replace_field_in_dict(ut_t.load_json(q_learning_config_file), 'batch_size', -5), + 'batch_size should be positive') +negative_replay_size = (ut_t.replace_field_in_dict(ut_t.load_json(q_learning_config_file), 'replay_size', -5), + 'replay_size should be positive') +negative_sync_target_frames = (ut_t.replace_field_in_dict(ut_t.load_json(q_learning_config_file), 'sync_target_frames', -5), + 'sync_target_frames should be positive') +negative_replay_start_size = (ut_t.replace_field_in_dict(ut_t.load_json(q_learning_config_file), 'replay_start_size', -5), + 'replay_start_size should be positive') +negative_epsilon_decay_last_frame = (ut_t.replace_field_in_dict(ut_t.load_json(q_learning_config_file), 'epsilon_decay_last_frame', -5), + 'epsilon_decay_last_frame should be positive') # These tests are missing a line in the config file, the import should throw a specific error message -missing_gamma = (ut_t.remove_key('gamma', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing gamma') -missing_batch_size = (ut_t.remove_key('batch_size', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing batch_size') -missing_replay_size = (ut_t.remove_key('replay_size', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing replay_size') -missing_learning_rate = (ut_t.remove_key('learning_rate', ut_t.create_hyperparameter_mock_dict_rl()), - 'your config_rl is missing learning_rate') -missing_sync_target_frames = (ut_t.remove_key('sync_target_frames', ut_t.create_hyperparameter_mock_dict_rl()), - 'your config_rl is missing sync_target_frames') -missing_replay_start_size = (ut_t.remove_key('replay_start_size', ut_t.create_hyperparameter_mock_dict_rl()), - 'your config_rl is missing replay_start_size') -missing_epsilon_decay_last_frame = (ut_t.remove_key('epsilon_decay_last_frame', ut_t.create_hyperparameter_mock_dict_rl()), - 'your config_rl is missing epsilon_decay_last_frame') -missing_epsilon_start = (ut_t.remove_key('epsilon_start', ut_t.create_hyperparameter_mock_dict_rl()), - 'your config_rl is missing epsilon_start') -missing_epsilon_final = (ut_t.remove_key('epsilon_final', ut_t.create_hyperparameter_mock_dict_rl()), - 'your config_rl is missing epsilon_final') +missing_two_items = (ut_t.remove_key('batch_size', ut_t.remove_key('gamma', ut_t.load_json(q_learning_config_file))), + "your config is missing {'batch_size', 'gamma'}") +missing_gamma = (ut_t.remove_key('gamma', ut_t.load_json(q_learning_config_file)), "your config is missing {'gamma'}") +missing_batch_size = (ut_t.remove_key('batch_size', ut_t.load_json(q_learning_config_file)), "your config is missing {'batch_size'}") +missing_replay_size = (ut_t.remove_key('replay_size', ut_t.load_json(q_learning_config_file)), "your config is missing {'replay_size'}") +missing_learning_rate = (ut_t.remove_key('learning_rate', ut_t.load_json(q_learning_config_file)), + "your config is missing {'learning_rate'}") +missing_sync_target_frames = (ut_t.remove_key('sync_target_frames', ut_t.load_json(q_learning_config_file)), + "your config is missing {'sync_target_frames'}") +missing_replay_start_size = (ut_t.remove_key('replay_start_size', ut_t.load_json(q_learning_config_file)), + "your config is missing {'replay_start_size'}") +missing_epsilon_decay_last_frame = (ut_t.remove_key('epsilon_decay_last_frame', ut_t.load_json(q_learning_config_file)), + "your config is missing {'epsilon_decay_last_frame'}") +missing_epsilon_start = (ut_t.remove_key('epsilon_start', ut_t.load_json(q_learning_config_file)), + "your config is missing {'epsilon_start'}") +missing_epsilon_final = (ut_t.remove_key('epsilon_final', ut_t.load_json(q_learning_config_file)), + "your config is missing {'epsilon_final'}") invalid_values_testcases = [ @@ -130,7 +77,6 @@ def test_reading_file_values(): missing_epsilon_decay_last_frame, missing_epsilon_start, missing_epsilon_final, - learning_rate_larger_one, negative_learning_rate, large_gamma, negative_gamma, @@ -145,9 +91,9 @@ def test_reading_file_values(): # Test that checks that an invalid/broken config.json gets detected correctly @pytest.mark.parametrize('rl_json, expected_message', invalid_values_testcases) def test_invalid_values(rl_json, expected_message): - mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict(rl=rl_json)) + mock_json = json.dumps(rl_json) with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: ut_t.check_mock_file(mock_file, mock_json) with pytest.raises(AssertionError) as assertion_message: - hyperparameter_config.HyperparameterConfigLoader.load('hyperparameter_config') + hyperparameter_config.HyperparameterConfigLoader.load('q_learning_config', QLearningAgent) assert expected_message in str(assertion_message.value) diff --git a/tests/test_hyperparameter_config_sim_market.py b/tests/test_hyperparameter_config_sim_market.py index d873ba3b..b3ac53b0 100644 --- a/tests/test_hyperparameter_config_sim_market.py +++ b/tests/test_hyperparameter_config_sim_market.py @@ -1,89 +1,65 @@ import json -from importlib import reload +import os from unittest.mock import mock_open, patch import pytest import utils_tests as ut_t -import recommerce.configuration.hyperparameter_config as hyperparameter_config +from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader +from recommerce.configuration.path_manager import PathManager +from recommerce.market.circular.circular_sim_market import CircularEconomyRebuyPriceMonopoly +market_config_file = os.path.join(PathManager.user_path, 'configuration_files', 'market_config.json') -def teardown_module(module): - print('***TEARDOWN***') - reload(hyperparameter_config) - -# mock format taken from: -# https://stackoverflow.com/questions/1289894/how-do-i-mock-an-open-used-in-a-with-statement-using-the-mock-framework-in-pyth # Test that checks if the config.json is read correctly def test_reading_file_values(): - mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict(sim_market=ut_t.create_hyperparameter_mock_dict_sim_market())) - with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: - ut_t.check_mock_file(mock_file, mock_json) - - config = hyperparameter_config.HyperparameterConfigLoader.load('hyperparameter_config') - - assert config.max_storage == 100 - assert config.episode_length == 25 - assert config.max_price == 10 - assert config.max_quality == 50 - assert config.number_of_customers == 10 - assert config.production_price == 3 - assert config.storage_cost_per_product == 0.1 - - # Test a second time with other values to ensure, that the values are read correctly - mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict( - sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(50, 50, 50, 80, 20, 10, 0.7))) - with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: - ut_t.check_mock_file(mock_file, mock_json) - - config = hyperparameter_config.HyperparameterConfigLoader.load('hyperparameter_config') + config = HyperparameterConfigLoader.load('market_config', CircularEconomyRebuyPriceMonopoly) - assert config.max_storage == 50 - assert config.episode_length == 50 - assert config.max_price == 50 - assert config.max_quality == 80 - assert config.number_of_customers == 20 - assert config.production_price == 10 - assert config.storage_cost_per_product == 0.7 + assert config.max_storage == 100 + assert config.episode_length == 50 + assert config.max_price == 10 + assert config.max_quality == 50 + assert config.number_of_customers == 20 + assert config.production_price == 3 + assert config.storage_cost_per_product == 0.1 # The following variables are input mock-json strings for the test_invalid_values test # These tests have invalid values in their input file, the import should throw a specific error message -odd_number_of_customers = (ut_t.create_hyperparameter_mock_dict_sim_market(number_of_customers=21), +odd_number_of_customers = (ut_t.replace_field_in_dict(ut_t.load_json(market_config_file), 'number_of_customers', 21), 'number_of_customers should be even and positive') -negative_number_of_customers = (ut_t.create_hyperparameter_mock_dict_sim_market(10, 50, 50, 80, -10, 10, 0.15), +negative_number_of_customers = (ut_t.replace_field_in_dict(ut_t.load_json(market_config_file), 'number_of_customers', -10), 'number_of_customers should be even and positive') -prod_price_higher_max_price = (ut_t.create_hyperparameter_mock_dict_sim_market(10, 50, 10, 80, 20, 50, 0.15), - 'production_price needs to be smaller than max_price and >=0') -negative_production_price = (ut_t.create_hyperparameter_mock_dict_sim_market(10, 50, 50, 80, 20, -10, 0.15), - 'production_price needs to be smaller than max_price and >=0') -negative_max_quality = (ut_t.create_hyperparameter_mock_dict_sim_market(10, 20, 15, -80, 30, 5, 0.15), +negative_production_price = (ut_t.replace_field_in_dict(ut_t.load_json(market_config_file), 'production_price', -10), + 'production_price should be non-negative') +negative_max_quality = (ut_t.replace_field_in_dict(ut_t.load_json(market_config_file), 'max_quality', -80), 'max_quality should be positive') -non_negative_storage_cost = (ut_t.create_hyperparameter_mock_dict_sim_market(10, 20, 15, 80, 30, 5, -3.5), +non_negative_storage_cost = (ut_t.replace_field_in_dict(ut_t.load_json(market_config_file), 'storage_cost_per_product', -3.5), 'storage_cost_per_product should be non-negative') # These tests are missing a line in the config file, the import should throw a specific error message -missing_max_storage = (ut_t.remove_key('max_storage', ut_t.create_hyperparameter_mock_dict_sim_market()), - 'your config is missing max_storage') -missing_episode_length = (ut_t.remove_key('episode_length', ut_t.create_hyperparameter_mock_dict_sim_market()), - 'your config is missing episode_length') -missing_max_price = (ut_t.remove_key('max_price', ut_t.create_hyperparameter_mock_dict_sim_market()), - 'your config is missing max_price') -missing_max_quality = (ut_t.remove_key('max_quality', ut_t.create_hyperparameter_mock_dict_sim_market()), - 'your config is missing max_quality') -missing_number_of_customers = (ut_t.remove_key('number_of_customers', ut_t.create_hyperparameter_mock_dict_sim_market()), - 'your config is missing number_of_customers') -missing_production_price = (ut_t.remove_key('production_price', ut_t.create_hyperparameter_mock_dict_sim_market()), - 'your config is missing production_price') -missing_storage_cost = (ut_t.remove_key('storage_cost_per_product', ut_t.create_hyperparameter_mock_dict_sim_market()), - 'your config is missing storage_cost_per_product') +missing_two_items = (ut_t.remove_key('episode_length', ut_t.remove_key('max_storage', ut_t.load_json(market_config_file))), + "your config is missing {'episode_length', 'max_storage'}") +missing_max_storage = (ut_t.remove_key('max_storage', ut_t.load_json(market_config_file)), + "your config is missing {'max_storage'}") +missing_episode_length = (ut_t.remove_key('episode_length', ut_t.load_json(market_config_file)), + "your config is missing {'episode_length'}") +missing_max_price = (ut_t.remove_key('max_price', ut_t.load_json(market_config_file)), + "your config is missing {'max_price'}") +missing_max_quality = (ut_t.remove_key('max_quality', ut_t.load_json(market_config_file)), + "your config is missing {'max_quality'}") +missing_number_of_customers = (ut_t.remove_key('number_of_customers', ut_t.load_json(market_config_file)), + "your config is missing {'number_of_customers'}") +missing_production_price = (ut_t.remove_key('production_price', ut_t.load_json(market_config_file)), + "your config is missing {'production_price'}") +missing_storage_cost = (ut_t.remove_key('storage_cost_per_product', ut_t.load_json(market_config_file)), + "your config is missing {'storage_cost_per_product'}") # All pairs concerning themselves with invalid config.json values should be added to this array to get tested in test_invalid_values invalid_values_testcases = [ odd_number_of_customers, negative_number_of_customers, - prod_price_higher_max_price, negative_production_price, negative_max_quality, non_negative_storage_cost, @@ -98,11 +74,11 @@ def test_reading_file_values(): # Test that checks that an invalid/broken config.json gets detected correctly -@pytest.mark.parametrize('sim_market_json, expected_message', invalid_values_testcases) -def test_invalid_values(sim_market_json, expected_message): - mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict(sim_market=sim_market_json)) +@pytest.mark.parametrize('market_json, expected_message', invalid_values_testcases) +def test_invalid_values(market_json, expected_message): + mock_json = json.dumps(market_json) with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: ut_t.check_mock_file(mock_file, mock_json) with pytest.raises(AssertionError) as assertion_message: - hyperparameter_config.HyperparameterConfigLoader.load('hyperparameter_config') + HyperparameterConfigLoader.load('market_config', CircularEconomyRebuyPriceMonopoly) assert expected_message in str(assertion_message.value) diff --git a/tests/test_policyanalyzer.py b/tests/test_policyanalyzer.py index 8e7f4311..1ba47a90 100644 --- a/tests/test_policyanalyzer.py +++ b/tests/test_policyanalyzer.py @@ -5,6 +5,7 @@ import pytest from attrdict import AttrDict +import recommerce.market.circular.circular_sim_market as circular_market from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader from recommerce.configuration.path_manager import PathManager from recommerce.market.circular.circular_sim_market import CircularEconomyRebuyPriceDuopoly, CircularEconomyRebuyPriceMonopoly @@ -16,8 +17,7 @@ write_to_path = os.path.join(PathManager.results_path, 'policyanalyzer') -config_market: AttrDict = HyperparameterConfigLoader.load('market_config') -config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') +config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) def test_rule_based_linear_competitor1(): @@ -80,7 +80,7 @@ def test_circular_monopoly_q_learning(title, policyaccess, expected_filename): q_learing_agent = QLearningAgent( marketplace=CircularEconomyRebuyPriceMonopoly(config=config_market), config_market=config_market, - config_rl=config_rl, + config_rl=HyperparameterConfigLoader.load('q_learning_config', QLearningAgent), load_path=os.path.join(PathManager.data_path, 'CircularEconomyRebuyPriceMonopoly_QLearningAgent.dat') ) pa = PolicyAnalyzer(q_learing_agent) @@ -108,7 +108,7 @@ def test_circular_duopol_q_learning(title, policyaccess, expected_filename): q_learing_agent = QLearningAgent( marketplace=CircularEconomyRebuyPriceDuopoly(config=config_market), config_market=config_market, - config_rl=config_rl, + config_rl=HyperparameterConfigLoader.load('actor_critic_config', ContinuousActorCriticAgentFixedOneStd), load_path=os.path.join(PathManager.data_path, 'CircularEconomyRebuyPriceDuopoly_QLearningAgent.dat') ) pa = PolicyAnalyzer(q_learing_agent) @@ -136,7 +136,7 @@ def test_circular_duopol_continuos_actorcritic(title, policyaccess, expected_fil a2c_agent = ContinuousActorCriticAgentFixedOneStd( marketplace=CircularEconomyRebuyPriceDuopoly(config=config_market), config_market=config_market, - config_rl=config_rl, + config_rl=HyperparameterConfigLoader.load('actor_critic_config', ContinuousActorCriticAgentFixedOneStd), load_path=os.path.join(PathManager.data_path, 'actor_parametersCircularEconomyRebuyPriceDuopoly_ContinuousActorCriticAgentFixedOneStd.dat') ) diff --git a/tests/test_q_learning_training.py b/tests/test_q_learning_training.py index 88c2b590..35e523ef 100644 --- a/tests/test_q_learning_training.py +++ b/tests/test_q_learning_training.py @@ -20,8 +20,8 @@ @pytest.mark.slow @pytest.mark.parametrize('marketplace_class', test_scenarios) def test_market_scenario(marketplace_class): - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') - config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) + config_rl: AttrDict = HyperparameterConfigLoader.load('q_learning_config', QLearningAgent) config_rl.replay_start_size = 500 config_rl.sync_target_frames = 100 q_learning_training.QLearningTrainer( diff --git a/tests/test_rl_vs_rl.py b/tests/test_rl_vs_rl.py index 355f94fd..37b2e8be 100644 --- a/tests/test_rl_vs_rl.py +++ b/tests/test_rl_vs_rl.py @@ -2,12 +2,14 @@ from attrdict import AttrDict from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader +from recommerce.market.circular.circular_sim_market import CircularEconomyRebuyPriceMonopoly from recommerce.rl.rl_vs_rl_training import train_rl_vs_rl +from recommerce.rl.stable_baselines.sb_ppo import StableBaselinesPPO @pytest.mark.training @pytest.mark.slow def test_rl_vs_rl(): - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') - config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', CircularEconomyRebuyPriceMonopoly) + config_rl: AttrDict = HyperparameterConfigLoader.load('sb_ppo_config', StableBaselinesPPO) train_rl_vs_rl(config_market, config_rl, num_switches=4, num_steps_per_switch=230) diff --git a/tests/test_self_play.py b/tests/test_self_play.py index fea7980d..8e9b4a9f 100644 --- a/tests/test_self_play.py +++ b/tests/test_self_play.py @@ -2,8 +2,10 @@ from attrdict import AttrDict from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader +from recommerce.market.circular.circular_sim_market import CircularEconomyRebuyPriceMonopoly from recommerce.rl.self_play import train_self_play -from recommerce.rl.stable_baselines.stable_baselines_model import StableBaselinesPPO, StableBaselinesSAC +from recommerce.rl.stable_baselines.sb_ppo import StableBaselinesPPO +from recommerce.rl.stable_baselines.sb_sac import StableBaselinesSAC agents = [StableBaselinesPPO, StableBaselinesSAC] @@ -12,6 +14,6 @@ @pytest.mark.slow @pytest.mark.parametrize('agent_class', agents) def test_self_play(agent_class): - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') - config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', CircularEconomyRebuyPriceMonopoly) + config_rl: AttrDict = HyperparameterConfigLoader.load('sb_ppo_config', StableBaselinesPPO) train_self_play(config_market=config_market, config_rl=config_rl, agent_class=agent_class, training_steps=230) diff --git a/tests/test_sim_market.py b/tests/test_sim_market.py index e24f8836..8aaebf4b 100644 --- a/tests/test_sim_market.py +++ b/tests/test_sim_market.py @@ -4,8 +4,7 @@ import recommerce.market.circular.circular_sim_market as circular_market import recommerce.market.linear.linear_sim_market as linear_market - -config_hyperparameter: AttrDict = ut_t.mock_config_hyperparameter() +from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader unique_output_dict_testcases = [ linear_market.LinearEconomyDuopoly, @@ -18,7 +17,8 @@ @pytest.mark.parametrize('marketclass', unique_output_dict_testcases) def test_unique_output_dict(marketclass): - market = marketclass(config=config_hyperparameter) + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) + market = marketclass(config=config_market) _, _, _, info_dict_1 = market.step(ut_t.create_mock_action(marketclass)) _, _, _, info_dict_2 = market.step(ut_t.create_mock_action(marketclass)) assert id(info_dict_1) != id(info_dict_2) diff --git a/tests/test_stable_baselines_training.py b/tests/test_stable_baselines_training.py index 7f81f167..62c097bc 100644 --- a/tests/test_stable_baselines_training.py +++ b/tests/test_stable_baselines_training.py @@ -2,19 +2,22 @@ from attrdict import AttrDict import recommerce.market.circular.circular_sim_market as circular_market -import recommerce.rl.stable_baselines.stable_baselines_model as sb_model from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader +from recommerce.rl.stable_baselines.sb_a2c import StableBaselinesA2C +from recommerce.rl.stable_baselines.sb_ddpg import StableBaselinesDDPG +from recommerce.rl.stable_baselines.sb_ppo import StableBaselinesPPO +from recommerce.rl.stable_baselines.sb_sac import StableBaselinesSAC +from recommerce.rl.stable_baselines.sb_td3 import StableBaselinesTD3 -config_market: AttrDict = HyperparameterConfigLoader.load('market_config') -config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') +config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) @pytest.mark.training @pytest.mark.slow def test_ddpg_training(): - sb_model.StableBaselinesDDPG( + StableBaselinesDDPG( config_market, - config_rl, + HyperparameterConfigLoader.load('sb_ddpg_config', StableBaselinesDDPG), circular_market.CircularEconomyRebuyPriceDuopoly( config=config_market, support_continuous_action_space=True) @@ -24,9 +27,9 @@ def test_ddpg_training(): @pytest.mark.training @pytest.mark.slow def test_td3_training(): - sb_model.StableBaselinesTD3( + StableBaselinesTD3( config_market, - config_rl, + HyperparameterConfigLoader.load('sb_td3_config', StableBaselinesTD3), circular_market.CircularEconomyRebuyPriceDuopoly(config=config_market, support_continuous_action_space=True) ).train_agent(1500, 30) @@ -35,9 +38,9 @@ def test_td3_training(): @pytest.mark.training @pytest.mark.slow def test_a2c_training(): - sb_model.StableBaselinesA2C( + StableBaselinesA2C( config_market, - config_rl, + HyperparameterConfigLoader.load('sb_a2c_config', StableBaselinesA2C), circular_market.CircularEconomyRebuyPriceDuopoly( config=config_market, support_continuous_action_space=True) @@ -47,9 +50,9 @@ def test_a2c_training(): @pytest.mark.training @pytest.mark.slow def test_ppo_training(): - sb_model.StableBaselinesPPO( + StableBaselinesPPO( config_market, - config_rl, + HyperparameterConfigLoader.load('sb_ppo_config', StableBaselinesPPO), marketplace=circular_market.CircularEconomyRebuyPriceDuopoly( config=config_market, support_continuous_action_space=True) @@ -59,9 +62,9 @@ def test_ppo_training(): @pytest.mark.training @pytest.mark.slow def test_sac_training(): - sb_model.StableBaselinesSAC( + StableBaselinesSAC( config_market, - config_rl, + HyperparameterConfigLoader.load('sb_sac_config', StableBaselinesSAC), marketplace=circular_market.CircularEconomyRebuyPriceDuopoly( config=config_market, support_continuous_action_space=True) diff --git a/tests/test_svg_manipulation.py b/tests/test_svg_manipulation.py index 800517ec..ee65e179 100644 --- a/tests/test_svg_manipulation.py +++ b/tests/test_svg_manipulation.py @@ -7,6 +7,7 @@ import recommerce.monitoring.svg_manipulation as svg_manipulation from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader +from recommerce.market.circular.circular_sim_market import CircularEconomyRebuyPriceMonopoly from recommerce.monitoring.exampleprinter import ExamplePrinter svg_manipulator = svg_manipulation.SVGManipulator() @@ -30,7 +31,7 @@ def test_correct_template(): assert correct_template == svg_manipulator.template_svg # run one exampleprinter and to make sure the template does not get changed - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', CircularEconomyRebuyPriceMonopoly) # initialize all functions to be mocked with patch('recommerce.monitoring.exampleprinter.ut.write_dict_to_tensorboard'), \ patch('recommerce.monitoring.svg_manipulation.os.path.isfile') as mock_isfile, \ @@ -192,7 +193,7 @@ def test_time_not_int(): def test_one_exampleprinter_run(): # run only three episodes to be able to reuse the correct_html - config_market: AttrDict = HyperparameterConfigLoader.load('market_config') + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', CircularEconomyRebuyPriceMonopoly) # initialize all functions to be mocked with patch('recommerce.monitoring.exampleprinter.ut.write_dict_to_tensorboard'), \ patch('recommerce.monitoring.svg_manipulation.os.path.isfile') as mock_isfile, \ diff --git a/tests/test_training_scenario.py b/tests/test_training_scenario.py index f510d09b..c41bc595 100644 --- a/tests/test_training_scenario.py +++ b/tests/test_training_scenario.py @@ -22,13 +22,13 @@ def test_train_continuous_a2c_circular_economy_rebuy(): def test_train_stable_baselines_ppo(): - with patch('recommerce.rl.stable_baselines.stable_baselines_model.StableBaselinesPPO.train_agent') as mock_train_agent: + with patch('recommerce.rl.stable_baselines.sb_ppo.StableBaselinesPPO.train_agent') as mock_train_agent: training_scenario.train_stable_baselines_ppo() assert mock_train_agent.called def test_train_stable_baselines_sac(): - with patch('recommerce.rl.stable_baselines.stable_baselines_model.StableBaselinesSAC.train_agent') as mock_train_agent: + with patch('recommerce.rl.stable_baselines.sb_sac.StableBaselinesSAC.train_agent') as mock_train_agent: training_scenario.train_stable_baselines_sac() assert mock_train_agent.called diff --git a/tests/test_utils.py b/tests/test_utils.py index 6c450787..cc59dde8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,15 +3,14 @@ import numpy as np import pytest -import utils_tests as ut_t from attrdict import AttrDict import recommerce.configuration.hyperparameter_config as hyperparameter_config import recommerce.configuration.utils as ut +from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader +from recommerce.market.circular.circular_sim_market import CircularEconomyRebuyPriceMonopoly from recommerce.monitoring.svg_manipulation import SVGManipulator -config_hyperparameter: AttrDict = ut_t.mock_config_hyperparameter() - def teardown_module(module): reload(hyperparameter_config) @@ -23,7 +22,7 @@ def teardown_module(module): @pytest.mark.parametrize('max_quality', testcases_shuffle_quality) def test_shuffle_quality(max_quality: int): - edited_config = ut_t.mock_config_hyperparameter() + edited_config: AttrDict = HyperparameterConfigLoader.load('market_config', CircularEconomyRebuyPriceMonopoly) edited_config.max_quality = max_quality quality = ut.shuffle_quality(edited_config) assert quality <= max_quality and quality >= 1 @@ -266,11 +265,7 @@ def test_write_content_of_dict_to_overview_svg( episode_dictionary: dict, cumulated_dictionary: dict, expected: dict): - mock_json = ut_t.create_hyperparameter_mock_dict( - sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(episode_length=50, number_of_customers=20, production_price=3)) - mock_json_flatten = mock_json['sim_market'] - mock_json_flatten.update(mock_json['rl']) - mock_attr_dict = AttrDict(mock_json_flatten) + config_market: AttrDict = HyperparameterConfigLoader.load('market_config', CircularEconomyRebuyPriceMonopoly) with patch('recommerce.monitoring.svg_manipulation.SVGManipulator.write_dict_to_svg') as mock_write_dict_to_svg: - ut.write_content_of_dict_to_overview_svg(SVGManipulator(), episode, episode_dictionary, cumulated_dictionary, mock_attr_dict) + ut.write_content_of_dict_to_overview_svg(SVGManipulator(), episode, episode_dictionary, cumulated_dictionary, config_market) mock_write_dict_to_svg.assert_called_once_with(target_dictionary=expected) diff --git a/tests/test_vendors.py b/tests/test_vendors.py index 7a7e78ce..f8d65dbe 100644 --- a/tests/test_vendors.py +++ b/tests/test_vendors.py @@ -1,9 +1,9 @@ import numpy as np import pytest -import utils_tests as ut_t from attrdict import AttrDict from numpy import random +import recommerce.market.circular.circular_sim_market as circular_market import recommerce.market.circular.circular_vendors as circular_vendors import recommerce.market.linear.linear_vendors as linear_vendors import recommerce.market.vendors as vendors @@ -12,8 +12,7 @@ from recommerce.rl.q_learning.q_learning_agent import QLearningAgent from recommerce.rl.reinforcement_learning_agent import ReinforcementLearningAgent -config_market: AttrDict = HyperparameterConfigLoader.load('market_config') -config_rl: AttrDict = HyperparameterConfigLoader.load('rl_config') +config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) abstract_agent_classes_testcases = [ vendors.Agent, @@ -53,7 +52,11 @@ def test_non_abstract_agent_classes(agent): def test_non_abstract_qlearning_agent(): - QLearningAgent(marketplace=LinearEconomyOligopoly(config=config_market), config_market=config_market, config_rl=config_rl) + QLearningAgent( + marketplace=LinearEconomyOligopoly(config=config_market), + config_market=config_market, + config_rl=HyperparameterConfigLoader.load('q_learning_config', QLearningAgent) + ) fixed_price_agent_observation_policy_pairs_testcases = [ @@ -96,16 +99,15 @@ def test_storage_evaluation(state, expected_prices): @pytest.mark.parametrize('state, expected_prices', storage_evaluation_with_rebuy_price_testcases) def test_storage_evaluation_with_rebuy_price(state, expected_prices): - changed_config = ut_t.mock_config_hyperparameter() + changed_config = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) changed_config.max_price = 10 changed_config.production_price = 2 agent = circular_vendors.RuleBasedCERebuyAgent(config_market=changed_config) - print('*********************************') assert expected_prices == agent.policy(state) def test_prices_are_not_higher_than_allowed(): - changed_config = ut_t.mock_config_hyperparameter() + changed_config = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) changed_config.max_price = 10 changed_config.production_price = 9 test_agent = circular_vendors.RuleBasedCEAgent(config_market=changed_config) @@ -170,7 +172,7 @@ def random_offer_circular_oligopoly(is_rebuy_economy: bool): # TODO: Update this test for all current competitors @pytest.mark.parametrize('competitor_class, state', policy_plus_one_testcases) def test_policy_plus_one(competitor_class, state): - changed_config = ut_t.mock_config_hyperparameter() + changed_config = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) changed_config.max_price = 10 changed_config.production_price = 2 @@ -187,7 +189,7 @@ def test_policy_plus_one(competitor_class, state): @pytest.mark.parametrize('price', clamp_price_testcases) def test_clamp_price(price): - changed_config = ut_t.mock_config_hyperparameter() + changed_config = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly) changed_config.max_price = 9 assert 0 <= circular_vendors.RuleBasedCEAgent(config_market=changed_config)._clamp_price(price) <= 9 diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 88fbbe3c..bb12ce87 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -1,105 +1,38 @@ import json from typing import Tuple, Union -from unittest.mock import mock_open, patch - -from attrdict import AttrDict import recommerce.market.circular.circular_sim_market as circular_market import recommerce.market.linear.linear_sim_market as linear_market -from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader - - -def create_hyperparameter_mock_dict_rl(gamma: float = 0.99, - batch_size: int = 32, - replay_size: int = 500, - learning_rate: float = 1e-6, - sync_target_frames: int = 10, - replay_start_size: int = 100, - epsilon_decay_last_frame: int = 400, - epsilon_start: float = 1.0, - epsilon_final: float = 0.1) -> dict: - """ - Create dictionary that can be used to mock the rl part of the hyperparameter_config.json file by calling json.dumps() on it. - Args: - gamma (float, optional): Defaults to 0.99. - batch_size (int, optional): Defaults to 32. - replay_size (int, optional): Defaults to 100000. - learning_rate (float, optional): Defaults to 1e-6. - sync_target_frames (int, optional): Defaults to 1000. - replay_start_size (int, optional): Defaults to 10000. - epsilon_decay_last_frame (int, optional): Defaults to 75000. - epsilon_start (float, optional): Defaults to 1.0. - epsilon_final (float, optional): Defaults to 0.1. - Returns: - dict: The mock dictionary. +def load_json(path: str): """ - return { - 'gamma': gamma, - 'batch_size': batch_size, - 'replay_size': replay_size, - 'learning_rate': learning_rate, - 'sync_target_frames': sync_target_frames, - 'replay_start_size': replay_start_size, - 'epsilon_decay_last_frame': epsilon_decay_last_frame, - 'epsilon_start': epsilon_start, - 'epsilon_final': epsilon_final, - } - - -def create_hyperparameter_mock_dict_sim_market( - max_storage: int = 100, - episode_length: int = 25, - max_price: int = 10, - max_quality: int = 50, - number_of_customers: int = 10, - production_price: int = 3, - storage_cost_per_product: float = 0.1) -> dict: - """ - Create dictionary that can be used to mock the sim_market part of the hyperparameter_config.json file by calling json.dumps() on it. + Load a json file. Args: - max_storage (int, optional): Defaults to 20. - episode_length (int, optional): Defaults to 20. - max_price (int, optional): Defaults to 15. - max_quality (int, optional): Defaults to 100. - number_of_customers (int, optional): Defaults to 30. - production_price (int, optional): Defaults to 5. - storage_cost_per_product (float, optional): Defaults to 0.3. + path (str): The path to the json file. Returns: - dict: The mock dictionary. + dict: The json file as a dictionary. """ - return { - 'max_storage': max_storage, - 'episode_length': episode_length, - 'max_price': max_price, - 'max_quality': max_quality, - 'number_of_customers': number_of_customers, - 'production_price': production_price, - 'storage_cost_per_product': storage_cost_per_product, - } + with open(path) as file: + return json.load(file) -def create_hyperparameter_mock_dict(rl: dict = create_hyperparameter_mock_dict_rl(), - sim_market: dict = create_hyperparameter_mock_dict_sim_market()) -> dict: +def replace_field_in_dict(initial_dict: dict, key: str, value: Union[str, int, float]) -> dict: """ - Create a dictionary in the format of the hyperparameter_config.json. - Call json.dumps() on the return value of this to mock the json file. + Replace a field in a dictionary with a new value. Args: - rl (dict, optional): The dictionary that should be used for the rl-part. Defaults to create_hyperparameter_mock_dict_rl(). - sim_market (dict, optional): The dictionary that should be used for the sim_market-part. - Defaults to create_hyperparameter_mock_dict_sim_market(). + initial_dict (dict): The dictionary in which to replace the field. + key (str): The key of the field to be replaced. + value (Union[str, int, float]): The new value of the field. Returns: - dict: The mock dictionary. + dict: The dictionary with the field replaced. """ - return { - 'rl': rl, - 'sim_market': sim_market - } + initial_dict[key] = value + return initial_dict def create_environment_mock_dict(task: str = 'agent_monitoring', @@ -143,35 +76,6 @@ def create_environment_mock_dict(task: str = 'agent_monitoring', } -def create_combined_mock_dict(hyperparameter: dict or None = create_hyperparameter_mock_dict(), - environment: dict or None = create_environment_mock_dict()) -> dict: - """ - Create a mock dictionary in the format of a configuration file with both a hyperparameter and environment part. - If any of the two parameters is `None`, leave that key out of the resulting dictionary. - - Args: - hyperparameter (dict | None, optional): The hyperparameter part of the combined config. Defaults to create_hyperparameter_mock_dict(). - environment (dict | None, optional): The environment part of the combined config. Defaults to create_environment_mock_dict(). - - Returns: - dict: The mock dictionary. - """ - if hyperparameter is None and environment is None: - return {} - elif hyperparameter is None: - return { - 'environment': environment - } - elif environment is None: - return { - 'hyperparameter': hyperparameter - } - return { - 'hyperparameter': hyperparameter, - 'environment': environment - } - - def check_mock_file(mock_file, mocked_file_content) -> None: """ Confirm that a mock JSON is read correctly. @@ -202,19 +106,6 @@ def remove_key(key: str, original_dict: dict) -> dict: return original_dict -def create_mock_rewards(num_entries) -> list: - """ - Create a list of ints to be used as e.g. mock rewards. - - Args: - num_entries (int): How many numbers should be in the list going from 1 to num_entries. - - Returns: - list: The list of rewards. - """ - return list(range(1, num_entries)) - - def create_mock_action(market_subclass) -> Union[int, Tuple]: """ Create an array to be used as an action. The length of the array fits to the argument's class. @@ -231,16 +122,3 @@ def create_mock_action(market_subclass) -> Union[int, Tuple]: return (1, 2, 3) elif issubclass(market_subclass, circular_market.CircularEconomy): return (1, 2) - - -def mock_config_hyperparameter() -> AttrDict: - """ - Reload the hyperparameter_config file to update the config variable with the mocked values. - - Returns: - HyperparameterConfig: The mocked hyperparameter config object. - """ - mock_json = json.dumps(create_hyperparameter_mock_dict()) - with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: - check_mock_file(mock_file, mock_json) - return HyperparameterConfigLoader.load('hyperparameter_config') diff --git a/webserver/alpha_business_app/adjustable_fields.py b/webserver/alpha_business_app/adjustable_fields.py new file mode 100644 index 00000000..3034f426 --- /dev/null +++ b/webserver/alpha_business_app/adjustable_fields.py @@ -0,0 +1,99 @@ +from recommerce.configuration.utils import get_class + +from .utils import convert_python_type_to_input_type + + +def get_agent_hyperparameter(agent: str, formdata: dict) -> list: + """ + Gets all hyperparameters for a specific agent in our list of dict format needed for the view. + + Args: + agent (str): classname as string of a recommerce agent + formdata (dict): content of the current configuration form + + Returns: + list: of dict, the dicts contain the following values (currently needed by view): + name: name of the hyperparameter + input_type: html type for the input field e.g. number + prefill: value that is already stored for this hyperparameter + """ + # get all fields that are possible for this agent + agent_class = get_class(agent) + agent_specs = agent_class.get_configurable_fields() + + # we want to keep values already inside the configuration form, so we need to parse existing html + parameter_values = _convert_form_to_value_dict(formdata) + + # convert parameter into special list format for view + all_parameter = [] + for spec in agent_specs: + this_parameter = {} + this_parameter['name'] = spec[0] + this_parameter['input_type'] = convert_python_type_to_input_type(spec[1]) + this_parameter['prefill'] = _get_value_from_dict(spec[0], parameter_values) + all_parameter += [this_parameter] + return all_parameter + + +def get_rl_parameter_prefill(prefill: dict, error: dict) -> list: + """ + Converts a prefill and error dict to our list of dictionary format needed by view. + + Args: + prefill (dict): 'rl' prefill dictionary + error (dict): 'rl' error dictionary produced by merging config objects + + Returns: + list: of dict, the dicts contain the following values (currently needed by view): + name: name of the hyperparameter + input_type: html type for the input field e.g. number + prefill: value that is already stored for this hyperparameter + error: error value for this parameter + """ + # returns list of dictionaries + all_parameter = [] + for key, value in prefill.items(): + this_parameter = {} + this_parameter['name'] = key + this_parameter['prefill'] = value if value else '' + this_parameter['error'] = error[key] if error[key] else '' + all_parameter += [this_parameter] + return all_parameter + + +def _convert_form_to_value_dict(config_form: dict) -> dict: + """ + Extracts the 'rl' part from the formdata dict as hierarchical dict. + + Args: + config_form (dict): flat config form from the website + + Returns: + dict: hierarchical rl dict with extracted values + """ + final_values = {} + # the formdata is a flat dict, containing two values per config parameter, name and value + # num_experiments and experiment_name are included in the form as well, but we do not consider those + for index in range((len(config_form) - 2) // 2): + current_name = config_form[f'formdata[{index}][name]'] + current_value = config_form[f'formdata[{index}][value]'] + if 'hyperparameter-rl' in current_name: + final_values[current_name.replace('hyperparameter-rl-', '')] = current_value + return final_values + + +def _get_value_from_dict(key: str, value_dict: dict) -> str: + """ + Save way to get either the key value of a key or ''. + + Args: + key (str): key the value should be retrieved from + value_dict (dict): dict the value should be retrieved from + + Returns: + str: value of the key in dict or '' + """ + try: + return value_dict[key] + except KeyError: + return '' diff --git a/webserver/alpha_business_app/buttons.py b/webserver/alpha_business_app/buttons.py index 6b6affb5..2018e899 100644 --- a/webserver/alpha_business_app/buttons.py +++ b/webserver/alpha_business_app/buttons.py @@ -4,6 +4,7 @@ from recommerce.configuration.config_validation import validate_config +from .adjustable_fields import get_rl_parameter_prefill from .config_merger import ConfigMerger from .config_parser import ConfigFlatDictParser from .container_parser import parse_response_to_database @@ -46,7 +47,7 @@ def __init__(self, self.selection_manager = SelectionManager() if request.method == 'POST': - self.wanted_key = request.POST['action'] + self.wanted_key = request.POST['action'].strip() if 'container_id' in request.POST: wanted_container_id = request.POST['container_id'].strip() self.wanted_container = Container.objects.get(id=wanted_container_id) @@ -73,8 +74,8 @@ def do_button_click(self) -> HttpResponse: return self._remove() if self.wanted_key == 'start': return self._start() - if self.wanted_key == 'pre-fill': - return self._pre_fill() + if self.wanted_key == 'prefill': + return self._prefill() if self.wanted_key == 'logs': return self._logs() if self.wanted_key == 'manage_config': @@ -169,7 +170,7 @@ def _render_configuration(self) -> HttpResponse: """ return render(self.request, self.view_to_render, {**self._params_for_config(), **self._message_for_view()}) - def _render_prefill(self, pre_fill_dict: dict, error_dict: dict) -> HttpResponse: + def _render_prefill(self, prefill_dict: dict, error_dict: dict) -> HttpResponse: """ This will return a rendering for `self.view` with params for selection a prefill dict and the error dict @@ -181,7 +182,7 @@ def _render_prefill(self, pre_fill_dict: dict, error_dict: dict) -> HttpResponse HttpResponse: _description_ """ return render(self.request, self.view_to_render, - {'prefill': pre_fill_dict, + {'prefill': prefill_dict, 'error_dict': error_dict, 'all_configurations': Config.objects.all().filter(user=self.request.user), **self._params_for_selection()}) @@ -266,7 +267,7 @@ def _manage_config(self) -> HttpResponse: self.wanted_config.delete() return redirect('/configurator') - def _pre_fill(self) -> HttpResponse: + def _prefill(self) -> HttpResponse: """ This function will be called when the config form should be prefilled with values from the config. It converts a list of given config objects to dicts and merges these dicts. @@ -278,8 +279,10 @@ def _pre_fill(self) -> HttpResponse: post_request = dict(self.request.POST.lists()) if 'config_id' not in post_request: return self._decide_rendering() + config_dict = ConfigFlatDictParser().flat_dict_to_complete_hierarchical_config_dict(post_request) merger = ConfigMerger() - final_dict, error_dict = merger.merge_config_objects(post_request['config_id']) + final_dict, error_dict = merger.merge_config_objects(post_request['config_id'], config_dict) + final_dict['hyperparameter']['rl'] = get_rl_parameter_prefill(final_dict['hyperparameter']['rl'], error_dict['hyperparameter']['rl']) # set an id for each agent (necessary for view) for agent_index in range(len(final_dict['environment']['agents'])): final_dict['environment']['agents'][agent_index]['display_name'] = 'Agent' if agent_index == 0 else 'Competitor' @@ -308,7 +311,7 @@ def _start(self) -> HttpResponse: config_dict = ConfigFlatDictParser().flat_dict_to_hierarchical_config_dict(post_request) - validate_status, validate_data = validate_config(config=config_dict, config_is_final=True) + validate_status, validate_data = validate_config(config=config_dict.copy()) if not validate_status: self.message = ['error', validate_data] return self._decide_rendering() diff --git a/webserver/alpha_business_app/config_merger.py b/webserver/alpha_business_app/config_merger.py index e8b3db78..dbb7d41c 100644 --- a/webserver/alpha_business_app/config_merger.py +++ b/webserver/alpha_business_app/config_merger.py @@ -5,25 +5,27 @@ class ConfigMerger(): def __init__(self) -> None: self.error_dict = Config.get_empty_structure_dict() - def merge_config_objects(self, config_object_ids: list) -> tuple: + def merge_config_objects(self, config_object_ids: list, base_config: dict = Config.get_empty_structure_dict()) -> tuple: """ merge a list of config objects given by their id. Args: config_object_ids (list): The id's of the config objects that should be merged. + base_config (dict): The config all other Configs should be merged into. Defaults to an empty config Returns: tuple (dict, dict): the final merged dict and the error dict with the latest error """ configuration_objects = [Config.objects.get(id=config_id) for config_id in config_object_ids] configuration_dicts = [config.as_dict() for config in configuration_objects] - # get initial empty dict to merge into - final_config = Config.get_empty_structure_dict() + + # merge configs + final_config = base_config for config in configuration_dicts: - final_config = self._merge_config_into_base_config(final_config, config) + final_config = self.merge_config_into_base_config(final_config, config) return final_config, self.error_dict - def _merge_config_into_base_config(self, base_config: dict, merging_config: dict, current_config_path: str = '') -> dict: + def merge_config_into_base_config(self, base_config: dict, merging_config: dict, current_config_path: str = '') -> dict: """ merges one config dict recursively into a base_config dict. @@ -51,7 +53,7 @@ def _merge_config_into_base_config(self, base_config: dict, merging_config: dict # base_config[key] = self._merge_agents_into_base_agents(base_config[key], sub_dict) continue new_config_path = f'{current_config_path}-{key}' if current_config_path else key - base_config[key] = self._merge_config_into_base_config(base_config[key], sub_dict, new_config_path) + base_config[key] = self.merge_config_into_base_config(base_config[key], sub_dict, new_config_path) # update values for key, value in contained_values_merge: @@ -63,27 +65,6 @@ def _merge_config_into_base_config(self, base_config: dict, merging_config: dict return base_config - # working version of comparing the agent lists for agents with the same names. - # Currently not used since we simply concatenate the agent lists - # def _merge_agents_into_base_agents(self, base_agent_config: list, merge_agent_config: list) -> list: - # """ - # Merges an agents config part into a base agents config part. It will be checked if two of the merged agents have the same name. - - # Args: - # base_agent_config (list): the config that will be merged into - # merge_agent_config (list): the config that should be merged - - # Returns: - # list: a final merged agents config - # """ - # base_names = [agent['name'] for agent in base_agent_config] - # for agent in merge_agent_config: - # if agent['name'] in base_names: - # self._update_error_dict(['environment', 'agents'], f'multiple agents named {agent["name"]}') - # else: - # base_agent_config.append(agent) - # return base_agent_config - def _update_error_dict(self, key_words: list, update_message: str) -> None: """ helper function, that updates a value in the error dict given by the list of key words diff --git a/webserver/alpha_business_app/config_parser.py b/webserver/alpha_business_app/config_parser.py index 5c31d30b..282443c4 100644 --- a/webserver/alpha_business_app/config_parser.py +++ b/webserver/alpha_business_app/config_parser.py @@ -1,4 +1,12 @@ -from .models.config import * +from .config_merger import ConfigMerger +from .models.agent_config import AgentConfig +from .models.agents_config import AgentsConfig +from .models.config import Config +from .models.environment_config import EnvironmentConfig +from .models.hyperparameter_config import HyperparameterConfig +from .models.rl_config import RlConfig +from .models.sim_market_config import SimMarketConfig +from .utils import remove_none_values_from_dict, to_config_class_name class ConfigFlatDictParser(): @@ -33,6 +41,10 @@ def flat_dict_to_hierarchical_config_dict(self, flat_dict: dict) -> dict: 'hyperparameter': self._flat_hyperparameter_to_hierarchical(hyperparameter) } + def flat_dict_to_complete_hierarchical_config_dict(self, flat_dict: dict) -> dict: + not_complete_config_dict = self.flat_dict_to_hierarchical_config_dict(flat_dict) + return ConfigMerger().merge_config_into_base_config(Config.get_empty_structure_dict(), not_complete_config_dict) + def _flat_environment_to_hierarchical(self, flat_dict: dict) -> dict: """ Parses the environment part of the flat dict to a hierarchical environment dict. @@ -232,4 +244,5 @@ def _create_object_from(self, class_name: str, parameters: dict): Returns: an instance of the `class_name` with the parameters given. """ + assert class_name in globals(), f'The provided name: {class_name} not in {globals()}' return globals()[class_name].objects.create(**parameters) diff --git a/webserver/alpha_business_app/handle_files.py b/webserver/alpha_business_app/handle_files.py index 1f2d784f..b5cb5b19 100644 --- a/webserver/alpha_business_app/handle_files.py +++ b/webserver/alpha_business_app/handle_files.py @@ -1,4 +1,5 @@ import json +import re import tarfile import zipfile from io import BytesIO @@ -9,7 +10,7 @@ from recommerce.configuration.config_validation import validate_config from .config_parser import ConfigModelParser -from .models.config import * +from .models.config import Config from .models.container import Container @@ -56,23 +57,31 @@ def handle_uploaded_file(request, uploaded_config) -> HttpResponse: except ValueError as value: return render(request, 'upload.html', {'error': str(value)}) - validate_status, validate_data = validate_config(content_as_dict, False) + # Validate the config file using the recommerce validation functionality + validate_status, validate_data = validate_config(content_as_dict) if not validate_status: return render(request, 'upload.html', {'error': validate_data}) - hyperparameter_config, environment_config = validate_data - parser = ConfigModelParser() - web_hyperparameter_config = None - web_environment_config = None - try: - web_hyperparameter_config = parser.parse_config_dict_to_datastructure('hyperparameter', hyperparameter_config) - web_environment_config = parser.parse_config_dict_to_datastructure('environment', environment_config) - except ValueError: - return render(request, 'upload.html', {'error': 'Your config is wrong'}) + # configs and their corresponding top level keys as list + config_objects = _get_top_level_and_configs(validate_data) + # parse config model to datastructure + parser = ConfigModelParser() + resulting_config_parts = [] + for top_level, config in config_objects: + try: + resulting_config_parts += [(top_level, parser.parse_config_dict_to_datastructure(top_level, config))] + except ValueError: + return render(request, 'upload.html', {'error': 'Your config is wrong'}) + except TypeError as error: + invalid_keyword_search = re.search('.*keyword argument (.*)', str(error)) + return render(request, 'upload.html', {'error': f'Your config contains an invalid key: {invalid_keyword_search.group(1)}'}) + + # Make it a real config object + environment_config, hyperparameter_config = _get_config_parts(resulting_config_parts) given_name = request.POST['config_name'] config_name = given_name if given_name else uploaded_config.name - Config.objects.create(environment=web_environment_config, hyperparameter=web_hyperparameter_config, name=config_name, user=request.user) + Config.objects.create(environment=environment_config, hyperparameter=hyperparameter_config, name=config_name, user=request.user) return redirect('/configurator', {'success': 'You successfully uploaded a config file'}) @@ -162,3 +171,62 @@ def _convert_tar_file_to_zip(fake_tar_archive: BytesIO) -> BytesIO: tar_archive.close() return file_like_zip + + +def _get_top_level_and_configs(validate_data: tuple) -> list: + """ + Prepares data returned by the recommerce validation function for parsing. + Should only be used when the validated config was correct + + Args: + validate_data (tuple): return of recommerce validation function, when config was correct + + Returns: + list: of tuples, the first tuple value indecates the top level ('hyperparameter' or 'environment') + and the second value is the corresponding config. + Length will be between 1 and 2. + """ + assert tuple == type(validate_data), \ + f'Data returned by "vaidate_config" for correct config should be tuple, but was {validate_data}' + result = [] + for config in validate_data: + if not config: + continue + if 'environment' in config: + result += [('environment', config['environment'])] + elif 'hyperparameter' in config: + result += [('hyperparameter', config['hyperparameter'])] + elif 'rl' in config or 'sim_market' in config: + # we need to add those two to the same hyperparameter name + existing_hyperparameter = [item for item in result if 'hyperparameter' in item] + if existing_hyperparameter: + new_hyperparameter = ('hyperparameter', {**existing_hyperparameter[0][1], **config}) + result.remove(existing_hyperparameter[0]) + result += [new_hyperparameter] + else: + result += [('hyperparameter', config)] + return result + + +def _get_config_parts(config_objects: list) -> tuple: + """ + Takes a list of tuple with the parsed objects from the config and their top level key + and returns an 'environment_config' and a 'hyperparameter_config' object to be inserted into the Config object + + Args: + config_objects (list): list of tuples, first tuple value indecating top-level key ('hyperparameter' / 'environment') + second value, the actual parsed config object + + Returns: + tuple: (instance of EnvironmentConfig, instance of HyperparameterConfig) + """ + assert len(config_objects) <= 2 and len(config_objects) >= 1, \ + 'At least one, at max two config parts should have been parsed' + environment_config = None + hyperparameter_config = None + for top_level, config_part in config_objects: + if top_level == 'environment': + environment_config = config_part + elif top_level == 'hyperparameter': + hyperparameter_config = config_part + return environment_config, hyperparameter_config diff --git a/webserver/alpha_business_app/handle_requests.py b/webserver/alpha_business_app/handle_requests.py index 104ef8a1..48b691e1 100644 --- a/webserver/alpha_business_app/handle_requests.py +++ b/webserver/alpha_business_app/handle_requests.py @@ -9,6 +9,7 @@ from .models.container import update_container DOCKER_API = 'https://vm-midea03.eaalab.hpi.uni-potsdam.de:8000' # remember to include the port and the protocol, i.e. http:// +# DOCKER_API = 'http://localhost:8000' def _get_api_token() -> str: @@ -150,7 +151,7 @@ def get_api_status() -> dict: return {'api_success': f'API available - {current_time}'} if api_is_available.status_code == 401: return {} - return {'api_docker_timeout': f'Docker unavailable - {current_time}'} + return {'api_docker_timeout': f'Docker unavailable - {current_time}'} def _error_handling_API(response) -> APIResponse: diff --git a/webserver/alpha_business_app/migrations/0013_rlconfig_stable_baseline_test_rlconfig_testvalue2_and_more.py b/webserver/alpha_business_app/migrations/0013_rlconfig_stable_baseline_test_rlconfig_testvalue2_and_more.py new file mode 100644 index 00000000..adb9aeec --- /dev/null +++ b/webserver/alpha_business_app/migrations/0013_rlconfig_stable_baseline_test_rlconfig_testvalue2_and_more.py @@ -0,0 +1,103 @@ +# Generated by Django 4.0.1 on 2022-06-02 19:39 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('alpha_business_app', '0012_config_user_container_user'), + ] + + operations = [ + migrations.AddField( + model_name='rlconfig', + name='stable_baseline_test', + field=models.FloatField(default=None, null=True), + ), + migrations.AddField( + model_name='rlconfig', + name='testvalue2', + field=models.FloatField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='batch_size', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='epsilon_decay_last_frame', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='epsilon_final', + field=models.FloatField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='epsilon_start', + field=models.FloatField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='gamma', + field=models.FloatField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='learning_rate', + field=models.FloatField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='replay_size', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='replay_start_size', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='sync_target_frames', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='simmarketconfig', + name='episode_length', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='simmarketconfig', + name='max_price', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='simmarketconfig', + name='max_quality', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='simmarketconfig', + name='max_storage', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='simmarketconfig', + name='number_of_customers', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='simmarketconfig', + name='production_price', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='simmarketconfig', + name='storage_cost_per_product', + field=models.FloatField(default=None, null=True), + ), + ] diff --git a/webserver/alpha_business_app/models/abstract_config.py b/webserver/alpha_business_app/models/abstract_config.py new file mode 100644 index 00000000..b5bcb4aa --- /dev/null +++ b/webserver/alpha_business_app/models/abstract_config.py @@ -0,0 +1,18 @@ +from django.db import models + +from ..utils import get_structure_dict_for, remove_none_values_from_dict, to_config_keyword + + +class AbstractConfig(): + def as_dict(self) -> dict: + config_field_values = vars(self) + resulting_dict = {} + for key, value in config_field_values.items(): + if key.startswith('_') or 'id' in key: + continue + resulting_dict[key] = value + return remove_none_values_from_dict(resulting_dict) + + @classmethod + def get_empty_structure_dict(cls: models.Model) -> dict: + return get_structure_dict_for(to_config_keyword(str(cls))) diff --git a/webserver/alpha_business_app/models/agent_config.py b/webserver/alpha_business_app/models/agent_config.py new file mode 100644 index 00000000..e28e2573 --- /dev/null +++ b/webserver/alpha_business_app/models/agent_config.py @@ -0,0 +1,10 @@ +from django.db import models + +from .abstract_config import AbstractConfig + + +class AgentConfig(AbstractConfig, models.Model): + agents_config = models.ForeignKey('alpha_business_app.AgentsConfig', on_delete=models.CASCADE, null=True) + name = models.CharField(max_length=100, default='') + agent_class = models.CharField(max_length=100, null=True) + argument = models.CharField(max_length=200, default='') diff --git a/webserver/alpha_business_app/models/agents_config.py b/webserver/alpha_business_app/models/agents_config.py new file mode 100644 index 00000000..ccb251e7 --- /dev/null +++ b/webserver/alpha_business_app/models/agents_config.py @@ -0,0 +1,12 @@ +from django.db import models + +from .abstract_config import AbstractConfig + + +class AgentsConfig(AbstractConfig, models.Model): + def as_list(self) -> dict: + referencing_agents = self.agentconfig_set.all() + return [agent.as_dict() for agent in referencing_agents] + + def as_dict(self) -> dict: + assert False, 'This should not be implemented as agents are a list.' diff --git a/webserver/alpha_business_app/models/config.py b/webserver/alpha_business_app/models/config.py index 352dfb77..067d0fa7 100644 --- a/webserver/alpha_business_app/models/config.py +++ b/webserver/alpha_business_app/models/config.py @@ -1,10 +1,15 @@ from django.contrib.auth.models import User from django.db import models +from ..utils import remove_none_values_from_dict +from .abstract_config import AbstractConfig +from .environment_config import EnvironmentConfig +from .hyperparameter_config import HyperparameterConfig -class Config(models.Model): - environment = models.ForeignKey('EnvironmentConfig', on_delete=models.CASCADE, null=True) - hyperparameter = models.ForeignKey('HyperparameterConfig', on_delete=models.CASCADE, null=True) + +class Config(AbstractConfig, models.Model): + environment = models.ForeignKey('alpha_business_app.EnvironmentConfig', on_delete=models.CASCADE, null=True) + hyperparameter = models.ForeignKey('alpha_business_app.HyperparameterConfig', on_delete=models.CASCADE, null=True) name = models.CharField(max_length=100, editable=False, default='') user = models.ForeignKey(User, on_delete=models.CASCADE, null=True,) @@ -16,169 +21,3 @@ def as_dict(self) -> dict: def is_referenced(self): # Query set is empty so we are not referenced by any container return bool(self.container_set.all()) - - @staticmethod - def get_empty_structure_dict(): - return { - 'environment': EnvironmentConfig.get_empty_structure_dict(), - 'hyperparameter': HyperparameterConfig.get_empty_structure_dict() - } - - -class EnvironmentConfig(models.Model): - agents = models.ForeignKey('AgentsConfig', on_delete=models.CASCADE, null=True) - enable_live_draw = models.BooleanField(null=True) - episodes = models.IntegerField(null=True) - plot_interval = models.IntegerField(null=True) - marketplace = models.CharField(max_length=150, null=True) - task = models.CharField(max_length=14, choices=((1, 'training'), (2, 'agent_monitoring'), (3, 'exampleprinter')), null=True) - - def as_dict(self) -> dict: - agents_list = self.agents.as_list() if self.agents is not None else None - return remove_none_values_from_dict({ - 'enable_live_draw': self.enable_live_draw, - 'episodes': self.episodes, - 'plot_interval': self.plot_interval, - 'marketplace': self.marketplace, - 'task': self.task, - 'agents': agents_list - }) - - @staticmethod - def get_empty_structure_dict(): - return { - 'enable_live_draw': None, - 'episodes': None, - 'plot_interval': None, - 'marketplace': None, - 'task': None, - 'agents': AgentsConfig.get_empty_structure_list() - } - - -class AgentsConfig(models.Model): - def as_list(self) -> dict: - referencing_agents = self.agentconfig_set.all() - return [agent.as_dict() for agent in referencing_agents] - - @staticmethod - def get_empty_structure_list(): - return [] - - -class AgentConfig(models.Model): - agents_config = models.ForeignKey('AgentsConfig', on_delete=models.CASCADE, null=True) - name = models.CharField(max_length=100, default='') - agent_class = models.CharField(max_length=100, null=True) - argument = models.CharField(max_length=200, default='') - - def as_dict(self) -> dict: - return remove_none_values_from_dict({ - 'name': self.name, - 'agent_class': self.agent_class, - 'argument': self.argument - }) - - -class HyperparameterConfig(models.Model): - rl = models.ForeignKey('RLConfig', on_delete=models.CASCADE, null=True) - sim_market = models.ForeignKey('SimMarketConfig', on_delete=models.CASCADE, null=True) - - def as_dict(self) -> dict: - sim_market_dict = self.sim_market.as_dict() if self.sim_market is not None else {'sim_market': None} - rl_dict = self.rl.as_dict() if self.rl is not None else {'rl': None} - return remove_none_values_from_dict({ - 'rl': rl_dict, - 'sim_market': sim_market_dict - }) - - @staticmethod - def get_empty_structure_dict(): - return { - 'rl': RlConfig.get_empty_structure_dict(), - 'sim_market': SimMarketConfig.get_empty_structure_dict() - } - - -class RlConfig(models.Model): - gamma = models.FloatField(null=True) - batch_size = models.IntegerField(null=True) - replay_size = models.IntegerField(null=True) - learning_rate = models.FloatField(null=True) - sync_target_frames = models.IntegerField(null=True) - replay_start_size = models.IntegerField(null=True) - epsilon_decay_last_frame = models.IntegerField(null=True) - epsilon_start = models.FloatField(null=True) - epsilon_final = models.FloatField(null=True) - - def as_dict(self) -> dict: - return remove_none_values_from_dict({ - 'gamma': self.gamma, - 'batch_size': self.batch_size, - 'replay_size': self.replay_size, - 'learning_rate': self.learning_rate, - 'sync_target_frames': self.sync_target_frames, - 'replay_start_size': self.replay_start_size, - 'epsilon_decay_last_frame': self.epsilon_decay_last_frame, - 'epsilon_start': self.epsilon_start, - 'epsilon_final': self.epsilon_final - }) - - @staticmethod - def get_empty_structure_dict(): - return { - 'gamma': None, - 'batch_size': None, - 'replay_size': None, - 'learning_rate': None, - 'sync_target_frames': None, - 'replay_start_size': None, - 'epsilon_decay_last_frame': None, - 'epsilon_start': None, - 'epsilon_final': None - } - - -class SimMarketConfig(models.Model): - max_storage = models.IntegerField(null=True) - episode_length = models.IntegerField(null=True) - max_price = models.IntegerField(null=True) - max_quality = models.IntegerField(null=True) - number_of_customers = models.IntegerField(null=True) - production_price = models.IntegerField(null=True) - storage_cost_per_product = models.FloatField(null=True) - - def as_dict(self) -> dict: - return remove_none_values_from_dict({ - 'max_storage': self.max_storage, - 'episode_length': self.episode_length, - 'max_price': self.max_price, - 'max_quality': self.max_quality, - 'number_of_customers': self.number_of_customers, - 'production_price': self.production_price, - 'storage_cost_per_product': self.storage_cost_per_product - }) - - @staticmethod - def get_empty_structure_dict(): - return { - 'max_storage': None, - 'episode_length': None, - 'max_price': None, - 'max_quality': None, - 'number_of_customers': None, - 'production_price': None, - 'storage_cost_per_product': None - } - - -def capitalize(word: str) -> str: - return word.upper() if len(word) <= 1 else word[0].upper() + word[1:] - - -def to_config_class_name(name: str) -> str: - return ''.join([capitalize(x) for x in name.split('_')]) + 'Config' - - -def remove_none_values_from_dict(dict_with_none_values: dict) -> dict: - return {key: value for key, value in dict_with_none_values.items() if value is not None} diff --git a/webserver/alpha_business_app/models/environment_config.py b/webserver/alpha_business_app/models/environment_config.py new file mode 100644 index 00000000..d6f7aa86 --- /dev/null +++ b/webserver/alpha_business_app/models/environment_config.py @@ -0,0 +1,24 @@ +from django.db import models + +from ..utils import remove_none_values_from_dict +from .abstract_config import AbstractConfig + + +class EnvironmentConfig(AbstractConfig, models.Model): + agents = models.ForeignKey('alpha_business_app.AgentsConfig', on_delete=models.CASCADE, null=True) + enable_live_draw = models.BooleanField(null=True) + episodes = models.IntegerField(null=True) + plot_interval = models.IntegerField(null=True) + marketplace = models.CharField(max_length=150, null=True) + task = models.CharField(max_length=14, choices=((1, 'training'), (2, 'agent_monitoring'), (3, 'exampleprinter')), null=True) + + def as_dict(self) -> dict: + agents_list = self.agents.as_list() if self.agents is not None else None + return remove_none_values_from_dict({ + 'enable_live_draw': self.enable_live_draw, + 'episodes': self.episodes, + 'plot_interval': self.plot_interval, + 'marketplace': self.marketplace, + 'task': self.task, + 'agents': agents_list + }) diff --git a/webserver/alpha_business_app/models/hyperparameter_config.py b/webserver/alpha_business_app/models/hyperparameter_config.py new file mode 100644 index 00000000..a8cc7267 --- /dev/null +++ b/webserver/alpha_business_app/models/hyperparameter_config.py @@ -0,0 +1,19 @@ +from django.db import models + +from ..utils import remove_none_values_from_dict +from .abstract_config import AbstractConfig +from .rl_config import RlConfig +from .sim_market_config import SimMarketConfig + + +class HyperparameterConfig(AbstractConfig, models.Model): + rl = models.ForeignKey('alpha_business_app.RLConfig', on_delete=models.CASCADE, null=True) + sim_market = models.ForeignKey('alpha_business_app.SimMarketConfig', on_delete=models.CASCADE, null=True) + + def as_dict(self) -> dict: + sim_market_dict = self.sim_market.as_dict() if self.sim_market is not None else {'sim_market': None} + rl_dict = self.rl.as_dict() if self.rl is not None else {'rl': None} + return remove_none_values_from_dict({ + 'rl': rl_dict, + 'sim_market': sim_market_dict + }) diff --git a/webserver/alpha_business_app/models/rl_config.py b/webserver/alpha_business_app/models/rl_config.py new file mode 100644 index 00000000..6b97452b --- /dev/null +++ b/webserver/alpha_business_app/models/rl_config.py @@ -0,0 +1,17 @@ +from django.db import models + +from .abstract_config import AbstractConfig + + +class RlConfig(AbstractConfig, models.Model): + epsilon_decay_last_frame = models.IntegerField(null=True, default=None) + epsilon_start = models.FloatField(null=True, default=None) + sync_target_frames = models.IntegerField(null=True, default=None) + replay_size = models.IntegerField(null=True, default=None) + replay_start_size = models.IntegerField(null=True, default=None) + gamma = models.FloatField(null=True, default=None) + batch_size = models.IntegerField(null=True, default=None) + epsilon_final = models.FloatField(null=True, default=None) + stable_baseline_test = models.FloatField(null=True, default=None) + learning_rate = models.FloatField(null=True, default=None) + testvalue2 = models.FloatField(null=True, default=None) diff --git a/webserver/alpha_business_app/models/sim_market_config.py b/webserver/alpha_business_app/models/sim_market_config.py new file mode 100644 index 00000000..34b29584 --- /dev/null +++ b/webserver/alpha_business_app/models/sim_market_config.py @@ -0,0 +1,13 @@ +from django.db import models + +from .abstract_config import AbstractConfig + + +class SimMarketConfig(AbstractConfig, models.Model): + max_storage = models.IntegerField(null=True, default=None) + storage_cost_per_product = models.FloatField(null=True, default=None) + episode_length = models.IntegerField(null=True, default=None) + max_price = models.IntegerField(null=True, default=None) + number_of_customers = models.IntegerField(null=True, default=None) + max_quality = models.IntegerField(null=True, default=None) + production_price = models.IntegerField(null=True, default=None) diff --git a/webserver/alpha_business_app/on_recommerce_change.py b/webserver/alpha_business_app/on_recommerce_change.py new file mode 100644 index 00000000..4230348b --- /dev/null +++ b/webserver/alpha_business_app/on_recommerce_change.py @@ -0,0 +1,52 @@ +# This file can be used to write own config files. +# It should be executed before running migrations. +# When using this file or changing the implementation, +# please keep in mind, that this is a potential security rist + +import os + +from utils import get_structure_with_types_of, to_config_class_name + + +class ConfigModelWriter: + def __init__(self, top_level: str, second_level: str = None) -> None: + self.whitespace = '\t' + self.top_level = top_level + self.second_level = second_level + self.name = second_level if second_level else top_level + self.class_name = to_config_class_name(self.name) + + def write_file(self) -> None: + print(f'{self._warning()}WARNING: This action will override the {self.class_name} file.{self._end()}') + print('Press enter to continue') + input() + # imports + lines = ['from django.db import models', '', 'from .abstract_config import AbstractConfig', ''] + # class definition + lines += [f'class {self.class_name}(AbstractConfig, models.Model):'] + # fields + attributes = get_structure_with_types_of(self.top_level, self.second_level) + for attr in attributes: + django_class = str(attr[1]).rsplit('.')[-1][:-2] + additional_attributes = self._get_additional_attributes(django_class) + lines += [f'{self.whitespace}{attr[0]} = models.{django_class}(null=True, default=None{additional_attributes})'] + path_to_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', f'{self.name}_config.py') + # write to file + print(f'Writing class definition of {self.class_name} to file.') + with open(path_to_file, 'w') as config_file: + config_file.write('\n'.join(lines)) + + def _get_additional_attributes(self, django_class: str) -> str: + if 'CharField' in django_class: + return ', max_length=100' + return '' + + def _warning(self) -> str: + return '\033[93m' + + def _end(self) -> str: + return '\033[0m' + + +ConfigModelWriter(top_level='hyperparameter', second_level='rl').write_file() +ConfigModelWriter(top_level='hyperparameter', second_level='sim_market').write_file() diff --git a/webserver/alpha_business_app/selection_manager.py b/webserver/alpha_business_app/selection_manager.py index 88e4ad47..caf2d213 100644 --- a/webserver/alpha_business_app/selection_manager.py +++ b/webserver/alpha_business_app/selection_manager.py @@ -4,17 +4,17 @@ import lxml.html from django.shortcuts import render -import recommerce.market.circular.circular_sim_market as circular_market -import recommerce.market.linear.linear_sim_market as linear_market from recommerce.configuration.utils import get_class +from .utils import get_recommerce_agents_for_marketplace, get_recommerce_marketplaces + class SelectionManager: def __init__(self) -> None: self.current_marketplace = None def get_agent_options_for_marketplace(self) -> list: - return self._to_tuple_list(self.current_marketplace.get_possible_rl_agents()) + return self._to_tuple_list(get_recommerce_agents_for_marketplace(self.current_marketplace)) def get_competitor_options_for_marketplace(self) -> list: return self._to_tuple_list(self.current_marketplace.get_competitor_classes()) @@ -71,25 +71,9 @@ def get_correct_agents_html_on_marketplace_change(self, request, marketplace_cla return lxml.html.tostring(html) def get_marketplace_options(self) -> list: - """ - Matches marketplaces of recommerce.market.circular.circular_sim_market and recommerce.market.linear.linear_sim_market, - which contain one of the Keywords: Oligopoly, Duopoly, Monopoly - - Returns: - list: tuple list for selection - """ - keywords = ['Monopoly', 'Duopoly', 'Oligopoly'] - # get all circular marketplaces - circular_marketplaces = list(set(filter(lambda class_name: any(keyword in class_name for keyword in keywords), dir(circular_market)))) - circular_market_str = [f'recommerce.market.circular.circular_sim_market.{market}' for market in sorted(circular_marketplaces)] - circular_tuples = self._to_tuple_list(circular_market_str) - # get all linear marketplaces - visible_linear_names = list(set(filter(lambda class_name: any(keyword in class_name for keyword in keywords), dir(linear_market)))) - linear_market_str = [f'recommerce.market.linear.linear_sim_market.{market}' for market in sorted(visible_linear_names)] - linear_tuples = self._to_tuple_list(linear_market_str) - - self.current_marketplace = get_class(circular_tuples[0][0]) - return circular_tuples + linear_tuples + recommerce_marketplaces = get_recommerce_marketplaces() + self.current_marketplace = get_class(recommerce_marketplaces[0]) + return self._to_tuple_list(recommerce_marketplaces) def _get_task_options(self) -> list: return [ diff --git a/webserver/alpha_business_app/static/js/custom.js b/webserver/alpha_business_app/static/js/custom.js index 26efabc7..7eec55f9 100644 --- a/webserver/alpha_business_app/static/js/custom.js +++ b/webserver/alpha_business_app/static/js/custom.js @@ -11,6 +11,12 @@ $(document).ready(function() { }); }; addEventToAddMoreButton() + + function getFormData () { + var form = $("form.config-form"); + var formdata = form.serializeArray(); + return formdata; + }; function updateAPIHealth() { // replaces the element by the element returned by ajax (html) and adds this click event to it @@ -55,10 +61,33 @@ $(document).ready(function() { success: function (data) { all_agents.empty().append(data); addEventToAddMoreButton(); + addChangeToAgent(); } }); - }).trigger('change'); - + }).trigger("change"); + + + function addChangeToAgent () { + $("select.agent-agent-class").change(function () { + // will be called when agent dropdown has changed, we need to change rl hyperparameter for that + var self = $(this); + var formdata = getFormData(); + const csrftoken = getCookie("csrftoken"); + $.ajax({ + type: "POST", + url: self.data("url"), + data: { + csrfmiddlewaretoken: csrftoken, + formdata, + "agent": self.val() + }, + success: function (data) { + $("div.rl-parameter").empty().append(data) + } + }); + }).trigger("change"); + } + addChangeToAgent() function getCookie(name) { let cookieValue = null; @@ -76,12 +105,9 @@ $(document).ready(function() { return cookieValue; } - $("button.form-check").click(function () { - $("table.config-status-display").remove(); - + $("button.form-check").click(function () { var self = $(this); - var form = $("form.config-form"); - var formdata = form.serializeArray(); + var formdata = getFormData(); const csrftoken = getCookie('csrftoken'); $.ajax({ diff --git a/webserver/alpha_business_app/tests/constant_tests.py b/webserver/alpha_business_app/tests/constant_tests.py index 2223628f..d636024a 100644 --- a/webserver/alpha_business_app/tests/constant_tests.py +++ b/webserver/alpha_business_app/tests/constant_tests.py @@ -7,8 +7,8 @@ 'environment-episodes': [''], 'environment-plot_interval': [''], 'environment-marketplace': ['recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly'], - 'environment-agents-name': ['Rule_Based Agent'], - 'environment-agents-agent_class': ['recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent'], + 'environment-agents-name': ['QLearning Agent'], + 'environment-agents-agent_class': ['recommerce.rl.q_learning.q_learning_agent.QLearningAgent'], 'environment-agents-argument': [''], 'hyperparameter-rl-gamma': ['0.99'], 'hyperparameter-rl-batch_size': ['32'], @@ -35,8 +35,8 @@ 'enable_live_draw': False, 'agents': [ { - 'name': 'Rule_Based Agent', - 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + 'name': 'QLearning Agent', + 'agent_class': 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent', 'argument': '' } ] @@ -126,7 +126,9 @@ 'replay_start_size': None, 'epsilon_decay_last_frame': None, 'epsilon_start': None, - 'epsilon_final': None + 'epsilon_final': None, + 'testvalue2': None, + 'stable_baseline_test': None }, 'sim_market': { 'max_storage': None, @@ -139,3 +141,17 @@ } } } + +EXAMPLE_RL_DICT = { + 'rl': { + 'gamma': 0.99, + 'batch_size': 32, + 'replay_size': 100000, + 'learning_rate': 1e-6, + 'sync_target_frames': 1000, + 'replay_start_size': 10000, + 'epsilon_decay_last_frame': 75000, + 'epsilon_start': 1.0, + 'epsilon_final': 0.1 + } +} diff --git a/webserver/alpha_business_app/tests/test_adjustable_fields.py b/webserver/alpha_business_app/tests/test_adjustable_fields.py new file mode 100644 index 00000000..27a016e3 --- /dev/null +++ b/webserver/alpha_business_app/tests/test_adjustable_fields.py @@ -0,0 +1,16 @@ +from django.test import TestCase + +from ..adjustable_fields import get_rl_parameter_prefill + + +class AdjustableFieldsTests(TestCase): + def test_rl_hyperparameter_with_prefill(self): + prefill_dict = {'gamma': 0.9, 'learning_rate': 0.4, 'test': None} + error_dict = {'gamma': 'test', 'learning_rate': None, 'test': None} + expected_list = [ + {'name': 'gamma', 'prefill': 0.9, 'error': 'test'}, + {'name': 'learning_rate', 'prefill': 0.4, 'error': ''}, + {'name': 'test', 'prefill': '', 'error': ''} + ] + actual_list = get_rl_parameter_prefill(prefill_dict, error_dict) + assert actual_list == expected_list diff --git a/webserver/alpha_business_app/tests/test_configuration_parser.py b/webserver/alpha_business_app/tests/test_config_flat_dict_parser.py similarity index 92% rename from webserver/alpha_business_app/tests/test_configuration_parser.py rename to webserver/alpha_business_app/tests/test_config_flat_dict_parser.py index 7bd99a19..8cc585b9 100644 --- a/webserver/alpha_business_app/tests/test_configuration_parser.py +++ b/webserver/alpha_business_app/tests/test_config_flat_dict_parser.py @@ -1,7 +1,11 @@ from django.test import TestCase from ..config_parser import ConfigFlatDictParser, ConfigModelParser -from ..models.config import AgentsConfig, Config, EnvironmentConfig, RlConfig, SimMarketConfig +from ..models.agents_config import AgentsConfig +from ..models.config import Config +from ..models.environment_config import EnvironmentConfig +from ..models.rl_config import RlConfig +from ..models.sim_market_config import SimMarketConfig from .constant_tests import EXAMPLE_HIERARCHY_DICT @@ -114,8 +118,8 @@ def test_flat_environment(self): 'plot_interval': [''], 'enable_live_draw': [''], 'marketplace': ['recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly'], - 'agents-name': ['Rule_Based Agent'], - 'agents-agent_class': ['recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent'], + 'agents-name': ['QLearning Agent'], + 'agents-agent_class': ['recommerce.rl.q_learning.q_learning_agent.QLearningAgent'], 'agents-argument': [''], } expected_environment_dict = EXAMPLE_HIERARCHY_DICT['environment'].copy() @@ -124,8 +128,8 @@ def test_flat_environment(self): def test_flat_agents(self): test_dict = { - 'name': ['Rule_Based Agent'], - 'agent_class': ['recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent'], + 'name': ['QLearning Agent'], + 'agent_class': ['recommerce.rl.q_learning.q_learning_agent.QLearningAgent'], 'argument': [''], } assert EXAMPLE_HIERARCHY_DICT['environment']['agents'] == self.flat_parser._flat_agents_to_hierarchical(test_dict) @@ -205,8 +209,8 @@ def test_parsing_config_dict(self): all_agents = environment_agents.agentconfig_set.all() assert 1 == len(all_agents) - assert 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent' == all_agents[0].agent_class - assert 'Rule_Based Agent' == all_agents[0].name + assert 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent' == all_agents[0].agent_class + assert 'QLearning Agent' == all_agents[0].name assert '' == all_agents[0].argument def test_parsing_agents(self): diff --git a/webserver/alpha_business_app/tests/test_config_model.py b/webserver/alpha_business_app/tests/test_config_model.py index 59ba0641..db33aee3 100644 --- a/webserver/alpha_business_app/tests/test_config_model.py +++ b/webserver/alpha_business_app/tests/test_config_model.py @@ -1,9 +1,16 @@ from django.test import TestCase -from ..models.config import * -from ..models.config import remove_none_values_from_dict +from ..models.agent_config import AgentConfig +from ..models.agents_config import AgentsConfig +from ..models.config import Config from ..models.container import Container -from .constant_tests import EMPTY_STRUCTURE_CONFIG +from ..models.environment_config import EnvironmentConfig +from ..models.hyperparameter_config import HyperparameterConfig +from ..models.rl_config import RlConfig +from ..models.sim_market_config import SimMarketConfig +from ..utils import remove_none_values_from_dict, to_config_class_name + +# from .constant_tests import EMPTY_STRUCTURE_CONFIG class ConfigTest(TestCase): @@ -25,15 +32,6 @@ def test_class_name_rl_config(self): def test_class_name_sim_market_config(self): assert 'SimMarketConfig' == to_config_class_name('sim_market') - def test_capitalize(self): - assert 'TestTesTTest' == capitalize('testTesTTest') - - def test_capitalize_empty_strings(self): - assert '' == capitalize('') - - def test_capitalize_one_letter_strings(self): - assert 'A' == capitalize('a') - def test_is_referenced(self): test_config_not_referenced = Config.objects.create() test_config_referenced = Config.objects.create() @@ -119,7 +117,16 @@ def test_config_to_dict(self): def test_dict_representation_of_agent(self): test_agent = AgentConfig.objects.create(name='test_agent', agent_class='test_class', argument='1234') expected_dict = {'name': 'test_agent', 'agent_class': 'test_class', 'argument': '1234'} - assert expected_dict == test_agent.as_dict(), (expected_dict, test_agent.as_dict()) + assert expected_dict == test_agent.as_dict() + + def test_dict_representation_of_environment_config(self): + test_environment = EnvironmentConfig.objects.create(enable_live_draw=True, episodes=50, plot_interval=12, marketplace='test') + expected_dict = {'enable_live_draw': True, 'episodes': 50, 'plot_interval': 12, 'marketplace': 'test'} + assert expected_dict == test_environment.as_dict() + + def test_dict_representation_of_empty_config(self): + test_config = Config.objects.create() + assert {} == test_config.as_dict() def test_list_representation_of_agents(self): test_agents = AgentsConfig.objects.create() @@ -140,13 +147,21 @@ def test_list_representation_of_agents(self): ] assert expected_list == test_agents.as_list() - def test_dict_representation_of_empty_config(self): - test_config = Config.objects.create() - assert {} == test_config.as_dict() - - def test_get_empty_structure_dict(self): - actual_dict = Config.get_empty_structure_dict() - assert EMPTY_STRUCTURE_CONFIG == actual_dict + def test_get_empty_structure_dict_for_rl(self): + expected_dict = { + 'sync_target_frames': None, + 'testvalue2': None, + 'gamma': None, + 'epsilon_start': None, + 'replay_size': None, + 'stable_baseline_test': None, + 'epsilon_decay_last_frame': None, + 'batch_size': None, + 'epsilon_final': None, + 'replay_start_size': None, + 'learning_rate': None + } + assert expected_dict == RlConfig.get_empty_structure_dict() def test_remove_none_values_from_dict(self): test_dict = {'test': 'test', 'test2': None} diff --git a/webserver/alpha_business_app/tests/test_config_model_parser.py b/webserver/alpha_business_app/tests/test_config_model_parser.py new file mode 100644 index 00000000..4132894e --- /dev/null +++ b/webserver/alpha_business_app/tests/test_config_model_parser.py @@ -0,0 +1,151 @@ +from django.test import TestCase + +from ..config_parser import ConfigModelParser +from ..models.agents_config import AgentsConfig +from ..models.config import Config +from ..models.environment_config import EnvironmentConfig +from ..models.hyperparameter_config import HyperparameterConfig +from ..models.rl_config import RlConfig +from ..models.sim_market_config import SimMarketConfig +from .constant_tests import EXAMPLE_HIERARCHY_DICT, EXAMPLE_RL_DICT + + +class ConfigModelParserTest(TestCase): + expected_dict = { + 'hyperparameter': { + 'rl': { + 'gamma': 0.99, + 'batch_size': 32, + 'replay_size': 100000, + 'learning_rate': 1e-06, + 'sync_target_frames': 1000, + 'replay_start_size': 10000, + 'epsilon_decay_last_frame': 75000, + 'epsilon_start': 1.0, + 'epsilon_final': 0.1 + }, + 'sim_market': { + 'max_storage': 100, + 'episode_length': 50, + 'max_price': 10, + 'max_quality': 50, + 'number_of_customers': 20, + 'production_price': 3, + 'storage_cost_per_product': 0.1 + } + }, + 'environment': { + 'task': 'training', + 'enable_live_draw': False, + 'marketplace': 'recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly', + 'agents': [ + { + 'name': 'Rule_Based Agent', + 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + 'argument': '' + } + ] + } + } + + def setUp(self) -> None: + self.parser = ConfigModelParser() + + def test_parsing_config_dict(self): + test_dict = EXAMPLE_HIERARCHY_DICT.copy() + + final_config = self.parser.parse_config(test_dict) + + assert Config == type(final_config) + assert final_config.hyperparameter is not None + + # assert all hyperparameters + hyperparameter_rl_config: RlConfig = final_config.hyperparameter.rl + hyperparameter_sim_market_config: SimMarketConfig = final_config.hyperparameter.sim_market + + assert hyperparameter_rl_config is not None + assert final_config.hyperparameter.sim_market is not None + + assert 0.99 == hyperparameter_rl_config.gamma + assert 32 == hyperparameter_rl_config.batch_size + assert 100000 == hyperparameter_rl_config.replay_size + assert 1e-06 == hyperparameter_rl_config.learning_rate + assert 1000 == hyperparameter_rl_config.sync_target_frames + assert 10000 == hyperparameter_rl_config.replay_start_size + assert 75000 == hyperparameter_rl_config.epsilon_decay_last_frame + assert 1.0 == hyperparameter_rl_config.epsilon_start + assert 0.1 == hyperparameter_rl_config.epsilon_final + + assert 100 == hyperparameter_sim_market_config.max_storage + assert 50 == hyperparameter_sim_market_config.episode_length + assert 10 == hyperparameter_sim_market_config.max_price + assert 50 == hyperparameter_sim_market_config.max_quality + assert 20 == hyperparameter_sim_market_config.number_of_customers + assert 3 == hyperparameter_sim_market_config.production_price + assert 0.1 == hyperparameter_sim_market_config.storage_cost_per_product + + # assert all environment + assert final_config.environment is not None + environment_config: EnvironmentConfig = final_config.environment + assert 'training' == environment_config.task + assert environment_config.enable_live_draw is False + assert environment_config.episodes is None + assert environment_config.plot_interval is None + assert 'recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly' == environment_config.marketplace + assert environment_config.agents is not None + + environment_agents: AgentsConfig = environment_config.agents + + all_agents = environment_agents.agentconfig_set.all() + assert 1 == len(all_agents) + assert 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent' == all_agents[0].agent_class + assert 'QLearning Agent' == all_agents[0].name + assert '' == all_agents[0].argument + + def test_parsing_agents(self): + test_dict = [ + { + 'name': 'test_agent1', + 'agent_class': 'test_class', + 'argument': '' + }, + { + 'name': 'test_agent2', + 'agent_class': 'test_class', + 'argument': '1234' + } + ] + agents = self.parser._parse_agents_to_datastructure(test_dict) + all_agents = agents.agentconfig_set.all() + + assert 'test_agent1' == all_agents[0].name + assert 'test_class' == all_agents[0].agent_class + assert '' == all_agents[0].argument + + assert 'test_agent2' == all_agents[1].name + assert 'test_class' == all_agents[1].agent_class + assert '1234' == all_agents[1].argument + + def test_parse_rl(self): + test_dict = EXAMPLE_RL_DICT.copy() + + final_config = self.parser.parse_config_dict_to_datastructure('hyperparameter', test_dict) + + assert HyperparameterConfig == type(final_config) + + # assert all hyperparameters + hyperparameter_rl_config: RlConfig = final_config.rl + hyperparameter_sim_market_config: SimMarketConfig = final_config.sim_market + assert hyperparameter_rl_config is not None + assert hyperparameter_sim_market_config is None + + assert 0.99 == hyperparameter_rl_config.gamma + assert 32 == hyperparameter_rl_config.batch_size + assert 100000 == hyperparameter_rl_config.replay_size + assert 1e-6 == hyperparameter_rl_config.learning_rate + assert 1000 == hyperparameter_rl_config.sync_target_frames + assert 10000 == hyperparameter_rl_config.replay_start_size + assert 75000 == hyperparameter_rl_config.epsilon_decay_last_frame + assert 1.0 == hyperparameter_rl_config.epsilon_start + assert 0.1 == hyperparameter_rl_config.epsilon_final + assert hyperparameter_rl_config.stable_baseline_test is None diff --git a/webserver/alpha_business_app/tests/test_data/test_config_complete.json b/webserver/alpha_business_app/tests/test_data/test_config_complete.json new file mode 100644 index 00000000..ba3fdf52 --- /dev/null +++ b/webserver/alpha_business_app/tests/test_data/test_config_complete.json @@ -0,0 +1,34 @@ +{ + "environment": { + "task": "training", + "marketplace": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly", + "agents": [ + { + "name": "QLearning Agent", + "agent_class": "recommerce.rl.q_learning.q_learning_agent.QLearningAgent", + "argument": "" + } + ] + }, "hyperparameter": { + "rl": { + "gamma": 0.99, + "batch_size": 32, + "replay_size": 100000, + "learning_rate": 1e-06, + "sync_target_frames": 1000, + "replay_start_size": 10000, + "epsilon_decay_last_frame": 75000, + "epsilon_start": 1.0, + "epsilon_final": 0.1 + }, + "sim_market": { + "max_storage": 100, + "episode_length": 50, + "max_price": 10, + "max_quality": 50, + "number_of_customers": 20, + "production_price": 3, + "storage_cost_per_product": 0.1 + } + } +} diff --git a/webserver/alpha_business_app/tests/test_data/test_environment_config.json b/webserver/alpha_business_app/tests/test_data/test_environment_config.json index edfa9644..b06c4d07 100644 --- a/webserver/alpha_business_app/tests/test_data/test_environment_config.json +++ b/webserver/alpha_business_app/tests/test_data/test_environment_config.json @@ -6,7 +6,7 @@ "marketplace": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly", "agents": [ { - "name": "Rule_Based Agent", + "name": "QLearning Agent", "agent_class": "recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent", "argument": "" }, @@ -15,5 +15,6 @@ "agent_class": "recommerce.rl.q_learning.q_learning_agent.QLearningAgent", "argument": "CircularEconomyRebuyPriceMonopoly_QLearningAgent.dat" } - ] + ], + "config_type": "environment" } diff --git a/webserver/alpha_business_app/tests/test_data/test_hyperparameter_config.json b/webserver/alpha_business_app/tests/test_data/test_hyperparameter_config.json deleted file mode 100644 index 5fe8956f..00000000 --- a/webserver/alpha_business_app/tests/test_data/test_hyperparameter_config.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "rl": { - "gamma" : 0.99, - "batch_size" : 32, - "replay_size" : 100000, - "learning_rate" : 1e-6, - "sync_target_frames" : 1000, - "replay_start_size" : 10000, - "epsilon_decay_last_frame" : 75000, - "epsilon_start" : 1.0, - "epsilon_final" : 0.1 - }, - "sim_market": { - "max_storage": 100, - "episode_length": 50, - "max_price": 10, - "max_quality": 50, - "number_of_customers": 20, - "production_price": 3, - "storage_cost_per_product": 0.1 - } -} diff --git a/webserver/alpha_business_app/tests/test_data/test_mixed_config.json b/webserver/alpha_business_app/tests/test_data/test_mixed_config.json deleted file mode 100644 index cec271cc..00000000 --- a/webserver/alpha_business_app/tests/test_data/test_mixed_config.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "task": "training", - "sim_market": { - "max_storage": 100, - "episode_length": 50 - }, - "enable_live_draw": false, - "rl": { - "gamma" : 0.99, - "batch_size" : 32 - }, - "episodes": 50 -} diff --git a/webserver/alpha_business_app/tests/test_data/test_rl_config.json b/webserver/alpha_business_app/tests/test_data/test_rl_config.json new file mode 100644 index 00000000..7d1eea3b --- /dev/null +++ b/webserver/alpha_business_app/tests/test_data/test_rl_config.json @@ -0,0 +1,12 @@ +{ + "gamma" : 0.99, + "batch_size" : 32, + "replay_size" : 100000, + "learning_rate" : 1e-6, + "sync_target_frames" : 1000, + "replay_start_size" : 10000, + "epsilon_decay_last_frame" : 75000, + "epsilon_start" : 1.0, + "epsilon_final" : 0.1, + "config_type": "rl" +} diff --git a/webserver/alpha_business_app/tests/test_data/test_sim_market_config.json b/webserver/alpha_business_app/tests/test_data/test_sim_market_config.json new file mode 100644 index 00000000..0656dce4 --- /dev/null +++ b/webserver/alpha_business_app/tests/test_data/test_sim_market_config.json @@ -0,0 +1,10 @@ +{ + "max_storage": 100, + "episode_length": 50, + "max_price": 10, + "max_quality": 50, + "number_of_customers": 20, + "production_price": 3, + "storage_cost_per_product": 0.1, + "config_type": "sim_market" +} diff --git a/webserver/alpha_business_app/tests/test_file_handling.py b/webserver/alpha_business_app/tests/test_file_handling.py index 82aaa0b5..29356c10 100644 --- a/webserver/alpha_business_app/tests/test_file_handling.py +++ b/webserver/alpha_business_app/tests/test_file_handling.py @@ -8,7 +8,12 @@ from ..config_parser import ConfigModelParser from ..handle_files import handle_uploaded_file -from ..models.config import * +from ..models.agents_config import AgentsConfig +from ..models.config import Config +from ..models.environment_config import EnvironmentConfig +from ..models.hyperparameter_config import HyperparameterConfig +from ..models.rl_config import RlConfig +from ..models.sim_market_config import SimMarketConfig class MockedResponse(): @@ -56,7 +61,7 @@ def test_uploaded_file_invalid_json(self): assert {'error': 'Your JSON is not valid'} == actual_arguments[2] def test_uploaded_file_with_unknown_key(self): - test_uploaded_file = MockedUploadedFile('test_file.json', b'{ "test": "1234" }') + test_uploaded_file = MockedUploadedFile('test_file.json', b'{ "test": "1234", "config_type": "rl"}') with patch('alpha_business_app.handle_files.render') as render_mock: handle_uploaded_file(self._setup_request(), test_uploaded_file) @@ -64,7 +69,7 @@ def test_uploaded_file_with_unknown_key(self): render_mock.assert_called_once() assert 'upload.html' == actual_arguments[1] - assert {'error': 'Your config contains an invalid key: test'} == actual_arguments[2], f'{actual_arguments[2]}' + assert {'error': "Your config contains an invalid key: 'test'"} == actual_arguments[2], f'{actual_arguments[2]}' def test_objects_from_parse_dict(self): test_dict = {'rl': {'batch_size': 32}, 'sim_market': {'episode_length': 50}} @@ -92,14 +97,14 @@ def test_objects_from_parse_dict(self): else: assert 32 == getattr(resulting_config.rl, name) - def test_parsing_with_only_hyperparameter(self): + def test_parsing_with_rl_hyperparameter(self): # get a test config to be parsed path_to_test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data') - with open(os.path.join(path_to_test_data, 'test_hyperparameter_config.json'), 'r') as file: + with open(os.path.join(path_to_test_data, 'test_rl_config.json'), 'r') as file: content = file.read() # mock uploaded file with test config test_uploaded_file = MockedUploadedFile('config.json', content.encode()) - # test method + with patch('alpha_business_app.handle_files.redirect') as redirect_mock: handle_uploaded_file(self._setup_request(), test_uploaded_file) redirect_mock.assert_called_once() @@ -108,12 +113,11 @@ def test_parsing_with_only_hyperparameter(self): assert Config == type(final_config) assert final_config.environment is None assert final_config.hyperparameter is not None + assert final_config.hyperparameter.sim_market is None hyperparameter_rl_config: RlConfig = final_config.hyperparameter.rl - hyperparameter_sim_market_config: SimMarketConfig = final_config.hyperparameter.sim_market assert hyperparameter_rl_config is not None - assert final_config.hyperparameter.sim_market is not None assert 0.99 == hyperparameter_rl_config.gamma assert 32 == hyperparameter_rl_config.batch_size @@ -125,6 +129,28 @@ def test_parsing_with_only_hyperparameter(self): assert 1.0 == hyperparameter_rl_config.epsilon_start assert 0.1 == hyperparameter_rl_config.epsilon_final + def test_parsing_with_sim_market_hyperparameter(self): + # get a test config to be parsed + path_to_test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data') + with open(os.path.join(path_to_test_data, 'test_sim_market_config.json'), 'r') as file: + content = file.read() + # mock uploaded file with test config + test_uploaded_file = MockedUploadedFile('config.json', content.encode()) + # test method + with patch('alpha_business_app.handle_files.redirect') as redirect_mock: + handle_uploaded_file(self._setup_request(), test_uploaded_file) + redirect_mock.assert_called_once() + # assert the datastructure, that should be present afterwards + final_config: Config = Config.objects.all().first() + assert Config == type(final_config) + assert final_config.environment is None + assert final_config.hyperparameter is not None + assert final_config.hyperparameter.rl is None + + hyperparameter_sim_market_config: SimMarketConfig = final_config.hyperparameter.sim_market + + assert final_config.hyperparameter.sim_market is not None + assert 100 == hyperparameter_sim_market_config.max_storage assert 50 == hyperparameter_sim_market_config.episode_length assert 10 == hyperparameter_sim_market_config.max_price @@ -168,10 +194,20 @@ def test_parsing_with_only_environment(self): assert 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent' == all_agents[1].agent_class assert 'CircularEconomyRebuyPriceMonopoly_QLearningAgent.dat' == all_agents[1].argument - def test_parsing_mixed_config(self): - # get a test config to be parsed + def test_parsing_invalid_rl_parameters(self): + test_uploaded_file = MockedUploadedFile('config.json', b'{"test":"bla", "config_type": "rl"}') + with patch('alpha_business_app.handle_files.render') as render_mock: + handle_uploaded_file(self._setup_request(), test_uploaded_file) + + actual_arguments = render_mock.call_args.args + render_mock.assert_called_once() + assert 'upload.html' == actual_arguments[1] + assert {'error': "Your config contains an invalid key: 'test'"} \ + == actual_arguments[2], f'{actual_arguments[2]}' + + def test_parsing_complete_config(self): path_to_test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data') - with open(os.path.join(path_to_test_data, 'test_mixed_config.json'), 'r') as file: + with open(os.path.join(path_to_test_data, 'test_config_complete.json'), 'r') as file: content = file.read() # mock uploaded file with test config test_uploaded_file = MockedUploadedFile('config.json', content.encode()) @@ -179,6 +215,7 @@ def test_parsing_mixed_config(self): with patch('alpha_business_app.handle_files.redirect') as redirect_mock: handle_uploaded_file(self._setup_request(), test_uploaded_file) redirect_mock.assert_called_once() + # assert the datastructure, that should be present afterwards final_config: Config = Config.objects.all().first() assert Config == type(final_config) @@ -189,28 +226,44 @@ def test_parsing_mixed_config(self): hyperparameter_config: HyperparameterConfig = final_config.hyperparameter assert 'training' == environment_config.task - assert environment_config.enable_live_draw is False - assert 50 == environment_config.episodes + assert environment_config.enable_live_draw is None + assert environment_config.episodes is None + assert environment_config.plot_interval is None + assert 'recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly' == environment_config.marketplace + assert environment_config.agents is not None - assert hyperparameter_config.sim_market is not None - assert hyperparameter_config.rl is not None + environment_agents: AgentsConfig = environment_config.agents - assert 100 == hyperparameter_config.sim_market.max_storage - assert 50 == hyperparameter_config.sim_market.episode_length - assert 0.99 == hyperparameter_config.rl.gamma - assert 32 == hyperparameter_config.rl.batch_size + all_agents = environment_agents.agentconfig_set.all() + assert 1 == len(all_agents) + assert 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent' == all_agents[0].agent_class + assert '' == all_agents[0].argument - def test_parsing_invalid_rl_parameters(self): - test_uploaded_file = MockedUploadedFile('config.json', b'{"rl": {"test":"bla"}}') - with patch('alpha_business_app.handle_files.render') as render_mock: - handle_uploaded_file(self._setup_request(), test_uploaded_file) + hyperparameter_sim_market_config: SimMarketConfig = hyperparameter_config.sim_market - actual_arguments = render_mock.call_args.args + assert final_config.hyperparameter.sim_market is not None - render_mock.assert_called_once() - assert 'upload.html' == actual_arguments[1] - assert {'error': 'The key "test" should not exist within a HyperparameterConfigValidator config (was checked at sub-key "rl")'} \ - == actual_arguments[2], f'{actual_arguments[2]}' + assert 100 == hyperparameter_sim_market_config.max_storage + assert 50 == hyperparameter_sim_market_config.episode_length + assert 10 == hyperparameter_sim_market_config.max_price + assert 50 == hyperparameter_sim_market_config.max_quality + assert 20 == hyperparameter_sim_market_config.number_of_customers + assert 3 == hyperparameter_sim_market_config.production_price + assert 0.1 == hyperparameter_sim_market_config.storage_cost_per_product + + hyperparameter_rl_config: RlConfig = hyperparameter_config.rl + + assert hyperparameter_rl_config is not None + + assert 0.99 == hyperparameter_rl_config.gamma + assert 32 == hyperparameter_rl_config.batch_size + assert 100000 == hyperparameter_rl_config.replay_size + assert 1e-6 == hyperparameter_rl_config.learning_rate + assert 1000 == hyperparameter_rl_config.sync_target_frames + assert 10000 == hyperparameter_rl_config.replay_start_size + assert 75000 == hyperparameter_rl_config.epsilon_decay_last_frame + assert 1.0 == hyperparameter_rl_config.epsilon_start + assert 0.1 == hyperparameter_rl_config.epsilon_final def test_parsing_duplicate_keys(self): test_uploaded_file = MockedUploadedFile('config.json', b'{"rl": {"test":"bla"}, "rl": {"test":"bla"}}') diff --git a/webserver/alpha_business_app/tests/test_prefill.py b/webserver/alpha_business_app/tests/test_prefill.py index 7e0f215a..df39f97d 100644 --- a/webserver/alpha_business_app/tests/test_prefill.py +++ b/webserver/alpha_business_app/tests/test_prefill.py @@ -8,7 +8,10 @@ # from ..buttons import ButtonHandler from ..config_merger import ConfigMerger from ..config_parser import ConfigModelParser -from ..models.config import Config, EnvironmentConfig, HyperparameterConfig, RlConfig +from ..models.config import Config +from ..models.environment_config import EnvironmentConfig +from ..models.hyperparameter_config import HyperparameterConfig +from ..models.rl_config import RlConfig from .constant_tests import EMPTY_STRUCTURE_CONFIG, EXAMPLE_HIERARCHY_DICT, EXAMPLE_HIERARCHY_DICT2 # from unittest.mock import patch @@ -35,9 +38,12 @@ def test_merge_one_config(self): expected_dict = copy.deepcopy(config_dict) expected_dict['environment']['episodes'] = None expected_dict['environment']['plot_interval'] = None + expected_dict['hyperparameter']['rl']['testvalue2'] = None + expected_dict['hyperparameter']['rl']['stable_baseline_test'] = None + empty_config = Config.get_empty_structure_dict() merger = ConfigMerger() - actual_config = merger._merge_config_into_base_config(empty_config, config_dict) + actual_config = merger.merge_config_into_base_config(empty_config, config_dict) assert expected_dict == actual_config @@ -50,7 +56,7 @@ def test_merge_two_configs_without_conflicts(self): test_config2 = Config.objects.create(hyperparameter=test_hyper_parameter_config) merger = ConfigMerger() - final_dict, error_dict = merger.merge_config_objects([test_config1.id, test_config2.id]) + final_dict, error_dict = merger.merge_config_objects([test_config1.id, test_config2.id], Config.get_empty_structure_dict()) expected_dict = copy.deepcopy(EMPTY_STRUCTURE_CONFIG) expected_dict['hyperparameter']['rl']['gamma'] = 0.99 @@ -67,7 +73,7 @@ def test_merge_two_small_configs_with_conflicts(self): test_config2 = Config.objects.create(environment=test_environment_config2) merger = ConfigMerger() - final_dict, error_dict = merger.merge_config_objects([test_config1.id, test_config2.id]) + final_dict, error_dict = merger.merge_config_objects([test_config1.id, test_config2.id], Config.get_empty_structure_dict()) expected_dict = copy.deepcopy(EMPTY_STRUCTURE_CONFIG) expected_dict['environment']['task'] = 'monitoring' @@ -87,7 +93,7 @@ def test_merge_two_configs_with_conflicts(self): config_object2 = parser.parse_config(test_config2) merger = ConfigMerger() - final_config, error_dict = merger.merge_config_objects([config_object1.id, config_object2.id]) + final_config, error_dict = merger.merge_config_objects([config_object1.id, config_object2.id], Config.get_empty_structure_dict()) expected_final_config = { 'environment': { @@ -98,8 +104,8 @@ def test_merge_two_configs_with_conflicts(self): 'task': 'monitoring', 'agents': [ { - 'name': 'Rule_Based Agent', - 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + 'name': 'QLearning Agent', + 'agent_class': 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent', 'argument': '' }, { @@ -122,7 +128,10 @@ def test_merge_two_configs_with_conflicts(self): 'sync_target_frames': 100, 'replay_start_size': 1000, 'epsilon_decay_last_frame': 7500, - 'epsilon_start': 0.9, 'epsilon_final': 0.2 + 'epsilon_start': 0.9, + 'epsilon_final': 0.2, + 'testvalue2': None, + 'stable_baseline_test': None }, 'sim_market': { 'max_storage': 80, @@ -153,7 +162,9 @@ def test_merge_two_configs_with_conflicts(self): 'replay_start_size': 'changed hyperparameter-rl replay_start_size from 10000 to 1000', 'epsilon_decay_last_frame': 'changed hyperparameter-rl epsilon_decay_last_frame from 75000 to 7500', 'epsilon_start': 'changed hyperparameter-rl epsilon_start from 1.0 to 0.9', - 'epsilon_final': 'changed hyperparameter-rl epsilon_final from 0.1 to 0.2' + 'epsilon_final': 'changed hyperparameter-rl epsilon_final from 0.1 to 0.2', + 'testvalue2': None, + 'stable_baseline_test': None }, 'sim_market': { 'max_storage': 'changed hyperparameter-sim_market max_storage from 100 to 80', diff --git a/webserver/alpha_business_app/tests/test_utils.py b/webserver/alpha_business_app/tests/test_utils.py new file mode 100644 index 00000000..a2ca740e --- /dev/null +++ b/webserver/alpha_business_app/tests/test_utils.py @@ -0,0 +1,133 @@ +from django.test import TestCase + +from ..utils import (convert_python_type_to_django_type, get_all_possible_rl_hyperparameter, get_all_possible_sim_market_hyperparameter, + get_structure_dict_for, to_config_keyword) + + +class UtilsTest(TestCase): + def test_get_structure_dict_for_config(self): + expected_dict = { + 'environment': { + 'task': None, + 'enable_live_draw': None, + 'episodes': None, + 'plot_interval': None, + 'marketplace': None, + 'agents': [] + }, + 'hyperparameter': { + 'sim_market': { + 'max_storage': None, + 'episode_length': None, + 'max_price': None, + 'max_quality': None, + 'number_of_customers': None, + 'production_price': None, + 'storage_cost_per_product': None + }, + 'rl': { + 'replay_size': None, + 'epsilon_start': None, + 'replay_start_size': None, + 'epsilon_decay_last_frame': None, + 'testvalue2': None, + 'sync_target_frames': None, + 'batch_size': None, + 'epsilon_final': None, + 'stable_baseline_test': None, + 'gamma': None, + 'learning_rate': None + } + } + } + assert expected_dict == get_structure_dict_for('') + + def test_get_structure_dict_for_rl(self): + expected_dict = { + 'replay_size': None, + 'epsilon_start': None, + 'replay_start_size': None, + 'epsilon_decay_last_frame': None, + 'testvalue2': None, + 'sync_target_frames': None, + 'batch_size': None, + 'epsilon_final': None, + 'stable_baseline_test': None, + 'gamma': None, + 'learning_rate': None + } + assert expected_dict == get_structure_dict_for('rl') + + def test_get_structure_dict_for_sim_market(self): + expected_dict = { + 'max_storage': None, + 'episode_length': None, + 'max_price': None, + 'max_quality': None, + 'number_of_customers': None, + 'production_price': None, + 'storage_cost_per_product': None + } + assert expected_dict == get_structure_dict_for('sim_market') + + def test_get_structure_dict_for_environment(self): + expected_dict = { + 'task': None, + 'enable_live_draw': None, + 'episodes': None, + 'plot_interval': None, + 'marketplace': None, + 'agents': [] + } + assert expected_dict == get_structure_dict_for('environment') + + def test_get_structure_dict_for_agents(self): + assert [] == get_structure_dict_for('agents') + + def test_to_config_keyword(self): + from ..models.config import Config + assert '' == to_config_keyword(Config) + from ..models.agents_config import AgentsConfig + assert 'agents' == to_config_keyword(AgentsConfig) + from ..models.environment_config import EnvironmentConfig + assert 'environment' == to_config_keyword(EnvironmentConfig) + from ..models.hyperparameter_config import HyperparameterConfig + assert 'hyperparameter' == to_config_keyword(HyperparameterConfig) + from ..models.rl_config import RlConfig + assert 'rl' == to_config_keyword(RlConfig) + from ..models.sim_market_config import SimMarketConfig + assert 'sim_market' == to_config_keyword(SimMarketConfig) + + def test_get_all_rl_parameter(self): + expected_parameter = { + ('gamma', float), + ('batch_size', int), + ('replay_start_size', int), + ('sync_target_frames', int), + ('epsilon_decay_last_frame', int), + ('replay_size', int), + ('epsilon_final', float), + ('stable_baseline_test', float), + ('testvalue2', float), + ('epsilon_start', float), + ('learning_rate', float) + } + assert expected_parameter == get_all_possible_rl_hyperparameter() + + def test_get_all_sim_market_parameter(self): + expected_parameter = { + ('max_price', int), + ('production_price', int), + ('episode_length', int), + ('max_quality', int), + ('max_storage', int), + ('storage_cost_per_product', (int, float)), + ('number_of_customers', int) + } + assert expected_parameter == get_all_possible_sim_market_hyperparameter() + + def test_convert_to_django_type(self): + assert "" == convert_python_type_to_django_type(int) + assert "" == convert_python_type_to_django_type(float) + assert "" == convert_python_type_to_django_type(str) + assert "" == convert_python_type_to_django_type((int, float)) diff --git a/webserver/alpha_business_app/urls.py b/webserver/alpha_business_app/urls.py index 16d78f34..f7f7a732 100644 --- a/webserver/alpha_business_app/urls.py +++ b/webserver/alpha_business_app/urls.py @@ -12,7 +12,8 @@ path('delete_config/', views.delete_config, name='delete_config'), # AJAX relevant url's - path('agent', views.agent, name='agent'), + path('new_agent', views.new_agent, name='new_agent'), + path('agent_changed', views.agent_changed, name='agent_changed'), path('api_availability', views.api_availability, name='api_availability'), path('marketplace_changed', views.marketplace_changed, name='marketplace'), path('validate_config', views.config_validation, name='config_validation'), diff --git a/webserver/alpha_business_app/utils.py b/webserver/alpha_business_app/utils.py new file mode 100644 index 00000000..b08bc1a7 --- /dev/null +++ b/webserver/alpha_business_app/utils.py @@ -0,0 +1,182 @@ +import recommerce.market.circular.circular_sim_market as circular_market +import recommerce.market.linear.linear_sim_market as linear_market +from recommerce.configuration.environment_config import EnvironmentConfig +from recommerce.configuration.utils import get_class + + +def convert_python_type_to_input_type(to_convert) -> str: + return 'number' if to_convert == float or to_convert == int else 'text' + + +def convert_python_type_to_django_type(to_convert: type) -> str: + """ + Converts standard python types, into a string of a Django model classes. + At the moment float, and int are supported, the rest will be Charfield. + + Args: + to_convert (type): standard python type ro be converted + + Returns: + str: string of a corresponding Django model class + """ + from django.db import models + if to_convert == float or (type(to_convert) == tuple and float in to_convert): + return str(models.FloatField) + elif to_convert == int: + return str(models.IntegerField) + else: + return str(models.CharField) + + +def get_recommerce_marketplaces() -> list: + """ + Matches marketplaces of recommerce.market.circular.circular_sim_market and recommerce.market.linear.linear_sim_market, + which contain one of the Keywords: Oligopoly, Duopoly, Monopoly + + Returns: + list: tuple list for selection + """ + keywords = ['Monopoly', 'Duopoly', 'Oligopoly'] + # get all circular marketplaces + circular_marketplaces = list(set(filter(lambda class_name: any(keyword in class_name for keyword in keywords), dir(circular_market)))) + circular_market_str = [f'recommerce.market.circular.circular_sim_market.{market}' for market in sorted(circular_marketplaces)] + # get all linear marketplaces + visible_linear_names = list(set(filter(lambda class_name: any(keyword in class_name for keyword in keywords), dir(linear_market)))) + linear_market_str = [f'recommerce.market.linear.linear_sim_market.{market}' for market in sorted(visible_linear_names)] + + return circular_market_str + linear_market_str + + +def get_recommerce_agents_for_marketplace(marketplace) -> list: + return marketplace.get_possible_rl_agents() + + +def get_all_possible_rl_hyperparameter() -> set: + """ + Gets all hyperparameters for all possible recommerce agents + + Returns: + set: of tuples, containing the hyperparameter name and the hyperparameter type + """ + all_marketplaces = get_recommerce_marketplaces() + all_agents = [] + for marketplace_str in all_marketplaces: + marketplace = get_class(marketplace_str) + all_agents += get_recommerce_agents_for_marketplace(marketplace) + + return get_attributes(all_agents) + + +def get_all_possible_sim_market_hyperparameter() -> set: + """ + Gets all hyperparameters for all possible recommerce markets + + Returns: + set: of tuples, containing the hyperparameter name and the hyperparameter type + """ + all_marketplaces = get_recommerce_marketplaces() + return get_attributes(all_marketplaces) + + +def get_attributes(all_classes: list) -> set: + """ + Calls `get_configurable_fields` and collects the name and the type of the returned fields in a list. + + Args: + all_classes (list): list of strings of classes that implement `get_configurable_fields` + + Returns: + set: of tuples, containing the attribute name and the attribute type + """ + all_attributes = [] + for class_str in all_classes: + current_class = get_class(class_str) + try: + # we do not necessarily need to include the rule, as it is currently not used in the webserver + all_attributes += [attribute[:2] for attribute in current_class.get_configurable_fields()] + except NotImplementedError: + print(f'please check the installation of the recommerce package!{current_class} does not implement "get_configurable_fields"') + return set(all_attributes) + + +def get_structure_dict_for(keyword: str) -> dict: + """ + Will return a Dictionary of the complete structure (all possible fields) for one suitable keyword. + + Args: + keyword (str): must be 'environment', 'hyperparameter', 'sim_market', 'rl', 'agents' or ''. + '' means it will return the whole strucutre dict + + Returns: + dict: general structure of the given keywords, values will always be None + """ + assert keyword in ['environment', 'hyperparameter', 'sim_market', 'rl', 'agents', ''], f'Your keyword {keyword} is not recognized.' + environment_dict = EnvironmentConfig.get_required_fields('top-dict') + environment_dict_with_none = {key: None for key in environment_dict.keys()} + environment_dict_with_none['agents'] = [] + + hyperparameter_dict_sim_market = {parameter[0]: None for parameter in get_all_possible_sim_market_hyperparameter()} + hyperparameter_dict_rl = {parameter[0]: None for parameter in get_all_possible_rl_hyperparameter()} + + hyperparameter_dict = { + 'sim_market': hyperparameter_dict_sim_market, + 'rl': hyperparameter_dict_rl + } + + structure_config_dict = { + 'environment': environment_dict_with_none, + 'hyperparameter': hyperparameter_dict + } + if keyword == 'environment': + return environment_dict_with_none + elif keyword == 'agents': + return environment_dict_with_none[keyword] + elif keyword == 'hyperparameter': + return hyperparameter_dict + elif keyword == 'rl': + return hyperparameter_dict_rl + elif keyword == 'sim_market': + return hyperparameter_dict_sim_market + elif not keyword: + return structure_config_dict + else: + assert False + + +def get_structure_with_types_of(top_level: str, second_level: str = None) -> dict: + """ + Currently only implemented for 'rl' and 'sim_market'. + Will return the structure of these configs with the correspondig types. + + Args: + top_level (str): top level dict key. ('envionment', 'hyperparameter') + second_level (str, optional): second level dict key. ('rl', 'sim_market', 'agents') Defaults to None. + + Returns: + dict: with keyword and Django type. + """ + assert top_level == 'hyperparameter' and second_level == 'rl' \ + or top_level == 'hyperparameter' and second_level == 'sim_market', \ + f'It is only implemented for "hyperparameter" and "rl, sim_market" not {top_level}, {second_level}' + if second_level == 'rl': + possible_attributes = get_all_possible_rl_hyperparameter() + if second_level == 'sim_market': + possible_attributes = get_all_possible_sim_market_hyperparameter() + final_attributes = [] + for attr in possible_attributes: + final_attributes += [(attr[0], convert_python_type_to_django_type(attr[1]))] + return final_attributes + + +def remove_none_values_from_dict(dict_with_none_values: dict) -> dict: + return {key: value for key, value in dict_with_none_values.items() if value is not None} + + +def to_config_class_name(name: str) -> str: + return ''.join([x.title() for x in name.split('_')]) + 'Config' + + +def to_config_keyword(class_name: str) -> str: + name_without_config = str(class_name).rsplit('.')[-1][:-2].replace('Config', '').lower() + # revert the '_' + return 'sim_market' if name_without_config == 'simmarket' else name_without_config diff --git a/webserver/alpha_business_app/views.py b/webserver/alpha_business_app/views.py index 65a08308..5736ec93 100644 --- a/webserver/alpha_business_app/views.py +++ b/webserver/alpha_business_app/views.py @@ -6,6 +6,7 @@ from recommerce.configuration.config_validation import validate_config +from .adjustable_fields import get_agent_hyperparameter from .buttons import ButtonHandler from .config_parser import ConfigFlatDictParser from .forms import UploadFileForm @@ -94,13 +95,21 @@ def delete_config(request, config_id) -> HttpResponse: # AJAX relevant views @login_required -def agent(request) -> HttpResponse: +def new_agent(request) -> HttpResponse: if not request.user.is_authenticated: return HttpResponse('Unauthorized', status=401) return render(request, 'configuration_items/agent.html', {'id': str(uuid4()), 'name': 'Competitor', 'agent_selections': selection_manager.get_competitor_options_for_marketplace()}) +@login_required +def agent_changed(request) -> HttpResponse: + if not request.user.is_authenticated: + return HttpResponse('Unauthorized', status=401) + return render(request, 'configuration_items/rl_parameter.html', + {'parameters': get_agent_hyperparameter(request.POST['agent'], request.POST.dict())}) + + def api_availability(request) -> HttpResponse: if not request.user.is_authenticated: return render(request, 'api_buttons/api_health_button.html') @@ -130,19 +139,19 @@ def config_validation(request) -> HttpResponse: config_dict = ConfigFlatDictParser().flat_dict_to_hierarchical_config_dict(resulting_dict) - validate_status, validate_data = validate_config(config=config_dict, config_is_final=True) + validate_status, validate_data = validate_config(config=config_dict) if not validate_status: return render(request, 'notice_field.html', {'error': validate_data}) return render(request, 'notice_field.html', {'success': 'This config is valid'}) +@login_required def marketplace_changed(request) -> HttpResponse: if not request.user.is_authenticated: return HttpResponse('Unauthorized', status=401) marketplace_class = None if request.method == 'POST': post_request = request.POST - # print(post_request) marketplace_class = post_request['marketplace'] raw_html = post_request['agents_html'] return HttpResponse(content=selection_manager.get_correct_agents_html_on_marketplace_change(request, marketplace_class, raw_html)) diff --git a/webserver/templates/base.html b/webserver/templates/base.html index 729fd4d2..9154f3c8 100644 --- a/webserver/templates/base.html +++ b/webserver/templates/base.html @@ -70,7 +70,7 @@ - + diff --git a/webserver/templates/configuration_items/add_more_button.html b/webserver/templates/configuration_items/add_more_button.html index 6ba5af6b..b00f2a07 100644 --- a/webserver/templates/configuration_items/add_more_button.html +++ b/webserver/templates/configuration_items/add_more_button.html @@ -1 +1 @@ - + diff --git a/webserver/templates/configuration_items/agent.html b/webserver/templates/configuration_items/agent.html index ee78fe99..1e25018f 100644 --- a/webserver/templates/configuration_items/agent.html +++ b/webserver/templates/configuration_items/agent.html @@ -19,7 +19,7 @@

class
- {% include "configuration_items/selection_list.html" with prefill_value=prefill.agent_class selections=agent_selections %}
diff --git a/webserver/templates/configuration_items/rl.html b/webserver/templates/configuration_items/rl.html index fa9c339e..cc2f781a 100644 --- a/webserver/templates/configuration_items/rl.html +++ b/webserver/templates/configuration_items/rl.html @@ -5,116 +5,8 @@

-
- {% load static %} -
-
- {% if error_dict.gamma %} - - {% endif %} - gamma -
-
- -
-
-
-
- {% if error_dict.batch_size %} - - {% endif %} - batch size -
-
- -
-
-
-
- {% if error_dict.replay_size %} - - {% endif %} - replay size -
-
- -
-
-
-
- {% if error_dict.learning_rate %} - - {% endif %} - learning rate -
-
- -
-
-
-
- {% if error_dict.sync_target_frames %} - - {% endif %} - sync target frames -
-
- -
-
-
-
- {% if error_dict.start_size %} - - {% endif %} - replay start size -
-
- -
-
-
-
- {% if error_dict.epsilon_decay_last_frame %} - - {% endif %} - epsilon decay last frame -
-
- -
-
-
-
- {% if error_dict.epsilon_start %} - - {% endif %} - epsilon start -
-
- -
-
-
-
- {% if error_dict.epsilon_final %} - - {% endif %} - epsilon final -
-
- -
-
+
+ {% include "configuration_items/rl_parameter.html" with parameters=prefill %}
diff --git a/webserver/templates/configuration_items/rl_parameter.html b/webserver/templates/configuration_items/rl_parameter.html new file mode 100644 index 00000000..9cc66af0 --- /dev/null +++ b/webserver/templates/configuration_items/rl_parameter.html @@ -0,0 +1,15 @@ +{% load static %} +{% for parameter in parameters %} +
+
+ {% if parameter.error %} + + {% endif %} + {{parameter.name}} +
+
+ +
+
+{% endfor %} diff --git a/webserver/templates/configurator.html b/webserver/templates/configurator.html index 6227caea..ede3e0c6 100644 --- a/webserver/templates/configurator.html +++ b/webserver/templates/configurator.html @@ -4,51 +4,48 @@ {% load static %}

You can configure your experiments here

-
-
- {% csrf_token %} - - - {% for config in all_configurations %} - - - - - - {% endfor %} -
- {% if not config.is_referenced %} - - - - {% endif %} - - - - {{config.name}} -
- {% if all_configurations %} - - {% endif %} -
-
-
-
- {% csrf_token %} - - {% include "configuration_items/config.html" with prefill=prefill should_show=True error_dict=error_dict %} -
- - + + {% csrf_token %} +
+
+ + {% for config in all_configurations %} + + + + + + {% endfor %} +
+ {% if not config.is_referenced %} + + + + {% endif %} + + + + {{config.name}} +
+ {% if all_configurations %} + + {% endif %}
+
+ {% include "configuration_items/config.html" with prefill=prefill should_show=True error_dict=error_dict %} +
+ + +
-
- -
- +
+ +
+ +
- -
+
+
{% endblock content %}