diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6483dc04..572fccec 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: pretty-format-toml args: [--autofix, --trailing-commas] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.14 + rev: v0.2.0 hooks: - id: ruff args: [--fix] @@ -22,7 +22,7 @@ repos: - id: ruff-format types_or: [python, pyi, jupyter] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.348 + rev: v1.1.349 hooks: - id: pyright - repo: https://github.com/doublify/pre-commit-rust @@ -31,7 +31,3 @@ repos: - id: fmt - id: cargo-check - id: clippy -- repo: https://github.com/keewis/blackdoc - rev: v0.3.9 - hooks: - - id: blackdoc diff --git a/pyproject.toml b/pyproject.toml index b26f099e..72305251 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,16 +134,21 @@ filterwarnings = [ ] [tool.ruff] +extend-include = ["*.ipynb"] +target-version = "py39" + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] extend-ignore = [ "D203", # no-blank-line-before-class "D212", # multi-line-summary-second-line "D407", # Missing dashed underline after section "F722", # Syntax error in forward annotation ] -extend-include = ["*.ipynb"] extend-select = ["B", "C90", "D", "I", "N", "RUF", "UP", "T"] isort = {known-first-party = ["differt", "tests"]} -target-version = "py39" [tool.ruff.lint.pep8-naming] extend-ignore-names = ["T"] @@ -153,5 +158,5 @@ extend-ignore-names = ["T"] "**/{tests,docs}/*" = ["D"] "python/differt/conftest.py" = ["D"] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" diff --git a/python/differt/geometry/utils.py b/python/differt/geometry/utils.py index 9a25aa2a..8c5b20b8 100644 --- a/python/differt/geometry/utils.py +++ b/python/differt/geometry/utils.py @@ -36,7 +36,9 @@ def normalize( :Examples: - >>> from differt.geometry.utils import normalize + >>> from differt.geometry.utils import ( + ... normalize, + ... ) >>> >>> vector = jnp.array([1.0, 1.0, 1.0]) >>> normalize(vector) # [1., 1., 1.] / sqrt(3), sqrt(3) diff --git a/python/differt/plotting/_utils.py b/python/differt/plotting/_utils.py index 5c378c9d..41021a69 100644 --- a/python/differt/plotting/_utils.py +++ b/python/differt/plotting/_utils.py @@ -51,17 +51,14 @@ def use(backend: str) -> None: >>> @dplt.dispatch ... def my_plot(): ... pass - ... >>> >>> @my_plot.register("vispy") ... def _(): ... print("Using vispy backend") - ... >>> >>> @my_plot.register("matplotlib") ... def _(): ... print("Using matplotlib backend") - ... >>> >>> my_plot() # When not specified, use default backend Using vispy backend @@ -74,9 +71,7 @@ def use(backend: str) -> None: >>> my_plot() # So that now it defaults to 'matplotlib' Using matplotlib backend >>> - >>> my_plot( - ... backend="vispy" - ... ) # Of course, the 'vispy' backend is still available + >>> my_plot(backend="vispy") # Of course, the 'vispy' backend is still available Using vispy backend """ if backend not in SUPPORTED_BACKENDS: @@ -126,25 +121,34 @@ def dispatch(fun: Callable[P, T]) -> Dispatcher[P, T]: >>> @dplt.dispatch ... def plot_line(vertices, color): ... pass - ... >>> >>> @plot_line.register("matplotlib") ... def _(vertices, color): ... print("Using matplotlib backend") - ... >>> >>> @plot_line.register("plotly") ... def _(vertices, color): ... print("Using plotly backend") - ... >>> - >>> plot_line(_, _, backend="matplotlib") + >>> plot_line( + ... _, + ... _, + ... backend="matplotlib", + ... ) Using matplotlib backend >>> - >>> plot_line(_, _, backend="plotly") + >>> plot_line( + ... _, + ... _, + ... backend="plotly", + ... ) Using plotly backend >>> - >>> plot_line(_, _, backend="vispy") # doctest: +IGNORE_EXCEPTION_DETAIL + >>> plot_line( + ... _, + ... _, + ... backend="vispy", + ... ) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): NotImplementedError: No backend implementation for 'vispy' >>> @@ -156,7 +160,6 @@ def dispatch(fun: Callable[P, T]) -> Dispatcher[P, T]: >>> @plot_line.register("numpy") # doctest: +IGNORE_EXCEPTION_DETAIL ... def _(vertices, color): ... pass - ... Traceback (most recent call last): ValueError: Unsupported backend 'numpy', allowed values are: ... """ diff --git a/python/differt/rt/image_method.py b/python/differt/rt/image_method.py index 34b6f6c9..2fd606e3 100644 --- a/python/differt/rt/image_method.py +++ b/python/differt/rt/image_method.py @@ -50,13 +50,28 @@ def image_of_vertices_with_respect_to_mirrors( ... ) >>> >>> key = jax.random.PRNGKey(0) - >>> key0, key1, key2 = jax.random.split(key, 3) + >>> ( + ... key0, + ... key1, + ... key2, + ... ) = jax.random.split(key, 3) >>> batch = (10, 20, 30) - >>> vertices = jax.random.uniform(key0, (*batch, 3)) - >>> mirror_vertices = jax.random.uniform(key1, (*batch, 3)) - >>> mirror_normals = jax.random.uniform(key2, (*batch, 3)) + >>> vertices = jax.random.uniform( + ... key0, + ... (*batch, 3), + ... ) + >>> mirror_vertices = jax.random.uniform( + ... key1, + ... (*batch, 3), + ... ) + >>> mirror_normals = jax.random.uniform( + ... key2, + ... (*batch, 3), + ... ) >>> images = image_of_vertices_with_respect_to_mirrors( - ... vertices, mirror_vertices, mirror_normals + ... vertices, + ... mirror_vertices, + ... mirror_normals, ... ) >>> images.shape (10, 20, 30, 3) @@ -137,11 +152,24 @@ def image_method( .. code-block:: python paths = image_method( - from_vertices, to_vertices, mirror_vertices, mirror_normals + from_vertices, + to_vertices, + mirror_vertices, + mirror_normals, ) full_paths = jnp.concatenate( - (from_vertices[None, ...], paths, to_vertices[None, ...]) + ( + from_vertices[ + None, + ..., + ], + paths, + to_vertices[ + None, + ..., + ], + ) ) """ diff --git a/python/differt/utils.py b/python/differt/utils.py index 9c4d3bec..84dac3d3 100644 --- a/python/differt/utils.py +++ b/python/differt/utils.py @@ -18,11 +18,16 @@ def sorted_array2(array: Shaped[Array, "m n"]) -> Shaped[Array, "m n"]: Examples: The following example shows how the sorting works. - >>> from differt.utils import sorted_array2 + >>> from differt.utils import ( + ... sorted_array2, + ... ) >>> >>> arr = jnp.arange(10).reshape(5, 2) >>> key = jax.random.PRNGKey(1234) - >>> key1, key2 = jax.random.split(key, 2) + >>> ( + ... key1, + ... key2, + ... ) = jax.random.split(key, 2) >>> arr = jax.random.permutation(key1, arr) >>> arr Array([[4, 5], @@ -38,7 +43,12 @@ def sorted_array2(array: Shaped[Array, "m n"]) -> Shaped[Array, "m n"]: [6, 7], [8, 9]], dtype=int32) >>> - >>> arr = jax.random.randint(key2, (5, 5), 0, 2) + >>> arr = jax.random.randint( + ... key2, + ... (5, 5), + ... 0, + ... 2, + ... ) >>> arr Array([[1, 1, 1, 0, 1], [1, 0, 1, 1, 1],