Skip to content

Commit

Permalink
Add some misc typing to zarr.util
Browse files Browse the repository at this point in the history
  • Loading branch information
dstansby committed Dec 24, 2023
1 parent b5f79dd commit a829b3d
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions zarr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
Union,
Iterable,
cast,
List,
)
from types import EllipsisType

import numpy as np
from asciitree import BoxStyle, LeftAligned
Expand Down Expand Up @@ -54,7 +56,7 @@ def flatten(arg: Iterable) -> Iterable:


class NumberEncoder(json.JSONEncoder):
def default(self, o):
def default(self, o: numbers.Number) -> float:
# See json.JSONEncoder.default docstring for explanation
# This is necessary to encode numpy dtype
if isinstance(o, numbers.Integral):
Expand Down Expand Up @@ -219,7 +221,7 @@ def normalize_dtype(dtype: Union[str, np.dtype], object_codec) -> Tuple[np.dtype


# noinspection PyTypeChecker
def is_total_slice(item, shape: Tuple[int]) -> bool:
def is_total_slice(item: Union[EllipsisType, slice, tuple], shape: Tuple[int]) -> bool:
"""Determine whether `item` specifies a complete slice of array with the
given `shape`. Used to optimize __setitem__ operations on the Chunk
class."""
Expand Down Expand Up @@ -263,7 +265,7 @@ def normalize_resize_args(old_shape, *args):
return new_shape


def human_readable_size(size) -> str:
def human_readable_size(size: float) -> str:
if size < 2**10:
return "%s" % size
elif size < 2**20:
Expand Down Expand Up @@ -391,7 +393,7 @@ def info_text_report(items: Dict[Any, Any]) -> str:
return report


def info_html_report(items) -> str:
def info_html_report(items: dict) -> str:
report = '<table class="zarr-info">'
report += "<tbody>"
for k, v in items:
Expand Down Expand Up @@ -420,25 +422,25 @@ def _repr_html_(self):


class TreeNode:
def __init__(self, obj, depth=0, level=None):
def __init__(self, obj, depth: int = 0, level: Optional[int] = None):
self.obj = obj
self.depth = depth
self.level = level

def get_children(self):
def get_children(self) -> List["TreeNode"]:
if hasattr(self.obj, "values"):
if self.level is None or self.depth < self.level:
depth = self.depth + 1
return [TreeNode(o, depth=depth, level=self.level) for o in self.obj.values()]
return []

def get_text(self):
def get_text(self) -> str:
name = self.obj.name.split("/")[-1] or "/"
if hasattr(self.obj, "shape"):
name += " {} {}".format(self.obj.shape, self.obj.dtype)
return name

def get_type(self):
def get_type(self) -> str:
return type(self.obj).__name__


Expand Down Expand Up @@ -466,7 +468,7 @@ def tree_get_icon(stype: str) -> str:
raise ValueError("Unknown type: %s" % stype)


def tree_widget_sublist(node, root=False, expand=False):
def tree_widget_sublist(node, root: bool = False, expand: Union[bool, int] = False):
import ipytree

result = ipytree.Node()
Expand All @@ -482,7 +484,7 @@ def tree_widget_sublist(node, root=False, expand=False):
return result


def tree_widget(group, expand, level):
def tree_widget(group, expand: Union[bool, int], level: int):
try:
import ipytree
except ImportError as error:
Expand All @@ -501,7 +503,7 @@ def tree_widget(group, expand, level):


class TreeViewer:
def __init__(self, group, expand=False, level=None):
def __init__(self, group, expand: Union[bool, int] = False, level: Optional[int] = None):
self.group = group
self.expand = expand
self.level = level
Expand All @@ -519,7 +521,7 @@ def __init__(self, group, expand=False, level=None):
VERTICAL_AND_RIGHT="\u251C",
)

def __bytes__(self):
def __bytes__(self) -> bytes:
drawer = LeftAligned(
traverse=TreeTraversal(), draw=BoxStyle(gfx=self.bytes_kwargs, **self.text_kwargs)
)
Expand All @@ -532,22 +534,22 @@ def __bytes__(self):

return result

def __unicode__(self):
def __unicode__(self) -> str:
drawer = LeftAligned(
traverse=TreeTraversal(), draw=BoxStyle(gfx=self.unicode_kwargs, **self.text_kwargs)
)
root = TreeNode(self.group, level=self.level)
return drawer(root)

def __repr__(self):
def __repr__(self) -> str:
return self.__unicode__()

def _repr_mimebundle_(self, **kwargs):
tree = tree_widget(self.group, expand=self.expand, level=self.level)
return tree._repr_mimebundle_(**kwargs)


def check_array_shape(param, array, shape):
def check_array_shape(param, array, shape) -> None:
if not hasattr(array, "shape"):
raise TypeError(
"parameter {!r}: expected an array-like object, got {!r}".format(param, type(array))
Expand All @@ -560,7 +562,7 @@ def check_array_shape(param, array, shape):
)


def is_valid_python_name(name):
def is_valid_python_name(name: str) -> bool:
from keyword import iskeyword

return name.isidentifier() and not iskeyword(name)
Expand All @@ -569,10 +571,10 @@ def is_valid_python_name(name):
class NoLock:
"""A lock that doesn't lock."""

def __enter__(self):
def __enter__(self) -> None:
pass

def __exit__(self, *args):
def __exit__(self, *args: Any) -> None:
pass


Expand Down Expand Up @@ -600,7 +602,7 @@ def __init__(self, store_key, chunk_store):
_key_path = "/".join(_key_path[:-1] + _chunk_path)
self.key_path = _key_path

def prepare_chunk(self):
def prepare_chunk(self) -> None:
assert self.buff is None
header = self.fs.read_block(self.key_path, 0, 16)
nbytes, self.cbytes, blocksize = cbuffer_sizes(header)
Expand All @@ -620,7 +622,7 @@ def prepare_chunk(self):
self.buff[16 : (16 + (self.nblocks * 4))] = start_points_buffer
self.n_per_block = blocksize / typesize

def read_part(self, start, nitems):
def read_part(self, start, nitems) -> None:
assert self.buff is not None
if self.nblocks == 1:
return
Expand Down Expand Up @@ -654,7 +656,7 @@ def __init__(self, store_key, chunk_store, itemsize):
self.store_key = store_key
self.itemsize = itemsize

def prepare_chunk(self):
def prepare_chunk(self) -> None:
pass

def read_part(self, start, nitems):
Expand Down Expand Up @@ -695,7 +697,7 @@ def retry_call(
raise


def all_equal(value: Any, array: Any):
def all_equal(value: Any, array: Any) -> bool:
"""
Test if all the elements of an array are equivalent to a value.
If `value` is None, then this function does not do any comparison and
Expand All @@ -720,11 +722,11 @@ def all_equal(value: Any, array: Any):
# Numpy errors if you call np.isnan on custom dtypes, so ensure
# we are working with floats before calling isnan
if np.issubdtype(array.dtype, np.floating) and np.isnan(value):
return np.all(np.isnan(array))
return bool(np.all(np.isnan(array)))
else:
# using == raises warnings from numpy deprecated pattern, but
# using np.equal() raises type errors for structured dtypes...
return np.all(value == array)
return bool(np.all(value == array))


def ensure_contiguous_ndarray_or_bytes(buf) -> Union[NDArrayLike, bytes]:
Expand Down

0 comments on commit a829b3d

Please sign in to comment.