Skip to content

Commit

Permalink
wip: follow protocol (have question on the protocol of update_states …
Browse files Browse the repository at this point in the history
…method, will create an issue for it.)
  • Loading branch information
hanson-hschang committed Sep 9, 2024
1 parent 8472c5f commit 9987611
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 12 deletions.
44 changes: 43 additions & 1 deletion src/bsr/geometry/composite/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
"""
__all__ = ["Pose"]

from typing import TYPE_CHECKING, Any

import bpy
import numpy as np
from numpy.typing import NDArray

from bsr.geometry.primitives.simple import Cylinder, Sphere
from bsr.geometry.protocol import CompositeProtocol
from bsr.tools.keyframe_mixin import KeyFrameControlMixin


Expand Down Expand Up @@ -44,15 +48,41 @@ def __init__(
self.__unit_length = unit_length
self.__ratio = thickness_ratio

# create sphere and cylinder materials
self.spheres_material: list[bpy.types.Material] = []
self.cylinders_material: list[bpy.types.Material] = []
self._bpy_materials: dict[str, bpy.types.Material] = {
"spheres": self.spheres_material,
"cylinders": self.cylinders_material,
}

self._build(position, directors)

@property
def material(self) -> dict[str, bpy.types.Material]:
"""
Return the dictionary of Blender materials: spheres and cylinders
"""
return self._bpy_materials

@property
def object(self) -> dict[str, bpy.types.Object]:
"""
Return the dictionary of Blender objects: spheres and cylinders
"""
return self._bpy_objs

@classmethod
def create(cls, states: dict[str, NDArray]) -> "Pose":
"""
Create a Pose object from the given states
States must have the following keys: position(n_dim,), directors(n_dim, n_dim)
"""
position = states["position"]
directors = states["directors"]
return cls(position, directors)

def _build(self, position: NDArray, directors: NDArray) -> None:
"""
Build the pose object from the given position and directors
Expand All @@ -73,12 +103,14 @@ def _build(self, position: NDArray, directors: NDArray) -> None:
self.__unit_length * self.__ratio,
)
self.cylinders.append(cylinder)
self.cylinders_material.append(cylinder.material)

sphere = Sphere(
tip_position,
self.__unit_length * self.__ratio,
)
self.spheres.append(sphere)
self.spheres_material.append(sphere.material)

def update_states(self, position: NDArray, directors: NDArray) -> None:
"""
Expand All @@ -93,7 +125,7 @@ def update_states(self, position: NDArray, directors: NDArray) -> None:
sphere = self.spheres[i + 1]
sphere.update_states(tip_position)

def update_material(self, **kwargs) -> None:
def update_material(self, **kwargs: dict[str, Any]) -> None:
"""
Updates the material of the pose object
"""
Expand All @@ -112,3 +144,13 @@ def set_keyframe(self, keyframe: int) -> None:

for cylinder in self.cylinders:
cylinder.set_keyframe(keyframe)


if TYPE_CHECKING:
data = {
"position": np.array([0.0, 0.0, 0.0]),
"directors": np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
),
}
_: CompositeProtocol = Pose.create(data)
23 changes: 21 additions & 2 deletions src/bsr/geometry/composite/rod.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""
__all__ = ["RodWithSphereAndCylinder", "Rod"]

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from collections import defaultdict

import bpy
import numpy as np
Expand Down Expand Up @@ -39,8 +41,23 @@ def __init__(self, positions: NDArray, radii: NDArray) -> None:
"cylinder": self.cylinders,
}

# create sphere and cylinder materials
self.spheres_material: list[bpy.types.Material] = []
self.cylinders_material: list[bpy.types.Material] = []
self._bpy_materials: dict[str, list[bpy.types.Material]] = {
"sphere": self.spheres_material,
"cylinder": self.cylinders_material,
}

self._build(positions, radii)

@property
def material(self) -> dict[str, list[bpy.types.Material]]:
"""
Return the dictionary of Blender materials: sphere and cylinder
"""
return self._bpy_materials

@property
def object(self) -> dict[str, list[bpy.types.Object]]:
"""
Expand All @@ -65,6 +82,7 @@ def _build(self, positions: NDArray, radii: NDArray) -> None:
for j in range(positions.shape[-1]):
sphere = Sphere(positions[:, j], _radii[j])
self.spheres.append(sphere)
self.spheres_material.append(sphere.material)

for j in range(radii.shape[-1]):
cylinder = Cylinder(
Expand All @@ -73,6 +91,7 @@ def _build(self, positions: NDArray, radii: NDArray) -> None:
radii[j],
)
self.cylinders.append(cylinder)
self.cylinders_material.append(cylinder.material)

