Skip to content
This repository has been archived by the owner on Sep 17, 2019. It is now read-only.

Mock return text and increase_count_on_error option #295

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions napalm_base/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,47 @@
from napalm_base.base import NetworkDriver
import napalm_base.exceptions

import ast
import inspect
import json
import os
import re


from functools import wraps
from pydoc import locate


def count_calls(name=None, pass_self=True):
def real_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
funcname = name or func.__name__
self = args[0]
args = args if pass_self else args[1:]
try:
self.current_count = self.calls[funcname]
except KeyError:
self.calls[funcname] = 1
self.current_count = 1
r = func(*args, **kwargs)
self.calls[funcname] += 1
except Exception:
if self.increase_count_on_error:
self.calls[funcname] += 1
raise
return r
return wrapper
return real_decorator


def raise_exception(result):
exc = locate(result["exception"])
if exc:
raise exc(*result.get("args", []), **result.get("kwargs", {}))
else:
raise TypeError("Couldn't resolve exception {}", result["exception"])
raise TypeError("Couldn't resolve exception {}".format(result["exception"]))


def is_mocked_method(method):
Expand All @@ -43,7 +69,7 @@ def is_mocked_method(method):
return False


def mocked_method(path, name, count):
def mocked_method(self, name):
parent_method = getattr(NetworkDriver, name)
parent_method_args = inspect.getargspec(parent_method)
modifier = 0 if 'self' not in parent_method_args.args else 1
Expand All @@ -60,7 +86,8 @@ def _mocked_method(*args, **kwargs):
if unexpected:
raise TypeError("{} got an unexpected keyword argument '{}'".format(name,
unexpected[0]))
return mocked_data(path, name, count)
return count_calls(name, pass_self=False)(
mocked_data)(self, self.path, name, self.calls.get(name, 1))

return _mocked_method

Expand All @@ -75,6 +102,10 @@ def mocked_data(path, name, count):

if "exception" in result:
raise_exception(result)
elif "plain_text" in result:
return result["plain_text"]
elif "direct_value" in result:
return ast.literal_eval(result["direct_value"])
else:
return result

Expand Down Expand Up @@ -116,10 +147,7 @@ def __init__(self, hostname, username, password, timeout=60, optional_args=None)
self.filename = None
self.config = None

def _count_calls(self, name):
current_count = self.calls.get(name, 0)
self.calls[name] = current_count + 1
return self.calls[name]
self.increase_count_on_error = optional_args.get("increase_count_on_error", True)

def _raise_if_closed(self):
if not self.opened:
Expand All @@ -134,54 +162,54 @@ def close(self):
def is_alive(self):
return {"is_alive": self.opened}

@count_calls()
def cli(self, commands):
count = self._count_calls("cli")
result = {}
regexp = re.compile('[^a-zA-Z0-9]+')
for i, c in enumerate(commands):
sanitized = re.sub(regexp, '_', c)
name = "cli.{}.{}".format(count, sanitized)
name = "cli.{}.{}".format(self.current_count, sanitized)
filename = "{}.{}".format(os.path.join(self.path, name), i)
with open(filename, 'r') as f:
result[c] = f.read()
return result

@count_calls()
def load_merge_candidate(self, filename=None, config=None):
count = self._count_calls("load_merge_candidate")
self._raise_if_closed()
self.merge = True
self.filename = filename
self.config = config
mocked_data(self.path, "load_merge_candidate", count)
mocked_data(self.path, "load_merge_candidate", self.current_count)

@count_calls()
def load_replace_candidate(self, filename=None, config=None):
count = self._count_calls("load_replace_candidate")
self._raise_if_closed()
self.merge = False
self.filename = filename
self.config = config
mocked_data(self.path, "load_replace_candidate", count)
mocked_data(self.path, "load_replace_candidate", self.current_count)

@count_calls()
def compare_config(self, filename=None, config=None):
count = self._count_calls("compare_config")
self._raise_if_closed()
return mocked_data(self.path, "compare_config", count)["diff"]
return mocked_data(self.path, "compare_config", self.current_count)["diff"]

@count_calls()
def commit_config(self):
count = self._count_calls("commit_config")
self._raise_if_closed()
self.merge = None
self.filename = None
self.config = None
mocked_data(self.path, "commit_config", count)
mocked_data(self.path, "commit_config", self.current_count)

@count_calls()
def discard_config(self):
count = self._count_calls("commit_config")
self._raise_if_closed()
self.merge = None
self.filename = None
self.config = None
mocked_data(self.path, "discard_config", count)
mocked_data(self.path, "discard_config", self.current_count)

def _rpc(self, get):
"""This one is only useful for junos."""
Expand All @@ -193,7 +221,6 @@ def _rpc(self, get):
def __getattribute__(self, name):
if is_mocked_method(name):
self._raise_if_closed()
count = self._count_calls(name)
return mocked_method(self.path, name, count)
return mocked_method(self, name)
else:
return object.__getattribute__(self, name)
33 changes: 32 additions & 1 deletion test/unit/TestMockDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import pytest

import copy
import os


Expand Down Expand Up @@ -101,7 +102,7 @@ def test_mock_error(self):

with pytest.raises(TypeError) as excinfo:
d.get_bgp_neighbors()
assert "Couldn't resolve exception NoIdeaException" in excinfo.value
assert "Couldn't resolve exception NoIdeaException" in str(excinfo.value)

d.close()

Expand Down Expand Up @@ -129,3 +130,33 @@ def test_configuration_replace(self):
d.compare_config() == "a_diff"
d.commit_config()
d.close()

def test_count_on_error(self):
optargs = copy.deepcopy(optional_args)
optargs["increase_count_on_error"] = True
d = driver("blah", "bleh", "blih", optional_args=optargs)
d.open()
try:
d.get_ntp_peers()
except Exception:
pass

with pytest.raises(NotImplementedError) as excinfo:
d.get_ntp_peers()
assert "get_ntp_peers.2" in str(excinfo.value)
d.close()

def test_dont_count_on_error(self):
optargs = copy.deepcopy(optional_args)
optargs["increase_count_on_error"] = False
d = driver("blah", "bleh", "blih", optional_args=optargs)
d.open()
try:
d.get_ntp_peers()
except Exception:
pass

with pytest.raises(NotImplementedError) as excinfo:
d.get_ntp_peers()
assert "get_ntp_peers.1" in str(excinfo.value)
d.close()