diff --git a/docs/api/collective.rst b/docs/api/collective.rst new file mode 100644 index 0000000..67fe003 --- /dev/null +++ b/docs/api/collective.rst @@ -0,0 +1,14 @@ +Collective Geometry +=================== + +Rod-like geometry +----------------- + +.. automodule:: bsr.rod + :members: + +Stacked geometry +---------------- + +.. automodule:: bsr.stack + :members: diff --git a/docs/conf.py b/docs/conf.py index da5aaae..439b0b2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,6 +25,7 @@ "sphinx.ext.viewcode", "sphinx.ext.mathjax", "sphinx.ext.autosectionlabel", + "sphinx.ext.viewcode", "sphinx_autodoc_typehints", "sphinx_click", "numpydoc", @@ -38,6 +39,7 @@ "undoc-members", "private-members", "special-members", + "inherited-members", ] source_suffix = [".rst", ".md"] diff --git a/docs/index.rst b/docs/index.rst index ed215a8..4fb7860 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -22,6 +22,7 @@ TODO: Description of the project. api/file api/geometry + api/collective .. toctree:: :maxdepth: 2 diff --git a/poetry.lock b/poetry.lock index eee72fc..2de4302 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "alabaster" @@ -90,13 +90,13 @@ zstandard = "*" [[package]] name = "certifi" -version = "2024.2.2" +version = "2024.6.2" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"}, - {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, + {file = "certifi-2024.6.2-py3-none-any.whl", hash = "sha256:ddc6c8ce995e6987e7faf5e3f1b02b302836a0e5d98ece18392cb1a36c72ad56"}, + {file = "certifi-2024.6.2.tar.gz", hash = "sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516"}, ] [[package]] @@ -1042,13 +1042,13 @@ files = [ [[package]] name = "nodeenv" -version = "1.9.0" +version = "1.9.1" description = "Node.js virtual environment builder" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ - {file = "nodeenv-1.9.0-py2.py3-none-any.whl", hash = "sha256:508ecec98f9f3330b636d4448c0f1a56fc68017c68f1e7857ebc52acf0eb879a"}, - {file = "nodeenv-1.9.0.tar.gz", hash = "sha256:07f144e90dae547bf0d4ee8da0ee42664a42a04e02ed68e06324348dafe4bdb1"}, + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, ] [[package]] @@ -1147,13 +1147,13 @@ test = ["matplotlib", "pytest", "pytest-cov"] [[package]] name = "packaging" -version = "24.0" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, - {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -1286,13 +1286,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pooch" -version = "1.8.1" -description = "\"Pooch manages your Python library's sample data files: it automatically downloads and stores them in a local directory, with support for versioning and corruption checks.\"" +version = "1.8.2" +description = "A friend to fetch your data files" optional = false python-versions = ">=3.7" files = [ - {file = "pooch-1.8.1-py3-none-any.whl", hash = "sha256:6b56611ac320c239faece1ac51a60b25796792599ce5c0b1bb87bf01df55e0a9"}, - {file = "pooch-1.8.1.tar.gz", hash = "sha256:27ef63097dd9a6e4f9d2694f5cfbf2f0a5defa44fccafec08d601e731d746270"}, + {file = "pooch-1.8.2-py3-none-any.whl", hash = "sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47"}, + {file = "pooch-1.8.2.tar.gz", hash = "sha256:76561f0de68a01da4df6af38e9955c4c9d1a5c90da73f7e40276a5728ec83d10"}, ] [package.dependencies] @@ -1387,13 +1387,13 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pytest" -version = "8.2.1" +version = "8.2.2" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.2.1-py3-none-any.whl", hash = "sha256:faccc5d332b8c3719f40283d0d44aa5cf101cec36f88cde9ed8f2bc0538612b1"}, - {file = "pytest-8.2.1.tar.gz", hash = "sha256:5046e5b46d8e4cac199c373041f26be56fdb81eb4e67dc11d4e10811fc3408fd"}, + {file = "pytest-8.2.2-py3-none-any.whl", hash = "sha256:c434598117762e2bd304e526244f67bf66bbd7b5d6cf22138be51ff661980343"}, + {file = "pytest-8.2.2.tar.gz", hash = "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977"}, ] [package.dependencies] @@ -1456,13 +1456,13 @@ six = ">=1.5" [[package]] name = "pyupgrade" -version = "3.15.2" +version = "3.16.0" description = "A tool to automatically upgrade syntax for newer versions." optional = false python-versions = ">=3.8.1" files = [ - {file = "pyupgrade-3.15.2-py2.py3-none-any.whl", hash = "sha256:ce309e0ff8ecb73f56a45f12570be84bbbde9540d13697cacb261a7f595fb1f5"}, - {file = "pyupgrade-3.15.2.tar.gz", hash = "sha256:c488d6896c546d25845712ef6402657123008d56c1063174e27aabe15bd6b4e5"}, + {file = "pyupgrade-3.16.0-py2.py3-none-any.whl", hash = "sha256:7a54ee28f3024d027048d49d101e5c702e88c85edc3a1d08b636c50ebef2a97d"}, + {file = "pyupgrade-3.16.0.tar.gz", hash = "sha256:237893a05d5b117259b31b423f23cbae4bce0b7eae57ba9a52c06098c2ddd76f"}, ] [package.dependencies] @@ -1506,6 +1506,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1513,8 +1514,16 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -1531,6 +1540,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -1538,6 +1548,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -1678,13 +1689,13 @@ test = ["cython (>=3.0)", "defusedxml (>=0.7.1)", "pytest (>=6.0)", "setuptools [[package]] name = "sphinx-autodoc-typehints" -version = "2.1.0" +version = "2.1.1" description = "Type hints (PEP 484) support for the Sphinx autodoc extension" optional = false python-versions = ">=3.9" files = [ - {file = "sphinx_autodoc_typehints-2.1.0-py3-none-any.whl", hash = "sha256:46f1a710b3ed35904f63a77c5e68334c5ee1c2e22828b75fdcd147f1c52c199b"}, - {file = "sphinx_autodoc_typehints-2.1.0.tar.gz", hash = "sha256:51bf8dc77c4fba747e32f0735002a91500747d0553cae616863848e8f5e49fe8"}, + {file = "sphinx_autodoc_typehints-2.1.1-py3-none-any.whl", hash = "sha256:22427d74786274add2b6d4afccb8b3c8c1843f48a704550f15a35fd948f8a4de"}, + {file = "sphinx_autodoc_typehints-2.1.1.tar.gz", hash = "sha256:0072b65f5ab2818c229d6d6c2cc993770af55d36bb7bfb16001e2fce4d14880c"}, ] [package.dependencies] @@ -1885,13 +1896,13 @@ telegram = ["requests"] [[package]] name = "typing-extensions" -version = "4.12.0" +version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.12.0-py3-none-any.whl", hash = "sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594"}, - {file = "typing_extensions-4.12.0.tar.gz", hash = "sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index 3e42027..81807f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,4 +173,6 @@ branch = true omit = [ "*/.local/*", "setup.py", + "*/protocol.py", + "macros.py", ] diff --git a/src/bsr/__init__.py b/src/bsr/__init__.py index b8405cf..11eb514 100644 --- a/src/bsr/__init__.py +++ b/src/bsr/__init__.py @@ -1,10 +1,10 @@ import sys from importlib import metadata as importlib_metadata -from .collections import * from .file import * from .macros import * from .rod import * +from .stack import * def get_version() -> str: diff --git a/src/bsr/collections.py b/src/bsr/collections.py deleted file mode 100644 index 5b14255..0000000 --- a/src/bsr/collections.py +++ /dev/null @@ -1,66 +0,0 @@ -__all__ = ["GeometryCollectionP", "RodCollection", "create_rod_collection"] - -from typing import TYPE_CHECKING, Protocol -from typing_extensions import Self - -import numpy as np - -from .rod import Rod - - -# TODO -class GeometryCollectionP(Protocol): - @property - def tag(self) -> str: ... - - def __len__(self) -> int: ... - - def __getitem__(self, index: int) -> "Geometry": ... - - def __iter__(self): ... - - @classmethod - def create_collection(cls, number: int, tag: str = None) -> Self: - """ - Create a collection of geometries - """ - ... - - -# TODO -class RodCollection: - def __init__(self): - self._tag: str = "" - self._rods: list["Rod"] = [] - - @property - def tag(self) -> str: - return self._tag - - def __len__(self) -> int: - return len(self._rods) - - def __getitem__(self, index: int) -> "Rod": - return self._rods[index] - - def __iter__(self): - return iter(self._rods) - - @classmethod - def create_collection( - cls, num_rods: int, num_nodes: int, tag: str = None - ) -> Self: - pass - - def update_history( - self, keyframes: np.ndarray, position: np.ndarray, radius: np.ndarray - ): - pass - - -# Alias for factory functions -create_rod_collection = RodCollection.create_collection - - -if TYPE_CHECKING: - _: GeometryCollectionP = RodCollection(...) diff --git a/src/bsr/geometry.py b/src/bsr/geometry.py index f163a48..54605f4 100644 --- a/src/bsr/geometry.py +++ b/src/bsr/geometry.py @@ -1,141 +1,422 @@ -import colorsys +__doc__ = """ +This module provides a set of geometry-mesh interfaces for blender objects. +""" +__all__ = ["Sphere", "Cylinder"] + +from typing import TYPE_CHECKING, cast + +import warnings +from numbers import Number import bpy import numpy as np +from numpy.typing import NDArray + +from .mixin import KeyFrameControlMixin +from .protocol import BlenderMeshInterfaceProtocol, MeshDataType + + +def calculate_cylinder_orientation( + position_1: NDArray, position_2: NDArray +) -> tuple[float, NDArray, NDArray]: + """ + Calculates the centerpoint, depth, and rotational angle of the cylinder object. + + Parameters + ---------- + position_1 : NDArray + One endpoint position of the cylinder object. (3D) + position_2: NDArray + Other endpoint position of the cylinder object. (3D) + + Returns + ------- + tuple: float, NDArray, NDArray + Tuple containing the values for the depth, centerpoint and rotation angle (3D) + + """ + + depth = np.linalg.norm(position_2 - position_1) + dz = position_2[2] - position_1[2] + dy = position_2[1] - position_1[1] + dx = position_2[0] - position_1[0] + center = (position_1 + position_2) / 2.0 + phi = np.arctan2(dy, dx) + theta = np.arccos(dz / depth) + angles = np.array([phi, theta]) + return float(depth), center, angles + + +def _validate_position(position: NDArray) -> None: + """ + Checks if inputted position values are valid + + Paramters + --------- + position: NDArray + Position input (endpoint or centerpoint depending on Object type) + + Raises + ------ + ValueError + If the position is the wrong shape or contains NaN values + """ + + if position.shape != (3,): + raise ValueError("The shape of the position is incorrect.") + if np.isnan(position).any(): + raise ValueError("The position contains NaN values.") + + +def _validate_radius(radius: float) -> None: + """ + Checks if inputted radius value is valid + + Parameters: + ----------- + radius: Float + Radius input + + Raises + ------ + ValueError + If the radius is not positive, or contains NaN values + """ + + if not isinstance(radius, Number) or radius <= 0: + raise ValueError("The radius must be a positive float.") + if np.isnan(radius): + raise ValueError("The radius contains NaN values.") + + +class Sphere(KeyFrameControlMixin): + """ + This class provides a mesh interface for Blender Sphere objects. + Sphere objects are created with the given position and radius. + + Parameters + ---------- + position : NDArray + The position of the sphere object. (3D) + radius : float + The radius of the sphere object. + + """ + + input_states = {"position", "radius"} + + def __init__(self, position: NDArray, radius: float) -> None: + """ + Sphere class constructor + """ + + self._obj = self._create_sphere() + self.update_states(position, radius) + + @classmethod + def create(cls, states: MeshDataType) -> "Sphere": + """ + Basic factory method to create a new Sphere object. + """ + + remaining_keys = set(states.keys()) - cls.input_states + if len(remaining_keys) > 0: + warnings.warn( + f"{list(remaining_keys)} are not used as a part of the state definition." + ) + return cls(states["position"], states["radius"]) + + @property + def object(self) -> bpy.types.Object: + """ + Access the Blender object. + """ + + return self._obj + def update_states( + self, position: NDArray | None = None, radius: float | None = None + ) -> None: + """ + Updates the position and radius of the sphere object. -class Sphere: - def __init__(self, location, radius=0.005): - self.obj = self.create_sphere(location, radius) + Parameters + ---------- + position : NDArray + The new position of the sphere object. + radius : float + The new radius of the sphere object. - def create_sphere(self, location, radius): - bpy.ops.mesh.primitive_uv_sphere_add(radius=radius, location=location) + Raises + ------ + ValueError + If the shape of the position or radius is incorrect, or if the data is NaN. + """ + + if position is not None: + _validate_position(position) + self.object.location.x = position[0] + self.object.location.y = position[1] + self.object.location.z = position[2] + if radius is not None: + _validate_radius(radius) + self.object.scale = (radius, radius, radius) + + def _create_sphere(self) -> bpy.types.Object: + """ + Creates a new sphere object with the given position and radius. + """ + bpy.ops.mesh.primitive_uv_sphere_add() return bpy.context.active_object - def update_position(self, location): - self.obj.location.z = location[2] - self.obj.location.y = location[1] - self.obj.location.x = location[0] + def set_keyframe(self, keyframe: int) -> None: + """ + Sets a keyframe at the given frame. + Parameters + ---------- + keyframe : int + """ + self.object.keyframe_insert(data_path="location", frame=keyframe) -class Cylinder: - def __init__(self, pos1, pos2): - self.obj = self.create_cylinder(pos1, pos2) - self.mat = bpy.data.materials.new(name="cyl_mat") - self.obj.active_material = self.mat - def create_cylinder(self, pos1, pos2): - depth, center, angles = self.calc_cyl_orientation(pos1, pos2) - bpy.ops.mesh.primitive_cylinder_add( - radius=0.005, depth=1, location=center +class Cylinder(KeyFrameControlMixin): + """ + This class provides a mesh interface for Blender Cylinder objects. + Cylinder objects are created with the given endpoint positions and radius. + + Parameters + ---------- + position_1 : NDArray + The first endpoint position of the cylinder object. (3D) + position_2 : NDArray + The second enspoint position of the cylinder object. (3D) + radius : float + The radius of the cylinder object. + """ + + input_keys = {"position_1", "position_2", "radius"} + + def __init__( + self, + position_1: NDArray, + position_2: NDArray, + radius: float, + ) -> None: + """ + Cylinder class constructor + """ + + self._obj = self._create_cylinder() + # FIXME: This is a temporary solution + # Ideally, these modules should not contain any data + self._states_position_1 = position_1 + self._states_position_2 = position_2 + self._states_radius = radius + self.update_states(position_1, position_2, radius) + + @classmethod + def create(cls, states: MeshDataType) -> "Cylinder": + """ + Basic factory method to create a new Cylinder object. + """ + + remaining_keys = set(states.keys()) - cls.input_keys + if len(remaining_keys) > 0: + warnings.warn( + f"{list(remaining_keys)} are not used as a part of the state definition." + ) + return cls(states["position_1"], states["position_2"], states["radius"]) + + @property + def object(self) -> bpy.types.Object: + """ + Access the Blender object. + """ + + return self._obj + + def update_states( + self, + position_1: NDArray | None = None, + position_2: NDArray | None = None, + radius: float | None = None, + ) -> None: + """ + Updates the position and radius of the cylinder object. + + Parameters + ---------- + position_1 : NDArray + The first new endpoint position of the cylinder object. + position_2 : NDArray + The second new endpoint position of the cylinder object. + radius : float + The new radius of the cylinder object. + + Raises + ------ + ValueError + If the shape of the positions or radius is incorrect, or if the data is NaN. + """ + if position_1 is None and position_2 is None and radius is None: + return + if position_1 is not None: + position_1 = cast(NDArray[np.floating], position_1) + _validate_position(position_1) + self._states_position_1 = position_1 + else: + position_1 = self._states_position_1 + if position_2 is not None: + position_2 = cast(NDArray[np.floating], position_2) + _validate_position(position_2) + self._states_position_2 = position_2 + else: + position_2 = self._states_position_2 + if radius is not None: + _validate_radius(radius) + self._states_radius = radius + else: + radius = self._states_radius + + # Validation check + if np.allclose(position_1, position_2): + raise ValueError( + f"Two positions must be different: {(position_1 - position_2)=}" + ) + + depth, center, angles = calculate_cylinder_orientation( + position_1, position_2 ) + self.object.location = center + self.object.rotation_euler = (0, angles[1], angles[0]) + self.object.scale[2] = depth + self.object.scale[0] = radius + self.object.scale[1] = radius + + def _create_cylinder( + self, + ) -> bpy.types.Object: + """ + Creates a new cylinder object. + """ + bpy.ops.mesh.primitive_cylinder_add( + radius=1.0, + depth=1.0, + ) # Fix keep these values as default. cylinder = bpy.context.active_object - cylinder.rotation_euler = (0, angles[1], angles[0]) - cylinder.scale[2] = depth return cylinder - def calc_cyl_orientation(self, pos1, pos2): - pos1 = np.array(pos1) - pos2 = np.array(pos2) - depth = np.linalg.norm(pos2 - pos1) - dz = pos2[2] - pos1[2] - dy = pos2[1] - pos1[1] - dx = pos2[0] - pos1[0] - center = (pos1 + pos2) / 2 - phi = np.arctan2(dy, dx) - theta = np.arccos(dz / depth) - angles = np.array([phi, theta]) - return depth, center, angles - - def update_position(self, pos1, pos2): - depth, center, angles = self.calc_cyl_orientation(pos1, pos2) - self.obj.location = (center[0], center[1], center[2]) - self.obj.rotation_euler = (0, angles[1], angles[0]) - self.obj.scale[2] = depth - - # computing deformation heat-map - max_def = 0.07 - - h = ( - -np.sqrt(self.obj.location[0] ** 2 + self.obj.location[2] ** 2) - / max_def - + 240 / 360 - ) - v = ( - np.sqrt(self.obj.location[0] ** 2 + self.obj.location[2] ** 2) - / max_def - * 0.5 - + 0.5 - ) + def set_keyframe(self, keyframe: int) -> None: + """ + Sets a keyframe at the given frame. - r, g, b = colorsys.hsv_to_rgb(h, 1, v) - self.update_color(r, g, b, 1) + Parameters + ---------- + keyframe : int + """ + self.object.keyframe_insert(data_path="location", frame=keyframe) + self.object.keyframe_insert(data_path="rotation_euler", frame=keyframe) + self.object.keyframe_insert(data_path="scale", frame=keyframe) - def update_color(self, r, g, b, a): - self.mat.diffuse_color = (r, g, b, a) +# TODO: Will be implemented in the future +class Frustum(KeyFrameControlMixin): # pragma: no cover + """ + This class provides a mesh interface for Blender Frustum objects. + Frustum objects are created with the given positions and radii. -# TODO: Refactor -class Frustum: - def __init__(self, pos1, pos2, radius1, radius2): - self.obj = self.create_frustum(pos1, pos2, radius1, radius2) - self.mat = bpy.data.materials.new(name="cyl_mat") - self.obj.active_material = self.mat + Parameters + ---------- + position_1 : NDArray + The position of the first end of the frustum object. (3D) + position_2 : NDArray + The position of the second end of the frustum object. (3D) + radius_1 : float + The radius of the first end of the frustum object. + radius_2 : float + The radius of the second end of the frustum object. + """ - def create_frustum(self, pos1, pos2, radius1, radius2): - depth, center, angles = self.calc_frust_orientation(pos1, pos2) - bpy.ops.mesh.primitive_cone_add( - radius1=radius1, radius2=radius2, depth=1, location=center - ) - frustum = bpy.context.active_object - frustum.rotation_euler = (0, angles[1], angles[0]) - frustum.scale[2] = depth - return frustum - - def calc_frust_orientation(self, pos1, pos2): - pos1 = np.array(pos1) - pos2 = np.array(pos2) - depth = np.linalg.norm(pos2 - pos1) - dz = pos2[2] - pos1[2] - dy = pos2[1] - pos1[1] - dx = pos2[0] - pos1[0] - center = (pos1 + pos2) / 2 - phi = np.arctan2(dy, dx) - theta = np.arccos(dz / depth) - angles = np.array([phi, theta]) - return depth, center, angles - - def update_position(self, pos1, pos2): - depth, center, angles = self.calc_frust_orientation(pos1, pos2) - self.obj.location = (center[0], center[1], center[2]) - self.obj.rotation_euler = (0, angles[1], angles[0]) - self.obj.scale[2] = depth - - # computing deformation heat-map - max_def = 0.07 - - h = ( - -np.sqrt(self.obj.location[0] ** 2 + self.obj.location[2] ** 2) - / max_def - + 240 / 360 - ) - v = ( - np.sqrt(self.obj.location[0] ** 2 + self.obj.location[2] ** 2) - / max_def - * 0.5 - + 0.5 - ) + input_keys = {"position_1", "position_2", "radius_1", "radius_2"} + + def __init__( + self, + position_1: NDArray, + position_2: NDArray, + radius_1: float, + radius_2: float, + ) -> None: + raise NotImplementedError + # self._obj = self._create_frustum( + # position_1, position_2, radius_1, radius_2 + # ) + # self.update_states(position_1, position_2, radius_1, radius_2) + + # self.mat = bpy.data.materials.new(name="cyl_mat") + # self.obj.active_material = self.mat + + @classmethod + def create(cls, states: MeshDataType) -> "Frustum": + raise NotImplementedError + # return cls( + # states["position_1"], + # states["position_2"], + # states["radius_1"], + # states["radius_2"], + # ) + + @property + def object(self) -> bpy.types.Object: + raise NotImplementedError + + def _create_frustum( + self, + position_1: NDArray, + position_2: NDArray, + radius_1: float, + radius_2: float, + ) -> bpy.types.Object: + raise NotImplementedError + # depth, center, angles = calculate_cylinder_orientation( + # position_1, position_2 + # ) + # bpy.ops.mesh.primitive_cone_add( + # radius1=radius_1, radius2=radius_2, depth=1, + # ) + # frustum = bpy.context.active_object + # frustum.rotation_euler = (0, angles[1], angles[0]) + # frustum.location = center + # frustum.scale[2] = depth + # return frustum + + def update_states( + self, + position_1: NDArray, + position_2: NDArray, + radius_1: float, + radius_2: float, + ) -> None: + raise NotImplementedError - r, g, b = colorsys.hsv_to_rgb(h, 1, v) - self.update_color(r, g, b, 1) + def set_keyframe(self, keyframe: int) -> None: + raise NotImplementedError - def update_color(self, r, g, b, a): - self.mat.diffuse_color = (r, g, b, a) - def update(self, pos1, pos2, time_step): - self.update_position(pos1, pos2) - # adding to keyframe - self.obj.keyframe_insert(data_path="location", frame=time_step) - self.obj.keyframe_insert(data_path="rotation_euler", frame=time_step) - self.obj.keyframe_insert(data_path="scale", frame=time_step) - self.mat.keyframe_insert(data_path="diffuse_color", frame=time_step) +if TYPE_CHECKING: + # This is required for explicit type-checking + data = {"position": np.array([0, 0, 0]), "radius": 1.0} + _: BlenderMeshInterfaceProtocol = Sphere.create(data) + data = { + "position_1": np.array([0, 0, 0]), + "position_2": np.array([1, 1, 1]), + "radius": 1.0, + } + _: BlenderMeshInterfaceProtocol = Cylinder.create(data) # type: ignore[no-redef] + data = { + "position_1": np.array([0, 0, 0]), + "position_2": np.array([1, 1, 1]), + "radius_1": 1.0, + "radius_2": 1.5, + } + _: BlenderMeshInterfaceProtocol = Frustum.create(data) # type: ignore[no-redef] diff --git a/src/bsr/macros.py b/src/bsr/macros.py index 15bf013..e4137f2 100644 --- a/src/bsr/macros.py +++ b/src/bsr/macros.py @@ -10,6 +10,16 @@ def clear_mesh_objects() -> None: bpy.ops.object.delete() +def scene_update() -> None: + """ + Update the scene + + Used to update object's matrix_world after transformations + (https://blender.stackexchange.com/questions/27667/incorrect-matrix-world-after-transformation) + """ + bpy.context.view_layer.update() + + def clear_materials() -> None: # Clear existing materials in the scene for material in bpy.data.materials: diff --git a/src/bsr/mixin.py b/src/bsr/mixin.py new file mode 100644 index 0000000..0652ac8 --- /dev/null +++ b/src/bsr/mixin.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + +import bpy + +from .protocol import BlenderMeshInterfaceProtocol + + +# Base class +class KeyFrameControlMixin(ABC): + """ + This mixin class provides methods for manipulating keyframes. + By adding this mixin, the class will conform to the BlenderKeyframeManipulateProtocol. + Otherwise, each meethods must be implemented in the class. + """ + + def clear_animation(self: BlenderMeshInterfaceProtocol) -> None: + """ + Clear all keyframes of the object. + """ + self.object.animation_data_clear() + self.object.animation_data_create() + + @abstractmethod + def set_keyframe(self: BlenderMeshInterfaceProtocol, keyframe: int) -> None: + raise NotImplementedError diff --git a/src/bsr/protocol.py b/src/bsr/protocol.py new file mode 100644 index 0000000..6e7d622 --- /dev/null +++ b/src/bsr/protocol.py @@ -0,0 +1,79 @@ +__all__ = [ + "BlenderMeshInterfaceProtocol", + "CompositeProtocol", + "StackProtocol", +] + +from typing import ( + TYPE_CHECKING, + Any, + ParamSpec, + Protocol, + Type, + TypeAlias, + TypeVar, +) +from typing_extensions import Self + +from abc import ABC, abstractmethod + +import bpy +from numpy.typing import NDArray + + +class BlenderKeyframeManipulateProtocol(Protocol): + def clear_animation(self) -> None: ... + + def set_keyframe(self, keyframe: int) -> None: ... + + +MeshDataType: TypeAlias = dict[str, Any] +S = TypeVar("S", bound="BlenderMeshInterfaceProtocol") + + +class BlenderMeshInterfaceProtocol(BlenderKeyframeManipulateProtocol, Protocol): + """ + This protocol defines the interface for Blender mesh objects. + """ + + input_states: set[str] + + # TODO: For future implementation + # @property + # def data(self): ... + + # @property + # def material(self): ... + + @property + def object(self) -> Any: + """Returns associated Blender object.""" + + @classmethod + def create(cls: Type[S], states: MeshDataType) -> S: + """Creates a new mesh object with the given states.""" + + def update_states(self, *args: Any) -> None: + """Updates the mesh object with the given states.""" + + # def update_material(self, material) -> None: ... # TODO: For future implementation + + +class CompositeProtocol(BlenderMeshInterfaceProtocol, Protocol): + @property + def object( + self, + ) -> dict[str, list[BlenderMeshInterfaceProtocol | bpy.types.Object]]: + """Returns associated Blender object.""" + + +class StackProtocol(BlenderMeshInterfaceProtocol, Protocol): + DefaultType: Type + + def __len__(self) -> int: ... + + def __getitem__(self, index: int) -> BlenderMeshInterfaceProtocol: ... + + @property + def object(self) -> list[BlenderMeshInterfaceProtocol]: + """Returns associated Blender object.""" diff --git a/src/bsr/rod.py b/src/bsr/rod.py index a94aab9..add7fd8 100644 --- a/src/bsr/rod.py +++ b/src/bsr/rod.py @@ -1,61 +1,126 @@ __doc__ = """ Rod class for creating and updating rods in Blender """ -__all__ = ["Rod"] +__all__ = ["RodWithSphereAndCylinder", "Rod"] +from typing import TYPE_CHECKING + +import bpy import numpy as np +from numpy.typing import NDArray from .geometry import Cylinder, Sphere +from .mixin import KeyFrameControlMixin +from .protocol import CompositeProtocol -# TODO -class Rod: +class RodWithSphereAndCylinder(KeyFrameControlMixin): """ Rod class for managing visualization and rendering in Blender + + Parameters + ---------- + positions : NDArray + The positions of the sphere objects. Expected shape is (n_dim, n_nodes). + n_dim = 3 + radii : NDArray + The radii of the sphere objects. Expected shape is (n_nodes-1,). + """ - def __init__(self) -> None: - self.bpy_objs = None + input_states = {"positions", "radii"} + + def __init__(self, positions: NDArray, radii: NDArray) -> None: + # check shape of positions and radii + assert positions.ndim == 2, "positions must be 2D array" + assert positions.shape[0] == 3, "positions must have 3 rows" + assert radii.ndim == 1, "radii must be 1D array" + assert ( + positions.shape[-1] == radii.shape[-1] + 1 + ), "radii must have n_nodes-1 elements" + + # create sphere and cylinder objects + self.spheres: list[Sphere] = [] + self.cylinders: list[Cylinder] = [] + self._bpy_objs: dict[str, list[bpy.types.Object]] = { + "sphere": self.spheres, + "cylinder": self.cylinders, + } - def clear(self): - raise NotImplementedError("Not yet implemented") + self._build(positions, radii) - def build(self, positions: np.ndarray, radii: np.ndarray): - # TODO: Refactor + @property + def object(self) -> dict[str, list[bpy.types.Object]]: + """ + Return the dictionary of Blender objects: sphere and cylinder + """ + return self._bpy_objs + + @classmethod + def create(cls, states: dict[str, NDArray]) -> "RodWithSphereAndCylinder": + """ + Create a Rod object from the given states + + States must have the following keys: positions(n_nodes, 3), radii(n_nodes-1,) + """ + rod = cls(**states) + return rod + + def _build(self, positions: NDArray, radii: NDArray) -> None: + _radii = np.concatenate([radii, [0]]) + _radii[1:] += radii + _radii[1:-1] /= 2.0 for j in range(positions.shape[-1]): - sphere = Sphere(positions[:, j]) - self.bpy_objs["sphere"].append(sphere) - sphere.obj.keyframe_insert(data_path="location", frame=0) + sphere = Sphere(positions[:, j], _radii[j]) + self.spheres.append(sphere) - for j in range(positions.shape[-1] - 1): + for j in range(radii.shape[-1]): cylinder = Cylinder( - self.bpy_objs["sphere"][j].obj.location, - self.bpy_objs["sphere"][j + 1].obj.location, - ) - self.bpy_objs["cylinder"].append(cylinder) - cylinder.obj.keyframe_insert(data_path="location", frame=0) - - def update(self, keyframe: int, positions: np.ndarray, radii: np.ndarray): - if self.bpy_objs is None: - self.bpy_objs = {"sphere": [], "cylinder": []} - self.build(positions, radii) - return - # TODO: Refactor - # update all sphere and cylinder positions and write object to keyframe - for idx, sphere in enumerate(self.bpy_objs["sphere"]): - sphere.update_position(positions[:, idx]) - sphere.obj.keyframe_insert(data_path="location", frame=time_step) - - for idx, cylinder in enumerate(self.bpy_objs["cylinder"]): - cylinder.update_position( - self.bpy_objs["sphere"][idx].obj.location, - self.bpy_objs["sphere"][idx + 1].obj.location, + positions[:, j], + positions[:, j + 1], + radii[j], ) - cylinder.obj.keyframe_insert(data_path="location", frame=time_step) - cylinder.obj.keyframe_insert( - data_path="rotation_euler", frame=time_step - ) - cylinder.obj.keyframe_insert(data_path="scale", frame=time_step) - cylinder.mat.keyframe_insert( - data_path="diffuse_color", frame=time_step + self.cylinders.append(cylinder) + + def update_states(self, positions: NDArray, radii: NDArray) -> None: + """ + Update the states of the rod object + + Parameters + ---------- + positions : NDArray + The positions of the sphere objects. Expected shape is (n_nodes, 3). + radii : NDArray + The radii of the sphere objects. Expected shape is (n_nodes-1,). + """ + _radii = np.concatenate([radii, [0]]) + _radii[1:] += radii + _radii[1:-1] /= 2.0 + for idx, sphere in enumerate(self.spheres): + sphere.update_states(positions[:, idx], radii[idx]) + + for idx, cylinder in enumerate(self.cylinders): + cylinder.update_states( + positions[:, idx], positions[:, idx + 1], radii[idx] ) + + def set_keyframe(self, keyframe: int) -> None: + """ + Set keyframe for the rod object + """ + for idx, sphere in enumerate(self.spheres): + sphere.set_keyframe(keyframe) + + for idx, cylinder in enumerate(self.cylinders): + cylinder.set_keyframe(keyframe) + + +# Alias +Rod = RodWithSphereAndCylinder + +if TYPE_CHECKING: + data = { + "positions": np.array([[0, 0, 0], [1, 1, 1]]), + "radii": np.array([1.0, 1.0]), + } + _: CompositeProtocol = RodWithSphereAndCylinder.create(data) diff --git a/src/bsr/stack.py b/src/bsr/stack.py new file mode 100644 index 0000000..1e492a8 --- /dev/null +++ b/src/bsr/stack.py @@ -0,0 +1,104 @@ +__all__ = ["BaseStack", "RodStack", "create_rod_collection"] + +from typing import TYPE_CHECKING, Any, Protocol, Type, overload +from typing_extensions import Self + +from collections.abc import Sequence + +import bpy +import numpy as np +from numpy.typing import NDArray +from tqdm import tqdm + +from .mixin import KeyFrameControlMixin +from .protocol import BlenderMeshInterfaceProtocol, StackProtocol +from .rod import Rod +from .typing import RodType + + +class BaseStack(Sequence, KeyFrameControlMixin): + """ + A stack of objects that can be manipulated together. + Internally, we use a list-like structure to store the objects. + """ + + DefaultType: Type + + def __init__(self) -> None: + self._objs: list[BlenderMeshInterfaceProtocol] = [] + + @overload + def __getitem__(self, index: int, /) -> BlenderMeshInterfaceProtocol: ... + @overload + def __getitem__( + self, index: slice, / + ) -> list[BlenderMeshInterfaceProtocol]: ... + def __getitem__( + self, index: int | slice + ) -> BlenderMeshInterfaceProtocol | list[BlenderMeshInterfaceProtocol]: + return self._objs[index] + + def __len__(self) -> int: + return len(self._objs) + + @property + def object(self) -> list[BlenderMeshInterfaceProtocol]: + """ + Returns the objects in the stack. + """ + return self._objs + + def set_keyframe(self, keyframe: int) -> None: + """ + Sets a keyframe at the given frame. + """ + for obj in self._objs: + obj.set_keyframe(keyframe) + + @classmethod + def create( + cls, + states: dict[str, NDArray], + ) -> Self: + """ + Creates a stack of objects from the given states. + """ + self = cls() + keys = states.keys() + lengths = [i.shape[0] for i in states.values()] + assert len(set(lengths)) <= 1, "All states must have the same length" + num_objects = lengths[0] + + for oidx in range(num_objects): + state = {k: v[oidx] for k, v in states.items()} + obj = self.DefaultType.create(state) + self._objs.append(obj) + return self + + def update_states(self, *variables: NDArray) -> None: + """ + Updates the states of the objects. + """ + if not all([v.shape[0] == len(self) for v in variables]): + raise IndexError( + "All variables must have the same length as the stack" + ) + for idx in range(len(self)): + self[idx].update_states(*[v[idx] for v in variables]) + + +class RodStack(BaseStack): + input_states = {"positions", "radii"} + DefaultType: Type = Rod + + +# Alias for factory functions +create_rod_collection = RodStack.create + + +if TYPE_CHECKING: + data: dict[str, NDArray] = { + "positions": np.array([[[0, 0, 0], [1, 1, 1]]]), + "radii": np.array([[1.0, 1.0]]), + } + _: StackProtocol = RodStack.create(data) diff --git a/src/bsr/typing.py b/src/bsr/typing.py new file mode 100644 index 0000000..df718b6 --- /dev/null +++ b/src/bsr/typing.py @@ -0,0 +1,5 @@ +from typing import TypeAlias + +from .rod import RodWithSphereAndCylinder + +RodType: TypeAlias = RodWithSphereAndCylinder diff --git a/src/elastica_blender/converter/npz2blend.py b/src/elastica_blender/converter/npz2blend.py index 77367dd..859dcbb 100644 --- a/src/elastica_blender/converter/npz2blend.py +++ b/src/elastica_blender/converter/npz2blend.py @@ -4,6 +4,7 @@ import click import numpy as np +from tqdm import tqdm import bsr @@ -53,22 +54,30 @@ def construct_blender_file( if tags is None: position_history = data["position_history"] radius_history = data["radius_history"] - num_rods = position_history.shape[0] - num_nodes = position_history.shape[3] - rods = bsr.create_rod_collection(num_rods, num_nodes) - rods.update_history( - keyframes=time, position=position_history, radius=radius_history - ) + init_state = { + "position": position_history[:, 0, ...], + "radius": radius_history[:, 0, ...], + } + rods = bsr.create_rod_collection(init_state) + for tidx, _ in tqdm(enumerate(time), total=len(time)): + rods.update_states( + position_history[:, tidx, ...], radius_history[:, tidx, ...] + ) + rods.set_keyframe(tidx) else: for tag in tags: position_history = data[tag + "_position_history"] radius_history = data[tag + "_radius_history"] - num_rods = position_history.shape[0] - num_nodes = position_history.shape[3] - rods = bsr.create_rod_collection(num_rods, num_nodes, tag) - rods.update_history( - keyframes=time, position=position_history, radius=radius_history - ) + init_state = { + "position": position_history[:, 0, ...], + "radius": radius_history[:, 0, ...], + } + rods = bsr.create_rod_collection(init_state) + for tidx, _ in tqdm(enumerate(time), total=len(time)): + rods.update_states( + position_history[:, tidx, ...], radius_history[:, tidx, ...] + ) + rods.set_keyframe(tidx) bsr.save(output) diff --git a/src/elastica_blender/rod_callback.py b/src/elastica_blender/rod_callback.py index 769b39d..b8a6047 100644 --- a/src/elastica_blender/rod_callback.py +++ b/src/elastica_blender/rod_callback.py @@ -18,15 +18,20 @@ def __init__(self, step_skip: int) -> None: CallBackBaseClass.__init__(self) self.every = step_skip self.keyframe = 0 - self.bpy_objs = bsr.Rod() + self.bpy_objs: bsr.Rod def make_callback( self, system: RodType, time: np.floating, current_step: int ) -> None: if current_step % self.every == 0: - self.bpy_objs.update( - keyframe=self.key_frame, - positions=system.position_collection, - radii=system.radius_collection, - ) + if current_step == 0: + self.bpy_objs = bsr.Rod( + system.position_collection, system.radius_collection + ) + else: + self.bpy_objs.update_states( + positions=system.position_collection, + radii=system.radius_collection, + ) + self.bpy_objs.set_keyframe(self.key_frame) self.key_frame += 1 diff --git a/tests/elastica_blender/test_npz2blend.py b/tests/elastica_blender/test_npz2blend.py index 09b0f83..a3f086a 100644 --- a/tests/elastica_blender/test_npz2blend.py +++ b/tests/elastica_blender/test_npz2blend.py @@ -127,7 +127,9 @@ def test_construct_blender_file_run(self, mocker, data_setup_cases): call_count = 1 if tags is None else len(tags) assert bsr.create_rod_collection.call_count == call_count - assert rods_mock.update_history.call_count == call_count + assert ( + rods_mock.update_states.call_count == call_count * 4 + ) # 4 frames bsr.save.assert_called_once_with(output_path) diff --git a/tests/geometry/test_interface_keyframe_setting.py b/tests/geometry/test_interface_keyframe_setting.py new file mode 100644 index 0000000..16d3be5 --- /dev/null +++ b/tests/geometry/test_interface_keyframe_setting.py @@ -0,0 +1,78 @@ +import math + +import bpy +import numpy as np +import pytest + +from bsr.geometry import Cylinder, Sphere + + +def get_keyframes(obj_list): + keyframes = [] + for obj in obj_list: + animation_data = obj.animation_data + if animation_data is not None and animation_data.action is not None: + for fcurve in animation_data.action.fcurves: + for keyframe in fcurve.keyframe_points: + x, y = keyframe.co + if x not in keyframes: + keyframes.append(math.ceil(x)) + return keyframes + + +def count_number_of_keyframes_action(obj): + action = obj.animation_data.action + if action is None: + return 0 + else: + return len(action.fcurves[0].keyframe_points) + + +def test_set_keyframe_count_for_primitive_sphere(): + primitive = Sphere(position=np.array([0, 0, 0]), radius=1.0) + + primitive.set_keyframe(1) + assert count_number_of_keyframes_action(primitive.object) == 1 + + primitive.set_keyframe(2) + assert count_number_of_keyframes_action(primitive.object) == 2 + + # Setting keyfrome at the same frame should not increase the number of keyframes: + primitive.set_keyframe(2) + assert count_number_of_keyframes_action(primitive.object) == 2 + + primitive.clear_animation() + assert count_number_of_keyframes_action(primitive.object) == 0 + + primitive.set_keyframe(1) + assert count_number_of_keyframes_action(primitive.object) == 1 + + # Clear the test + primitive.clear_animation() + + +def test_set_keyframe_count_for_primitive_cylinder(): + primitive = Cylinder( + position_1=np.array([0, 0, 0]), + position_2=np.array([0, 0, 1]), + radius=1.0, + ) + + primitive.set_keyframe(1) + assert count_number_of_keyframes_action(primitive.object) == 1 + + primitive.set_keyframe(2) + assert count_number_of_keyframes_action(primitive.object) == 2 + + # Setting keyfrome at the same frame should not increase the number of keyframes: + primitive.set_keyframe(2) + assert count_number_of_keyframes_action(primitive.object) == 2 + + primitive.clear_animation() + assert count_number_of_keyframes_action(primitive.object) == 0 + + primitive.set_keyframe(1) + assert count_number_of_keyframes_action(primitive.object) == 1 + + # Clear the test + primitive.clear_animation() diff --git a/tests/geometry/test_primitive_geometry_mesh.py b/tests/geometry/test_primitive_geometry_mesh.py new file mode 100644 index 0000000..ff8f668 --- /dev/null +++ b/tests/geometry/test_primitive_geometry_mesh.py @@ -0,0 +1,141 @@ +import numpy as np +import pytest +from utils import get_mesh_limit + +from bsr.geometry import Cylinder, Sphere + +# Visual tolerance for the mesh limit +_VISUAL_ATOL = 1e-7 +_VISUAL_RTOL = 1e-4 + + +@pytest.mark.parametrize( + "center", + [ + np.array([10, 10, 10]), + np.array([10, 11, 10]), + np.array([10, 11, 11]), + np.array([11, 11, 11]), + ], +) +@pytest.mark.parametrize("radius", [1, 2, 3, 5.5]) +def test_sphere_radius_and_position(center, radius): + x_min, x_max = center[0] - radius, center[0] + radius + y_min, y_max = center[1] - radius, center[1] + radius + z_min, z_max = center[2] - radius, center[2] + radius + + sphere = Sphere(position=center, radius=radius) + + mesh_limit = get_mesh_limit(sphere) + + np.testing.assert_allclose( + (x_min, x_max, y_min, y_max, z_min, z_max), + mesh_limit, + rtol=_VISUAL_RTOL, + atol=_VISUAL_ATOL, + ) + + +@pytest.mark.parametrize( + "position_one", + [ + np.array([10, 10, 10]), + np.array([10, 11, -10]), + np.array([-10, 11, 11]), + ], +) +@pytest.mark.parametrize("length", [1, 10.5, -1, -10.5]) +@pytest.mark.parametrize("radius", [1, 3, 5.5]) +def test_x_cylinder_radius_and_positions(position_one, length, radius): + position_two = position_one + np.array([length, 0, 0]) + y, z = position_one[1], position_one[2] + + x_min, x_max = min(position_one[0], position_two[0]), max( + position_one[0], position_two[0] + ) + y_min, y_max = y - radius, y + radius + z_min, z_max = z - radius, z + radius + + cylinder = Cylinder( + position_1=position_one, position_2=position_two, radius=radius + ) + + mesh_limit = get_mesh_limit(cylinder) + + np.testing.assert_allclose( + (x_min, x_max, y_min, y_max, z_min, z_max), + mesh_limit, + rtol=_VISUAL_RTOL, + atol=_VISUAL_ATOL, + ) + + +@pytest.mark.parametrize( + "position_one", + [ + np.array([10, 10, 10]), + np.array([10, 11, -10]), + np.array([-10, 11, 11]), + ], +) +@pytest.mark.parametrize("length", [1, 10.5, -1, -10.5]) +@pytest.mark.parametrize("radius", [1, 3, 5.5]) +def test_y_cylinder_radius_and_positions(position_one, length, radius): + position_two = position_one + np.array([0, length, 0]) + x, z = position_one[0], position_one[2] + + x_min, x_max = x - radius, x + radius + y_min, y_max = min(position_one[1], position_two[1]), max( + position_one[1], position_two[1] + ) + z_min, z_max = z - radius, z + radius + + cylinder = Cylinder( + position_1=position_one, position_2=position_two, radius=radius + ) + + mesh_limit = get_mesh_limit(cylinder) + + print("Expected limits:", (x_min, x_max, y_min, y_max, z_min, z_max)) + print("Actual limits:", mesh_limit) + + np.testing.assert_allclose( + (x_min, x_max, y_min, y_max, z_min, z_max), + mesh_limit, + rtol=_VISUAL_RTOL, + atol=_VISUAL_ATOL, + ) + + +@pytest.mark.parametrize( + "position_one", + [ + np.array([10, 10, 10]), + np.array([10, 11, -10]), + np.array([-10, 11, 11]), + ], +) +@pytest.mark.parametrize("length", [1, 10.5, -1, -10.5]) +@pytest.mark.parametrize("radius", [1, 3, 5.5]) +def test_z_cylinder_radius_and_positions(position_one, length, radius): + position_two = position_one + np.array([0, 0, length]) + x, y = position_one[0], position_one[1] + + x_min, x_max = x - radius, x + radius + y_min, y_max = y - radius, y + radius + z_min, z_max = min(position_one[2], position_two[2]), max( + position_one[2], position_two[2] + ) + + cylinder = Cylinder( + position_1=position_one, position_2=position_two, radius=radius + ) + + mesh_limit = get_mesh_limit(cylinder) + + np.testing.assert_allclose( + (x_min, x_max, y_min, y_max, z_min, z_max), + mesh_limit, + rtol=_VISUAL_RTOL, + atol=_VISUAL_ATOL, + ) diff --git a/tests/geometry/utils.py b/tests/geometry/utils.py new file mode 100644 index 0000000..9655376 --- /dev/null +++ b/tests/geometry/utils.py @@ -0,0 +1,22 @@ +import bpy +import numpy as np + +from bsr.geometry import BlenderMeshInterfaceProtocol +from bsr.macros import scene_update + + +def get_mesh_limit(interface: BlenderMeshInterfaceProtocol): + """(For testing) Given blender mesh object, return xyz limit""" + + obj = interface.object + scene_update() + + vertices_coords = [] + for v in obj.data.vertices: + global_coord = obj.matrix_world @ v.co + vertices_coords.append(list(global_coord)) + vertices_coords = np.array(vertices_coords) + + x_min, y_min, z_min = np.min(vertices_coords, axis=0) + x_max, y_max, z_max = np.max(vertices_coords, axis=0) + return x_min, x_max, y_min, y_max, z_min, z_max diff --git a/tests/stack/test_base_stack_stacking.py b/tests/stack/test_base_stack_stacking.py new file mode 100644 index 0000000..5bf7c3d --- /dev/null +++ b/tests/stack/test_base_stack_stacking.py @@ -0,0 +1,137 @@ +import numpy as np +import pytest + +from bsr.stack import BaseStack + + +class MockObjectToStack: + def __init__(self, value): + self.value = value + + @property + def object(self): + return self.value + + @classmethod + def create(cls, states): + return cls(states["value"]) + + def update_states(self, value): + self.value = value + + +class MockStack(BaseStack[MockObjectToStack]): + DefaultType = MockObjectToStack + + +class MockStackStack(BaseStack[MockStack]): + DefaultType = MockStack + + +def test_stack(): + stack = MockStack.create({"value": np.array([1, 2, 3])}) + + assert len(stack) == 3 + assert stack[0].object == 1 + assert stack[1].object == 2 + assert stack[2].object == 3 + + +def test_recursive_stack(): + stack = MockStackStack.create( + { + "value": np.array( + [ + [1, 2, 3], + [4, 5, 6], + ] + ) + }, + ) + + assert len(stack) == 2 + assert stack[0][0].object == 1 + assert stack[0][1].object == 2 + assert stack[0][2].object == 3 + assert stack[1][0].object == 4 + assert stack[1][1].object == 5 + assert stack[1][2].object == 6 + + stack.update_states(np.ones((2, 3))) + + assert stack[0][0].object == 1 + assert stack[0][1].object == 1 + assert stack[0][2].object == 1 + assert stack[1][0].object == 1 + assert stack[1][1].object == 1 + assert stack[1][2].object == 1 + + +@pytest.mark.parametrize( + "update_data", + [ + np.ones(2), + np.ones((2, 3)), + np.ones(4), + np.ones((4, 3)), + ], +) +def test_update_wrong_size(update_data): + stack = MockStack.create( + { + "value": np.array( + [1, 2, 3], + ) + }, + ) + + assert len(stack) == 3 + with pytest.raises(IndexError): + stack.update_states(update_data) + + +class MockObjectToStack2: + def __init__(self, value1, value2): + self.value = (value1, value2) + + @property + def object(self): + return self.value + + @classmethod + def create(cls, states): + return cls(states["value1"], states["value2"]) + + def update_states(self, value1, value2): + self.value = (value1, value2) + + +class Mock2Stack(BaseStack[MockObjectToStack2]): + DefaultType = MockObjectToStack2 + + +def test_stack2(): + stack = Mock2Stack.create( + {"value1": np.array([1, 2, 3]), "value2": np.array([4, 5, 6])} + ) + + assert len(stack) == 3 + assert stack[0].object == (1, 4) + assert stack[1].object == (2, 5) + assert stack[2].object == (3, 6) + + stack.update_states(np.ones(3), np.ones(3) * 2) + + assert stack[0].object == (1, 2) + assert stack[1].object == (1, 2) + assert stack[2].object == (1, 2) + + +def test_stack2_wrong_size(): + stack = Mock2Stack.create( + {"value1": np.array([1, 2, 3]), "value2": np.array([4, 5, 6])} + ) + + assert len(stack) == 3 + with pytest.raises(IndexError): + stack.update_states(np.ones(2), np.ones(2) * 2) diff --git a/tests/stack/test_stack_properties.py b/tests/stack/test_stack_properties.py new file mode 100644 index 0000000..db48401 --- /dev/null +++ b/tests/stack/test_stack_properties.py @@ -0,0 +1,22 @@ +import pytest + +from bsr.stack import BaseStack + + +def test_object_property(): + stack = BaseStack() + stack._objs = [1, 2, 3] + assert stack.object == [1, 2, 3] + + +def test_set_keyframe(mocker): + stack = BaseStack() + mock_rod = mocker.Mock() + n_repeat = 3 + val = 5 + stack._objs = [mock_rod] * n_repeat + stack.set_keyframe(val) + + mock_rod.set_keyframe.assert_called() + assert mock_rod.set_keyframe.call_count == n_repeat + mock_rod.set_keyframe.assert_called_with(val) diff --git a/tests/test_blender_mesh_interface.py b/tests/test_blender_mesh_interface.py new file mode 100644 index 0000000..67ba135 --- /dev/null +++ b/tests/test_blender_mesh_interface.py @@ -0,0 +1,80 @@ +import bpy +import numpy as np +import pytest + +from bsr.geometry import Cylinder, Sphere + + +class TestBlenderMeshInterfaceObjectsSphere: + + @pytest.fixture(autouse=True) + def primitive(self): + Data = dict(position=np.array([0, 0, 0]), radius=1.0) + return Sphere.create(Data) + + def test_object_type(self, primitive): + assert isinstance(primitive.object, bpy.types.Object) + + def test_update_states_with_empty_data(self, primitive): + primitive.update_states() # Calling empty data should pass + assert True + + +class TestBlenderMeshInterfaceObjectsCylinder( + TestBlenderMeshInterfaceObjectsSphere +): + @pytest.fixture(autouse=True) + def primitive(self): + Data = dict( + position_1=np.array([0, 0, 0]), + position_2=np.array([0, 0, 1]), + radius=1.0, + ) + return Cylinder.create(Data) + + +@pytest.mark.parametrize( + "wrong_key", + [ + "__wrong_key", + 5, + ], +) +def test_update_states_warning_message_if_wrong_key_sphere( + wrong_key, +): + t = Sphere + data = {wrong_key: 0, "position": np.array([0, 0, 0]), "radius": 1.0} + with pytest.warns(UserWarning) as record: + t.create(data) + assert ( + f"not used as a part of the state definition" + in record[0].message.args[0] + ) + assert str(wrong_key) in record[0].message.args[0] + + +@pytest.mark.parametrize( + "wrong_key", + [ + "__wrong_key", + 5, + ], +) +def test_update_states_warning_message_if_wrong_key_cylinder( + wrong_key, +): + t = Cylinder + data = { + wrong_key: 0, + "position_1": np.array([0, 0, 0]), + "position_2": np.array([0, 0, 1]), + "radius": 1.0, + } + with pytest.warns(UserWarning) as record: + t.create(data) + assert ( + f"not used as a part of the state definition" + in record[0].message.args[0] + ) + assert str(wrong_key) in record[0].message.args[0] diff --git a/tests/test_cylinder_update.py b/tests/test_cylinder_update.py new file mode 100644 index 0000000..e56727e --- /dev/null +++ b/tests/test_cylinder_update.py @@ -0,0 +1,114 @@ +import bpy.types as bpy_types +import numpy as np +import pytest + +from bsr.geometry import Cylinder + + +@pytest.mark.parametrize( + "possible_cylinder_data", + [ + dict( + position_1=np.array([10, 10, 10]), + position_2=np.array([20, 20, 20]), + radius=10.0, + ), + dict(position_1=np.array([10, 10, 10]), radius=10.0), + dict(radius=10.0), + dict(position_1=np.array([10, 10, 10])), + ], +) +def test_update_states_with_data(possible_cylinder_data): + default_data = dict( + position_1=np.array([0, 0, 0]), + position_2=np.array([1, 1, 1]), + radius=1.0, + ) + primitive = Cylinder.create(default_data) + assert primitive.object is not None + np.testing.assert_allclose( + primitive.object.scale, np.array([1.0, 1.0, np.sqrt(3)]) + ) + np.testing.assert_allclose( + primitive.object.location, np.array([0.5, 0.5, 0.5]) + ) + + primitive.update_states(**possible_cylinder_data) + + default_data.update(possible_cylinder_data) # Update state dictionary + if ( + "position_1" in possible_cylinder_data + or "position_2" in possible_cylinder_data + ): + center_point = ( + default_data["position_1"] + default_data["position_2"] + ) / 2.0 + depth = np.linalg.norm( + default_data["position_1"] - default_data["position_2"] + ) + np.testing.assert_allclose(primitive.object.location, center_point) + np.testing.assert_allclose(primitive.object.scale[2], depth) + if "radius" in possible_cylinder_data: + np.testing.assert_allclose( + primitive.object.scale[0], default_data["radius"] + ) + np.testing.assert_allclose( + primitive.object.scale[1], default_data["radius"] + ) + + +@pytest.mark.parametrize( + "impossible_shaped_data", + [ + dict(position_1=np.array([10, 10, 10, 10]), radius=10.0), + dict(position_2=np.array([10, 10, 10, 10]), radius=10.0), + dict(position_1=np.array([0, 0, 0]), position_2=np.array([0, 0, 0])), + dict(radius=np.array([10, 10, 10])), + dict(radius=-1), + dict(radius=0), + dict(position_1=np.array([10, 10])), + ], +) +def test_update_states_with_wrong_shape(impossible_shaped_data): + default_data = dict( + position_1=np.array([0, 0, 0]), + position_2=np.array([1, 1, 1]), + radius=1.0, + ) + primitive = Cylinder.create(default_data) + with pytest.raises(ValueError): + primitive.update_states(**impossible_shaped_data) + + +@pytest.mark.parametrize( + "nan_data", + [ + dict(position_1=np.array([10, 10, 10]), radius=np.nan), + dict(position_1=np.array([np.nan, 10, 10]), radius=10.0), + dict(position_2=np.array([np.nan, 10, 10]), radius=10.0), + ], +) +def test_update_states_with_nan_values(nan_data): + default_data = dict( + position_1=np.array([0, 0, 0]), + position_2=np.array([1, 1, 1]), + radius=1.0, + ) + primitive = Cylinder.create(default_data) + with pytest.raises(ValueError) as exc_info: + primitive.update_states(**nan_data) + assert "contains NaN" in str(exc_info.value) + + +def test_cylinder_creator(): + default_data = dict( + position_1=np.array([0, 0, 0]), + position_2=np.array([1, 1, 1]), + radius=1.0, + ) + primitive = Cylinder.create(default_data) + old_cylinder = primitive.object + new_cylinder = primitive._create_cylinder() + assert new_cylinder is not None + assert old_cylinder is not new_cylinder + assert isinstance(new_cylinder, bpy_types.Object) diff --git a/tests/test_sphere_update.py b/tests/test_sphere_update.py new file mode 100644 index 0000000..11fba72 --- /dev/null +++ b/tests/test_sphere_update.py @@ -0,0 +1,77 @@ +import bpy.types as bpy_types +import numpy as np +import pytest + +from bsr.geometry import Sphere + + +# Sphere-specific tests +@pytest.mark.parametrize( + "possible_sphere_data", + [ + dict(position=np.array([10, 10, 10]), radius=10.0), + dict(radius=10.0), + dict(position=np.array([10, 10, 10])), + ], +) +def test_update_states_with_data(possible_sphere_data): + default_data = dict(position=np.array([0, 0, 0]), radius=1.0) + primitive = Sphere.create(default_data) + + np.testing.assert_allclose( + primitive.object.location, default_data["position"] + ) + np.testing.assert_allclose(primitive.object.scale, default_data["radius"]) + + primitive.update_states(**possible_sphere_data) + + if "position" in possible_sphere_data: + np.testing.assert_allclose( + primitive.object.location, possible_sphere_data["position"] + ) + if "radius" in possible_sphere_data: + np.testing.assert_allclose( + primitive.object.scale, possible_sphere_data["radius"] + ) + + +@pytest.mark.parametrize( + "impossible_shaped_data", + [ + dict(position=np.array([10, 10, 10, 10]), radius=10.0), + dict(radius=np.array([10, 10, 10])), + dict(radius=-1), + dict(radius=0), + dict(position=np.array([10, 10])), + ], +) +def test_update_states_with_wrong_shape(impossible_shaped_data): + default_data = dict(position=np.array([0, 0, 0]), radius=1.0) + primitive = Sphere.create(default_data) + with pytest.raises(ValueError): + primitive.update_states(**impossible_shaped_data) + + +@pytest.mark.parametrize( + "nan_data", + [ + dict(position=np.array([10, 10, 10]), radius=np.nan), + dict(position=np.array([np.nan, 10, 10]), radius=10.0), + ], +) +def test_update_states_with_nan_values(nan_data): + default_data = dict(position=np.array([0, 0, 0]), radius=1.0) + primitive = Sphere.create(default_data) + with pytest.raises(ValueError) as exc_info: + primitive.update_states(**nan_data) + assert "contains NaN" in str(exc_info.value) + + +def test_sphere_creator(): + default_data = dict(position=np.array([0, 0, 0]), radius=1.0) + primitive = Sphere.create(default_data) + old_sphere = primitive.object + new_sphere = primitive._create_sphere() + assert new_sphere is not None + assert old_sphere is not new_sphere + assert isinstance(new_sphere, bpy_types.Object)