diff --git a/openfisca_core/indexed_enums.py b/openfisca_core/indexed_enums.py index c4091aee0d..736dd8a8e8 100644 --- a/openfisca_core/indexed_enums.py +++ b/openfisca_core/indexed_enums.py @@ -9,8 +9,6 @@ logical_not as not_, ndarray, select, - size, - take, ) ENUM_ARRAY_DTYPE = int16 @@ -85,11 +83,8 @@ def encode( # gives them a different identity from the ones imported in the usual way. # So, instead of relying on the "cls" passed in, we use only its name to # check that the values in the array, if non-empty, are of the right type. - array_size: int = size(array) - array_name: str = take(array, 0).__class__.__name__ - - if array_size > 0 and array_name is cls.__name__: - cls = take(array, 0).__class__ + if len(array) > 0 and array[0].__class__.__name__ is cls.__name__: + cls = array[0].__class__ array = select( [array == item for item in cls], diff --git a/tests/core/test_indexed_enums.py b/tests/core/test_indexed_enums.py index 6e5cde12c1..2d2347bb53 100644 --- a/tests/core/test_indexed_enums.py +++ b/tests/core/test_indexed_enums.py @@ -49,9 +49,8 @@ def test_enum_encode_when_array_is_array_of_enums(my_enum): def test_enum_encode_when_array_is_scalar_array_of_enum(my_enum): values = array(my_enum.bar) - result = my_enum.encode(values) - - assert result == 1 + with pytest.raises(TypeError): + my_enum.encode(values) def test_enum_encode_when_array_is_array_of_indices(my_enum): @@ -85,8 +84,8 @@ def test_enum_encode_when_array_is_not_a_data_structure(my_enum): def test_enum_array___eq__(my_enum): - enum_array1 = EnumArray(array(1), my_enum) - enum_array2 = EnumArray(array(1), my_enum) + enum_array1 = EnumArray([array(1)], my_enum) + enum_array2 = EnumArray([array(1)], my_enum) result = enum_array1 == enum_array2 @@ -94,8 +93,8 @@ def test_enum_array___eq__(my_enum): def test_enum_array___ne__(my_enum): - enum_array1 = EnumArray(array(0), my_enum) - enum_array2 = EnumArray(array(1), my_enum) + enum_array1 = EnumArray([array(0)], my_enum) + enum_array2 = EnumArray([array(1)], my_enum) result = enum_array1 != enum_array2 @@ -103,14 +102,14 @@ def test_enum_array___ne__(my_enum): def test_enum_array__forbidden_operation(my_enum): - enum_array = EnumArray(array(1), my_enum) + enum_array = EnumArray([array(1)], my_enum) with pytest.raises(TypeError): enum_array * 1 def test_enum_array_decode(my_enum): - values = array(my_enum.bar) + values = array([my_enum.bar]) enum_array = my_enum.encode(values) result = enum_array.decode() @@ -119,7 +118,7 @@ def test_enum_array_decode(my_enum): def test_enum_array_decode_to_str(my_enum): - values = array(my_enum.bar.value) + values = array([my_enum.bar.value]) enum_array = my_enum.encode(values) result = enum_array.decode_to_str() @@ -128,16 +127,16 @@ def test_enum_array_decode_to_str(my_enum): def test_enum_array___repr__(my_enum): - enum_array = EnumArray(array(1), my_enum) + enum_array = EnumArray([array(1)], my_enum) result = repr(enum_array) - assert result == "EnumArray(MyEnum.bar)" + assert result == "EnumArray([])" def test_enum_array___str__(my_enum): - enum_array = EnumArray(array(1), my_enum) + enum_array = EnumArray([array(1)], my_enum) result = str(enum_array) - assert result == "bar" + assert result == "['bar']"