def update_states(self, positions: NDArray, radii: NDArray) -> None:
"""
Expand Down Expand Up @@ -104,7 +123,7 @@ def update_states(self, positions: NDArray, radii: NDArray) -> None:
positions[:, idx], positions[:, idx + 1], _radii[idx]
)

def update_material(self, **kwargs) -> None:
def update_material(self, **kwargs: dict[str, Any]) -> None:
"""
Updates the material of the rod object
"""
Expand Down
24 changes: 24 additions & 0 deletions src/bsr/geometry/composite/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class BaseStack(Sequence, KeyFrameControlMixin):

def __init__(self) -> None:
self._objs: list[BlenderMeshInterfaceProtocol] = []
self._mats: list[BlenderMeshInterfaceProtocol] = []

@overload
def __getitem__(self, index: int, /) -> BlenderMeshInterfaceProtocol: ...
Expand All @@ -40,6 +41,13 @@ def __getitem__(
def __len__(self) -> int:
return len(self._objs)

@property
def material(self) -> list[BlenderMeshInterfaceProtocol]:
"""
Returns the materials in the stack.
"""
return self._mats

@property
def object(self) -> list[BlenderMeshInterfaceProtocol]:
"""
Expand Down Expand Up @@ -72,6 +80,7 @@ def create(
state = {k: v[oidx] for k, v in states.items()}
obj = self.DefaultType.create(state)
self._objs.append(obj)
self._mats.append(obj.material)
return self

def update_states(self, *variables: NDArray) -> None:
Expand All @@ -85,6 +94,21 @@ def update_states(self, *variables: NDArray) -> None:
for idx in range(len(self)):
self[idx].update_states(*[v[idx] for v in variables])

def update_material(self, **kwargs: dict[str, NDArray]) -> None:
"""
Updates the material of the objects.
"""
for material_key, material_values in kwargs.items():
assert isinstance(
material_values, np.ndarray
), "Values of kwargs must be a numpy array"
if material_values.shape[0] != len(self):
raise IndexError(
"All values must have the same length as the stack"
)
for idx in range(len(self)):
self[idx].update_material({material_key: material_values[idx]})


class RodStack(BaseStack):
input_states = {"positions", "radii"}
Expand Down
40 changes: 33 additions & 7 deletions src/bsr/geometry/primitives/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
__all__ = ["Sphere", "Cylinder"]

from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

import warnings
from numbers import Number
Expand Down Expand Up @@ -130,6 +130,14 @@ def create(cls, states: MeshDataType) -> "Sphere":
)
return cls(states["position"], states["radius"])

@property
def material(self) -> bpy.types.Material:
"""
Access the Blender material.
"""

return self._material

@property
def object(self) -> bpy.types.Object:
"""
Expand Down Expand Up @@ -166,7 +174,7 @@ def update_states(
_validate_radius(radius)
self.object.scale = (radius, radius, radius)

def update_material(self, color: NDArray | None = None) -> None:
def update_material(self, **kwargs: dict[str, Any]) -> None:
"""
Updates the material of the sphere object.
Expand All @@ -175,7 +183,12 @@ def update_material(self, color: NDArray | None = None) -> None:
color : NDArray
The new color of the sphere object in RGBA format.
"""
if color is not None:

if "color" in kwargs:
color = kwargs["color"]
assert isinstance(
color, np.ndarray
), "Keyword argument `color` should be a numpy array."
assert color.shape == (
4,
), "Keyword argument color should be a 1D array with 4 elements: RGBA."
Expand Down Expand Up @@ -254,6 +267,14 @@ def create(cls, states: MeshDataType) -> "Cylinder":
)
return cls(states["position_1"], states["position_2"], states["radius"])

@property
def material(self) -> bpy.types.Material:
"""
Access the Blender material.
"""

return self._material

@property
def object(self) -> bpy.types.Object:
"""
Expand Down Expand Up @@ -320,16 +341,21 @@ def update_states(
self.object.scale[0] = radius
self.object.scale[1] = radius

def update_material(self, color: NDArray | None = None) -> None:
def update_material(self, **kwargs: dict[str, Any]) -> None:
"""
Updates the material of the cylinder object.
Parameters
----------
color : NDArray
The new color of the cylinder object in RGBA format.
kwargs : dict
Keyword arguments for the material update.
"""
if color is not None:

if "color" in kwargs:
color = kwargs["color"]
assert isinstance(
color, np.ndarray
), "Keyword argument `color` should be a numpy array."
assert color.shape == (
4,
), "Keyword argument `color` should be a 1D array with 4 elements: RGBA."
Expand Down
4 changes: 2 additions & 2 deletions src/bsr/geometry/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def object(self) -> Any:
def create(cls: Type[S], states: MeshDataType) -> S:
"""Creates a new mesh object with the given states."""

def update_states(self, **kwargs: Any) -> None:
def update_states(self, *args: Any) -> None:
"""Updates the mesh object with the given states."""

def update_material(self, **kwargs: Any) -> None:
def update_material(self, **kwargs: dict[str, Any]) -> None:
"""Updates the material of the mesh object."""


Expand Down

0 comments on commit 9987611

Please sign in to comment.