Skip to content

Commit

Permalink
Assume scalars are not expected
Browse files Browse the repository at this point in the history
  • Loading branch information
Mauko Quiroga committed Oct 12, 2020
1 parent de142e0 commit 1c14773
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 21 deletions.
9 changes: 2 additions & 7 deletions openfisca_core/indexed_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
logical_not as not_,
ndarray,
select,
size,
take,
)

ENUM_ARRAY_DTYPE = int16
Expand Down Expand Up @@ -85,11 +83,8 @@ def encode(
# gives them a different identity from the ones imported in the usual way.
# So, instead of relying on the "cls" passed in, we use only its name to
# check that the values in the array, if non-empty, are of the right type.
array_size: int = size(array)
array_name: str = take(array, 0).__class__.__name__

if array_size > 0 and array_name is cls.__name__:
cls = take(array, 0).__class__
if len(array) > 0 and array[0].__class__.__name__ is cls.__name__:
cls = array[0].__class__

array = select(
[array == item for item in cls],
Expand Down
27 changes: 13 additions & 14 deletions tests/core/test_indexed_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ def test_enum_encode_when_array_is_array_of_enums(my_enum):
def test_enum_encode_when_array_is_scalar_array_of_enum(my_enum):
values = array(my_enum.bar)

result = my_enum.encode(values)

assert result == 1
with pytest.raises(TypeError):
my_enum.encode(values)


def test_enum_encode_when_array_is_array_of_indices(my_enum):
Expand Down Expand Up @@ -85,32 +84,32 @@ def test_enum_encode_when_array_is_not_a_data_structure(my_enum):


def test_enum_array___eq__(my_enum):
enum_array1 = EnumArray(array(1), my_enum)
enum_array2 = EnumArray(array(1), my_enum)
enum_array1 = EnumArray([array(1)], my_enum)
enum_array2 = EnumArray([array(1)], my_enum)

result = enum_array1 == enum_array2

assert result


def test_enum_array___ne__(my_enum):
enum_array1 = EnumArray(array(0), my_enum)
enum_array2 = EnumArray(array(1), my_enum)
enum_array1 = EnumArray([array(0)], my_enum)
enum_array2 = EnumArray([array(1)], my_enum)

result = enum_array1 != enum_array2

assert result


def test_enum_array__forbidden_operation(my_enum):
enum_array = EnumArray(array(1), my_enum)
enum_array = EnumArray([array(1)], my_enum)

with pytest.raises(TypeError):
enum_array * 1


def test_enum_array_decode(my_enum):
values = array(my_enum.bar)
values = array([my_enum.bar])
enum_array = my_enum.encode(values)

result = enum_array.decode()
Expand All @@ -119,7 +118,7 @@ def test_enum_array_decode(my_enum):


def test_enum_array_decode_to_str(my_enum):
values = array(my_enum.bar.value)
values = array([my_enum.bar.value])
enum_array = my_enum.encode(values)

result = enum_array.decode_to_str()
Expand All @@ -128,16 +127,16 @@ def test_enum_array_decode_to_str(my_enum):


def test_enum_array___repr__(my_enum):
enum_array = EnumArray(array(1), my_enum)
enum_array = EnumArray([array(1)], my_enum)

result = repr(enum_array)

assert result == "EnumArray(MyEnum.bar)"
assert result == "EnumArray([<MyEnum.bar: 'bar'>])"


def test_enum_array___str__(my_enum):
enum_array = EnumArray(array(1), my_enum)
enum_array = EnumArray([array(1)], my_enum)

result = str(enum_array)

assert result == "bar"
assert result == "['bar']"

0 comments on commit 1c14773

Please sign in to comment.