diff --git a/.n6-version b/.n6-version index 2aa4d8f..005e92c 100644 --- a/.n6-version +++ b/.n6-version @@ -1 +1 @@ -3.0.0b2 +3.0.0b3 diff --git a/N6BrokerAuthApi/n6brokerauthapi/__init__.py b/N6BrokerAuthApi/n6brokerauthapi/__init__.py index c9cf5f8..16c7791 100644 --- a/N6BrokerAuthApi/n6brokerauthapi/__init__.py +++ b/N6BrokerAuthApi/n6brokerauthapi/__init__.py @@ -1,5 +1,5 @@ -# Copyright (c) 2013-2019 NASK. All rights reserved. - +# Copyright (c) 2013-2021 NASK. All rights reserved. +#TODO: Module modernized to Python 3, but no changes detected, comment to be deleted after MR """ This package provides a REST API implementation intended to cooperate with `rabbitmq-auth-backend-http` -- the RabbitMQ AMQP message broker's diff --git a/N6BrokerAuthApi/n6brokerauthapi/auth_base.py b/N6BrokerAuthApi/n6brokerauthapi/auth_base.py index 460f6e6..56bcea9 100644 --- a/N6BrokerAuthApi/n6brokerauthapi/auth_base.py +++ b/N6BrokerAuthApi/n6brokerauthapi/auth_base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2019 NASK. All rights reserved. +# Copyright (c) 2013-2021 NASK. All rights reserved. import sys import threading @@ -17,7 +17,7 @@ LOGGER = get_logger(__name__) -class BaseBrokerAuthManagerMaker(object): +class BaseBrokerAuthManagerMaker: def __init__(self, settings): self._db_connector = SQLAuthDBConnector(settings=settings) @@ -47,7 +47,7 @@ def get_manager_factory_kwargs(self, validated_view_params): params=validated_view_params) -class BaseBrokerAuthManager(object): +class BaseBrokerAuthManager: def __init__(self, db_connector, @@ -114,9 +114,9 @@ def client_type(self): assert self.client_obj is not None if isinstance(self.client_obj, models.User): return 'user' - elif isinstance(self.client_obj, models.Component): + if isinstance(self.client_obj, models.Component): return 'component' - raise TypeError('the client object {!r} is an instance of ' + raise TypeError('the client object {!a} is an instance of ' 'a wrong class'.format(self.client_obj)) @property @@ -134,7 +134,7 @@ def _get_admins_group(self): return self.db_session.query(models.SystemGroup).filter( models.SystemGroup.name == ADMINS_SYSTEM_GROUP_NAME).one() except NoResultFound: - LOGGER.error('System group %r not found in auth db!', ADMINS_SYSTEM_GROUP_NAME) + LOGGER.error('System group %a not found in auth db!', ADMINS_SYSTEM_GROUP_NAME) return None # diff --git a/N6BrokerAuthApi/n6brokerauthapi/auth_stream_api.py b/N6BrokerAuthApi/n6brokerauthapi/auth_stream_api.py index ec855b6..4870445 100644 --- a/N6BrokerAuthApi/n6brokerauthapi/auth_stream_api.py +++ b/N6BrokerAuthApi/n6brokerauthapi/auth_stream_api.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2019 NASK. All rights reserved. +# Copyright (c) 2013-2021 NASK. All rights reserved. import re @@ -28,7 +28,7 @@ class StreamApiBrokerAuthManagerMaker(ConfigMixin, BaseBrokerAuthManagerMaker): """ def __init__(self, settings): - super(StreamApiBrokerAuthManagerMaker, self).__init__(settings=settings) + super().__init__(settings=settings) self._config = self.get_config_section(settings) self._thread_local = ThreadLocalNamespace(attr_factories={ 'autogenerated_queue_matcher': self._make_autogenerated_queue_matcher, @@ -46,7 +46,7 @@ def get_manager_factory(self, params): return StreamApiBrokerAuthManager def get_manager_factory_kwargs(self, params): - base = super(StreamApiBrokerAuthManagerMaker, self).get_manager_factory_kwargs(params) + base = super().get_manager_factory_kwargs(params) return dict(base, push_exchange_name=self._config['push_exchange_name'] or None, privileged_component_logins=self._config['privileged_component_logins'], @@ -63,7 +63,7 @@ def __init__(self, self._push_exchange_name = push_exchange_name self._privileged_component_logins = privileged_component_logins self._autogenerated_queue_matcher = autogenerated_queue_matcher - super(StreamApiBrokerAuthManager, self).__init__(**kwargs) + super().__init__(**kwargs) EXPLICITLY_ILLEGAL_USERNAMES = ('', 'guest') @@ -71,15 +71,15 @@ def __init__(self, def should_try_to_verify_client(self): if self.broker_username in self.EXPLICITLY_ILLEGAL_USERNAMES: LOGGER.error( - "The '%s' username is explicitly considered illegal!", + "The '%a' username is explicitly considered illegal!", ascii_str(self.broker_username)) return False if self.password is not None: LOGGER.error( - "Authentication by password is not supported - cannot authenticate '%s'!", + "Authentication by password is not supported - cannot authenticate '%a'!", ascii_str(self.broker_username)) return False - return super(StreamApiBrokerAuthManager, self).should_try_to_verify_client() + return super().should_try_to_verify_client() def verify_and_get_user_obj(self): user_obj = self._from_db(models.User, 'login', self.broker_username) diff --git a/N6BrokerAuthApi/n6brokerauthapi/tests/test_views_with_auth_stream_api.py b/N6BrokerAuthApi/n6brokerauthapi/tests/test_views_with_auth_stream_api.py index 9e47de3..99397fa 100644 --- a/N6BrokerAuthApi/n6brokerauthapi/tests/test_views_with_auth_stream_api.py +++ b/N6BrokerAuthApi/n6brokerauthapi/tests/test_views_with_auth_stream_api.py @@ -1,9 +1,9 @@ -# Copyright (c) 2013-2019 NASK. All rights reserved. +# Copyright (c) 2013-2021 NASK. All rights reserved. import itertools import unittest -from mock import ( +from unittest.mock import ( MagicMock, call, patch, @@ -167,22 +167,22 @@ def assertConnectorUsedOnlyAfterEnsuredClean(self): # noinspection PyUnresolvedReferences -class _AssertResponseMixin(object): +class _AssertResponseMixin: def assertAllow(self, resp): - self.assertIn(resp.body, ['allow', 'allow administrator']) + self.assertIn(resp.body, [b'allow', b'allow administrator']) self.assertEqual(resp.status_code, 200) def assertDeny(self, resp): - self.assertEqual(resp.body, 'deny') + self.assertEqual(resp.body, b'deny') self.assertEqual(resp.status_code, 200) def assertAdministratorTagPresent(self, resp): - self.assertIn('administrator', resp.body.split()) + self.assertIn(b'administrator', resp.body.split()) self.assertEqual(resp.status_code, 200) def assertNoAdministratorTag(self, resp): - self.assertNotIn('administrator', resp.body.split()) + self.assertNotIn(b'administrator', resp.body.split()) self.assertEqual(resp.status_code, 200) @@ -210,7 +210,7 @@ def basic_allow_params(cls): @paramseq def __param_name_combinations(cls): required_param_names = sorted(cls.basic_allow_params()) - for i in xrange(len(required_param_names)): + for i in range(len(required_param_names)): for some_param_names in itertools.combinations(required_param_names, i+1): assert set(some_param_names).issubset(required_param_names) yield list(some_param_names) @@ -218,7 +218,7 @@ def __param_name_combinations(cls): @staticmethod def __adjust_params(params, kwargs): params.update(kwargs) - for name, value in list(params.iteritems()): + for name, value in list(params.items()): if value is None: del params[name] diff --git a/N6BrokerAuthApi/n6brokerauthapi/views.py b/N6BrokerAuthApi/n6brokerauthapi/views.py index f235cad..7e95989 100644 --- a/N6BrokerAuthApi/n6brokerauthapi/views.py +++ b/N6BrokerAuthApi/n6brokerauthapi/views.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2019 NASK. All rights reserved. +# Copyright (c) 2013-2021 NASK. All rights reserved. import logging @@ -21,7 +21,7 @@ class _DenyAccess(Exception): def __init__(self, error_log_message=None): - super(_DenyAccess, self).__init__(error_log_message) + super().__init__(error_log_message) self.error_log_message = error_log_message @@ -38,7 +38,7 @@ def __call__(self): try: # involves use of `iter_deduplicated_params()` and `make_response()`... try: - return super(_N6BrokerAuthViewBase, self).__call__() + return super().__call__() except ParamCleaningError as exc: raise _DenyAccess(error_log_message=exc.public_message) except _DenyAccess as deny_exc: @@ -81,7 +81,7 @@ def auth_manager_maker(self): @attr_required('param_name_to_required_flag') def get_required_param_names(cls): return {name - for name, required in cls.param_name_to_required_flag.iteritems() + for name, required in cls.param_name_to_required_flag.items() if required} def allow_response(self): @@ -100,18 +100,18 @@ def safe_name(self, name): # Private stuff def _log(self, level, log_message): - LOGGER.log(level, '[%r: %s] %s', + LOGGER.log(level, '[%a: %a] %a', self, ascii_str(self.request.url), ascii_str(log_message)) def _ensure_all_param_names_and_values_are_strings(self): - if not all(isinstance(key, basestring) and - isinstance(val, basestring) - for key, val in self.params.iteritems()): + if not all(isinstance(key, str) and + isinstance(val, str) + for key, val in self.params.items()): raise AssertionError( 'this should never happen: not all request param names and ' - 'values are strings! (params: {!r})'.format(self.params)) + 'values are strings! (params: {!a})'.format(self.params)) def _warn_if_unknown_params(self): known_param_names = set(self.param_name_to_required_flag) @@ -147,8 +147,8 @@ class _N6BrokerAuthResourceViewBase(_N6BrokerAuthViewBase): @attr_required('valid_permissions', 'valid_resources') def validate_params(self): - super(_N6BrokerAuthResourceViewBase, self).validate_params() - assert self.params.viewkeys() >= {'resource', 'permission'} + super().validate_params() + assert self.params.keys() >= {'resource', 'permission'} resource = self.params['resource'] permission = self.params['permission'] if resource not in self.valid_resources: diff --git a/N6BrokerAuthApi/setup.py b/N6BrokerAuthApi/setup.py index 4e5ab37..d2441fe 100644 --- a/N6BrokerAuthApi/setup.py +++ b/N6BrokerAuthApi/setup.py @@ -21,16 +21,16 @@ def get_n6_version(filename_base): path = matching_paths[0] except IndexError: sys.exit('[{}] Cannot determine the n6 version ' - '(no files match the pattern {!r}).' + '(no files match the pattern {!a}).' .format(setup_human_readable_ref, path_glob_pattern)) try: - with open(path) as f: #3: add: `, encoding='ascii'` + with open(path, encoding='ascii') as f: return f.read().strip() except (OSError, UnicodeError) as exc: sys.exit('[{}] Cannot determine the n6 version ' '(an error occurred when trying to ' - 'read it from the file {!r} - {}).' + 'read it from the file {!a} - {}).' .format(setup_human_readable_ref, path, exc)) @@ -41,7 +41,6 @@ def get_n6_version(filename_base): requires = [ 'n6lib==' + n6_version, 'pyramid==1.10.8', - 'typing', ] setup( @@ -51,12 +50,13 @@ def get_n6_version(filename_base): packages=find_packages(), include_package_data=True, zip_safe=False, + python_requres='==3.9.*', install_requires=requires, entry_points="""\ [paste.app_factory] main = n6brokerauthapi:main """, - tests_require=['mock==3.0.5', 'unittest_expander==0.3.1'], + tests_require=['unittest_expander==0.3.1'], test_suite='n6brokerauthapi.tests', description='Authentication and authorization API for RabbitMQ', url='https://github.com/CERT-Polska/n6', @@ -66,7 +66,7 @@ def get_n6_version(filename_base): 'License :: OSI Approved :: GNU Affero General Public License v3', 'Operating System :: POSIX :: Linux', 'Programming Language :: Python', - 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.9', "Framework :: Pyramid", 'Topic :: Security', ], diff --git a/N6Core/README.md b/N6Core/README.md new file mode 100644 index 0000000..762275c --- /dev/null +++ b/N6Core/README.md @@ -0,0 +1,8 @@ +**Note:** `N6Core` contains legacy *Python-2-only* stuff. Typically, +you will want to use -- instead of it -- the new, *Python-3-only* stuff +residing in `N6DataPipeline`. + +Then it comes to data sources -- i.e., collectors and parsers -- +`N6DataSources` is the place where new sources should be implemented +(in Python 3). The collectors and parsers residing in `N6Core` will +be gradually migrated to `N6DataSources` (if not obsolete). diff --git a/N6Core/n6/archiver/recorder.py b/N6Core/n6/archiver/recorder.py index e9cc835..d90edde 100644 --- a/N6Core/n6/archiver/recorder.py +++ b/N6Core/n6/archiver/recorder.py @@ -9,6 +9,7 @@ ### TODO: this module is to be replaced with a new implementation... +from builtins import range #3: -- import datetime import logging import os @@ -187,7 +188,7 @@ def _setup_db(self): def _install_session_variables_setter(self, engine, **session_variables): setter_sql = 'SET ' + ' , '.join( 'SESSION {} = {}'.format(name, value) - for name, value in session_variables.iteritems()) + for name, value in session_variables.items()) @sqlalchemy.event.listens_for(engine, 'connect') def set_session_variables(dbapi_connection, connection_record): @@ -270,7 +271,7 @@ def get_truncated_rk(rk, parts): rk = rk.split('.') parts_rk = [] try: - for i in xrange(parts): + for i in range(parts): parts_rk.append(rk[i]) except IndexError: LOGGER.warning("routing key %r contains less than %r segments", rk, parts) diff --git a/N6Core/n6/collectors/generic.py b/N6Core/n6/collectors/generic.py index 0ee4a01..d7730cb 100644 --- a/N6Core/n6/collectors/generic.py +++ b/N6Core/n6/collectors/generic.py @@ -35,6 +35,8 @@ make_exc_ascii_str, ) from n6corelib.email_message import ReceivedEmailMessage + +from n6lib.const import RAW_TYPE_ENUMS from n6lib.http_helpers import RequestPerformer from n6lib.log_helpers import ( get_logger, @@ -264,7 +266,6 @@ class BaseCollector(CollectorConfigMixin, QueuedBase, AbstractBaseCollector): # (note that this is something completely *different* than # .event_type and ['type']) type = None - limits_type_of = ('stream', 'file', 'blacklist') # the attribute has to be overridden, if a component should # accept the "--n6recovery" argument option and inherits from @@ -326,9 +327,9 @@ def set_queue_name(self): def _validate_type(self): """Validate type of message, should be one of: 'stream', 'file', 'blacklist.""" - if self.type not in self.limits_type_of: + if self.type not in RAW_TYPE_ENUMS: raise Exception('Wrong type of archived data in mongo: {0},' - ' should be one of: {1}'.format(self.type, self.limits_type_of)) + ' should be one of: {1}'.format(self.type, RAW_TYPE_ENUMS)) def update_connection_params_dict_before_run(self, params_dict): """ diff --git a/N6Core/n6/data/conf/pipeline.conf b/N6Core/n6/data/conf/00_pipeline.conf similarity index 98% rename from N6Core/n6/data/conf/pipeline.conf rename to N6Core/n6/data/conf/00_pipeline.conf index 422e3d7..9ffcabc 100644 --- a/N6Core/n6/data/conf/pipeline.conf +++ b/N6Core/n6/data/conf/00_pipeline.conf @@ -25,3 +25,4 @@ comparator = enriched filter = enriched, compared anonymizer = filtered recorder = filtered +counter= recorded diff --git a/N6Core/n6/data/conf/07_aggregator.conf b/N6Core/n6/data/conf/07_aggregator.conf index 61af786..c5010ff 100644 --- a/N6Core/n6/data/conf/07_aggregator.conf +++ b/N6Core/n6/data/conf/07_aggregator.conf @@ -1,15 +1,17 @@ [aggregator] -## path to the local aggregator's database file -## (the database file will be created automatically -## on the 1st aggregator run, if possible) +# path to the local aggregator's database file +# (the database file will be created automatically +# on the 1st aggregator run, if possible) dbpath=~/.n6aggregator/aggregator_db.pickle -## time interval (in seconds) within which non-monotonic times of -## events are tolerated +# time interval (in seconds) within which non-monotonic times of +# events are tolerated time_tolerance=600 -## time interval like `time_tolerance`, but defined for specific source -## (if it is not defined for the current source, -## `time_tolerance` is used) -time_tolerance_per_source={} +# time interval like `time_tolerance`, but defined for specific source +# (if it is not defined for the current source, +# `time_tolerance` is used) +;time_tolerance_per_source={ +; "some-src.its-channel": 1200, +; "other-src.foobar": 900} diff --git a/N6Core/n6/data/conf/09_auth_db.conf b/N6Core/n6/data/conf/09_auth_db.conf index c74d234..3f97a70 100644 --- a/N6Core/n6/data/conf/09_auth_db.conf +++ b/N6Core/n6/data/conf/09_auth_db.conf @@ -1,34 +1,34 @@ [auth_db] -## connection URL, e.g.: mysql+mysqldb://n6:somepassword@localhost/n6 -## it must start with `mysql+mysqldb:` (or just `mysql:`) because other -## dialects/drivers are not supported -## (see also: http://docs.sqlalchemy.org/en/rel_0_9/core/engines.html) -#url = mysql://user:password@host/dbname +# connection URL, e.g.: mysql+mysqldb://n6:somepassword@localhost/n6 +# it must start with `mysql+mysqldb:` (or just `mysql:`) because other +# dialects/drivers are not supported +# (see also: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) +;url = mysql://user:password@host/dbname -## if you want to use SSL, the following options must be set to -## appropriate file paths: -#ssl_cacert = /some/path/to/CACertificatesFile.pem -#ssl_cert = /some/path/to/ClientCertificateFile.pem -#ssl_key = /some/path/to/private/ClientCertificateKeyFile.pem +# if you want to use SSL, the following options must be set to +# appropriate file paths: +;ssl_cacert = /some/path/to/CACertificatesFile.pem +;ssl_cert = /some/path/to/ClientCertificateFile.pem +;ssl_key = /some/path/to/private/ClientCertificateKeyFile.pem [auth_db_session_variables] -## all MySQL variables specified within this section will be set by -## executing "SET SESSION = , ...". -## WARNING: for simplicity, the variable names and values are inserted -## into SQL code "as is", *without* any escaping (we assume we can treat -## configuration files as a *trusted* source of data). +# all MySQL variables specified within this section will be set by +# executing "SET SESSION = , ...". +# WARNING: for simplicity, the variable names and values are inserted +# into SQL code "as is", *without* any escaping (we assume we can treat +# configuration files as a *trusted* source of data). -## (`[auth_db_session_variables].wait_timeout` should be -## greater than `[auth_db_connection_pool].pool_recycle`) +# (`[auth_db_session_variables].wait_timeout` should be +# greater than `[auth_db_connection_pool].pool_recycle`) wait_timeout = 7200 [auth_db_connection_pool] -## (generally, the defaults should be OK in most cases; if you are -## interested in technical details -- see: SQLAlchemy docs...) +# (generally, the defaults should be OK in most cases; if you are +# interested in technical details -- see: SQLAlchemy docs...) pool_recycle = 3600 pool_timeout = 20 pool_size = 15 diff --git a/N6Core/n6/tests/utils/test_anonymizer.py b/N6Core/n6/tests/utils/test_anonymizer.py index 7d58cd3..7a4ce21 100644 --- a/N6Core/n6/tests/utils/test_anonymizer.py +++ b/N6Core/n6/tests/utils/test_anonymizer.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# Copyright (c) 2013-2019 NASK. All rights reserved. +# Copyright (c) 2013-2021 NASK. All rights reserved. import datetime import json diff --git a/N6Core/n6/utils/aggregator.py b/N6Core/n6/utils/aggregator.py index f11670c..d614a8e 100644 --- a/N6Core/n6/utils/aggregator.py +++ b/N6Core/n6/utils/aggregator.py @@ -281,11 +281,12 @@ class Aggregator(ConfigMixin, QueuedBase): [aggregator] dbpath time_tolerance :: int - time_tolerance_per_source = {} :: json + time_tolerance_per_source = {} :: py_namespaces_dict ''' def __init__(self, **kwargs): self.aggregator_config = self.get_config_section() + self.aggregator_config['dbpath'] = os.path.expanduser(self.aggregator_config['dbpath']) dbpath_dirname = os.path.dirname(self.aggregator_config['dbpath']) try: os.makedirs(dbpath_dirname, 0o700) diff --git a/N6Core/n6/utils/anonymizer.py b/N6Core/n6/utils/anonymizer.py index 166f669..4f6bb0d 100644 --- a/N6Core/n6/utils/anonymizer.py +++ b/N6Core/n6/utils/anonymizer.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# Copyright (c) 2013-2019 NASK. All rights reserved. +# Copyright (c) 2013-2021 NASK. All rights reserved. """ Anonymizer -- performs validation and anonymization of event data @@ -72,7 +72,7 @@ def _process_input(self, event_type, event_data): force_exit_on_any_remaining_entered_contexts(self.auth_api) with self.auth_api: resource_to_org_ids = self._get_resource_to_org_ids(event_type, event_data) - if any(org_ids for org_ids in resource_to_org_ids.itervalues()): + if any(org_ids for org_ids in resource_to_org_ids.values()): (raw_result_dict, cleaned_result_dict, output_body) = self._get_result_dicts_and_output_body( @@ -114,7 +114,7 @@ def _get_resource_to_org_ids(self, org_id.decode('ascii', 'strict') for org_id in event_data.get('client', ())) for subsource_refint, ( - predicate, res_to_org_ids) in subsource_to_saa_info.iteritems(): + predicate, res_to_org_ids) in subsource_to_saa_info.items(): subs_inside_org_ids = res_to_org_ids['inside'] & client_org_ids subs_threats_org_ids = res_to_org_ids['threats'] if not subs_inside_org_ids and not subs_threats_org_ids: @@ -148,7 +148,7 @@ def _get_result_dicts_and_output_body(self, raw_result_dict = cleaned_result_dict = None try: raw_result_dict = { - k: v for k, v in event_data.iteritems() + k: v for k, v in event_data.items() if (k in self.data_spec.all_result_keys and # eliminating empty `address` and `client` sequences # (as the data spec will not accept them empty): @@ -193,7 +193,7 @@ def _publish_output_data(self, done_resource_to_org_ids = { resource: [] for resource in resource_to_org_ids} - for resource, res_org_ids in sorted(resource_to_org_ids.iteritems()): + for resource, res_org_ids in sorted(resource_to_org_ids.items()): done_org_ids = done_resource_to_org_ids[resource] output_rk = self.OUTPUT_RK_PATTERN.format( resource=resource, diff --git a/N6Core/n6/utils/comparator.py b/N6Core/n6/utils/comparator.py index 17d2353..5564bb6 100644 --- a/N6Core/n6/utils/comparator.py +++ b/N6Core/n6/utils/comparator.py @@ -2,10 +2,11 @@ import datetime import json -import cPickle +import pickle import os import os.path +from n6lib.common_helpers import open_file from n6lib.config import Config from n6lib.datetime_helpers import parse_iso_datetime_to_utc from n6lib.log_helpers import get_logger, logging_configured @@ -125,7 +126,7 @@ def process_event(self, data): def process_deleted(self): ret_value = [] - for key, event in list(self.blacklist.iteritems()): + for key, event in list(self.blacklist.items()): if event.flag is None: value = event.payload.copy() if key in self.blacklist: @@ -144,7 +145,7 @@ def process_deleted(self): return ret_value def clear_flags(self, flag_id): - for key, event in list(self.blacklist.iteritems()): + for key, event in list(self.blacklist.items()): if event.flag == flag_id: event.flag = None self.blacklist[key] = event @@ -182,14 +183,14 @@ def __init__(self, dbpath): def store_state(self): try: - with open(self.dbpath, "w") as f: - cPickle.dump(self.comp_data, f) + with open_file(self.dbpath, "w") as f: + pickle.dump(self.comp_data, f, 2) except IOError: LOGGER.error("Error saving state to: %r", self.dbpath) def restore_state(self): - with open(self.dbpath, "r") as f: - self.comp_data = cPickle.load(f) + with open_file(self.dbpath, "r") as f: + self.comp_data = pickle.load(f) def process_new_message(self, data): """Processes a message and validates agains db to detect new/change/update. @@ -319,7 +320,7 @@ def __init__(self, **kwargs): self.comparator_config["dbpath"] = os.path.expanduser(self.comparator_config["dbpath"]) dbpath_dirname = os.path.dirname(self.comparator_config["dbpath"]) try: - os.makedirs(dbpath_dirname, 0700) + os.makedirs(dbpath_dirname, 0o700) except OSError: pass super(Comparator, self).__init__(**kwargs) @@ -363,7 +364,7 @@ def finalize_series(self, series_id, bl_name): def _cleanup_data(self, data): """Removes artifacts from earlier processing (_bl-series-no, _bl-series-total, _bl-series-id) """ - for k in data.keys(): + for k in list(data.keys()): if k.startswith("_bl-series"): del data[k] return data diff --git a/N6Core/n6/utils/enrich.py b/N6Core/n6/utils/enrich.py index e415198..e9dd165 100644 --- a/N6Core/n6/utils/enrich.py +++ b/N6Core/n6/utils/enrich.py @@ -244,7 +244,7 @@ def _final_sanity_assertions(self, data): for name in enriched_keys) assert all( set(addr_keys).issubset(ip_to_addr[ip]) - for ip, addr_keys in list(ip_to_enriched_address_keys.items())) + for ip, addr_keys in ip_to_enriched_address_keys.items()) # # Resolution helpers diff --git a/N6CoreLib/README.md b/N6CoreLib/README.md new file mode 100644 index 0000000..74abf25 --- /dev/null +++ b/N6CoreLib/README.md @@ -0,0 +1,2 @@ +**Note:** `N6CoreLib` contains *Python-2-only* helpers, used by the +`N6Core` legacy stuff. diff --git a/N6DataPipeline/console_scripts b/N6DataPipeline/console_scripts index fb2e4ba..337eec6 100644 --- a/N6DataPipeline/console_scripts +++ b/N6DataPipeline/console_scripts @@ -1,10 +1,10 @@ #n6archiveraw = n6datapipeline.archive_raw:main n6aggregator = n6datapipeline.aggregator:main n6enrich = n6datapipeline.enrich:main -#n6comparator = n6datapipeline.comparator:main -#n6filter = n6datapipeline.filter:main -#n6recorder = n6datapipeline.recorder:main -#n6anonymizer = n6datapipeline.aux.anonymizer:main +n6comparator = n6datapipeline.comparator:main +n6filter = n6datapipeline.filter:main +n6recorder = n6datapipeline.recorder:main +n6anonymizer = n6datapipeline.aux.anonymizer:main #n6manage = n6datapipeline.aux.pkitools.n6manage:main n6run_bot = n6datapipeline.intelmq.wrapper:main n6run_adapter_to_n6 = n6datapipeline.intelmq.utils.intelmq_adapter:run_intelmq_to_n6 diff --git a/N6DataPipeline/n6datapipeline/aggregator.py b/N6DataPipeline/n6datapipeline/aggregator.py index 8798eca..9695d5e 100644 --- a/N6DataPipeline/n6datapipeline/aggregator.py +++ b/N6DataPipeline/n6datapipeline/aggregator.py @@ -281,11 +281,12 @@ class Aggregator(ConfigMixin, LegacyQueuedBase): [aggregator] dbpath time_tolerance :: int - time_tolerance_per_source = {} :: json + time_tolerance_per_source = {} :: py_namespaces_dict ''' def __init__(self, **kwargs): self.aggregator_config = self.get_config_section() + self.aggregator_config['dbpath'] = os.path.expanduser(self.aggregator_config['dbpath']) dbpath_dirname = os.path.dirname(self.aggregator_config['dbpath']) try: os.makedirs(dbpath_dirname, 0o700) diff --git a/N6DataPipeline/n6datapipeline/aux/anonymizer.py b/N6DataPipeline/n6datapipeline/aux/anonymizer.py new file mode 100644 index 0000000..74d1812 --- /dev/null +++ b/N6DataPipeline/n6datapipeline/aux/anonymizer.py @@ -0,0 +1,240 @@ +# Copyright (c) 2013-2021 NASK. All rights reserved. + +""" +Anonymizer -- performs validation and anonymization of event data +before publishing them using the (STOMP-based) Stream API. +""" + +import json + +from n6datapipeline.base import LegacyQueuedBase +from n6lib.auth_api import AuthAPI +from n6lib.const import TYPE_ENUMS +from n6lib.context_helpers import force_exit_on_any_remaining_entered_contexts +from n6lib.data_spec import N6DataSpec +from n6lib.db_filtering_abstractions import RecordFacadeForPredicates +from n6lib.log_helpers import get_logger, logging_configured +from n6sdk.pyramid_commons.renderers import data_dict_to_json + + +LOGGER = get_logger(__name__) + + +class Anonymizer(LegacyQueuedBase): + + # note: here `resource` denotes a *Stream API resource*: + # "inside" (corresponding to the "inside" access zone) or + # "threats" (corresponding to the "threats" access zone) + # -- see the _get_resource_to_org_ids() method below + OUTPUT_RK_PATTERN = '{resource}.{category}.{anon_source}' + + input_queue = { + 'exchange': 'event', + 'exchange_type': 'topic', + 'queue_name': 'anonymizer', + } + + output_queue = { + 'exchange': 'clients', + 'exchange_type': 'headers', + } + + basic_prop_kwargs = {'delivery_mode': 1} # non-persistent + + supports_n6recovery = False + + _VALID_EVENT_TYPES = frozenset(TYPE_ENUMS) + + def __init__(self, **kwargs): + LOGGER.info("Anonymizer Start") + super(Anonymizer, self).__init__(**kwargs) + self.auth_api = AuthAPI() + self.data_spec = N6DataSpec() + + def input_callback(self, routing_key, body, properties): + # NOTE: we do not need to use n6lib.record_dict.RecordDict here, + # because: + # * previous components (such as filter) have already done the + # preliminary validation (using RecordDict's mechanisms); + # * we are doing the final validation anyway using + # N6DataSpec.clean_result_dict() (below -- in the + # _get_result_dicts_and_output_body() method) + event_data = json.loads(body) + with self.setting_error_event_info(event_data): + event_type = routing_key.split('.', 1)[0] + self._process_input(event_type, event_data) + + def _process_input(self, event_type, event_data): + self._check_event_type(event_type, event_data) + force_exit_on_any_remaining_entered_contexts(self.auth_api) + with self.auth_api: + resource_to_org_ids = self._get_resource_to_org_ids(event_type, event_data) + if any(org_ids for org_ids in resource_to_org_ids.values()): + (raw_result_dict, + cleaned_result_dict, + output_body) = self._get_result_dicts_and_output_body( + event_type, + event_data, + resource_to_org_ids) + + self._publish_output_data( + event_type, + resource_to_org_ids, + raw_result_dict, + cleaned_result_dict, + output_body) + else: + LOGGER.debug('no recipients for event #%s', event_data['id']) + + def _check_event_type(self, event_type, event_data): + if event_type != event_data.get('type', 'event'): + raise ValueError( + "event type from rk ({!a}) does " + "not match the 'type' item ({!a})" + .format(event_type, event_data.get('type'))) + if event_type not in self._VALID_EVENT_TYPES: + raise ValueError('illegal event type tag: {!a}'.format(event_type)) + + def _get_resource_to_org_ids(self, + event_type, + event_data): + subsource_refint = None + try: + inside_org_ids = set() + threats_org_ids = set() + source = event_data['source'] + subsource_to_saa_info = ( + self.auth_api.get_source_ids_to_subs_to_stream_api_access_infos().get(source)) + if subsource_to_saa_info: + predicate_ready_dict = RecordFacadeForPredicates(event_data, self.data_spec) + client_org_ids = set(event_data.get('client', ())) + assert all(isinstance(s, str) for s in client_org_ids) + for subsource_refint, ( + predicate, res_to_org_ids) in subsource_to_saa_info.items(): + assert all(isinstance(s, str) for s in res_to_org_ids['inside']) + subs_inside_org_ids = res_to_org_ids['inside'] & client_org_ids + subs_threats_org_ids = res_to_org_ids['threats'] + if not subs_inside_org_ids and not subs_threats_org_ids: + continue + if not predicate(predicate_ready_dict): + continue + inside_org_ids.update(subs_inside_org_ids) + threats_org_ids.update(subs_threats_org_ids) + assert all(isinstance(s, str) for s in inside_org_ids) + assert all(isinstance(s, str) for s in threats_org_ids) + return { + 'inside': sorted(inside_org_ids), + 'threats': sorted(threats_org_ids), + } + except: + LOGGER.error( + 'Could not determine org ids per resources' + '(event type: %a; event data: %a%s)', + event_type, + event_data, + ('' if subsource_refint is None else ( + "; lately processed subsource's refint: {!a}".format(subsource_refint)))) + raise + + def _get_result_dicts_and_output_body(self, + event_type, + event_data, + resource_to_org_ids): + raw_result_dict = cleaned_result_dict = None + try: + raw_result_dict = { + k: v for k, v in event_data.items() + if (k in self.data_spec.all_result_keys and + # eliminating empty `address` and `client` sequences + # (as the data spec will not accept them empty): + not (k in ('address', 'client') and not v))} + cleaned_result_dict = self.data_spec.clean_result_dict( + raw_result_dict, + auth_api=self.auth_api, + full_access=False, + opt_primary=False) + cleaned_result_dict['type'] = event_type + # note: the output body will be a cleaned result dict, + # being an ordinary dict (not a RecordDict instance), + # with the 'type' item added, serialized to a string + # using n6sdk.pyramid_commons.renderers.data_dict_to_json() + output_body = data_dict_to_json(cleaned_result_dict) + return ( + raw_result_dict, + cleaned_result_dict, + output_body, + ) + except: + LOGGER.error( + 'Could not prepare an anonymized data record ' + '(event type: %a; raw result dict: %a; ' + 'cleaned result dict: %a; %s)', + event_type, + raw_result_dict, + cleaned_result_dict, + '; '.join( + '`{0}` org ids: {1}'.format( + r, + ', '.join(map(repr, resource_to_org_ids[r])) or 'none') + for r in sorted(resource_to_org_ids))) + raise + + def _publish_output_data(self, + event_type, + resource_to_org_ids, + raw_result_dict, + cleaned_result_dict, + output_body): + done_resource_to_org_ids = { + resource: [] + for resource in resource_to_org_ids} + for resource, res_org_ids in sorted(resource_to_org_ids.items()): + done_org_ids = done_resource_to_org_ids[resource] + output_rk = self.OUTPUT_RK_PATTERN.format( + resource=resource, + category=cleaned_result_dict['category'], + anon_source=cleaned_result_dict['source']) + while res_org_ids: + org_id = res_org_ids[-1] + try: + self.publish_output( + routing_key=output_rk, + body=output_body, + prop_kwargs={'headers': {'n6-client-id': org_id}}) + except: + LOGGER.error( + 'Could not send an anonymized data record, for ' + 'the resource %a, to the client %a (event type: ' + '%a; raw result dict: %a; routing key %a; ' + 'body: %a; %s)', + resource, + org_id, + event_type, + raw_result_dict, + output_rk, + output_body, + '; '.join( + 'for the resource {0!a} -- ' + '* skipped for the org ids: {1}; ' + '* done for the org ids: {2}'.format( + r, + ', '.join(map(repr, resource_to_org_ids[r])) or 'none', + ', '.join(map(repr, done_resource_to_org_ids[r])) or 'none') + for r in sorted(resource_to_org_ids))) + raise + else: + done_org_ids.append(org_id) + del res_org_ids[-1] + + +def main(): + with logging_configured(): + d = Anonymizer() + try: + d.run() + except KeyboardInterrupt: + d.stop() + + +if __name__ == "__main__": + main() diff --git a/N6DataPipeline/n6datapipeline/base.py b/N6DataPipeline/n6datapipeline/base.py index ad4bf34..dda6397 100644 --- a/N6DataPipeline/n6datapipeline/base.py +++ b/N6DataPipeline/n6datapipeline/base.py @@ -5,6 +5,7 @@ # some of the docstrings are taken from or contain fragments of the # docs of the `pika` library. +import argparse import collections import contextlib import copy @@ -99,24 +100,21 @@ class LegacyQueuedBase(object): AMQP_SETUP_TIMEOUT = 60 # the name of the config section the RabbitMQ settings shall be taken from - rabbitmq_config_section = 'rabbitmq' + rabbitmq_config_section: str = 'rabbitmq' - # (see: the __new__() class method below) - input_queue = None - output_queue = None + # (see: the __new__() special method below) + input_queue: Optional[dict] = None + output_queue: Optional[Union[dict, list[dict]]] = None - # if a script should run only in one instance - used to set basic_consume(exclusive=) flag - single_instance = True + # used to set the value of the `exclusive` flag argument + # passed to `pika` input channel's `basic_consume(...)` + single_instance: bool = True - # in a subclass, it should be set to False if the component should not - # accept --n6recovery argument option (see: the get_arg_parser() method) - supports_n6recovery = True + # in concrete subclasses, it should be set to False *if* the + # component should not accept the `--n6recovery argument` option + # (see: the `get_arg_parser()` method) + supports_n6recovery: bool = True - # it is set on a new instance by __new__() (which is called - # automatically before __init__()) to an argparse.Namespace instance - cmdline_args = None - - # parameter prefetch_count # Specifies a prefetch window in terms of whole messages. # This field may be used in combination with the prefetch-size field # (although the prefetch-size limit is not implemented @@ -124,10 +122,14 @@ class LegacyQueuedBase(object): # if both prefetch windows (and those at the channel # and connection level) allow it. The prefetch-count is ignored # if the no-ack option is set. - prefetch_count = 20 + prefetch_count: int = 20 + + # basic kwargs for `pika.BasicProperties` (message-publishing-related) + basic_prop_kwargs: KwargsDict = {'delivery_mode': 2} - # basic kwargs for pika.BasicProperties (message-publishing-related) - basic_prop_kwargs = {'delivery_mode': 2} + # it is set automatically on a new instance by __new__() (which is + # always called on instantiation, before __init__()) + cmdline_args: argparse.Namespace = None # diff --git a/N6DataPipeline/n6datapipeline/comparator.py b/N6DataPipeline/n6datapipeline/comparator.py new file mode 100644 index 0000000..dcce861 --- /dev/null +++ b/N6DataPipeline/n6datapipeline/comparator.py @@ -0,0 +1,442 @@ +# Copyright (c) 2013-2021 NASK. All rights reserved. + +import datetime +import json +import pickle +import os +import os.path + +from n6lib.common_helpers import open_file +from n6lib.config import Config +from n6lib.datetime_helpers import parse_iso_datetime_to_utc +from n6lib.log_helpers import get_logger, logging_configured +from n6datapipeline.base import ( + LegacyQueuedBase, + n6QueueProcessingException, +) + + +LOGGER = get_logger(__name__) + + +class BlackListData(object): + + def __init__(self, payload): + self.id = payload.get("id") + self.source = payload.get("source") + self.url = payload.get("url") + self.fqdn = payload.get("fqdn") + self.ip = [str(addr["ip"]) for addr in payload.get("address")] if payload.get("address") is not None else [] + self.flag = payload.get("flag") + self.expires = parse_iso_datetime_to_utc(payload.get("expires")) + self.payload = payload.copy() + + def to_dict(self): + return self.payload + + def update_payload(self, update_dict): + tmp = self.payload.copy() + tmp.update(update_dict) + self.payload = tmp + + +class SourceData(object): + + def __init__(self): + self.time = None # current time tracked for source (based on event _bl-time) + # real time of the last event (used to trigger cleanup if source is inactive) + self.last_event = None + self.blacklist = {} # current state of black list + + def update_time(self, event_time): + if event_time > self.time: + self.time = event_time + self.last_event = datetime.datetime.now() ## FIXME unused variable ? + + def _are_ips_different(self, ips_old, ips_new): + """ + Compare lists of ips. + + Returns: + True if lists different + False if lists the same + """ + if ips_old is None and ips_new is None: + return False + if (ips_old is None and ips_new is not None) or (ips_old is not None and ips_new is None): + return True + if sorted(ips_old) == sorted(ips_new): + return False + else: + return True + + def get_event_key(self, data): + if data.get("url") is not None: + return data.get("url") + elif data.get("fqdn") is not None: + return data.get("fqdn") + elif data.get("address") is not None: + ips = tuple(sorted([str(addr["ip"]) for addr in data.get("address")])) + return ips + else: + raise n6QueueProcessingException('Unable to determine event key for source: {}. Event ' + 'must have at least one of `url`, `fqdn`, ' + '`address`, data: {}'.format(data['source'], data) ) + + def process_event(self, data): + event_time = parse_iso_datetime_to_utc(data['_bl-time']) + + if self.time is None: + self.time = event_time + if event_time < self.time: + LOGGER.error('Event out of order. Ignoring.\nData: %s', data) + raise n6QueueProcessingException('Event belongs to blacklist' + ' older than the last one processed.') + + event_key = self.get_event_key(data) + event = self.blacklist.get(event_key) + + if event is None: + # new bl event + new_event = BlackListData(data) + new_event.flag = data.get("_bl-series-id") + self.blacklist[event_key] = new_event + return 'bl-new', new_event.payload + else: + # existing + ips_old = event.ip + ips_new = [x["ip"] for x in data.get("address")] if data.get("address") is not None else [] + if self._are_ips_different(ips_old, ips_new): + data["replaces"] = event.id + new_event = BlackListData(data) + new_event.flag = data.get("_bl-series-id") + self.blacklist[event_key] = new_event + return "bl-change", new_event.payload + elif parse_iso_datetime_to_utc(data.get("expires")) != event.expires: + event.expires = parse_iso_datetime_to_utc(data.get("expires")) + event.flag = data.get("_bl-series-id") + event.update_payload({"expires": data.get("expires")}) + self.blacklist[event_key] = event + return "bl-update", event.payload + else: + event.flag = data.get("_bl-series-id") + self.blacklist[event_key] = event + return None, event.payload + + def process_deleted(self): + + ret_value = [] + for key, event in list(self.blacklist.items()): + if event.flag is None: + value = event.payload.copy() + if key in self.blacklist: + del self.blacklist[key] + # yield "bl-delist", value + ret_value.append(["bl-delist", value]) + continue + if event.expires < self.time: + value = event.payload.copy() + if key in self.blacklist: + del self.blacklist[key] + # yield "bl-expire", value + ret_value.append(["bl-expire", value]) + continue + event.flag = None + return ret_value + + def clear_flags(self, flag_id): + for key, event in list(self.blacklist.items()): + if event.flag == flag_id: + event.flag = None + self.blacklist[key] = event + + def __repr__(self): + return repr(self.groups) + + +class ComparatorData(object): + + def __init__(self): + self.sources = {} + + def get_or_create_sourcedata(self, source_name): + sd = self.sources.get(source_name) + if sd is None: + sd = SourceData() + self.sources[source_name] = sd + return sd + + def __repr__(self): + return repr(self.sources) + + +class ComparatorDataWrapper(object): + + def __init__(self, dbpath): + self.comp_data = None + self.dbpath = dbpath + try: + self.restore_state() + except: + LOGGER.error("Error restoring state from: %r", self.dbpath) + self.comp_data = ComparatorData() + + def store_state(self): + try: + with open_file(self.dbpath, "wb") as f: + pickle.dump(self.comp_data, f, 2) + except IOError: + LOGGER.error("Error saving state to: %r", self.dbpath) + + def restore_state(self): + with open_file(self.dbpath, "rb") as f: + self.comp_data = pickle.load(f) + + def process_new_message(self, data): + """Processes a message and validates agains db to detect new/change/update. + Adds new entry to db if necessary (new) or updates entry (change/update) and + stores flag in db for processed event. + """ + source_data = self.comp_data.get_or_create_sourcedata(data['source']) + result = source_data.process_event(data) + source_data.update_time(parse_iso_datetime_to_utc(data['_bl-time'])) + return result + + def clear_flags(self, source, flag_id): + """Cleans up flags in the db after processing complete blacklist + """ + source_data = self.comp_data.get_or_create_sourcedata(source) + source_data.clear_flags(flag_id) + self.store_state() + + def process_deleted(self, source): + """Finds unflagged and expired messages for a bl_name (deleted) and generates delist/expire messages. + Removes entries from db. + """ + + source_data = self.comp_data.get_or_create_sourcedata(source) + for event in source_data.process_deleted(): + yield event + self.store_state() + + +class ComparatorState(object): + + def __init__(self, cleanup_time): + """ + closed_series = {series-id: expires_time, #time.time when the closed series expires and should be removed + ...} + open_series = {series-id: {"total": int, #total number of messages in a series + "msg-count": int #number of messages seen so far + "msg-nums": [int, ...], #message numbers of seen messages + "msg-ids": [str, ...], #ids of the seen messages + "timeout-id": str, #id of the created timeout for a serie + } + ...} + """ + self.open_series = dict() + self.cleanup_time = cleanup_time + + def is_series_complete(self, series_id): + """Verify if the series is complete""" + assert series_id in self.open_series + if self.open_series[series_id]["total"] == self.open_series[series_id]["msg-count"]: + return True + else: + return False + + def is_message_valid(self, message): + """Check if message belongs to open series and it was not seen earlier + (i.e. is not a duplicate) + """ + if message["_bl-series-id"] in self.open_series: + #if message["id"] in self.open_series[message["_bl-series-id"]]["msg-ids"]: + # return False + if message["_bl-series-total"] != self.open_series[message["_bl-series-id"]]["total"]: + return False + if message["_bl-series-no"] in self.open_series[message["_bl-series-id"]]["msg-nums"]: + return False + if self.open_series[message["_bl-series-id"]]["msg-count"] + 1 > self.open_series[message["_bl-series-id"]]["total"]: + return False + return True + + def update_series(self, message): + """Update series state based on message: + - create new series if necessary + - update message count for a series + - store message id and msg num + """ + if message["_bl-series-id"] not in self.open_series: + self.open_series[message["_bl-series-id"]] = {"total": int(message["_bl-series-total"]), + "timeout-id": None, + "msg-count": 0, + "msg-nums": [], + "msg-ids": [] + } + self.open_series[message["_bl-series-id"]]["msg-count"] += 1 + self.open_series[message["_bl-series-id"]]["msg-nums"].append(int(message["_bl-series-no"])) + self.open_series[message["_bl-series-id"]]["msg-ids"].append(message["id"]) + # print "received message series %s: %d of %d" % (message["_bl-series-id"], + # self.open_series[message["_bl-series-id"]]["msg-count"], + # self.open_series[message["_bl-series-id"]]["total"]) + + def close_series(self, series_id): + """Close given series + """ + assert series_id in self.open_series + del self.open_series[series_id] + + def save_timeout(self, series_id, timeout_id): + """Save timeout id for a given series + """ + assert series_id in self.open_series + self.open_series[series_id]["timeout-id"] = timeout_id + + def get_timeout(self, series_id): + """Get timeout id for a given series + """ + assert series_id in self.open_series + return self.open_series[series_id]["timeout-id"] + + +class Comparator(LegacyQueuedBase): + + input_queue = { + "exchange": "event", + "exchange_type": "topic", + "queue_name": "comparator", + "accepted_event_types": [ + "bl", + ], + } + output_queue = { + "exchange": "event", + "exchange_type": "topic", + } + + def __init__(self, **kwargs): + config = Config(required={"comparator": ("dbpath", "series_timeout", "cleanup_time")}) + self.comparator_config = config["comparator"] + self.comparator_config["dbpath"] = os.path.expanduser(self.comparator_config["dbpath"]) + dbpath_dirname = os.path.dirname(self.comparator_config["dbpath"]) + try: + os.makedirs(dbpath_dirname, 0o700) + except OSError: + pass + super(Comparator, self).__init__(**kwargs) + # store dir doesn't exist, stop comparator + if not os.path.isdir(dbpath_dirname): + raise Exception('store dir does not exist, stop comparator, path:', + self.comparator_config["dbpath"]) + # store directory exists, but it has no rights to write + if not os.access(dbpath_dirname, os.W_OK): + raise Exception('stop comparator, remember to set the rights' + ' for user, which runs comparator, path:', + self.comparator_config["dbpath"]) + self.state = ComparatorState(int(self.comparator_config["cleanup_time"])) + self.db = ComparatorDataWrapper(self.comparator_config["dbpath"]) + + def on_series_timeout(self, source, series_id): + """Callback called when the messages for a given series have + not arrived within series_timeout from the last msg. + Cleans up the flags in the db and closes the series in ComparatorState + """ + self.db.clear_flags(source, series_id) + self.state.close_series(series_id) + + def process_event(self, data): + """Processes the event by querying the blacklist db and generating + bl-new, bl-change, bl-update messages + """ + event = self.db.process_new_message(data) + self.publish_event(event) + + def finalize_series(self, series_id, bl_name): + """If all the messages for a series have arrived it finalizes the series: + generates bl-delist, bl-expire messages from db. + Close the series in ComparatorState + """ + self.remove_timeout(series_id) + for event in self.db.process_deleted(bl_name): + self.publish_event(event) + self.state.close_series(series_id) + + def _cleanup_data(self, data): + """Removes artifacts from earlier processing (_bl-series-no, _bl-series-total, _bl-series-id) + """ + for k in list(data.keys()): + if k.startswith("_bl-series"): + del data[k] + return data + + def publish_event(self, data): + """Publishes event to the output queue + """ + type_, payload = data + if type_ is None: + return + payload = self._cleanup_data(payload) + payload["type"] = type_ + source, channel = payload["source"].split(".") + rk = "{}.{}.{}.{}".format(type_, "compared", source, channel) + body = json.dumps(payload) + self.publish_output(routing_key=rk, body=body) + + def set_timeout(self, source, series_id): + self.remove_timeout(series_id) + timeout_id = self._connection.add_timeout( + int(self.comparator_config['series_timeout']), + lambda: self.on_series_timeout(source, series_id)) + self.state.save_timeout(series_id, timeout_id) + + def remove_timeout(self, series_id): + timeout_id = self.state.get_timeout(series_id) + if timeout_id is not None: + self._connection.remove_timeout(timeout_id) + + def validate_bl_headers(self, message): + if ('_bl-series-id' not in message or + '_bl-series-total' not in message or + '_bl-series-no' not in message or + '_bl-time' not in message or + 'expires' not in message): + raise n6QueueProcessingException("Invalid message for a black list") + try: + parse_iso_datetime_to_utc(message["expires"]) + except ValueError: + raise n6QueueProcessingException("Invalid expiry date") + + def input_callback(self, routing_key, body, properties): + data = json.loads(body) + ## FIXME:? ^ shouldn't `data` be deserialized to a + ## RecordDict (BLRecordDict) instance? (for consistency etc.) + with self.setting_error_event_info(data): + self._process_input(data) + + def _process_input(self, data): + self.validate_bl_headers(data) + if not self.state.is_message_valid(data): + raise n6QueueProcessingException("Invalid message for a series: {}".format(data)) + self.state.update_series(data) + self.set_timeout(data["source"], data["_bl-series-id"]) + self.process_event(data) + if self.state.is_series_complete(data["_bl-series-id"]): + LOGGER.info("Finalizing series: %r", data["_bl-series-id"]) + self.finalize_series(data["_bl-series-id"], data["source"]) + + def stop(self): + self.db.store_state() + super(Comparator, self).stop() + + +def main(): + with logging_configured(): + c = Comparator() + try: + c.run() + except KeyboardInterrupt: + c.stop() + + +if __name__ == '__main__': + main() diff --git a/N6DataPipeline/n6datapipeline/data/conf/00_pipeline.conf b/N6DataPipeline/n6datapipeline/data/conf/00_pipeline.conf new file mode 100644 index 0000000..9ffcabc --- /dev/null +++ b/N6DataPipeline/n6datapipeline/data/conf/00_pipeline.conf @@ -0,0 +1,28 @@ +# The n6 components use the 'pipeline' section to configure their +# "place" in the RabbitMQ pipeline. To configure a component, create +# the option, which name equals to the component's lowercase class +# name. Each option can be assigned a list of values (each value being +# a string, separated by commas). These values, called "routing states" +# here, are then used to generate their binding keys - keys that +# assign messages sent by other components within the same exchange +# to the component's inner queue. +# +# Routing states that components' output messages are sent with: +# * Parsers: parsed +# * Aggregator: aggregated +# * Enricher: enriched +# * Comparator: compared +# * Filter: filtered +# * Recorder: recorded +# +# Values in this configuration template create a default order +# of components in n6 pipeline. + +[pipeline] +aggregator = parsed +enricher = parsed, aggregated +comparator = enriched +filter = enriched, compared +anonymizer = filtered +recorder = filtered +counter= recorded diff --git a/N6DataPipeline/n6datapipeline/data/conf/02_archiveraw.conf b/N6DataPipeline/n6datapipeline/data/conf/02_archiveraw.conf new file mode 100644 index 0000000..2a49a07 --- /dev/null +++ b/N6DataPipeline/n6datapipeline/data/conf/02_archiveraw.conf @@ -0,0 +1,19 @@ +[archiveraw] + +## MongoDB server hostname or IP +mongohost = 127.0.0.1 + +## MongoDB server port +mongoport = 27017 + +## name of database in MongoDB +mongodb = n6 + +## please replace USER and PASSWORD with the actual credentials +## but do not change the rest of the value of `uri` +#uri = mongodb://USER:PASSWORD@%(mongohost)s:%(mongoport)s/?authSource=source_database + +## retry connection options +time_sleep_between_try_connect=5 ; sleep time (in seconds) between connection attempts +count_try_connection=1000 ; the number of connection attempts +## (so total time of attempts == time_sleep_between_try_connect * count_try_connection) diff --git a/N6DataPipeline/n6datapipeline/data/conf/05_enrich.conf b/N6DataPipeline/n6datapipeline/data/conf/05_enrich.conf new file mode 100644 index 0000000..3f988ea --- /dev/null +++ b/N6DataPipeline/n6datapipeline/data/conf/05_enrich.conf @@ -0,0 +1,15 @@ +[enrich] +dnshost=8.8.8.8 +dnsport=53 + +# options below are optional; if they are provided, IP addresses +# in 'address' field in processed data will be looked up against +# one of or both GeoIP databases (i.e., the 'asndatabasefilename' +# and 'citydatabasefilename' options) +geoippath= +asndatabasefilename= +citydatabasefilename= + +# optional - a list of IP addresses that should be excluded +# within enriched data +#excluded_ips=0.0.0.0, 255.255.255.255,127.0.0.0/8 diff --git a/N6DataPipeline/n6datapipeline/data/conf/07_aggregator.conf b/N6DataPipeline/n6datapipeline/data/conf/07_aggregator.conf new file mode 100644 index 0000000..c5010ff --- /dev/null +++ b/N6DataPipeline/n6datapipeline/data/conf/07_aggregator.conf @@ -0,0 +1,17 @@ +[aggregator] + +# path to the local aggregator's database file +# (the database file will be created automatically +# on the 1st aggregator run, if possible) +dbpath=~/.n6aggregator/aggregator_db.pickle + +# time interval (in seconds) within which non-monotonic times of +# events are tolerated +time_tolerance=600 + +# time interval like `time_tolerance`, but defined for specific source +# (if it is not defined for the current source, +# `time_tolerance` is used) +;time_tolerance_per_source={ +; "some-src.its-channel": 1200, +; "other-src.foobar": 900} diff --git a/N6DataPipeline/n6datapipeline/data/conf/07_comparator.conf b/N6DataPipeline/n6datapipeline/data/conf/07_comparator.conf new file mode 100644 index 0000000..2a34409 --- /dev/null +++ b/N6DataPipeline/n6datapipeline/data/conf/07_comparator.conf @@ -0,0 +1,9 @@ +[comparator] + +## path to the local comparator's database file +## (the database file will be created automatically +## on the 1st comparator run, if possible) +dbpath=~/.n6comparator/comparator_db.pickle + +series_timeout=300 +cleanup_time=6000 diff --git a/N6DataPipeline/n6datapipeline/data/conf/09_auth_db.conf b/N6DataPipeline/n6datapipeline/data/conf/09_auth_db.conf index c74d234..3f97a70 100644 --- a/N6DataPipeline/n6datapipeline/data/conf/09_auth_db.conf +++ b/N6DataPipeline/n6datapipeline/data/conf/09_auth_db.conf @@ -1,34 +1,34 @@ [auth_db] -## connection URL, e.g.: mysql+mysqldb://n6:somepassword@localhost/n6 -## it must start with `mysql+mysqldb:` (or just `mysql:`) because other -## dialects/drivers are not supported -## (see also: http://docs.sqlalchemy.org/en/rel_0_9/core/engines.html) -#url = mysql://user:password@host/dbname +# connection URL, e.g.: mysql+mysqldb://n6:somepassword@localhost/n6 +# it must start with `mysql+mysqldb:` (or just `mysql:`) because other +# dialects/drivers are not supported +# (see also: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) +;url = mysql://user:password@host/dbname -## if you want to use SSL, the following options must be set to -## appropriate file paths: -#ssl_cacert = /some/path/to/CACertificatesFile.pem -#ssl_cert = /some/path/to/ClientCertificateFile.pem -#ssl_key = /some/path/to/private/ClientCertificateKeyFile.pem +# if you want to use SSL, the following options must be set to +# appropriate file paths: +;ssl_cacert = /some/path/to/CACertificatesFile.pem +;ssl_cert = /some/path/to/ClientCertificateFile.pem +;ssl_key = /some/path/to/private/ClientCertificateKeyFile.pem [auth_db_session_variables] -## all MySQL variables specified within this section will be set by -## executing "SET SESSION = , ...". -## WARNING: for simplicity, the variable names and values are inserted -## into SQL code "as is", *without* any escaping (we assume we can treat -## configuration files as a *trusted* source of data). +# all MySQL variables specified within this section will be set by +# executing "SET SESSION = , ...". +# WARNING: for simplicity, the variable names and values are inserted +# into SQL code "as is", *without* any escaping (we assume we can treat +# configuration files as a *trusted* source of data). -## (`[auth_db_session_variables].wait_timeout` should be -## greater than `[auth_db_connection_pool].pool_recycle`) +# (`[auth_db_session_variables].wait_timeout` should be +# greater than `[auth_db_connection_pool].pool_recycle`) wait_timeout = 7200 [auth_db_connection_pool] -## (generally, the defaults should be OK in most cases; if you are -## interested in technical details -- see: SQLAlchemy docs...) +# (generally, the defaults should be OK in most cases; if you are +# interested in technical details -- see: SQLAlchemy docs...) pool_recycle = 3600 pool_timeout = 20 pool_size = 15 diff --git a/N6DataPipeline/n6datapipeline/data/conf/21_recorder.conf b/N6DataPipeline/n6datapipeline/data/conf/21_recorder.conf new file mode 100644 index 0000000..d195212 --- /dev/null +++ b/N6DataPipeline/n6datapipeline/data/conf/21_recorder.conf @@ -0,0 +1,12 @@ +[recorder] + +# Uncomment and adjust this option but DO NOT change the `mysql://` prefix: +;uri = mysql://dbuser:dbpassword@dbhost/dbname + +# DO NOT change this option unless you also adjusted appropriately your database: +;connect_charset = utf8 + +# see: https://docs.sqlalchemy.org/en/13/core/engines.html#more-on-the-echo-flag +;echo = 0 + +;wait_timeout = 28800 diff --git a/N6DataPipeline/n6datapipeline/data/conf/23_filter.conf b/N6DataPipeline/n6datapipeline/data/conf/23_filter.conf new file mode 100644 index 0000000..adf04f7 --- /dev/null +++ b/N6DataPipeline/n6datapipeline/data/conf/23_filter.conf @@ -0,0 +1,2 @@ +[filter] +#categories_filtered_through_fqdn_only=leak diff --git a/N6DataPipeline/n6datapipeline/data/conf/logging.conf b/N6DataPipeline/n6datapipeline/data/conf/logging.conf new file mode 100644 index 0000000..8bbeffd --- /dev/null +++ b/N6DataPipeline/n6datapipeline/data/conf/logging.conf @@ -0,0 +1,89 @@ +# See: https://docs.python.org/library/logging.config.html#configuration-file-format + +# +# Declarations + +[loggers] +keys = root,nonstandard_names + +[handlers] +keys = syslog,stream,amqp +#keys = syslog,stream,amqp,file + +[formatters] +keys = standard,cut_notraceback,cut,n6_syslog_handler + +# +# Loggers + +# the top-level (root) logger +# (gathers messages from almost all its sub-loggers...) +[logger_root] +level = INFO +handlers = syslog,stream,amqp + +# the parent logger for *non-standard-names-dedicated* +# loggers -- each named according to the pattern: +# 'NONSTANDARD_NAMES.'; +# parser components use these loggers to report any +# non-standard values of the `name` attribute of events +# (see: n6lib.const.CATEGORY_TO_NORMALIZED_NAME) +[logger_nonstandard_names] +level = INFO +handlers = amqp +propagate = 0 +qualname = NONSTANDARD_NAMES + +# +# Handlers + +[handler_syslog] +class = n6lib.log_helpers.N6SysLogHandler +level = ERROR +formatter = n6_syslog_handler +args = ('/dev/log',) + +[handler_stream] +class = StreamHandler +level = INFO +formatter = standard +args = () + +[handler_amqp] +class = n6lib.log_helpers.AMQPHandler +level = INFO +args = (None, 'logging', {'exchange_type': 'topic', 'durable': True}) +# args = (, , ) +# ^ first item of args should be one of: +# * None -- then AMQP connection settings will be taken from the global N6Core config, +# see: 00_global.conf (this possibility is available only for N6Core components, +# not for N6Portal/N6RestApi components) +# * a dict -- containg AMQP connection settings ('host', 'port', 'ssl', 'ssl_options'...) + +#[handler_file] +#class = FileHandler +#level = NOTSET +#formatter = cut +#args = ('/home/somebody/log_all',) + +# +# Formatters + +# full information +[formatter_standard] +format = n6: %(levelname) -10s %(asctime)s %(name) -25s in %(funcName)s() (#%(lineno)d): %(message)s + +# brief information: no tracebacks, messages no longer than ~2k +[formatter_cut_notraceback] +format = n6: %(levelname) -10s %(asctime)s %(name) -25s in %(funcName)s() (#%(lineno)d): %(message)s +class = n6lib.log_helpers.NoTracebackCutFormatter + +# semi-brief information: with tracebacks but messages no longer than: ~2k + traceback length +[formatter_cut] +format = n6: %(levelname) -10s %(asctime)s %(name) -25s in %(funcName)s() (#%(lineno)d): %(message)s +class = n6lib.log_helpers.CutFormatter + +# same as formatter_cut_notraceback but with N6SysLogHandler's `script_basename` field included +[formatter_n6_syslog_handler] +format = n6: %(levelname) -10s %(asctime)s %(script_basename)s, %(name)s in %(funcName)s() (#%(lineno)d): %(message)s +class = n6lib.log_helpers.NoTracebackCutFormatter diff --git a/N6DataPipeline/n6datapipeline/enrich.py b/N6DataPipeline/n6datapipeline/enrich.py index 0b91d35..94f5a35 100644 --- a/N6DataPipeline/n6datapipeline/enrich.py +++ b/N6DataPipeline/n6datapipeline/enrich.py @@ -244,7 +244,7 @@ def _final_sanity_assertions(self, data): for name in enriched_keys) assert all( set(addr_keys).issubset(ip_to_addr[ip]) - for ip, addr_keys in list(ip_to_enriched_address_keys.items())) + for ip, addr_keys in ip_to_enriched_address_keys.items()) # # Resolution helpers diff --git a/N6DataPipeline/n6datapipeline/filter.py b/N6DataPipeline/n6datapipeline/filter.py new file mode 100644 index 0000000..4c61562 --- /dev/null +++ b/N6DataPipeline/n6datapipeline/filter.py @@ -0,0 +1,99 @@ +# Copyright (c) 2013-2021 NASK. All rights reserved. + +""" +The Filter component, responsible for assigning events to the right +client organizations -- by adding the `client` item (and also +`urls_matched` if needed) to each processed record dict. +""" + +from n6datapipeline.base import LegacyQueuedBase +from n6lib.auth_api import AuthAPI +from n6lib.common_helpers import replace_segment +from n6lib.config import ConfigMixin +from n6lib.log_helpers import get_logger, logging_configured +from n6lib.record_dict import RecordDict + + +LOGGER = get_logger(__name__) + + +class Filter(ConfigMixin, LegacyQueuedBase): + + input_queue = { + 'exchange': 'event', + 'exchange_type': 'topic', + 'queue_name': 'filter', + 'accepted_event_types': [ + 'event', + 'bl-new', + 'bl-update', + 'bl-change', + 'bl-delist', + 'bl-expire', + 'suppressed', + ], + } + + output_queue = { + 'exchange': 'event', + 'exchange_type': 'topic', + } + + config_spec = ''' + [filter] + categories_filtered_through_fqdn_only = :: list_of_str + ''' + + single_instance = False + + def __init__(self, **kwargs): + LOGGER.info("Filter Start") + self.auth_api = AuthAPI() + self.config = self.get_config_section() + self.fqdn_only_categories = frozenset(self.config['categories_filtered_through_fqdn_only']) + super(Filter, self).__init__(**kwargs) + + def input_callback(self, routing_key, body, properties): + record_dict = RecordDict.from_json(body) + with self.setting_error_event_info(record_dict): + client, urls_matched = self.get_client_and_urls_matched( + record_dict, + self.fqdn_only_categories) + record_dict['client'] = client + if urls_matched: + record_dict['urls_matched'] = urls_matched + self.publish_event(record_dict, routing_key) + + def get_client_and_urls_matched(self, record_dict, fqdn_only_categories): + resolver = self.auth_api.get_inside_criteria_resolver() + client_org_ids, urls_matched = resolver.get_client_org_ids_and_urls_matched( + record_dict, + fqdn_only_categories) + return sorted(client_org_ids), urls_matched + + def publish_event(self, data, rk): + """ + Push the given event into the output queue. + + Args: + `data` (RecordDict instance): + The event data. + `rk` (string): + The *input* routing key. + """ + output_rk = replace_segment(rk, 1, 'filtered') + body = data.get_ready_json() + self.publish_output(routing_key=output_rk, body=body) + + +def main(): + with logging_configured(): + d = Filter() + try: + d.run() + except KeyboardInterrupt: + d.stop() + + +if __name__ == "__main__": + main() diff --git a/N6DataPipeline/n6datapipeline/intelmq/__init__.py b/N6DataPipeline/n6datapipeline/intelmq/__init__.py index 1db0d6d..fd60817 100644 --- a/N6DataPipeline/n6datapipeline/intelmq/__init__.py +++ b/N6DataPipeline/n6datapipeline/intelmq/__init__.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021 NASK. All rights reserved. + import logging from inspect import isroutine diff --git a/N6DataPipeline/n6datapipeline/intelmq/bots_config.py b/N6DataPipeline/n6datapipeline/intelmq/bots_config.py index 396c9c2..8546cc7 100644 --- a/N6DataPipeline/n6datapipeline/intelmq/bots_config.py +++ b/N6DataPipeline/n6datapipeline/intelmq/bots_config.py @@ -1,5 +1,4 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- +# Copyright (c) 2021 NASK. All rights reserved. import inspect import logging @@ -50,7 +49,7 @@ class DummyParameter: def __init__(self, name): if name in self.exceptional_attrs: raise AttributeError - LOGGER.warning('The parameter {!r} should not be used!'.format(name)) + LOGGER.warning('The parameter %a should not be used!', name) class BotParameterProvider: @@ -63,7 +62,7 @@ def __getattribute__(self, item): try: return config_dict[item] except KeyError: - LOGGER.warning('Nonexistent attribute `%s` was tried to be accessed.', item) + LOGGER.warning('Nonexistent attribute %a was tried to be accessed.', item) raise AttributeError diff --git a/N6DataPipeline/n6datapipeline/intelmq/helpers.py b/N6DataPipeline/n6datapipeline/intelmq/helpers.py index b6f2820..88ca7a8 100644 --- a/N6DataPipeline/n6datapipeline/intelmq/helpers.py +++ b/N6DataPipeline/n6datapipeline/intelmq/helpers.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021 NASK. All rights reserved. + import hashlib import json import re @@ -13,6 +15,7 @@ from n6datapipeline.base import LegacyQueuedBase from n6datapipeline.intelmq import bots_config from n6datapipeline.intelmq.utils.intelmq_converter import IntelToN6Converter +from n6lib.const import RAW_TYPE_ENUMS from n6lib.datetime_helpers import parse_iso_datetime_to_utc @@ -36,7 +39,7 @@ def getattribute(self, item): and is not adapted to work inside n6 pipeline. """ if item in original_class_objects and item not in allowed_objects: - raise NotImplementedError(repr(item)) + raise NotImplementedError(ascii(item)) else: return super(cls, self).__getattribute__(item) @@ -288,7 +291,6 @@ class BaseCollectorExtended(QueuedBaseExtended): DEFAULT_SOURCE_CHANNEL = 'intelmq-collector' type = 'stream' - limits_type_of = ('stream', 'file', 'blacklist') input_queue = None bot_group_name = 'intelmq-collectors' @@ -322,9 +324,9 @@ def _validate_type(self): Validate type of message to be archived in MongoDB. It should be one of: 'stream', 'file', 'blacklist. """ - if self.type not in self.limits_type_of: + if self.type not in RAW_TYPE_ENUMS: raise Exception(f'Wrong type of data being archived in MongoDB: {self.type}, ' - f'should be one of: {self.limits_type_of}') + f'should be one of: {RAW_TYPE_ENUMS}') def _get_output_message_id(self, timestamp, output_data_body): return hashlib.md5(('\0'.join((self.source_label, @@ -418,9 +420,9 @@ def _set_constant_items(self): value = self.config[opt_name] except KeyError: value = default_value - LOGGER.warning("The %r option has not been found in the 'n6config' section of " - "the runtime config for the parser bot with ID: %r. Using " - "default value: %r", opt_name, self.bot_id, value) + LOGGER.warning("The %a option has not been found in the 'n6config' section of " + "the runtime config for the parser bot with ID: %a. Using " + "default value: %a", opt_name, self.bot_id, value) setattr(self, opt_name, value) def _get_bot_rk(self): diff --git a/N6DataPipeline/n6datapipeline/intelmq/utils/intelmq_adapter.py b/N6DataPipeline/n6datapipeline/intelmq/utils/intelmq_adapter.py index 2711b3a..708cbb1 100644 --- a/N6DataPipeline/n6datapipeline/intelmq/utils/intelmq_adapter.py +++ b/N6DataPipeline/n6datapipeline/intelmq/utils/intelmq_adapter.py @@ -1,7 +1,6 @@ -""" -Copyright (c) 2017-2021 NASK -Software Development Department +# Copyright (c) 2017-2021 NASK. All rights reserved. +""" The IntelMQ Adapter component, responsible for communicating with IntelMQ System. """ diff --git a/N6DataPipeline/n6datapipeline/intelmq/utils/intelmq_converter.py b/N6DataPipeline/n6datapipeline/intelmq/utils/intelmq_converter.py index 3de19ea..2f51add 100644 --- a/N6DataPipeline/n6datapipeline/intelmq/utils/intelmq_converter.py +++ b/N6DataPipeline/n6datapipeline/intelmq/utils/intelmq_converter.py @@ -1,7 +1,6 @@ -""" -Copyright (c) 2017-2021 NASK -Software Development Department +# Copyright (c) 2017-2021 NASK. All rights reserved. +""" The IntelMQ Adapter component, responsible for converting data between N6 and IntelMQ Systems. diff --git a/N6DataPipeline/n6datapipeline/intelmq/wrapper.py b/N6DataPipeline/n6datapipeline/intelmq/wrapper.py index 3058207..7dd80f7 100644 --- a/N6DataPipeline/n6datapipeline/intelmq/wrapper.py +++ b/N6DataPipeline/n6datapipeline/intelmq/wrapper.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021 NASK. All rights reserved. + import importlib import logging @@ -31,7 +33,7 @@ class IntelMQWrapper: "'/etc/intelmq/runtime.yaml' etc. The bot ID section must provide the " "'module' option - its value must be the path of the bot's module. " "You can, or for some types of bots, like parser bots, you must add the " - f"{RUNTIME_CONF_N6_SECTION_NAME!r} option in bot's section. The option " + f"{RUNTIME_CONF_N6_SECTION_NAME!a} option in bot's section. The option " f"is not recognized by IntelMQ, but allows to configure bot's instance " f"running within n6 pipeline.") diff --git a/N6DataPipeline/n6datapipeline/recorder.py b/N6DataPipeline/n6datapipeline/recorder.py new file mode 100644 index 0000000..f8b54d5 --- /dev/null +++ b/N6DataPipeline/n6datapipeline/recorder.py @@ -0,0 +1,596 @@ +# Copyright (c) 2013-2021 NASK. All rights reserved. + +""" +The *recorder* component -- adds n6 events to the Event DB. +""" + +### TODO: this module is to be replaced with a new implementation... + +import datetime +import logging +import os +import sys + +import MySQLdb.cursors +import sqlalchemy.event +from sqlalchemy import ( + create_engine, + text as sqla_text, +) +from sqlalchemy.exc import ( + IntegrityError, + OperationalError, + SQLAlchemyError, +) + +from n6datapipeline.base import LegacyQueuedBase +from n6lib.common_helpers import str_to_bool +from n6lib.config import Config +from n6lib.data_backend_api import ( + N6DataBackendAPI, + transact, +) +from n6lib.data_spec.fields import SourceField +from n6lib.datetime_helpers import parse_iso_datetime_to_utc +from n6lib.db_events import ( + n6ClientToEvent, + n6NormalizedData, +) +from n6lib.log_helpers import ( + get_logger, + logging_configured, +) +from n6lib.record_dict import ( + RecordDict, + BLRecordDict, +) +from n6lib.common_helpers import ( + ascii_str, + make_exc_ascii_str, + replace_segment, +) + + +LOGGER = get_logger(__name__) + +DB_WARN_LOGGER_LEGACY_NAME = 'n6.archiver.mysqldb_patch' # TODO later: change or make configurable +DB_WARN_LOGGER = get_logger(DB_WARN_LOGGER_LEGACY_NAME) + + +class N6RecorderCursor(MySQLdb.cursors.Cursor): + + # Note: the places where our `__log_warnings_from_database()` method + # is invoked are analogous to the places where appropriate warnings- + # -related stuff was invoked by the default cursor (a client-side one) + # provided by the version 1.3.14 of the *mysqlclient* library -- see: + # https://github.com/PyMySQL/mysqlclient/blob/1.3.14/MySQLdb/cursors.py + # -- which was the last version before removing the warnings-related + # stuff from that library (now we use a newer version, without that + # stuff). + # + # In practice, from the following three methods extended by us, + # rather only `execute()` is really relevant for us (note that it is + # also called by the `executemany()` method which, therefore, does + # not need to be extended). + + def execute(self, query, args=None): + ret = super(N6RecorderCursor, self).execute(query, args) + self.__log_warnings_from_database('QUERY', query, args) + return ret + + def callproc(self, procname, args=()): + ret = super(N6RecorderCursor, self).callproc(procname, args) + self.__log_warnings_from_database('PROCEDURE CALL', procname, args) + return ret + + def nextset(self): + ret = super(N6RecorderCursor, self).nextset() + if ret is not None: + self.__log_warnings_from_database('ON NEXT RESULT SET') + return ret + + __cur_warning_count = None + + def __log_warnings_from_database(self, caption, query_or_proc=None, args=None): + conn = self.connection + if conn is None or not conn.warning_count(): + return + for level, code, msg in conn.show_warnings(): + log_msg = '[{}] {} (code: {}), {}'.format(ascii_str(level), + ascii_str(msg), + ascii_str(code), + caption) + if query_or_proc or args: + log_msg_format = (log_msg.replace('%', '%%') + + ': %a, ' + + 'ARGS: %a') + DB_WARN_LOGGER.warning(log_msg_format, query_or_proc, args) + else: + DB_WARN_LOGGER.warning(log_msg) + + +class Recorder(LegacyQueuedBase): + """Save record in zbd queue.""" + + _MIN_WAIT_TIMEOUT = 3600 + _MAX_WAIT_TIMEOUT = _DEFAULT_WAIT_TIMEOUT = 28800 + + input_queue = { + "exchange": "event", + "exchange_type": "topic", + "queue_name": "zbd", + "accepted_event_types": [ + "event", + "bl-new", + "bl-change", + "bl-delist", + "bl-expire", + "bl-update", + "suppressed", + ], + } + + output_queue = { + "exchange": "event", + "exchange_type": "topic", + } + + def __init__(self, **kwargs): + LOGGER.info("Recorder Start") + config = Config(required={"recorder": ("uri",)}) + self.config = config["recorder"] + self.record_dict = None + self.records = None + self.routing_key = None + self.session_db = self._setup_db() + self.dict_map_fun = { + "event.filtered": (RecordDict.from_json, self.new_event), + "bl-new.filtered": (BLRecordDict.from_json, self.blacklist_new), + "bl-change.filtered": (BLRecordDict.from_json, self.blacklist_change), + "bl-delist.filtered": (BLRecordDict.from_json, self.blacklist_delist), + "bl-expire.filtered": (BLRecordDict.from_json, self.blacklist_expire), + "bl-update.filtered": (BLRecordDict.from_json, self.blacklist_update), + "suppressed.filtered": (RecordDict.from_json, self.suppressed_update), + } + # keys in each of the tuples being values of `dict_map_fun` + self.FROM_JSON = 0 + self.HANDLE_EVENT = 1 + super(Recorder, self).__init__(**kwargs) + + def _setup_db(self): + wait_timeout = int(self.config.get("wait_timeout", self._DEFAULT_WAIT_TIMEOUT)) + wait_timeout = min(max(wait_timeout, self._MIN_WAIT_TIMEOUT), self._MAX_WAIT_TIMEOUT) + # (`pool_recycle` should be significantly less than `wait_timeout`) + pool_recycle = wait_timeout // 2 + engine = create_engine( + self.config["uri"], + connect_args=dict( + charset=self.config.get( + "connect_charset", + N6DataBackendAPI.EVENT_DB_LEGACY_CHARSET), + use_unicode=True, + binary_prefix=True, + cursorclass=N6RecorderCursor), + pool_recycle=pool_recycle, + echo=str_to_bool(self.config.get("echo", "false"))) + self._install_session_variables_setter( + engine, + wait_timeout=wait_timeout, + time_zone="'+00:00'") + session_db = N6DataBackendAPI.configure_db_session(engine) + session_db.execute(sqla_text("SELECT 1")) # Let's crash early if db is misconfigured. + return session_db + + def _install_session_variables_setter(self, engine, **session_variables): + setter_sql = 'SET ' + ' , '.join( + 'SESSION {} = {}'.format(name, value) + for name, value in session_variables.items()) + + @sqlalchemy.event.listens_for(engine, 'connect') + def set_session_variables(dbapi_connection, connection_record): + """ + Execute + "SET SESSION = , SESSION = , ..." + to set the specified variables. + + To be called automatically whenever a new low-level + connection to the database is established. + + WARNING: for simplicity, the variable names and values are + inserted "as is", *without* any escaping -- we assume we + can treat them as *trusted* data. + """ + with dbapi_connection.cursor() as cursor: + cursor.execute(setter_sql) + + @classmethod + def get_arg_parser(cls): + parser = super(Recorder, cls).get_arg_parser() + parser.add_argument("--n6recorder-blacklist", type=SourceField().clean_result_value, + help="the identifier of a blacklist source (in the " + "format: 'source-label.source-channel'); if given, " + "this recorder instance will consume and store " + "*only* events from this blacklist source") + parser.add_argument("--n6recorder-non-blacklist", action="store_true", + help="if given, this recorder instance will consume " + "and store *only* events from *all* non-blacklist " + "sources (note: then the '--n6recorder-blacklist' " + "option, if given, is just ignored)") + return parser + + def ping_connection(self): + """ + Required to maintain the connection to MySQL. + Perform ping before each query to the database. + OperationalError if an exception occurs, remove sessions, and connects again. + """ + try: + self.session_db.execute(sqla_text("SELECT 1")) + except OperationalError as exc: + # OperationalError: (2006, 'MySQL server has gone away') + LOGGER.warning("Database server went away: %a", exc) + LOGGER.info("Reconnect to server") + self.session_db.remove() + try: + self.session_db.execute(sqla_text("SELECT 1")) + except SQLAlchemyError as exc: + LOGGER.error( + "Could not reconnect to the MySQL database: %s", + make_exc_ascii_str(exc)) + sys.exit(1) + + @staticmethod + def get_truncated_rk(rk, parts): + """ + Get only a part of the given routing key. + + Args: + `rk`: routing key. + `parts`: number of dot-separated parts (segments) to be kept. + + Returns: + Truncated `rk` (containing only first `parts` segments). + + >>> Recorder.get_truncated_rk('111.222.333.444', 0) + '' + >>> Recorder.get_truncated_rk('111.222.333.444', 1) + '111' + >>> Recorder.get_truncated_rk('111.222.333.444', 2) + '111.222' + >>> Recorder.get_truncated_rk('111.222.333.444', 3) + '111.222.333' + >>> Recorder.get_truncated_rk('111.222.333.444', 4) + '111.222.333.444' + >>> Recorder.get_truncated_rk('111.222.333.444', 5) # with log warning + '111.222.333.444' + """ + rk = rk.split('.') + parts_rk = [] + try: + for i in range(parts): + parts_rk.append(rk[i]) + except IndexError: + LOGGER.warning("routing key %a contains less than %a segments", rk, parts) + return '.'.join(parts_rk) + + def input_callback(self, routing_key, body, properties): + """ Channel callback method """ + # first let's try ping mysql server + self.ping_connection() + self.records = {'event': [], 'client': []} + self.routing_key = routing_key + + # take the first two parts of the routing key + truncated_rk = self.get_truncated_rk(self.routing_key, 2) + + # run BLRecordDict.from_json() or RecordDict.from_json() + # depending on the routing key + from_json = self.dict_map_fun[truncated_rk][self.FROM_JSON] + self.record_dict = from_json(body) + # add modified time, set microseconds to 0, because the database + # does not have microseconds, and it is not known if the base is not rounded + self.record_dict['modified'] = datetime.datetime.utcnow().replace(microsecond=0) + # run the handler method corresponding to the routing key + handle_event = self.dict_map_fun[truncated_rk][self.HANDLE_EVENT] + with self.setting_error_event_info(self.record_dict): + handle_event() + + assert 'source' in self.record_dict + LOGGER.debug("source: %a", self.record_dict['source']) + LOGGER.debug("properties: %a", properties) + #LOGGER.debug("body: %a", body) + + def json_to_record(self, rows): + """ + Deserialize json to record db.append. + + Args: `rows`: row from RecordDict + """ + if 'client' in rows[0]: + for client in rows[0]['client']: + tmp_rows = rows[0].copy() + tmp_rows['client'] = client + self.records['client'].append(tmp_rows) + + def insert_new_event(self, items, with_transact=True, recorded=False): + """ + New events and new blacklist add to database, + default in the transaction, or the outer transaction(with_transact=False). + """ + try: + if with_transact: + with transact: + self.session_db.add_all(items) + else: + assert transact.is_entered + self.session_db.add_all(items) + except IntegrityError as exc: + str_exc = make_exc_ascii_str(exc) + LOGGER.warning(str_exc) + else: + if recorded and not self.cmdline_args.n6recovery: + rk = replace_segment(self.routing_key, 1, 'recorded') + LOGGER.debug( + 'Publish for email notifications ' + '-- rk: %a, record_dict: %a', + rk, self.record_dict) + self.publish_event(self.record_dict, rk) + + def publish_event(self, data, rk): + """ + Publishes event to the output queue. + + Args: + `data`: data from recorddict + `rk` : routing key + """ + body = data.get_ready_json() + self.publish_output(routing_key=rk, body=body) + + def new_event(self, _is_blacklist=False): + """ + Add new event to n6 database. + """ + LOGGER.debug('* new_event() %a', self.record_dict) + + # add event records from RecordDict + for event_record in self.record_dict.iter_db_items(): + if _is_blacklist: + event_record["status"] = "active" + self.records['event'].append(event_record) + + self.json_to_record(self.records['event']) + items = [] + for record in self.records['event']: + event = n6NormalizedData(**record) + items.append(event) + + for record in self.records['client']: + client = n6ClientToEvent(**record) + items.append(client) + + LOGGER.debug("insert new events, count.: %a", len(items)) + self.insert_new_event(items, recorded=True) + + def blacklist_new(self): + self.new_event(_is_blacklist=True) + + def blacklist_change(self): + """ + Black list change(change status to replaced in existing blacklist event, + and add new event in changing values(new id, and old replaces give comparator)). + """ + # add event records from RecordDict + for event_record in self.record_dict.iter_db_items(): + self.records['event'].append(event_record) + + self.json_to_record(self.records['event']) + id_db = self.records['event'][0]["id"] + id_replaces = self.records['event'][0]["replaces"] + LOGGER.debug("ID: %a REPLACES: %a", id_db, id_replaces) + + try: + with transact: + rec_count = (self.session_db.query(n6NormalizedData). + filter(n6NormalizedData.id == id_replaces). + update({'status': 'replaced', + 'modified': datetime.datetime.utcnow().replace(microsecond=0) + })) + + with transact: + items = [] + for record in self.records['event']: + record["status"] = "active" + event = n6NormalizedData(**record) + items.append(event) + + for record in self.records['client']: + client = n6ClientToEvent(**record) + items.append(client) + + if rec_count: + LOGGER.debug("insert new events, count.: %a", len(items)) + else: + LOGGER.debug("bl-change, records with id %a DO NOT EXIST!", id_replaces) + LOGGER.debug("inserting new events anyway, count.: %a", len(items)) + self.insert_new_event(items, with_transact=False, recorded=True) + + except IntegrityError as exc: + LOGGER.warning("IntegrityError: %a", exc) + + def blacklist_delist(self): + """ + Black list delist (change status to delisted in existing blacklist event). + """ + # add event records from RecordDict + for event_record in self.record_dict.iter_db_items(): + self.records['event'].append(event_record) + + self.json_to_record(self.records['event']) + id_db = self.records['event'][0]["id"] + LOGGER.debug("ID: %a STATUS: %a", id_db, 'delisted') + + with transact: + (self.session_db.query(n6NormalizedData). + filter(n6NormalizedData.id == id_db). + update( + { + 'status': 'delisted', + 'modified': datetime.datetime.utcnow().replace(microsecond=0), + })) + + def blacklist_expire(self): + """ + Black list expire (change status to expired in existing blacklist event). + """ + # add event records from RecordDict + for event_record in self.record_dict.iter_db_items(): + self.records['event'].append(event_record) + + self.json_to_record(self.records['event']) + + id_db = self.records['event'][0]["id"] + LOGGER.debug("ID: %a STATUS: %a", id_db, 'expired') + + with transact: + (self.session_db.query(n6NormalizedData). + filter(n6NormalizedData.id == id_db). + update( + { + 'status': 'expired', + 'modified': datetime.datetime.utcnow().replace(microsecond=0), + })) + + def blacklist_update(self): + """ + Black list update (change expires to new value in existing blacklist event). + """ + # add event records from RecordDict + for event_record in self.record_dict.iter_db_items(): + self.records['event'].append(event_record) + + self.json_to_record(self.records['event']) + id_event = self.records['event'][0]["id"] + expires = self.records['event'][0]["expires"] + LOGGER.debug("ID: %a NEW_EXPIRES: %a", id_event, expires) + + with transact: + rec_count = (self.session_db.query(n6NormalizedData). + filter(n6NormalizedData.id == id_event). + update({'expires': expires, + 'modified': datetime.datetime.utcnow().replace(microsecond=0), + })) + if rec_count: + LOGGER.debug("records with the same id %a exist: %a", + id_event, rec_count) + else: + items = [] + for record in self.records['event']: + record["status"] = "active" + event = n6NormalizedData(**record) + items.append(event) + + for record in self.records['client']: + client = n6ClientToEvent(**record) + items.append(client) + LOGGER.debug("bl-update, records with id %a DO NOT EXIST!", id_event) + LOGGER.debug("insert new events,::count:: %a", len(items)) + self.insert_new_event(items, with_transact=False) + + def suppressed_update(self): + """ + Agregated event update(change fields: until and count, to the value of suppressed event). + """ + LOGGER.debug('* suppressed_update() %a', self.record_dict) + + # add event records from RecordDict + for event_record in self.record_dict.iter_db_items(): + self.records['event'].append(event_record) + + self.json_to_record(self.records['event']) + id_event = self.records['event'][0]["id"] + until = self.records['event'][0]["until"] + count = self.records['event'][0]["count"] + + # optimization: we can limit time => searching within one partition, not all; + # it seems that mysql (and/or sqlalchemy?) truncates times to seconds, + # we are also not 100% sure if other time data micro-distortions are not done + # -- that's why here we use a 1-second-range instead of an exact value + first_time_min = parse_iso_datetime_to_utc( + self.record_dict["_first_time"]).replace(microsecond=0) + first_time_max = first_time_min + datetime.timedelta(days=0, seconds=1) + + with transact: + rec_count = (self.session_db.query(n6NormalizedData) + .filter( + n6NormalizedData.time >= first_time_min, + n6NormalizedData.time <= first_time_max, + n6NormalizedData.id == id_event) + .update({'until': until, 'count': count})) + if rec_count: + LOGGER.debug("records with the same id %a exist: %a", + id_event, rec_count) + else: + items = [] + for record in self.records['event']: + event = n6NormalizedData(**record) + items.append(event) + + for record in self.records['client']: + client = n6ClientToEvent(**record) + items.append(client) + LOGGER.warning("suppressed_update, records with id %a DO NOT EXIST!", id_event) + LOGGER.debug("insert new events,,::count:: %a", len(items)) + self.insert_new_event(items, with_transact=False) + + +def main(): + parser = Recorder.get_arg_parser() + args = Recorder.parse_only_n6_args(parser) + if args.n6recorder_non_blacklist: + monkey_patch_non_bl_recorder() + elif args.n6recorder_blacklist is not None: + monkey_patch_bl_recorder(args.n6recorder_blacklist) + with logging_configured(): + if os.environ.get('n6integration_test'): + # for debugging only + LOGGER.setLevel(logging.DEBUG) + LOGGER.addHandler(logging.StreamHandler(stream=sys.__stdout__)) + d = Recorder() + try: + d.run() + except KeyboardInterrupt: + d.stop() + + +def monkey_patch_non_bl_recorder(): + Recorder.input_queue = { + "exchange": "event", + "exchange_type": "topic", + "queue_name": 'zbd-non-blacklist', + "binding_keys": [ + 'event.filtered.*.*', + 'suppressed.filtered.*.*', + ] + } + + +def monkey_patch_bl_recorder(source): + Recorder.input_queue = { + "exchange": "event", + "exchange_type": "topic", + "queue_name": 'zbd-bl-{}'.format(source.replace(".", "-")), + "binding_keys": [ + x.format(source) for x in [ + 'bl-new.filtered.{}', + 'bl-change.filtered.{}', + 'bl-delist.filtered.{}', + 'bl-expire.filtered.{}', + 'bl-update.filtered.{}', + ] + ] + } + + +if __name__ == "__main__": + main() diff --git a/N6DataPipeline/n6datapipeline/tests/test_anonymizer.py b/N6DataPipeline/n6datapipeline/tests/test_anonymizer.py new file mode 100644 index 0000000..190a41a --- /dev/null +++ b/N6DataPipeline/n6datapipeline/tests/test_anonymizer.py @@ -0,0 +1,910 @@ +# Copyright (c) 2013-2021 NASK. All rights reserved. + +import datetime +import json +import unittest +from unittest.mock import ( + ANY, + MagicMock, + call, + patch, + sentinel as sen, +) + +from unittest_expander import ( + expand, + foreach, + param, +) + +from n6datapipeline.aux.anonymizer import Anonymizer +from n6lib.const import TYPE_ENUMS +from n6lib.data_spec import N6DataSpec +from n6lib.db_filtering_abstractions import RecordFacadeForPredicates +from n6lib.unit_test_helpers import TestCaseMixin, MethodProxy +from n6sdk.exceptions import ( + ResultKeyCleaningError, + ResultValueCleaningError, +) + + + +@expand +class TestAnonymizer__input_callback(TestCaseMixin, unittest.TestCase): + + def setUp(self): + self.event_type = 'bl-update' + self.event_data = {'some...': 'content...', 'id': 'some id...'} + self.routing_key = self.event_type + '.filtered.*.*' + self.body = json.dumps(self.event_data) + self.resource_to_org_ids = {} + + self.mock = MagicMock(__class__=Anonymizer) + self.meth = MethodProxy(Anonymizer, self.mock, '_process_input') + + self.mock._get_resource_to_org_ids.return_value = self.resource_to_org_ids + self.mock._get_result_dicts_and_output_body.return_value = ( + sen.raw_result_dict, + sen.cleaned_result_dict, + sen.output_body, + ) + self.force_exit_on_any_remaining_entered_contexts_mock = self.patch( + 'n6datapipeline.aux.anonymizer.force_exit_on_any_remaining_entered_contexts') + + + @foreach( + param(resource_to_org_ids_items={ + 'foo': [sen.o1, sen.o2], + }), + param(resource_to_org_ids_items={ + 'foo': [sen.o1, sen.o2], + 'bar': [], + }), + param(resource_to_org_ids_items={ + 'foo': [], + 'bar': [sen.o3, sen.o4, sen.o5], + }), + param(resource_to_org_ids_items={ + 'foo': [sen.o1, sen.o2], + 'bar': [sen.o3, sen.o4, sen.o5], + }), + ) + def test_with_some_org_ids(self, resource_to_org_ids_items): + self.resource_to_org_ids.update(resource_to_org_ids_items) + + self.meth.input_callback( + self.routing_key, + self.body, + sen.properties) + + self.assertEqual(self.force_exit_on_any_remaining_entered_contexts_mock.mock_calls, [ + call(self.mock.auth_api), + ]) + self.assertEqual(self.mock.mock_calls, [ + call.setting_error_event_info(self.event_data), + call.setting_error_event_info().__enter__(), + call._check_event_type( + self.event_type, + self.event_data), + call.auth_api.__enter__(), + call._get_resource_to_org_ids( + self.event_type, + self.event_data), + call._get_result_dicts_and_output_body( + self.event_type, + self.event_data, + self.resource_to_org_ids), + call._publish_output_data( + self.event_type, + self.resource_to_org_ids, + sen.raw_result_dict, + sen.cleaned_result_dict, + sen.output_body), + call.auth_api.__exit__(None, None, None), + call.setting_error_event_info().__exit__(None, None, None), + ]) + + + @foreach( + param(resource_to_org_ids_items={}), + param(resource_to_org_ids_items={ + 'foo': [], + }), + param(resource_to_org_ids_items={ + 'foo': [], + 'bar': [], + }), + ) + def test_without_org_ids(self, resource_to_org_ids_items): + self.resource_to_org_ids.update(resource_to_org_ids_items) + + self.meth.input_callback( + self.routing_key, + self.body, + sen.properties) + + self.assertEqual(self.force_exit_on_any_remaining_entered_contexts_mock.mock_calls, [ + call(self.mock.auth_api), + ]) + self.assertEqual(self.mock.mock_calls, [ + call.setting_error_event_info(self.event_data), + call.setting_error_event_info().__enter__(), + call._check_event_type( + self.event_type, + self.event_data), + call.auth_api.__enter__(), + call._get_resource_to_org_ids( + self.event_type, + self.event_data), + call.auth_api.__exit__(None, None, None), + call.setting_error_event_info().__exit__(None, None, None), + ]) + + + def test_with_some_error(self): + self.resource_to_org_ids.update({ + 'foo': [sen.o1, sen.o2], + 'bar': [sen.o3, sen.o4, sen.o5], + }) + exc_type = ZeroDivisionError # (just an example exception class) + self.mock._get_result_dicts_and_output_body.side_effect = exc_type + + with self.assertRaises(exc_type) as exc_context: + self.meth.input_callback( + self.routing_key, + self.body, + sen.properties) + + self.assertEqual(self.force_exit_on_any_remaining_entered_contexts_mock.mock_calls, [ + call(self.mock.auth_api), + ]) + self.assertEqual(self.mock.mock_calls, [ + call.setting_error_event_info(self.event_data), + call.setting_error_event_info().__enter__(), + call._check_event_type( + self.event_type, + self.event_data), + call.auth_api.__enter__(), + call._get_resource_to_org_ids( + self.event_type, + self.event_data), + call._get_result_dicts_and_output_body( + self.event_type, + self.event_data, + self.resource_to_org_ids), + call.auth_api.__exit__(exc_type, exc_context.exception, ANY), + call.setting_error_event_info().__exit__(exc_type, exc_context.exception, ANY), + ]) + + + +@expand +class TestAnonymizer___check_event_type(TestCaseMixin, unittest.TestCase): + + def setUp(self): + self.mock = MagicMock(__class__=Anonymizer) + self.meth = MethodProxy(Anonymizer, self.mock, '_VALID_EVENT_TYPES') + + + @foreach( + param( + event_type='event', + event_data={ + 'some_key': sen.some_value, + }, + ).label('no type in event data'), + param( + event_type='event', + event_data={ + 'type': 'event', + 'some_key': sen.some_value, + }, + ).label('type "event" in event data'), + param( + event_type='bl-update', + event_data={ + 'type': 'bl-update', + 'some_key': sen.some_value, + }, + ).label('another type in event data'), + ) + def test_matching_and_valid(self, event_type, event_data): + assert (event_type == event_data.get('type', 'event') and + event_type in TYPE_ENUMS) # (test case self-test) + + self.meth._check_event_type(event_type, event_data) + + # the _check_event_type() method is called outside the AuthAPI + # context (outside its `with` statement) -- so we want to ensure + # that no AuthAPI methods are called: + self.assertEqual(self.mock.auth_api.mock_calls, []) + + + @foreach( + param( + event_type='event', + event_data={ + 'type': 'bl-update', + 'some_key': sen.some_value, + }, + ).label('type "event" does not match another one'), + param( + event_type='bl-update', + event_data={ + 'type': 'event', + 'some_key': sen.some_value, + }, + ).label('another type does not match "event"'), + ) + def test_not_matching(self, event_type, event_data): + assert (event_type != event_data.get('type', 'event') and + event_type in TYPE_ENUMS) # (test case self-test) + + with self.assertRaises(ValueError): + self.meth._check_event_type(event_type, event_data) + + # the _check_event_type() method is called outside the AuthAPI + # context (outside its `with` statement) -- so we want to ensure + # that no AuthAPI methods are called: + self.assertEqual(self.mock.auth_api.mock_calls, []) + + + def test_matching_but_not_valid(self): + event_type = 'illegal' + event_data = { + 'type': event_type, + 'some_key': sen.some_value, + } + assert event_type not in TYPE_ENUMS # (test case self-test) + + with self.assertRaises(ValueError): + self.meth._check_event_type(event_type, event_data) + + # the _check_event_type() method is called outside the AuthAPI + # context (outside its `with` statement) -- so we want to ensure + # that no AuthAPI methods are called: + self.assertEqual(self.mock.auth_api.mock_calls, []) + + +@expand +class TestAnonymizer___get_resource_to_org_ids(TestCaseMixin, unittest.TestCase): + + def setUp(self): + self.event_type = 'bl-update' + + def YES_predicate(record): + self.assertIsInstance(record, RecordFacadeForPredicates) + return True + + def NO_predicate(record): + self.assertIsInstance(record, RecordFacadeForPredicates) + return False + + self.mock = MagicMock(__class__=Anonymizer) + self.meth = MethodProxy(Anonymizer, self.mock) + + self.mock.data_spec = N6DataSpec() + self.mock.auth_api.get_source_ids_to_subs_to_stream_api_access_infos.return_value = \ + self.s_to_s_to_saai = { + 'src.empty': {}, + 'src.some-1': { + sen.something_1: ( + YES_predicate, + { + 'inside': set(), + 'threats': set(), + 'search': set(), + } + ), + sen.something_2: ( + YES_predicate, + { + 'inside': {'o4'}, + 'threats': set(), + 'search': {'o1', 'o2', 'o3', 'o4', 'o5', 'o6'}, + } + ), + sen.something_3: ( + NO_predicate, + { + 'inside': {'o2'}, + 'threats': {'o3'}, + 'search': set(), + } + ), + sen.something_4: ( + NO_predicate, + { + 'inside': {'o1', 'o3', 'o9'}, + 'threats': {'o3', 'o5', 'o6'}, + 'search': {'o3', 'o4', 'o5', 'o6'}, + } + ), + }, + 'src.some-2': { + sen.something_5: ( + YES_predicate, + { + 'inside': {'o1', 'o3', 'o9'}, + 'threats': {'o3', 'o5', 'o6'}, + 'search': {'o3', 'o4', 'o5', 'o6'}, + } + ), + sen.something_6: ( + YES_predicate, + { + 'inside': {'o2'}, + 'threats': {'o2'}, + 'search': set(), + } + ), + sen.something_7: ( + YES_predicate, + { + 'inside': set(), + 'threats': {'o8'}, + 'search': set(), + } + ), + sen.something_8: ( + YES_predicate, + { + 'inside': set(), + 'threats': set(), + 'search': set(), + } + ), + sen.something_9: ( + NO_predicate, + { + 'inside': {'o1', 'o5', 'o4', 'o9'}, + 'threats': {'o3', 'o4', 'o5', 'o9'}, + 'search': {'o1', 'o2', 'o3', 'o4'}, + } + ), + }, + } + + + @foreach( + param( + event_data=dict( + source='src.not-found', + client=['o5', 'o1', 'o3', 'o2'], + ), + expected_result=dict( + inside=[], + threats=[], + ), + ).label('no such source'), + param( + event_data=dict( + source='src.empty', + client=['o5', 'o1', 'o3', 'o2'], + ), + expected_result=dict( + inside=[], + threats=[], + ), + ).label('no subsources'), + param( + event_data=dict( + source='src.some-1', + client=['o5', 'o1', 'o3', 'o2'], + ), + expected_result=dict( + inside=[], + threats=[], + ), + ).label('no matching subsources/organizations'), + param( + event_data=dict( + source='src.some-2', + client=['o5', 'o1', 'o3', 'o2'], + ), + expected_result=dict( + inside=['o1', 'o2', 'o3'], + threats=['o2', 'o3', 'o5', 'o6', 'o8'], + ), + ).label('some matching subsources and organizations (1)'), + param( + event_data=dict( + source='src.some-2', + client=['o2', 'o4', 'o9'], + ), + expected_result=dict( + inside=['o2', 'o9'], + threats=['o2', 'o3', 'o5', 'o6', 'o8'], + ), + ).label('some matching subsources and organizations (2)'), + param( + event_data=dict( + source='src.some-2', + client=['o4'], + ), + expected_result=dict( + inside=[], + threats=['o2', 'o3', 'o5', 'o6', 'o8'], + ), + ).label('some matching subsources and organizations (only "threats")'), + ) + def test_normal(self, event_data, expected_result): + expected_mock_calls = [ + call.auth_api.get_source_ids_to_subs_to_stream_api_access_infos(), + ] + + with patch('n6datapipeline.aux.anonymizer.LOGGER') as LOGGER_mock: + result = self.meth._get_resource_to_org_ids(self.event_type, event_data) + + self.assertEqual(result, expected_result) + self.assertEqual(self.mock.mock_calls, expected_mock_calls) + self.assertFalse(LOGGER_mock.error.mock_calls) + + + def test_error(self): + event_data = dict( + source='src.some-2', + client=['o5', 'o1', 'o3', 'o2'], + ) + res_to_org_ids = { + 'inside': set(), + 'threats': {'o8'}, + 'search': set(), + } + exc_type = ZeroDivisionError # (just an example exception class) + def raise_exc(rec): + raise exc_type('blablabla') + self.s_to_s_to_saai['src.some-2'][sen.something_7] = raise_exc, res_to_org_ids + + with patch('n6datapipeline.aux.anonymizer.LOGGER') as LOGGER_mock, \ + self.assertRaises(exc_type): + self.meth._get_resource_to_org_ids(self.event_type, event_data) + + self.assertEqual(len(LOGGER_mock.error.mock_calls), 1) + + + +@expand +class TestAnonymizer___get_result_dicts_and_output_body(TestCaseMixin, unittest.TestCase): + + forward_source_mapping = { + 'some.source': 'hidden.42', + } + + event_raw_base = dict( + id=(32 * '3'), + rid=(32 * '4'), # (restricted - to be skipped before *value* cleaning) + source='some.source', # (to be anonymized) + restriction='public', # (restricted - to be skipped before *value* cleaning) + confidence='low', + category='malurl', + time='2013-07-12 11:30:00', + ) + + cleaned_base = dict( + id=(32 * '3'), + source='hidden.42', # (after anonymization) + confidence='low', + category='malurl', + time=datetime.datetime(2013, 7, 12, 11, 30, 00), + type=sen.TO_BE_SET, + ) + + + def setUp(self): + self.mock = MagicMock(__class__=Anonymizer) + self.meth = MethodProxy(Anonymizer, self.mock) + + self.mock.data_spec = N6DataSpec() + self.mock.auth_api.get_anonymized_source_mapping.return_value = { + 'forward_mapping': self.forward_source_mapping, + } + self.mock.auth_api.get_dip_anonymization_disabled_source_ids.return_value = frozenset() + + + @foreach( + param( + event_type='event', + event_data=dict( + event_raw_base, + client=[], # (empty `client` -- to be skipped before *any* cleaning) + ), + expected_raw=dict( + event_raw_base, + ), + expected_cleaned=dict( + cleaned_base, + type='event', # (event_type value set *after* cleaning) + ), + ), + param( + event_type='event', + event_data=dict( + event_raw_base, + client=['o1', 'o3', 'o2'], + address=[], # (empty `address` -- to be skipped before *any* cleaning) + dip='192.168.0.1', + fqdn='www.example.com', + type='foobar', # (not a result key -- to be skipped before *any* cleaning) + blabla='foooo', # (not a result key -- to be skipped before *any* cleaning) + until='spamspam', + min_amplification=4000*'foo bar', + rid='xxxxx', + ), + expected_raw=dict( + event_raw_base, + client=['o1', 'o3', 'o2'], # (restricted -- to be skipped before *value* cleaning) + dip='192.168.0.1', # (to be anonymized -> as 'adip') + fqdn='www.example.com', + until='spamspam', # (restricted -- to be skipped before *value* cleaning) + min_amplification=4000*'foo bar', # (restricted [custom] -- as above) + rid='xxxxx', # (restricted [+required] -- as above) + ), + expected_cleaned=dict( + cleaned_base, + adip='x.x.0.1', # ('dip' value after anonymization) + fqdn='www.example.com', + type='event', # (event_type value set *after* cleaning) + ), + ), + param( + event_type='bl-update', + event_data=dict(event_raw_base, **{ + 'client': [], # (empty `client` -- to be skipped before *any* cleaning) + 'address': [{'ip': '1.2.3.4', 'cc': 'pl', 'asn': '1.1'}], + 'adip': 'x.10.20.30', + 'dip': '192.168.0.1', + '_bl-series-no': 42, # (not a result field -- to be skipped before *any* cleaning) + 'type': 'barfoo', # (not a result field -- to be skipped before *any* cleaning) + }), + expected_raw=dict(event_raw_base, **{ + 'address': [{'ip': '1.2.3.4', 'cc': 'pl', 'asn': '1.1'}], + 'adip': 'x.10.20.30', + 'dip': '192.168.0.1', # (to be just omitted -- 'adip' is explicitly specified) + }), + expected_cleaned=dict( + cleaned_base, + address=[{'ip': '1.2.3.4', 'cc': 'PL', 'asn': 65537}], + adip='x.10.20.30', # (just given 'adip') + type='bl-update', # (event_type value set *after* cleaning) + ), + ), + # below -- the same two as above but with dip anonymization disabled + param( + event_type='event', + event_data=dict( + event_raw_base, + client=['o1', 'o3', 'o2'], + address=[], # (empty `address` -- to be skipped before *any* cleaning) + dip='192.168.0.1', + fqdn='www.example.com', + type='foobar', # (not a result key -- to be skipped before *any* cleaning) + blabla='foooo', # (not a result key -- to be skipped before *any* cleaning) + until='spamspam', + min_amplification=4000*'foo bar', + rid='xxxxx', + ), + expected_raw=dict( + event_raw_base, + client=['o1', 'o3', 'o2'], # (restricted -- to be skipped before *value* cleaning) + dip='192.168.0.1', # (to be *not* anonymized [sic]) + fqdn='www.example.com', + until='spamspam', # (restricted -- to be skipped before *value* cleaning) + min_amplification=4000*'foo bar', # (restricted [custom] -- as above) + rid='xxxxx', # (restricted [+required] -- as above) + ), + expected_cleaned=dict( + cleaned_base, + dip='192.168.0.1', # (*not* anonymized [sic]) + fqdn='www.example.com', + type='event', # (event_type value set *after* cleaning) + ), + dip_anonymization_disabled_source_ids=frozenset(['some.source']), + ), + param( + event_type='bl-update', + event_data=dict(event_raw_base, **{ + 'client': [], # (empty `client` -- to be skipped before *any* cleaning) + 'address': [{'ip': '1.2.3.4', 'cc': 'pl', 'asn': '1.1'}], + 'adip': 'x.10.20.30', + 'dip': '192.168.0.1', + '_bl-series-no': 42, # (not a result field -- to be skipped before *any* cleaning) + 'type': 'barfoo', # (not a result field -- to be skipped before *any* cleaning) + }), + expected_raw=dict(event_raw_base, **{ + 'address': [{'ip': '1.2.3.4', 'cc': 'pl', 'asn': '1.1'}], + 'adip': 'x.10.20.30', + 'dip': '192.168.0.1', + }), + expected_cleaned=dict( + cleaned_base, + address=[{'ip': '1.2.3.4', 'cc': 'PL', 'asn': 65537}], + adip='x.10.20.30', # (just given 'adip') + dip='192.168.0.1', # (just given 'dip' [sic]) + type='bl-update', # (event_type value set *after* cleaning) + ), + dip_anonymization_disabled_source_ids=frozenset(['some.source']), + ), + ) + def test_normal(self, event_type, event_data, expected_raw, expected_cleaned, + dip_anonymization_disabled_source_ids=frozenset()): + expected_auth_api_calls = [call.get_anonymized_source_mapping()] + if 'dip' in event_data: + expected_auth_api_calls.append(call.get_dip_anonymization_disabled_source_ids()) + self.mock.auth_api.get_dip_anonymization_disabled_source_ids.return_value = ( + dip_anonymization_disabled_source_ids) + + with patch('n6datapipeline.aux.anonymizer.LOGGER') as LOGGER_mock: + (raw_result_dict, + cleaned_result_dict, + output_body) = self.meth._get_result_dicts_and_output_body( + event_type, + event_data, + sen.resource_to_org_ids) + + self.assertEqual(raw_result_dict, expected_raw) + self.assertEqual(cleaned_result_dict, expected_cleaned) + self.assertEqual( + json.loads(output_body), + self._get_expected_body_content(expected_cleaned)) + self.assertCountEqual(self.mock.auth_api.mock_calls, expected_auth_api_calls) + self.assertFalse(LOGGER_mock.error.mock_calls) + + @staticmethod + def _get_expected_body_content(expected_cleaned): + formatted_time = expected_cleaned['time'].isoformat() + 'Z' + assert formatted_time[10] == 'T' and formatted_time[-1] == 'Z' + return dict( + expected_cleaned, + time=formatted_time) + + + @foreach( + param( + event_data=dict( + event_raw_base, + client=['o3', 'o1', 'o2'], + ), + without_keys={'id'}, + exc_type=ResultKeyCleaningError, + ).label('missing key: required and unrestricted'), + param( + event_data=dict( + event_raw_base, + client=['o3', 'o1', 'o2'], + ), + without_keys={'source'}, + exc_type=ResultKeyCleaningError, + ).label('missing key: required and anonymized'), + param( + event_data=dict( + event_raw_base, + client=['o3', 'o1', 'o2'], + ), + without_keys={'rid'}, + exc_type=ResultKeyCleaningError, + ).label('missing key: required and restricted'), + param( + event_data=dict( + event_raw_base, + client=['o3', 'o1', 'o2'], + id='spam', + ), + exc_type=ResultValueCleaningError, + ).label('illegal value for required and unrestricted key'), + param( + event_data=dict( + event_raw_base, + client=['o3', 'o1', 'o2'], + fqdn='foo..bar', + ), + exc_type=ResultValueCleaningError, + ).label('illegal value for optional and unrestricted key'), + param( + event_data=dict( + event_raw_base, + client=['o3', 'o1', 'o2'], + dip='spam', + ), + exc_type=ResultValueCleaningError, + ).label('illegal value for optional and anonymized-source key'), + param( + event_data=dict( + event_raw_base, + client=['o3', 'o1', 'o2'], + adip='spam', + ), + exc_type=ResultValueCleaningError, + ).label('illegal value for optional and anonymized-target key'), + ) + def test_error(self, event_data, exc_type, without_keys=()): + event_type = 'event' + event_data = event_data.copy() + for key in without_keys: + del event_data[key] + resource_to_org_ids = {'foo': {'bar'}, 'baz': {'spam', 'ham'}} + with patch('n6datapipeline.aux.anonymizer.LOGGER') as LOGGER_mock, \ + self.assertRaises(exc_type): + self.meth._get_result_dicts_and_output_body( + event_type, + event_data, + resource_to_org_ids) + self.assertEqual(len(LOGGER_mock.error.mock_calls), 1) + + + +@expand +class TestAnonymizer___publish_output_data(TestCaseMixin, unittest.TestCase): + + def setUp(self): + self.cleaned_result_dict = { + 'category': 'bots', + 'source': 'hidden.42', + } + self.mock = MagicMock(__class__=Anonymizer) + self.meth = MethodProxy(Anonymizer, self.mock, 'OUTPUT_RK_PATTERN') + + + @foreach( + param( + resource_to_org_ids={ + 'inside': ['o2', 'o3'], + 'threats': ['o3', 'o5', 'o8'], + }, + expected_publish_output_calls=[ + call( + routing_key='inside.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o3'}}, + ), + call( + routing_key='inside.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o2'}}, + ), + call( + routing_key='threats.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o8'}}, + ), + call( + routing_key='threats.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o5'}}, + ), + call( + routing_key='threats.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o3'}}, + ), + ], + ).label('for both resources'), + param( + resource_to_org_ids={ + 'inside': ['o2', 'o3'], + 'threats': [], + }, + expected_publish_output_calls=[ + call( + routing_key='inside.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o3'}}, + ), + call( + routing_key='inside.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o2'}}, + ), + ], + ).label('for "inside" only'), + param( + resource_to_org_ids={ + 'inside': [], + 'threats': ['o3', 'o5', 'o8'], + }, + expected_publish_output_calls=[ + call( + routing_key='threats.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o8'}}, + ), + call( + routing_key='threats.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o5'}}, + ), + call( + routing_key='threats.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o3'}}, + ), + ], + ).label('for "threats" only'), + param( + resource_to_org_ids={ + 'inside': [], + 'threats': [], + }, + expected_publish_output_calls=[], + ).label('for no resources'), + ) + def test_normal(self, resource_to_org_ids, expected_publish_output_calls): + with patch('n6datapipeline.aux.anonymizer.LOGGER') as LOGGER_mock: + self.meth._publish_output_data( + sen.event_type, + resource_to_org_ids, + sen.raw_result_dict, + self.cleaned_result_dict, + sen.output_body) + + self.assertEqual( + self.mock.publish_output.mock_calls, + expected_publish_output_calls) + self.assertFalse(LOGGER_mock.error.mock_calls) + + + def test_error(self): + resource_to_org_ids = { + 'inside': ['o2', 'o3'], + 'threats': ['o3', 'o5', 'o8'], + } + expected_publish_output_calls = [ + call( + routing_key='inside.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o3'}}, + ), + call( + routing_key='inside.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o2'}}, + ), + call( + routing_key='threats.bots.hidden.42', + body=sen.output_body, + prop_kwargs={'headers': {'n6-client-id': 'o8'}}, + ), + ] + exc_type = ZeroDivisionError # (just an example exception class) + self.mock.publish_output.side_effect = [ + None, + None, + exc_type, + ] + + with patch('n6datapipeline.aux.anonymizer.LOGGER') as LOGGER_mock, \ + self.assertRaises(exc_type): + self.meth._publish_output_data( + sen.event_type, + resource_to_org_ids, + sen.raw_result_dict, + self.cleaned_result_dict, + sen.output_body) + + self.assertEqual( + self.mock.publish_output.mock_calls, + expected_publish_output_calls) + self.assertEqual(LOGGER_mock.error.mock_calls, [ + call( + ANY, + 'threats', + 'o8', + sen.event_type, + sen.raw_result_dict, + 'threats.bots.hidden.42', + sen.output_body, + ( + "for the resource 'inside' -- " + "* skipped for the org ids: none; " + "* done for the org ids: 'o3', 'o2'; " + "for the resource 'threats' -- " + "* skipped for the org ids: 'o3', 'o5', 'o8'; " + "* done for the org ids: none" + ), + ), + ]) + + + +if __name__ == '__main__': + unittest.main() diff --git a/N6DataPipeline/n6datapipeline/tests/test_comparator.py b/N6DataPipeline/n6datapipeline/tests/test_comparator.py new file mode 100644 index 0000000..a275dc3 --- /dev/null +++ b/N6DataPipeline/n6datapipeline/tests/test_comparator.py @@ -0,0 +1,503 @@ +# Copyright (c) 2013-2021 NASK. All rights reserved. + +import unittest +import json + +from unittest.mock import ( + MagicMock, + call, + sentinel as sen, +) + +from n6datapipeline.comparator import ( + Comparator, + ComparatorData, + ComparatorDataWrapper, + ComparatorState, +) +from n6lib.unit_test_helpers import TestCaseMixin + + +class TestComparator__message_flow(TestCaseMixin, unittest.TestCase): + + def setUp(self): + + self.comparator = Comparator.__new__(Comparator) + self.comparator.comparator_config = MagicMock() + self.comparator._connection = MagicMock() + + self.comparator.state = ComparatorState(sen.irrelevant) + + self.patch_object(ComparatorDataWrapper, 'store_state') + self.comparator.db = ComparatorDataWrapper.__new__(ComparatorDataWrapper) + self.comparator.db.comp_data = ComparatorData() + + self.comparator.publish_output = MagicMock() + + def test_message_flow_basic(self): + + routing_key = 'bl-new.compared.source_test1.channel_test1' + + input_data_1_1 = { + '_bl-time': '2017-01-19 12:07:32', + '_bl-series-total': 1, + '_bl-series-no': 1, + '_bl-series-id': '11111111111111111111111111111111', + 'expires': '2017-01-20 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'source': 'source_test1.channel_test1', + 'id': '111111111119d9ab98f08761e7168ebd', + } + + expected_event_1_1 = { + '_bl-time': '2017-01-19 12:07:32', + 'type': 'bl-new', + 'expires': '2017-01-20 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'source': 'source_test1.channel_test1', + 'id': '111111111119d9ab98f08761e7168ebd', + } + + expected_calls_list = [ + call(body=expected_event_1_1, routing_key=routing_key) + ] + self.comparator._process_input(input_data_1_1) + + publish_output_call_args_list = self._get_deserialized_calls(self.comparator.publish_output.call_args_list) + self.assertEqual(publish_output_call_args_list, expected_calls_list) + + def test_message_flow_new_update_change_delist(self): + + # name pattern: data_type_runNo_srcNo_eventNo + # first run, + routing_key_1_1_1 = 'bl-new.compared.source_test1.channel_test1' + routing_key_1_1_2 = 'bl-new.compared.source_test1.channel_test1' + + input_data_1_1_1 = { + '_bl-time': '2017-01-19 12:07:32', + '_bl-series-total': 2, + '_bl-series-no': 1, + '_bl-series-id': '11111111111111111111111111111111', + 'expires': '2017-01-20 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'source': 'source_test1.channel_test1', + 'id': '9104c0ad2339d9ab98f08761e7168ebd', + } + + expected_event_1_1_1 = { + 'type': 'bl-new', + '_bl-time': '2017-01-19 12:07:32', + 'expires': '2017-01-20 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'id': '9104c0ad2339d9ab98f08761e7168ebd', + 'source': 'source_test1.channel_test1', + } + + input_data_1_1_2 = { + '_bl-time': '2017-01-19 12:07:32', + '_bl-series-total': 2, + '_bl-series-no': 2, + '_bl-series-id': '11111111111111111111111111111111', + 'expires': '2017-01-20 19:19:19', + 'time': '2017-01-18 19:19:19', + 'url': 'http://www.tests.pl', + 'address': [{ + 'cc': 'XX', + 'ip': '2.2.2.2', + 'asn': 3215 + }], + 'source': 'source_test1.channel_test1', + 'id': '1c9a2638b51f334da3d2311e01817884', + } + + expected_event_1_1_2 = { + 'type': 'bl-new', + '_bl-time': '2017-01-19 12:07:32', + 'expires': '2017-01-20 19:19:19', + 'time': '2017-01-18 19:19:19', + 'url': 'http://www.tests.pl', + 'address': [{ + 'cc': 'XX', + 'ip': '2.2.2.2', + 'asn': 3215 + }], + 'id': '1c9a2638b51f334da3d2311e01817884', + 'source': 'source_test1.channel_test1', + } + + # Second run, + # 2. 1. msg bl-update, 2. bl-change + routing_key_2_1_1 = 'bl-update.compared.source_test1.channel_test1' + routing_key_2_1_2 = 'bl-change.compared.source_test1.channel_test1' + + input_data_2_1_1 = { + '_bl-time': '2017-01-19 12:13:36', + '_bl-series-total': 2, + '_bl-series-no': 1, + '_bl-series-id': '22222222222222222222222222222222', + 'expires': '2017-01-21 15:15:15', + 'time': '2017-01-19 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'source': 'source_test1.channel_test1', + 'id': '4273a190e57da23c1dee67a7689e115a', + } + + expected_event_2_1_1 = { + 'type': 'bl-update', + '_bl-time': '2017-01-19 12:07:32', + 'expires': '2017-01-21 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'id': '9104c0ad2339d9ab98f08761e7168ebd', + 'source': 'source_test1.channel_test1', + } + input_data_2_1_2 = { + '_bl-time': '2017-01-19 14:14:14', + '_bl-series-total': 2, + '_bl-series-no': 2, + '_bl-series-id': '22222222222222222222222222222222', + 'expires': '2017-01-21 18:18:18', + 'time': '2017-01-19 19:19:19', + 'url': 'http://www.tests.pl', + 'address': [ + { + 'cc': 'XX', + 'ip': '2.2.2.2', + 'asn': 3215 + }, + { + 'cc': 'XX', + 'ip': '22.22.22.22', + 'asn': 3215 + } + ], + 'source': 'source_test1.channel_test1', + 'id': '929c840e0dec26e26410aeeac418067d', + } + + expected_event_2_1_2 = { + 'type': 'bl-change', + '_bl-time': '2017-01-19 14:14:14', + 'expires': '2017-01-21 18:18:18', + 'time': '2017-01-19 19:19:19', + 'url': 'http://www.tests.pl', + 'address': [ + { + 'cc': 'XX', + 'ip': '2.2.2.2', + 'asn': 3215 + }, + { + 'cc': 'XX', + 'ip': '22.22.22.22', + 'asn': 3215 + } + ], + 'id': '929c840e0dec26e26410aeeac418067d', + 'source': 'source_test1.channel_test1', + 'replaces': '1c9a2638b51f334da3d2311e01817884', + } + + # third run, + # one (3.) bl-new and 1., 2. bl-delist (old events) + routing_key_3_1_1 = 'bl-delist.compared.source_test1.channel_test1' + routing_key_3_1_2 = 'bl-delist.compared.source_test1.channel_test1' + routing_key_3_1_3 = 'bl-new.compared.source_test1.channel_test1' + + input_data_3_1_3 = { + '_bl-time': '2017-01-20 10:10:10', + '_bl-series-no': 1, + '_bl-series-total': 1, + '_bl-series-id': '33333333333333333333333333333333', + 'expires': '2017-01-22 10:10:10', + 'time': '2017-01-20 10:10:10', + 'address': [{ + 'cc': 'XX', + 'ip': '3.3.3.3' + }], + 'id': 'ed928c2322422b2a8e419b00426fbcb0', + 'source': 'source_test1.channel_test1', + } + + expected_event_3_1_1 = { + 'type': 'bl-delist', + '_bl-time': '2017-01-19 12:07:32', + 'time': '2017-01-18 15:15:15', + 'expires': '2017-01-21 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'id': '9104c0ad2339d9ab98f08761e7168ebd', + 'source': 'source_test1.channel_test1', + } + + expected_event_3_1_2 = { + 'type': 'bl-delist', + 'expires': '2017-01-21 18:18:18', + 'address': [ + { + 'cc': 'XX', + 'ip': '2.2.2.2', + 'asn': 3215, + }, + { + 'cc': 'XX', + 'ip': '22.22.22.22', + 'asn': 3215, + } + ], + 'id': '929c840e0dec26e26410aeeac418067d', + '_bl-time': '2017-01-19 14:14:14', + 'replaces': '1c9a2638b51f334da3d2311e01817884', + 'url': 'http://www.tests.pl', + 'source': 'source_test1.channel_test1', + 'time': '2017-01-19 19:19:19', + } + + expected_event_3_1_3 = { + 'type': 'bl-new', + '_bl-time': '2017-01-20 10:10:10', + 'expires': '2017-01-22 10:10:10', + 'time': '2017-01-20 10:10:10', + 'address': [{ + 'cc': 'XX', + 'ip': '3.3.3.3' + }], + 'id': 'ed928c2322422b2a8e419b00426fbcb0', + 'source': 'source_test1.channel_test1', + } + + self.comparator._process_input(input_data_1_1_1) + self.comparator._process_input(input_data_1_1_2) + self.comparator._process_input(input_data_2_1_1) + self.comparator._process_input(input_data_2_1_2) + self.comparator._process_input(input_data_3_1_3) + + expected_calls_list = [ + call(body=expected_event_1_1_1, routing_key=routing_key_1_1_1), + call(body=expected_event_1_1_2, routing_key=routing_key_1_1_2), + call(body=expected_event_2_1_1, routing_key=routing_key_2_1_1), + call(body=expected_event_2_1_2, routing_key=routing_key_2_1_2), + call(body=expected_event_3_1_3, routing_key=routing_key_3_1_3), + call(body=expected_event_3_1_1, routing_key=routing_key_3_1_1), + call(body=expected_event_3_1_2, routing_key=routing_key_3_1_2), + ] + + publish_output_call_args_list = self._get_deserialized_calls(self.comparator.publish_output.call_args_list) + self.assertListEqual(publish_output_call_args_list, expected_calls_list) + + def test_message_flow_new_but_expired_msg(self): + + # _bl-time > expires -> bl-new, and bl-expires + routing_key_1 = 'bl-new.compared.source_test1.channel_test1' + routing_key_2 = 'bl-expire.compared.source_test1.channel_test1' + + input_data_1_1 = { + '_bl-time': '2017-01-24 10:10:10', + '_bl-series-no': 1, + '_bl-series-total': 1, + '_bl-series-id': '44444444444444444444444444444444', + 'expires': '2017-01-22 10:10:10', + 'time': '2017-01-20 10:10:10', + 'address': [{ + 'cc': 'XX', + 'ip': '4.4.4.4' + }], + 'id': 'ed928c2322422b2a8e419b00426fbcb0', + 'source': 'source_test1.channel_test1', + } + + expected_event_1_1 = { + 'type': 'bl-new', + 'expires': '2017-01-22 10:10:10', + 'address': [{ + 'cc': 'XX', + 'ip': '4.4.4.4' + }], + 'id': 'ed928c2322422b2a8e419b00426fbcb0', + '_bl-time': '2017-01-24 10:10:10', + 'source': 'source_test1.channel_test1', + 'time': '2017-01-20 10:10:10', + } + + expected_event_1_2 = { + 'type': 'bl-expire', + 'expires': '2017-01-22 10:10:10', + 'address': [{ + 'cc': 'XX', + 'ip': '4.4.4.4' + }], + 'id': 'ed928c2322422b2a8e419b00426fbcb0', + '_bl-time': '2017-01-24 10:10:10', + 'source': 'source_test1.channel_test1', + 'time': '2017-01-20 10:10:10', + } + + expected_calls_list = [ + call(body=expected_event_1_1, routing_key=routing_key_1), + call(body=expected_event_1_2, routing_key=routing_key_2), + ] + + self.comparator._process_input(input_data_1_1) + publish_output_call_args_list = self._get_deserialized_calls(self.comparator.publish_output.call_args_list) + self.assertListEqual(publish_output_call_args_list, expected_calls_list) + + def test_message_flow_two_msgs_with_the_same_ip(self): + + routing_key = 'bl-new.compared.source_test1.channel_test1' + + input_data_1_1 = { + '_bl-time': '2017-01-19 12:07:32', + '_bl-series-total': 2, + '_bl-series-no': 1, + '_bl-series-id': '11111111111111111111111111111111', + 'expires': '2017-01-20 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'source': 'source_test1.channel_test1', + 'id': '111111111119d9ab98f08761e7168ebd', + } + + input_data_1_2 = { + '_bl-time': '2017-01-19 12:07:32', + '_bl-series-total': 2, + '_bl-series-no': 2, + '_bl-series-id': '11111111111111111111111111111111', + 'expires': '2017-01-20 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'source': 'source_test1.channel_test1', + 'id': '2222222d2339d9ab98f08761e7168ebd', + } + + expected_event_1_1 = { + '_bl-time': '2017-01-19 12:07:32', + 'type': 'bl-new', + 'expires': '2017-01-20 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'id': '111111111119d9ab98f08761e7168ebd', + 'source': 'source_test1.channel_test1', + } + + expected_calls_list = [ + call(body=expected_event_1_1, routing_key=routing_key) + ] + + self.comparator._process_input(input_data_1_1) + self.comparator._process_input(input_data_1_2) + + publish_output_call_args_list = self._get_deserialized_calls(self.comparator.publish_output.call_args_list) + self.assertEqual(publish_output_call_args_list, expected_calls_list) + + def test_message_flow_msg_from_different_sources_with_the_same_ip(self): + + routing_key_1 = 'bl-new.compared.source_test1.channel_test1' + routing_key_2 = 'bl-new.compared.source_test2.channel_test2' + + input_data_1_1 = { + '_bl-time': '2017-01-19 12:07:32', + '_bl-series-total': 1, + '_bl-series-no': 1, + '_bl-series-id': '11111111111111111111111111111111', + 'expires': '2017-01-20 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'source': 'source_test1.channel_test1', + 'id': '9104c0ad2339d9ab98f08761e7168ebd', + } + + expected_event_1_1 = { + 'type': 'bl-new', + '_bl-time': '2017-01-19 12:07:32', + 'expires': '2017-01-20 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'id': '9104c0ad2339d9ab98f08761e7168ebd', + 'source': 'source_test1.channel_test1', + } + + input_data_2_1 = { + '_bl-time': '2017-01-19 10:10:10', + '_bl-series-total': 1, + '_bl-series-no': 1, + '_bl-series-id': '22222222222222222222222222222222', + 'expires': '2017-01-20 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'source': 'source_test2.channel_test2', + 'id': '23f3b0f7fc3db9ab98f08761e7168ebd', + } + + expected_event_2_1 = { + 'type': 'bl-new', + '_bl-time': '2017-01-19 10:10:10', + 'expires': '2017-01-20 15:15:15', + 'time': '2017-01-18 15:15:15', + 'address': [{ + 'cc': 'XX', + 'ip': '1.1.1.1' + }], + 'id': '23f3b0f7fc3db9ab98f08761e7168ebd', + 'source': 'source_test2.channel_test2', + } + + expected_calls_list = [ + call(body=expected_event_1_1, routing_key=routing_key_1), + call(body=expected_event_2_1, routing_key=routing_key_2) + ] + + self.comparator._process_input(input_data_1_1) + self.comparator._process_input(input_data_2_1) + + publish_output_call_args_list = self._get_deserialized_calls(self.comparator.publish_output.call_args_list) + self.assertEqual(publish_output_call_args_list, expected_calls_list) + + def _get_deserialized_calls(self, calls_list): + deserialized_call_list = [] + for call_ in calls_list: + _, call_kwargs = call_ + new_body = json.loads(call_kwargs['body']) + deserialized_call_list.append(call(body=new_body, routing_key=call_kwargs['routing_key'])) + return deserialized_call_list diff --git a/N6DataPipeline/n6datapipeline/tests/test_filter.py b/N6DataPipeline/n6datapipeline/tests/test_filter.py new file mode 100644 index 0000000..130fb6b --- /dev/null +++ b/N6DataPipeline/n6datapipeline/tests/test_filter.py @@ -0,0 +1,1068 @@ +# Copyright (c) 2013-2021 NASK. All rights reserved. + +import json +import unittest +from unittest.mock import MagicMock, call + +from n6datapipeline.base import LegacyQueuedBase +from n6datapipeline.filter import Filter +from n6lib.auth_api import InsideCriteriaResolver +from n6lib.record_dict import RecordDict, AdjusterError + + + +## maybe TODO later: clean-up/refactor the stuff in this module... + +# (note that the main job of *filter* -- i.e., determining which +# organization ids the event's `client` attribute should include +# -- is covered also by comprehensive tests implemented within the +# n6lib.tests.test_auth_api.TestInsideCriteriaResolver_initialization and +# n6lib.tests.test_auth_api.TestInsideCriteriaResolver__get_client_and_urls_matched_org_ids +# classes) + + + +TEST_CRITERIA = ( + [ + {'org_id': 'afbc', + 'cc_seq': ['AL'], + 'asn_seq': [43756], + 'fqdn_seq': [u'mycertbridgeonetalamakotawpmikmoknask.org', + u'alamakota.biz', + u'mikmokcertmakabimynask.net'], + 'ip_min_max_seq': [(2334252224, 2334252227)]}, + {'org_id': 'fdc', + 'cc_seq': ['SU', 'RU'], + 'asn_seq': [45975, 13799], + 'fqdn_seq': [u'onetbridgemikmokcert.eu', u'mikmokcertalamakota.org', u'mikmoknaskwp.info'], + 'ip_min_max_seq': [(2589577040, 2589577041)]}, + {'org_id': 'edca', + 'cc_seq': ['DD', 'DD'], + 'asn_seq': [8262, 4079], + 'fqdn_seq': [u'virut.eu', u'bridgealamakotawpvirut.eu', u'bridge.biz'], + 'ip_min_max_seq': [(653221832, 653221835)]}, + {'org_id': 'bdc', + 'cc_seq': ['GU', 'GU'], + 'asn_seq': [10546, 63520], + 'fqdn_seq': [u'certmikmokonetnaskvirutmakabiforcewpmybridgealamakota.org', + u'forcemyonetvirutbridgemikmokwpnaskmakabi.info', + u'virut.net'], + 'ip_min_max_seq': [(494991530, 494991530)]}, + {'org_id': 'befa', + 'cc_seq': ['CX', 'US'], + 'asn_seq': [31110, 26648], + 'fqdn_seq': [u'makabimikmokvirutonet.biz'], + 'ip_min_max_seq': [(3228707569, 3228707578)]}, + {'org_id': 'ebcadf', + 'cc_seq': ['US', 'GU'], + 'asn_seq': [52042], + 'fqdn_seq': [u'alamakotamakabi.pl', u'forcemakabivirutcert.com'], + 'ip_min_max_seq': [(1787298955, 1787298955)]}, + {'org_id': 'cfa', + 'cc_seq': ['DD', 'AI'], + 'asn_seq': [59009, 39165, 43185], + 'fqdn_seq': [u'alamakotabridge.pl', + u'makabibridgevirutmycertnaskonetalamakotawpforcemikmok.biz', + u'naskmikmok.eu'], + 'ip_min_max_seq': [(1378497104, 1378497107)]}, + {'org_id': 'eabf', + 'cc_seq': ['AL'], + 'asn_seq': [33151, 61490, 57963], + 'fqdn_seq': [u'wpbridgemakabialamakota.pl', + u'bridgemakabialamakotamikmokonetforcenaskmywpvirutcert.org', + u'onetmikmokwpbridgecert.ru'], + 'ip_min_max_seq': [(1007811092, 1007811093), (1007811094, 1007811095)]}, + {'org_id': 'caebf', + 'cc_seq': ['DD'], + 'asn_seq': [40051, 39020, 61348], + 'fqdn_seq': [u'naskalamakotaonet.info'], + 'ip_min_max_seq': [(2031422565, 2031422565)]}, + {'org_id': 'decfba', + 'cc_seq': ['SU', 'RU'], + 'asn_seq': [21463], + 'fqdn_seq': [u'mikmokalamakota.eu'], + 'ip_min_max_seq': [(1292036523, 1292036525)]}, + {'org_id': 'cli16bit', + 'cc_seq': ['PL'], + 'asn_seq': [21467], + 'fqdn_seq': [u'ala.eu', u'król.pl'], + 'ip_min_max_seq': [(1292042241, 1292107775), (1292036523, 1292042239)]}, + ]) + + +class TestFilter(unittest.TestCase): + + def setUp(self): + self.filter = Filter.__new__(Filter) + self.per_test_inside_criteria = None # to be set in methods that need it + self.filter.auth_api = self._make_auth_api_mock() + self.fqdn_only_categories = frozenset(['leak']) + + def _make_auth_api_mock(self): + m = MagicMock() + m.get_inside_criteria_resolver.side_effect = ( + lambda: InsideCriteriaResolver(self.per_test_inside_criteria)) + return m + + def tearDown(self): + assert all( + c == call.get_inside_criteria_resolver() + for c in self.filter.auth_api.mock_calls), 'test must be updated?' + + + def test_parameters_queue(self): + """Test parameters queue.""" + self.assertTrue(issubclass(Filter, LegacyQueuedBase)) + self.assertEqual(Filter.input_queue['exchange'], + 'event') + self.assertEqual(Filter.input_queue['exchange_type'], + 'topic') + self.assertEqual(Filter.input_queue['queue_name'], + 'filter') + self.assertEqual(Filter.input_queue['accepted_event_types'], + ['event', + 'bl-new', + 'bl-update', + 'bl-change', + 'bl-delist', + 'bl-expire', + 'suppressed']) + self.assertEqual(Filter.output_queue['exchange'], + 'event') + self.assertEqual(Filter.output_queue['exchange_type'], + 'topic') + + def reset_body(self, d): + d['address'][0]['cc'] = 'XX' + d['address'][1]['cc'] = 'XX' + d['address'][2]['cc'] = 'XX' + d['address'][0]['asn'] = '1' + d['address'][1]['asn'] = '1' + d['address'][2]['asn'] = '1' + d['address'][0]['ip'] = '0.0.0.0' + d['address'][1]['ip'] = '0.0.0.1' + d['address'][2]['ip'] = '0.0.0.2' + return d + + def test__get_client_and_urls_matched__1(self): + body = {"category": "bots", "restriction": "public", "confidence": "medium", + "sha1": "023a00e7c2ef04ee5b0f767ba73ee39734323432", "name": "virut", + "proto": "tcp", "address": [{"cc": "XX", "ip": "139.33.220.192", "asn": "1"}, + {"cc": "XX", "ip": "100.71.83.178", "asn": "1"}, + {"cc": "XX", "ip": "102.71.83.178", "asn": "1"}], + "fqdn": "domain.com", "url": "http://onet.pl", "source": "hpfeeds.dionaea", + "time": "2013-07-01 20:37:20", "dport": "445", + "rid": "023a00e7c2ef04ee5b0f767ba73ee397", + "sport": "2147", "dip": "10.28.71.43", "id": "023a00e7c2ef04ee5b0f767ba73ee397"} + + # tested key:[test_value,[valid value]] + input_data = {'ip': ['139.33.220.192', ['afbc']], + 'cc': ['AL', ['afbc', 'eabf']], + 'asn': ['43756', ['afbc']], + 'fqdn': ['mycertbridgeonetalamakotawpmikmoknask.org', ['afbc']]} + + self.per_test_inside_criteria = TEST_CRITERIA + + for i in input_data: + body = self.reset_body(body) + if i == 'fqdn': + body['fqdn'] = input_data[i][0] + else: + body['address'][0][i] = input_data[i][0] + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (input_data[i][1], {})) + + def test__get_client_and_urls_matched__2(self): + body = {"category": "bots", "restriction": "public", "confidence": "medium", + "sha1": "023a00e7c2ef04ee5b0f767ba73ee39734323432", "name": "virut", + "proto": "tcp", "address": [{"cc": "XX", "ip": "1.1.1.1", "asn": "1"}, + {"cc": "XX", "ip": "100.71.83.178", "asn": "1"}, + {"cc": "XX", "ip": "102.71.83.178", "asn": "1"}], + "fqdn": "domain.com", "url": "http://onet.pl", "source": "hpfeeds.dionaea", + "time": "2013-07-01 20:37:20", "dport": "445", + "rid": "023a00e7c2ef04ee5b0f767ba73ee397", + "sport": "2147", "dip": "10.28.71.43", "id": "023a00e7c2ef04ee5b0f767ba73ee397"} + + # tested key:[test_value,[valid value]] + input_data = {'ip': ['154.89.207.81', ['fdc']], + 'cc': ['SU', ['decfba', 'fdc']], + 'asn': ['45975', ['fdc']], + 'fqdn': ['onetbridgemikmokcert.eu', ['fdc']]} + + self.per_test_inside_criteria = TEST_CRITERIA + + for i in input_data: + body = self.reset_body(body) + if i == 'fqdn': + body['fqdn'] = input_data[i][0] + else: + body['address'][0][i] = input_data[i][0] + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertCountEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (input_data[i][1], {})) + + def test__get_client_and_urls_matched__3(self): + body = {"category": "bots", "restriction": "public", "confidence": "medium", + "sha1": "023a00e7c2ef04ee5b0f767ba73ee39734323432", "name": "virut", + "proto": "tcp", "address": [{"cc": "XX", "ip": "1.1.1.1", "asn": "1"}, + {"cc": "XX", "ip": "100.71.83.178", "asn": "1"}, + {"cc": "XX", "ip": "102.71.83.178", "asn": "1"}], + "fqdn": "domain.com", "url": "http://onet.pl", "source": "hpfeeds.dionaea", + "time": "2013-07-01 20:37:20", "dport": "445", + "rid": "023a00e7c2ef04ee5b0f767ba73ee397", + "sport": "2147", "dip": "10.28.71.43", "id": "023a00e7c2ef04ee5b0f767ba73ee397"} + + # tested key:[test_value,[valid value]] + input_data = {'ip': ['192.114.42.241', ['befa']], + 'cc': ['CX', ['befa']], + 'asn': ['31110', ['befa']], + 'fqdn': ['makabimikmokvirutonet.biz', ['befa']]} + + self.per_test_inside_criteria = TEST_CRITERIA + for i in input_data: + body = self.reset_body(body) + if i == 'fqdn': + body['fqdn'] = input_data[i][0] + else: + body['address'][0][i] = input_data[i][0] + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertCountEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (input_data[i][1], {})) + + # tested key:[test_value,[valid value]] + input_data = {'ip': ['192.114.42.242', ['befa']], + 'cc': ['CX', ['befa']], + 'asn': ['31110', ['befa']], + 'fqdn': ['makabimikmokvirutonet.biz', ['befa']]} + for i in input_data: + body = self.reset_body(body) + if i == 'fqdn': + body['fqdn'] = input_data[i][0] + else: + body['address'][0][i] = input_data[i][0] + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertCountEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (input_data[i][1], {})) + + def test__get_client_and_urls_matched__empty_cc(self): + test_criteria_local = [ + {'org_id': 'befa', + 'cc_seq': ['CX', 'US'], + 'asn_seq': [31110, 26648], + 'fqdn_seq': [u'makabimikmokvirutonet.biz'], + 'ip_min_max_seq': [(3228707569, 3228707578)], + }, + ] + body = {"category": "bots", "restriction": "public", "confidence": "medium", + "sha1": "023a00e7c2ef04ee5b0f767ba73ee39734323432", "name": "virut", + "proto": "tcp", "address": [{"cc": "XX", "ip": "1.1.1.1", "asn": "1"}, + {"cc": "XX", "ip": "100.71.83.178", "asn": "1"}, + {"cc": "XX", "ip": "102.71.83.178", "asn": "1"}], + "fqdn": "domain.com", "url": "http://onet.pl", "source": "hpfeeds.dionaea", + "time": "2013-07-01 20:37:20", "dport": "445", + "rid": "023a00e7c2ef04ee5b0f767ba73ee397", + "sport": "2147", "dip": "10.28.71.43", "id": "023a00e7c2ef04ee5b0f767ba73ee397"} + + # tested key:[test_value,[valid value]] + input_data = {'ip': ['192.114.42.241', ['befa']], + # 'cc':['CX', ['befa']], + 'asn': ['31110', ['befa']], + 'fqdn': ['makabimikmokvirutonet.biz', ['befa']]} + + self.per_test_inside_criteria = test_criteria_local + for i in input_data: + body = self.reset_body(body) + if i == 'fqdn': + body['fqdn'] = input_data[i][0] + else: + body['address'][0][i] = input_data[i][0] + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertCountEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (input_data[i][1], {})) + + def test__get_client_and_urls_matched__range_ip(self): + test_criteria_local = [ + { + 'org_id': 'cli16bit', + 'cc_seq': ['PL'], + 'asn_seq': [21467], + 'fqdn_seq': [u'ala.eu'], + 'ip_min_max_seq': TEST_CRITERIA[-1]['ip_min_max_seq'], + } + ] + body = self.prepare_mock(test_criteria_local) + + body['address'][0]['ip'] = '77.2.233.171' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['address'][0]['ip'] = '77.2.233.250' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['address'][0]['ip'] = '77.2.254.250' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['address'][0]['ip'] = '77.3.100.1' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['address'][0]['ip'] = '77.3.255.255' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + # test outside the network + body['address'][0]['ip'] = '77.2.233.170' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + body['address'][0]['ip'] = '77.2.233.0' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + body['address'][0]['ip'] = '77.1.233.171' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + body['address'][0]['ip'] = '77.3.0.0' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + body['address'][0]['ip'] = '77.4.0.0' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + def test__get_client_and_urls_matched__single_ip(self): + test_criteria_local = [ + { + 'org_id': 'cli16bit', + 'cc_seq': ['PL'], + 'asn_seq': [21467], + 'fqdn_seq': [u'ala.eu'], + 'ip_min_max_seq': [(1292036523, 1292036523)], # '77.2.233.171' + } + ] + data = self.prepare_mock(test_criteria_local) + data['address'][0]['ip'] = '77.2.233.171' + record_dict = RecordDict(data) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + for other_ip in [ + '77.2.233.170', + '77.2.233.172', + '77.2.233.250', + '77.2.233.0', + '77.2.254.250', + '77.3.100.1', + '77.3.255.255', + '77.1.233.171', + '77.3.0.0', + '77.4.0.0', + '1.2.3.4', + '10.20.30.40', + '0.0.0.0', + '255.255.255.255', + ]: + data['address'][0]['ip'] = other_ip + record_dict = RecordDict(data) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + def test__get_client_and_urls_matched__cc(self): + test_criteria_local = [ + {'org_id': 'cli16bit', + 'cc_seq': ['PL'], + 'asn_seq': [21467], + 'fqdn_seq': [u'ala.eu'], + 'ip_min_max_seq': TEST_CRITERIA[-1]['ip_min_max_seq'], }] + body = self.prepare_mock(test_criteria_local) + + body['address'][0]['cc'] = 'PL' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['address'][0]['cc'] = 'pl' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['address'][0]['cc'] = 'Pl' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + # test outside the network + body['address'][0]['cc'] = 'EU' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + def test__get_client_and_urls_matched__asn(self): + test_criteria_local = [ + {'org_id': 'cli16bit', + 'cc_seq': ['PL'], + 'asn_seq': [21467], + 'fqdn_seq': [u'ala.eu'], + 'ip_min_max_seq': TEST_CRITERIA[-1]['ip_min_max_seq'], }] + body = self.prepare_mock(test_criteria_local) + + body['address'][0]['asn'] = '21467' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['address'][0]['asn'] = ' 21467' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['address'][0]['asn'] = '21467 ' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['address'][0]['asn'] = '0021467' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['address'][0]['asn'] = 21467 + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + # test outside the network + body['address'][0]['asn'] = '21466' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + body['address'][0]['asn'] = '21468' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + body['address'][0]['asn'] = '21468' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + body['address'][0]['asn'] = '0' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + def test__get_client_and_urls_matched__fqdn_seq(self): + test_criteria_local = [ + {'org_id': 'cli16bit', + 'cc_seq': ['PL'], + 'asn_seq': [21467], + 'fqdn_seq': [u'ala.eu', u'xxx.org', u'aaa.aa'], + 'ip_min_max_seq': TEST_CRITERIA[-1]['ip_min_max_seq'], }] + body = self.prepare_mock(test_criteria_local) + + body['fqdn'] = 'ala.eu' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['fqdn'] = 'xxx.org' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['fqdn'] = 'aaa.aa' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + body['fqdn'] = u'aaa.aa' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['cli16bit'], {})) + + # test outside the network + body['fqdn'] = 'xxx.eu' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + def test__get_client_and_urls_matched__empty_fileds_asn_ip_cc_fqdn_address(self): + test_criteria_local = [{'org_id': 'org1', + 'cc_seq': ["PL", "DE", "US"], + 'asn_seq': [42, 555, 12312], + 'fqdn_seq': [u"nask.pl", u"onet.pl"], + 'ip_min_max_seq': TEST_CRITERIA[-1]['ip_min_max_seq'], }, + {'org_id': 'org2', + 'cc_seq': ["RU", "DE", "US"], + 'asn_seq': [4235], + 'fqdn_seq': [u"nask.pl", u"cert.pl"], + 'ip_min_max_seq': [(0, 4194303), + (4294901760, 4294901792)], }] + + body = {"category": "bots", "restriction": "public", "confidence": "medium", + "sha1": "023a00e7c2ef04ee5b0f767ba73ee39734323432", "name": "virut", + "proto": "tcp", "address": [{"cc": "XX", "ip": "1.1.1.1", "asn": "1"}], + "fqdn": "domain.com", "url": "http://onet.pl", "source": "hpfeeds.dionaea", + "time": "2013-07-01 20:37:20", "dport": "445", + "rid": "023a00e7c2ef04ee5b0f767ba73ee397", + "sport": "2147", "dip": "10.28.71.43", "id": "023a00e7c2ef04ee5b0f767ba73ee397"} + + # test_all_fields + self.per_test_inside_criteria = test_criteria_local + body['fqdn'] = 'onet.pl' + body['address'][0]['cc'] = 'GH' + body['address'][0]['asn'] = '1234' + body['address'][0]['ip'] = '73.2.233.171' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertCountEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org1'], {})) + + # test test_empty_fqdn + if 'fqdn' in body: + del body['fqdn'] + body['address'][0]['cc'] = 'GH' + body['address'][0]['asn'] = '1234' + body['address'][0]['ip'] = '73.2.233.171' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertCountEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + # test test_empty_cc + if 'cc' in body['address'][0]: + del body['address'][0]['cc'] + body['fqdn'] = 'onet.pl' + body['address'][0]['asn'] = '4235' + body['address'][0]['ip'] = '73.2.233.171' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertCountEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org1', 'org2'], {})) + + # test test_empty_asn + if 'asn' in body['address'][0]: + del body['address'][0]['asn'] + body['fqdn'] = 'www.onet.pl' + body['address'][0]['cc'] = 'XX' + body['address'][0]['ip'] = '73.2.233.171' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertCountEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org1'], {})) + + # test test_empty_ip + if 'ip' in body['address'][0]: + del body['address'][0]['ip'] + body['fqdn'] = 'www.onet.com' + body['address'][0]['cc'] = 'XX' + body['address'][0]['asn'] = '1234' + json_msg = json.dumps(body) + self.assertRaises(AdjusterError, RecordDict.from_json, json_msg) + + # test test_empty_address + if 'address' in body: + del body['address'] + body['fqdn'] = 'www.onet.pl' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertCountEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org1'], {})) + + # test test_empty_all + if 'address' in body: + del body['address'] + if 'fqdn' in body: + del body['fqdn'] + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertCountEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + + body = {"category": "bots", "restriction": "public", "confidence": "medium", + "sha1": "023a00e7c2ef04ee5b0f767ba73ee39734323432", "name": "virut", + "proto": "tcp", "address": [{"cc": "XX", "ip": "1.1.1.1", "asn": "1"}], + "fqdn": "domain.com", "url": "http://onet.pl", "source": "hpfeeds.dionaea", + "time": "2013-07-01 20:37:20", "dport": "445", + "rid": "023a00e7c2ef04ee5b0f767ba73ee397", + "sport": "2147", "dip": "10.28.71.43", "id": "023a00e7c2ef04ee5b0f767ba73ee397"} + # test test_empty_ip_asn + if 'ip' in body['address'][0]: + del body['address'][0]['ip'] + if 'asn' in body['address']: + del body['address'][0]['asn'] + body['fqdn'] = 'www.onet.pl' + body['address'][0]['cc'] = 'PL' + json_msg = json.dumps(body) + self.assertRaises(AdjusterError, RecordDict.from_json, json_msg) + + # test test_empty_ip_cc + if 'ip' in body['address'][0]: + del body['address'][0]['ip'] + if 'cc' in body['address'][0]: + del body['address'][0]['cc'] + body['fqdn'] = 'www.onet.pl' + body['address'][0]['asn'] = '1234' + json_msg = json.dumps(body) + self.assertRaises(AdjusterError, RecordDict.from_json, json_msg) + + # test test_empty_ip_fqdn + if 'ip' in body['address'][0]: + del body['address'][0]['ip'] + if 'fqdn' in body: + del body['fqdn'] + body['address'][0]['asn'] = '1234' + body['address'][0]['cc'] = 'PL' + json_msg = json.dumps(body) + self.assertRaises(AdjusterError, RecordDict.from_json, json_msg) + + # test test_empty_asn_cc + if 'asn' in body['address'][0]: + del body['address'][0]['asn'] + if 'cc' in body['address'][0]: + del body['address'][0]['cc'] + body['fqdn'] = 'www' + body['address'][0]['ip'] = '77.2.233.171' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org1'], {})) + + # test test_empty_asn_cc + if 'asn' in body['address'][0]: + del body['address'][0]['asn'] + if 'fqdn' in body: + del body['fqdn'] + body['address'][0]['cc'] = 'PL' + body['address'][0]['ip'] = '77.2.233.171' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org1'], {})) + + # test test_empty_fqdn_cc + if 'cc' in body['address'][0]: + del body['address'][0]['cc'] + if 'fqdn' in body: + del body['fqdn'] + body['address'][0]['asn'] = '42' + body['address'][0]['ip'] = '7.2.233.171' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org1'], {})) + + # test test_empty_asn_ip_fqdn + if 'asn' in body['address'][0]: + del body['address'][0]['asn'] + if 'fqdn' in body: + del body['fqdn'] + if 'ip' in body['address'][0]: + del body['address'][0]['ip'] + body['address'][0]['cc'] = 'PL' + json_msg = json.dumps(body) + self.assertRaises(AdjusterError, RecordDict.from_json, json_msg) + + # test test_empty_asn_cc_fqdn + if 'asn' in body['address'][0]: + del body['address'][0]['asn'] + if 'fqdn' in body: + del body['fqdn'] + if 'cc' in body['address'][0]: + del body['address'][0]['cc'] + body['address'][0]['ip'] = '77.2.233.171' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org1'], {})) + + # test test_empty_asn_cc_fqdn + if 'asn' in body['address'][0]: + del body['address'][0]['asn'] + if 'fqdn' in body: + del body['fqdn'] + if 'cc' in body['address'][0]: + del body['address'][0]['cc'] + body['address'][0]['ip'] = '77.2.233.171' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org1'], {})) + + def test__get_client_and_urls_matched__no_fqdn_seq(self): + test_criteria_local = [ + {'org_id': 'org4', + 'asn_seq': [21467], + 'fqdn_seq': [''], }] + body = self.prepare_mock(test_criteria_local) + # test outside network + body['address'][0]['cc'] = 'GU' + body['address'][0]['asn'] = 21467 + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org4'], {})) + + def test__get_client_and_urls_matched__url_pattern(self): + test_criteria_local = [ + {'org_id': 'org4', + 'asn_seq': [], + 'fqdn_seq': [''], + 'url_seq': ['wp.pl', u'wpą.pl'], }] + body = self.prepare_mock(test_criteria_local) + # test glob mach + body[u'url_pattern'] = u'*.*' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org4'], {'org4': ['wp.pl', u'wpą.pl']})) + # test regexp match + body[u'url_pattern'] = u'^w.*\\.[pu][ls]' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org4'], {'org4': ['wp.pl', u'wpą.pl']})) + # test glob, not match + body[u'url_pattern'] = u'*/*' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + # test regexp not mach + body[u'url_pattern'] = u'^w.*\\.[au][ls]' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + # test bad regexp + body[u'url_pattern'] = u'^w.*\\.[au][ls' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + # test bad glob + body[u'url_pattern'] = u'??!xx%$2ąść„ŋ…' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + ([], {})) + # test glob with unicode, match + body[u'url_pattern'] = u'??ą.pl' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org4'], {'org4': [u'wpą.pl']})) + # test regexp with unicode, match + body[u'url_pattern'] = u'..*ą\\.pl' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org4'], {'org4': [u'wpą.pl']})) + # test regexp with unicode, match + body[u'url_pattern'] = r'\w+(?<=ą)\.[p]' + json_msg = json.dumps(body) + record_dict = RecordDict.from_json(json_msg) + self.assertEqual( + self.filter.get_client_and_urls_matched(record_dict, self.fqdn_only_categories), + (['org4'], {'org4': [u'wpą.pl']})) + # test regexp with unicode, match + body[u'url_pattern'] = r'\w+(? None: if self.is_config_spec_or_group_declared(): self.config = self.get_config_section(**self.get_config_spec_format_kwargs()) else: @@ -56,8 +63,12 @@ def set_configuration(self): # time -- no `config_spec`/`config_spec_pattern` self.config = ConfigSection('') + # A hook method (can be extended in subclasses...) + def get_config_spec_format_kwargs(self) -> KwargsDict: + return {} + -class CollectorWithStateMixin(object): +class CollectorWithStateMixin: """ Mixin for tracking state of an inheriting collector. @@ -86,11 +97,11 @@ def load_state(self): except (OSError, ValueError, EOFError) as exc: state = self.make_default_state() LOGGER.warning( - "Could not load state (%s), returning: %r", + "Could not load state (%s), returning: %a", make_exc_ascii_str(exc), state) else: - LOGGER.info("Loaded state: %r", state) + LOGGER.info("Loaded state: %a", state) return state def save_state(self, state): @@ -108,7 +119,7 @@ def save_state(self, state): with AtomicallySavedFile(self._cache_file_path, 'wb') as f: pickle.dump(state, f, self.pickle_protocol) - LOGGER.info("Saved state: %r", state) + LOGGER.info("Saved state: %a", state) def get_cache_file_name(self): source_channel = self.get_source_channel() @@ -122,7 +133,7 @@ def make_default_state(self): # # Base classes -class AbstractBaseCollector(object): +class AbstractBaseCollector: """ Abstract base class for a collector script implementations. @@ -172,57 +183,90 @@ class BaseCollector(CollectorConfigMixin, LegacyQueuedBase, AbstractBaseCollecto The standard "root" base class for collectors. """ - output_queue = { + output_queue: Optional[Union[dict, list[dict]]] = { 'exchange': 'raw', 'exchange_type': 'topic', } # None or a string being the tag of the raw data format version # (can be set in a subclass) - raw_format_version_tag = None - - # the name of the config group - # (it does not have to be implemented if one of the `config_spec` - # or the `config_spec_pattern` attribute is set in a subclass, - # containing a declaration of exactly *one* config section) - config_group = None - - # a sequence of required config fields (can be extended in + raw_format_version_tag: Optional[str] = None + + # at most *one* of the following two attributes can + # be set in a subclass to a non-None value (see: + # `n6lib.config.ConfigMixin`...) + config_spec: Optional[str] + config_spec_pattern: Optional[str] + + # the config section name, to be set in concrete subclasses + # (it does not have to be provided if `config_spec` + # or `config_spec_pattern` is provided in a subclass and + # contains a declaration of exactly *one* config section; + # see: `n6lib.config.ConfigMixin`...) + config_group: Optional[str] = None + + # a sequence of required config options (can be extended in # subclasses; typically, 'source' should be included there!) - config_required = ('source',) + # [TODO: let's get rid of this legacy attribute; note that + # `config_spec`/`config_spec_pattern` should be sufficient; + # see: `n6lib.config.ConfigMixin`...] + config_required: Sequence[str] = ('source',) # (NOTE: the `source` setting value in the config is only # the first part -- the `label` part -- of the actual # source specification string '