Skip to content

Commit

Permalink
add testing framework
Browse files Browse the repository at this point in the history
  • Loading branch information
PolarBean committed Oct 23, 2024
1 parent b1a73a1 commit d6bb97f
Show file tree
Hide file tree
Showing 256 changed files with 391 additions and 63,686 deletions.
2 changes: 1 addition & 1 deletion PyNutil/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .main import PyNutil
from .main import PyNutil
93 changes: 75 additions & 18 deletions PyNutil/coordinate_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def transform_to_registration(seg_height, seg_width, reg_height, reg_width):
x_scale = reg_width / seg_width
return y_scale, x_scale


# related to coordinate extraction
def find_matching_pixels(segmentation, id):
"""This function returns the Y and X coordinates of all the pixels in the segmentation that match the id provided."""
Expand Down Expand Up @@ -207,14 +208,14 @@ def folder_to_atlas_space(
[t.join() for t in threads]
# Flatten points_list

points_len = [
len(points) if None not in points else 0 for points in points_list
]
points_len = [len(points) if None not in points else 0 for points in points_list]
centroids_len = [
len(centroids) if None not in centroids else 0 for centroids in centroids_list
]
len(centroids) if None not in centroids else 0 for centroids in centroids_list
]
points_list = [points for points in points_list if None not in points]
centroids_list = [centroids for centroids in centroids_list if None not in centroids]
centroids_list = [
centroids for centroids in centroids_list if None not in centroids
]
if len(points_list) == 0:
points = np.array([])
else:
Expand All @@ -224,7 +225,6 @@ def folder_to_atlas_space(
else:
centroids = np.concatenate(centroids_list)


return (
np.array(points),
np.array(centroids),
Expand All @@ -234,6 +234,7 @@ def folder_to_atlas_space(
segmentations,
)


def load_segmentation(segmentation_path: str):
"""Load a segmentation from a file."""
print(f"working on {segmentation_path}")
Expand All @@ -243,14 +244,25 @@ def load_segmentation(segmentation_path: str):
else:
return cv2.imread(segmentation_path)


def detect_pixel_id(segmentation: np.array):
"""Remove the background from the segmentation and return the pixel id."""
segmentation_no_background = segmentation[~np.all(segmentation == 0, axis=2)]
pixel_id = segmentation_no_background[0]
print("detected pixel_id: ", pixel_id)
return pixel_id

def get_region_areas(use_flat, atlas_labels, flat_file_atlas, seg_width, seg_height, slice_dict, atlas_volume, triangulation):

def get_region_areas(
use_flat,
atlas_labels,
flat_file_atlas,
seg_width,
seg_height,
slice_dict,
atlas_volume,
triangulation,
):
if use_flat:
region_areas = flat_to_dataframe(
atlas_labels, flat_file_atlas, (seg_width, seg_height)
Expand All @@ -262,24 +274,38 @@ def get_region_areas(use_flat, atlas_labels, flat_file_atlas, seg_width, seg_hei
(seg_width, seg_height),
slice_dict["anchoring"],
atlas_volume,
triangulation
triangulation,
)
return region_areas

def get_transformed_coordinates(non_linear, slice_dict, method, scaled_x, scaled_y, centroids, scaled_centroidsX, scaled_centroidsY, triangulation):

def get_transformed_coordinates(
non_linear,
slice_dict,
method,
scaled_x,
scaled_y,
centroids,
scaled_centroidsX,
scaled_centroidsY,
triangulation,
):
new_x, new_y, centroids_new_x, centroids_new_y = None, None, None, None
if non_linear and "markers" in slice_dict:
if method in ["per_pixel", "all"] and scaled_x is not None:
new_x, new_y = transform_vec(triangulation, scaled_x, scaled_y)
if method in ["per_object", "all"] and centroids is not None:
centroids_new_x, centroids_new_y = transform_vec(triangulation, scaled_centroidsX, scaled_centroidsY)
centroids_new_x, centroids_new_y = transform_vec(
triangulation, scaled_centroidsX, scaled_centroidsY
)
else:
if method in ["per_pixel", "all"]:
new_x, new_y = scaled_x, scaled_y
if method in ["per_object", "all"]:
centroids_new_x, centroids_new_y = scaled_centroidsX, scaled_centroidsY
return new_x, new_y, centroids_new_x, centroids_new_y


def segmentation_to_atlas_space(
slice_dict,
segmentation_path,
Expand All @@ -305,20 +331,51 @@ def segmentation_to_atlas_space(
triangulation = triangulate(reg_width, reg_height, slice_dict["markers"])
else:
triangulation = None
region_areas = get_region_areas(use_flat, atlas_labels, flat_file_atlas, seg_width, seg_height, slice_dict, atlas_volume, triangulation)
y_scale, x_scale = transform_to_registration(seg_height, seg_width, reg_height, reg_width)
region_areas = get_region_areas(
use_flat,
atlas_labels,
flat_file_atlas,
seg_width,
seg_height,
slice_dict,
atlas_volume,
triangulation,
)
y_scale, x_scale = transform_to_registration(
seg_height, seg_width, reg_height, reg_width
)
centroids, points = None, None
scaled_centroidsX, scaled_centroidsY, scaled_x, scaled_y = None, None, None, None
scaled_centroidsX, scaled_centroidsY, scaled_x, scaled_y = None, None, None, None
if method in ["per_object", "all"]:
centroids, scaled_centroidsX, scaled_centroidsY = get_centroids(segmentation, pixel_id, y_scale, x_scale, object_cutoff)
centroids, scaled_centroidsX, scaled_centroidsY = get_centroids(
segmentation, pixel_id, y_scale, x_scale, object_cutoff
)
if method in ["per_pixel", "all"]:
scaled_y, scaled_x = get_scaled_pixels(segmentation, pixel_id, y_scale, x_scale)

new_x, new_y, centroids_new_x, centroids_new_y = get_transformed_coordinates(non_linear, slice_dict, method, scaled_x, scaled_y, centroids, scaled_centroidsX, scaled_centroidsY, triangulation)
new_x, new_y, centroids_new_x, centroids_new_y = get_transformed_coordinates(
non_linear,
slice_dict,
method,
scaled_x,
scaled_y,
centroids,
scaled_centroidsX,
scaled_centroidsY,
triangulation,
)
if method in ["per_pixel", "all"] and new_x is not None:
points = transform_to_atlas_space(slice_dict["anchoring"], new_y, new_x, reg_height, reg_width)
points = transform_to_atlas_space(
slice_dict["anchoring"], new_y, new_x, reg_height, reg_width
)
if method in ["per_object", "all"] and centroids_new_x is not None:
centroids = transform_to_atlas_space(slice_dict["anchoring"], centroids_new_y, centroids_new_x, reg_height, reg_width)
centroids = transform_to_atlas_space(
slice_dict["anchoring"],
centroids_new_y,
centroids_new_x,
reg_height,
reg_width,
)
points_list[index] = np.array(points if points is not None else [])
centroids_list[index] = np.array(centroids if centroids is not None else [])
region_areas_list[index] = region_areas
Expand Down
13 changes: 10 additions & 3 deletions PyNutil/counting_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .generate_target_slice import generate_target_slice
from .visualign_deformations import transform_vec


# related to counting and load
def label_points(points, label_volume, scale_factor=1):
"""This function takes a list of points and assigns them to a region based on the region_volume.
Expand Down Expand Up @@ -119,7 +120,6 @@ def pixel_count_per_region(
"""Read flat file, write into an np array, assign label file values, return array"""



def read_flat_file(file):
with open(file, "rb") as f:
b, w, h = struct.unpack(">BII", f.read(9))
Expand Down Expand Up @@ -163,6 +163,7 @@ def rescale_image(image, rescaleXY):
w, h = rescaleXY
return cv2.resize(image, (h, w), interpolation=cv2.INTER_NEAREST)


def assign_labels_to_image(image, labelfile):
w, h = image.shape
allen_id_image = np.zeros((h, w)) # create an empty image array
Expand All @@ -186,7 +187,7 @@ def count_pixels_per_label(image, scale_factor=False):

def warp_image(image, triangulation, rescaleXY):
if rescaleXY is not None:
w,h = rescaleXY
w, h = rescaleXY
else:
h, w = image.shape
reg_h, reg_w = image.shape
Expand All @@ -211,8 +212,14 @@ def warp_image(image, triangulation, rescaleXY):
new_image = image[newY, newX]
return new_image


def flat_to_dataframe(
labelfile, file=None, rescaleXY=None, image_vector=None, volume=None, triangulation=None
labelfile,
file=None,
rescaleXY=None,
image_vector=None,
volume=None,
triangulation=None,
):
if (image_vector is not None) and (volume is not None):
image = generate_target_slice(image_vector, volume)
Expand Down
18 changes: 10 additions & 8 deletions PyNutil/generate_target_slice.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import numpy as np
import math


def generate_target_slice(ouv, atlas):
width = None
height = None
ox, oy, oz, ux, uy, uz, vx, vy, vz = ouv
width = np.floor(math.hypot(ux,uy,uz)).astype(int) + 1
height = np.floor(math.hypot(vx,vy,vz)).astype(int) + 1
width = np.floor(math.hypot(ux, uy, uz)).astype(int) + 1
height = np.floor(math.hypot(vx, vy, vz)).astype(int) + 1
data = np.zeros((width, height), dtype=np.uint32).flatten()
xdim, ydim, zdim = atlas.shape
y_values = np.arange(height)
Expand All @@ -17,19 +18,20 @@ def generate_target_slice(ouv, atlas):
wx = ux * (x_values / width)
wy = uy * (x_values / width)
wz = uz * (x_values / width)
lx = np.floor(hx[:, None] + wx).astype(int)
ly = np.floor(hy[:, None] + wy).astype(int)
lz = np.floor(hz[:, None] + wz).astype(int)
valid_indices = (0 <= lx) & (lx < xdim) & (0 <= ly) & (ly < ydim) & (0 <= lz) & (lz < zdim)
lx = np.floor(hx[:, None] + wx).astype(int)
ly = np.floor(hy[:, None] + wy).astype(int)
lz = np.floor(hz[:, None] + wz).astype(int)
valid_indices = (
(0 <= lx) & (lx < xdim) & (0 <= ly) & (ly < ydim) & (0 <= lz) & (lz < zdim)
)
valid_indices = valid_indices.flatten()
lxf = lx.flatten()
lyf = ly.flatten()
lzf = lz.flatten()
valid_lx = lxf[valid_indices]
valid_ly = lyf[valid_indices]
valid_lz = lzf[valid_indices]
atlas_slice = atlas[valid_lx,valid_ly,valid_lz]
atlas_slice = atlas[valid_lx, valid_ly, valid_lz]
data[valid_indices] = atlas_slice
data_im = data.reshape((height, width))
return data_im

39 changes: 23 additions & 16 deletions PyNutil/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,17 @@ def __init__(
self.colour = colour
self.atlas_name = atlas_name
if (atlas_path or label_path) and atlas_name:
raise ValueError("Please only specify an atlas_path and a label_path or an atlas_name, atlas and label paths are only used for loading custom atlases")
raise ValueError(
"Please only specify an atlas_path and a label_path or an atlas_name, atlas and label paths are only used for loading custom atlases"
)
if atlas_path and label_path:
self.atlas_volume, self.atlas_labels = self.load_custom_atlas(atlas_path, label_path)
self.atlas_volume, self.atlas_labels = self.load_custom_atlas(
atlas_path, label_path
)
else:
self.atlas_volume, self.atlas_labels = self.load_atlas_data(atlas_name=atlas_name)
self.atlas_volume, self.atlas_labels = self.load_atlas_data(
atlas_name=atlas_name
)
###This is just because of the migration to BrainGlobe

def load_atlas_data(self, atlas_name):
Expand All @@ -129,22 +135,23 @@ def load_atlas_data(self, atlas_name):
# this could potentially be moved into init
print("loading atlas volume")
atlas = brainglobe_atlasapi.BrainGlobeAtlas(atlas_name=atlas_name)
atlas_structures = {'idx':[i['id'] for i in atlas.structures_list],
'name':[i['name'] for i in atlas.structures_list],
'r':[i['rgb_triplet'][0] for i in atlas.structures_list],
'g':[i['rgb_triplet'][1] for i in atlas.structures_list],
'b':[i['rgb_triplet'][2] for i in atlas.structures_list]
}
atlas_structures['idx'].insert(0,0)
atlas_structures['name'].insert(0,'Clear Label')
atlas_structures['r'].insert(0,0)
atlas_structures['g'].insert(0,0)
atlas_structures['b'].insert(0,0)
atlas_structures = {
"idx": [i["id"] for i in atlas.structures_list],
"name": [i["name"] for i in atlas.structures_list],
"r": [i["rgb_triplet"][0] for i in atlas.structures_list],
"g": [i["rgb_triplet"][1] for i in atlas.structures_list],
"b": [i["rgb_triplet"][2] for i in atlas.structures_list],
}
atlas_structures["idx"].insert(0, 0)
atlas_structures["name"].insert(0, "Clear Label")
atlas_structures["r"].insert(0, 0)
atlas_structures["g"].insert(0, 0)
atlas_structures["b"].insert(0, 0)

atlas_labels = pd.DataFrame(atlas_structures)
if "allen_mouse_" in atlas_name:
if "allen_mouse_" in atlas_name:
print("reorienting allen atlas into quicknii space...")
atlas_volume = np.transpose(atlas.annotation,[2,0,1])[:,::-1,::-1]
atlas_volume = np.transpose(atlas.annotation, [2, 0, 1])[:, ::-1, ::-1]
else:
atlas_volume = atlas.annotation
print("atlas labels loaded ✅")
Expand Down
2 changes: 1 addition & 1 deletion PyNutil/metadata/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import *
from . import *
2 changes: 0 additions & 2 deletions PyNutil/read_and_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def load_visualign_json(filename):
"slices": slices,
}



else:
slices = vafile["slices"]
if len(slices) > 1:
Expand Down
1 change: 1 addition & 0 deletions PyNutil/visualign_deformations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This code was written by Gergely Csucs and Rembrandt Bakker"""

import numpy as np


Expand Down
Loading

0 comments on commit d6bb97f

Please sign in to comment.