Skip to content

Commit

Permalink
new implementation of target slice that exactly matches QuickNII
Browse files Browse the repository at this point in the history
  • Loading branch information
PolarBean authored Mar 25, 2024
1 parent 3c60d97 commit 9b3c9d3
Showing 1 changed file with 32 additions and 80 deletions.
112 changes: 32 additions & 80 deletions PyNutil/generate_target_slice.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,35 @@
import numpy as np
import math

def generate_target_slize(ouv, atlas):
width = None
height = None
ox, oy, oz, ux, uy, uz, vx, vy, vz = ouv
width = np.floor(math.hypot(ux,uy,uz)).astype(int) + 1
height = np.floor(math.hypot(vx,vy,vz)).astype(int) + 1
data = np.zeros((width, height), dtype=np.uint32).flatten()
xdim, ydim, zdim = atlas.shape
y_values = np.arange(height)
x_values = np.arange(width)
hx = ox + vx * (y_values / height)
hy = oy + vy * (y_values / height)
hz = oz + vz * (y_values / height)
wx = ux * (x_values / width)
wy = uy * (x_values / width)
wz = uz * (x_values / width)
lx = np.floor(hx[:, None] + wx).astype(int)
ly = np.floor(hy[:, None] + wy).astype(int)
lz = np.floor(hz[:, None] + wz).astype(int)
valid_indices = (0 <= lx) & (lx < xdim) & (0 <= ly) & (ly < ydim) & (0 <= lz) & (lz < zdim)
valid_indices = valid_indices.flatten()
lxf = lx.flatten()
lyf = ly.flatten()
lzf = lz.flatten()
valid_lx = lxf[valid_indices]
valid_ly = lyf[valid_indices]
valid_lz = lzf[valid_indices]
atlas_slice = atlas[valid_lx,valid_ly,valid_lz]
data[valid_indices] = atlas_slice
data_im = data.reshape((height, width))
return data_im

def generate_target_slice(alignment, volume):
Ox, Oy, Oz, Ux, Uy, Uz, Vx, Vy, Vz = alignment
##just for mouse for now
bounds = [455, 527, 319]
X_size = np.sqrt(np.sum(np.square((Ux, Uy, Uz))))
Z_size = np.sqrt(np.sum(np.square((Vx, Vy, Vz))))
X_size = np.round(X_size).astype(int)
Z_size = np.round(Z_size).astype(int)
# make this into a grid (0,0) to (320,456)
Uarange = np.arange(0, 1, 1 / X_size)
Varange = np.arange(0, 1, 1 / Z_size)
Ugrid, Vgrid = np.meshgrid(Uarange, Varange)
Ugrid_x = Ugrid * Ux
Ugrid_y = Ugrid * Uy
Ugrid_z = Ugrid * Uz
Vgrid_x = Vgrid * Vx
Vgrid_y = Vgrid * Vy
Vgrid_z = Vgrid * Vz

X_Coords = (Ugrid_x + Vgrid_x).flatten() + Ox
Y_Coords = (Ugrid_y + Vgrid_y).flatten() + Oy
Z_Coords = (Ugrid_z + Vgrid_z).flatten() + Oz

X_Coords = np.round(X_Coords).astype(int)
Y_Coords = np.round(Y_Coords).astype(int)
Z_Coords = np.round(Z_Coords).astype(int)

out_bounds_Coords = (
(X_Coords > bounds[0])
| (Y_Coords > bounds[1])
| (Z_Coords > bounds[2])
| (X_Coords < 0)
| (Y_Coords < 0)
| (Z_Coords < 0)
)
X_pad = X_Coords.copy()
Y_pad = Y_Coords.copy()
Z_pad = Z_Coords.copy()

X_pad[out_bounds_Coords] = 0
Y_pad[out_bounds_Coords] = 0
Z_pad[out_bounds_Coords] = 0

regions = volume[X_pad, Y_pad, Z_pad]
##this is a quick hack to solve rounding errors
C = len(regions)
compare = C - X_size * Z_size
if abs(compare) == X_size:
if compare > 0:
Z_size += 1
if compare < 0:
Z_size -= 1
elif abs(C - X_size * Z_size) == Z_size:
if compare > 0:
X_size += 1
if compare < 0:
X_size -= 1
elif abs(C - X_size * Z_size) == Z_size + X_size:
if compare > 0:
X_size += 1
Z_size += 1
if compare < 0:
X_size -= 1
Z_size -= 1
elif abs(C - X_size * Z_size) == Z_size - X_size:
if compare > 0:
X_size += 1
Z_size -= 1
if compare < 0:
X_size -= 1
Z_size += 1
elif abs(C - X_size * Z_size) == X_size - Z_size:
if compare > 0:
X_size -= 1
Z_size += 1
if compare < 0:
X_size += 1
Z_size -= 1
regions = regions.reshape((abs(Z_size), abs(X_size)))
return regions

0 comments on commit 9b3c9d3

Please sign in to comment.