Skip to content

Commit

Permalink
added nonlinear but needs more testing
Browse files Browse the repository at this point in the history
  • Loading branch information
PolarBean committed Mar 25, 2024
1 parent c84bb72 commit 8298203
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 108 deletions.
165 changes: 64 additions & 101 deletions PyNutil/coordinate_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,55 @@ def folder_to_atlas_space(
segmentations,
)

def load_segmentation(segmentation_path: str):
"""Load a segmentation from a file."""
print(f"working on {segmentation_path}")
if segmentation_path.endswith(".dzip"):
print("Reconstructing dzi")
return reconstruct_dzi(segmentation_path)
else:
return cv2.imread(segmentation_path)

def remove_background(segmentation: np.array):
"""Remove the background from the segmentation and return the pixel id."""
segmentation_no_background = segmentation[~np.all(segmentation == 0, axis=2)]
print("length of non background pixels: ", len(segmentation_no_background))
pixel_id = segmentation_no_background[0]
print("detected pixel_id: ", pixel_id)
return segmentation_no_background, pixel_id

def get_region_areas(use_flat, atlas_labels, flat_file_atlas, seg_width, seg_height, slice_dict, atlas_volume, triangulation):
if use_flat:
region_areas = flat_to_dataframe(
atlas_labels, flat_file_atlas, (seg_width, seg_height)
)
else:
region_areas = flat_to_dataframe(
atlas_labels,
flat_file_atlas,
(seg_width, seg_height),
slice_dict["anchoring"],
atlas_volume,
triangulation
)
return region_areas

def get_transformed_coordinates(non_linear, slice_dict, method, scaled_x, scaled_y, centroids, scaled_centroidsX, scaled_centroidsY, triangulation):
new_x, new_y, centroids_new_x, centroids_new_y = None, None, None, None
if non_linear and "markers" in slice_dict:
if method in ["per_pixel", "all"] and scaled_x is not None:
new_x, new_y = transform_vec(triangulation, scaled_x, scaled_y)
if method in ["per_object", "all"] and centroids is not None:
centroids_new_x, centroids_new_y = transform_vec(triangulation, scaled_centroidsX, scaled_centroidsY)
else:
if method in ["per_pixel", "all"]:
new_x, new_y = scaled_x, scaled_y
if method in ["per_object", "all"]:
centroids_new_x, centroids_new_y = scaled_centroidsX, scaled_centroidsY
return new_x, new_y, centroids_new_x, centroids_new_y

def segmentation_to_atlas_space(
slice,
slice_dict,
segmentation_path,
atlas_labels,
flat_file_atlas=None,
Expand All @@ -251,113 +297,30 @@ def segmentation_to_atlas_space(
atlas_volume=None,
use_flat=False,
):
"""Combines many functions to convert a segmentation to atlas space. It takes care
of deformations."""
print(f"working on {segmentation_path}")
if segmentation_path.endswith(".dzip"):
print("Reconstructing dzi")
segmentation = reconstruct_dzi(segmentation_path)

else:
segmentation = cv2.imread(segmentation_path)
segmentation = load_segmentation(segmentation_path)
if pixel_id == "auto":

# Remove the background from the segmentation
segmentation_no_background = segmentation[~np.all(segmentation == 0, axis=2)]
# pixel_id = np.vstack(
# {tuple(r) for r in segmentation_no_background.reshape(-1, 3)}
# ) # Remove background
# Currently only works for a single label
print("length of non background pixels: ", len(segmentation_no_background))
pixel_id = segmentation_no_background[0]
print("detected pixel_id: ", pixel_id)

# Transform pixels to registration space (the registered image and segmentation have different dimensions)
seg_height = segmentation.shape[0]
seg_width = segmentation.shape[1]
reg_height = slice["height"]
reg_width = slice["width"]
if use_flat == True:
region_areas = flat_to_dataframe(
atlas_labels, flat_file_atlas, (seg_width, seg_height)
)
segmentation, pixel_id = remove_background(segmentation)
seg_height, seg_width = segmentation.shape[:2]
reg_height, reg_width = slice_dict["height"], slice_dict["width"]
if non_linear and "markers" in slice_dict:
triangulation = triangulate(reg_width, reg_height, slice_dict["markers"])
else:
region_areas = flat_to_dataframe(
atlas_labels,
flat_file_atlas,
(seg_width, seg_height),
slice["anchoring"],
atlas_volume,
)
# This calculates reg/seg
y_scale, x_scale = transform_to_registration(
seg_height, seg_width, reg_height, reg_width
)

triangulation = None
region_areas = get_region_areas(use_flat, atlas_labels, flat_file_atlas, seg_width, seg_height, slice_dict, atlas_volume, triangulation)
y_scale, x_scale = transform_to_registration(seg_height, seg_width, reg_height, reg_width)
centroids, points = None, None

if method in ["per_object", "all"]:
centroids, scaled_centroidsX, scaled_centroidsY = get_centroids(
segmentation, pixel_id, y_scale, x_scale, object_cutoff
)
centroids, scaled_centroidsX, scaled_centroidsY = get_centroids(segmentation, pixel_id, y_scale, x_scale, object_cutoff)
if method in ["per_pixel", "all"]:
scaled_y, scaled_x = get_scaled_pixels(segmentation, pixel_id, y_scale, x_scale)

if non_linear:
if "markers" in slice:
# This creates a triangulation using the reg width
triangulation = triangulate(reg_width, reg_height, slice["markers"])
if method in ["per_pixel", "all"]:
if scaled_x is not None:
new_x, new_y = transform_vec(triangulation, scaled_x, scaled_y)
else:
new_x, new_y = scaled_x, scaled_y
if method in ["per_object", "all"]:
if centroids is not None:
centroids_new_x, centroids_new_y = transform_vec(
triangulation, scaled_centroidsX, scaled_centroidsY
)
else:
centroids_new_x, centroids_new_y = (
scaled_centroidsX,
scaled_centroidsY,
)
else:
print(
f"No markers found for {slice['filename']}, result for section will be linear."
)
if method in ["per_pixel", "all"]:
new_x, new_y = scaled_x, scaled_y
if method in ["per_object", "all"]:
centroids_new_x, centroids_new_y = scaled_centroidsX, scaled_centroidsY
else:
if method in ["per_pixel", "all"]:
new_x, new_y = scaled_x, scaled_y
if method in ["per_object", "all"]:
centroids_new_x, centroids_new_y = scaled_centroidsX, scaled_centroidsY
# Scale U by Uxyz/RegWidth and V by Vxyz/RegHeight
if method in ["per_pixel", "all"]:
if new_x is not None:
points = transform_to_atlas_space(
slice["anchoring"], new_y, new_x, reg_height, reg_width
)
else:
points = np.array([])
if method in ["per_object", "all"]:
if centroids_new_x is not None:
centroids = transform_to_atlas_space(
slice["anchoring"],
centroids_new_y,
centroids_new_x,
reg_height,
reg_width,
)
else:
centroids = np.array([])


points_list[index] = np.array(points)
centroids_list[index] = np.array(centroids)
new_x, new_y, centroids_new_x, centroids_new_y = get_transformed_coordinates(non_linear, slice_dict, method, scaled_x, scaled_y, centroids, scaled_centroidsX, scaled_centroidsY, triangulation)
if method in ["per_pixel", "all"] and new_x is not None:
points = transform_to_atlas_space(slice_dict["anchoring"], new_y, new_x, reg_height, reg_width)
if method in ["per_object", "all"] and centroids_new_x is not None:
centroids = transform_to_atlas_space(slice_dict["anchoring"], centroids_new_y, centroids_new_x, reg_height, reg_width)
points_list[index] = np.array(points if points is not None else [])
centroids_list[index] = np.array(centroids if centroids is not None else [])
region_areas_list[index] = region_areas


Expand Down
38 changes: 32 additions & 6 deletions PyNutil/counting_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import struct
import cv2
from .generate_target_slice import generate_target_slice

from .visualign_deformations import transform_vec

# related to counting and load
def label_points(points, label_volume, scale_factor=1):
Expand Down Expand Up @@ -117,10 +117,7 @@ def pixel_count_per_region(

"""Read flat file and write into an np array"""
"""Read flat file, write into an np array, assign label file values, return array"""
import struct
import cv2
import numpy as np
import pandas as pd



def read_flat_file(file):
Expand Down Expand Up @@ -187,12 +184,41 @@ def count_pixels_per_label(image, scale_factor=False):
return df_area_per_label


def warp_image(image, triangulation, rescaleXY):
if rescaleXY is not None:
w,h = rescaleXY
else:
h, w = image.shape
reg_h, reg_w = image.shape
oldX, oldY = np.meshgrid(np.arange(reg_w), np.arange(reg_h))
oldX = oldX.flatten()
oldY = oldY.flatten()
h_scale = h / reg_h
w_scale = w / reg_w
oldX = oldX * w_scale
oldY = oldY * h_scale
newX, newY = transform_vec(triangulation, oldX, oldY)
newX = newX / w_scale
newY = newY / h_scale
newX = newX.reshape(reg_h, reg_w)
newY = newY.reshape(reg_h, reg_w)
newX = newX.astype(int)
newY = newY.astype(int)
newX[newX >= reg_w] = reg_w - 1
newY[newY >= reg_h] = reg_h - 1
newX[newX < 0] = 0
newY[newY < 0] = 0
new_image = image[newY, newX]
return new_image

def flat_to_dataframe(
labelfile, file=None, rescaleXY=None, image_vector=None, volume=None
labelfile, file=None, rescaleXY=None, image_vector=None, volume=None, triangulation=None
):
if (image_vector is not None) and (volume is not None):
image = generate_target_slice(image_vector, volume)
image = np.float64(image)
if triangulation is not None:
image = warp_image(image, triangulation, rescaleXY)
elif file.endswith(".flat"):
image = read_flat_file(file)
elif file.endswith(".seg"):
Expand Down
67 changes: 67 additions & 0 deletions messing_around_files/demo_image_warp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import json
import cv2
import matplotlib.pyplot as plt
import os
import sys
import numpy as np
import nrrd
atlas_path = r"/home/harryc/Github/PyNutilWeb/server/PyNutil/PyNutil/metadata/annotation_volumes/annotation_25_reoriented_2017.nrrd"

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from PyNutil.generate_target_slice import generate_target_slice
from PyNutil.visualign_deformations import triangulate, transform_vec

def make_slice_ordinal(data):
for i, slice in enumerate(np.unique(data)):
data[data==slice] = i
return data
data_path = r"/home/harryc/Github/PyNutilWeb/server/PyNutil/test_data/PyNutil_testdataset_Nonlin_SY_fixed_bigcaudoputamen.json"
with open(data_path, "r") as f:
data = json.load(f)

volume, _ = nrrd.read(atlas_path)
demo = data["slices"][0]
demo_alignment = demo["anchoring"]
demo_markers = demo["markers"]
h = demo["height"]
w = demo["width"]


image = generate_target_slice(demo_alignment, volume)
image = make_slice_ordinal(image)
plt.imshow(image)
plt.show()

triangulation = triangulate(w, h, demo_markers)





def warp_image(image, triangulation, h,w):
reg_h, reg_w = image.shape

oldX, oldY = np.meshgrid(np.arange(reg_w), np.arange(reg_h))
oldX = oldX.flatten()
oldY = oldY.flatten()
h_scale = h / reg_h
w_scale = w / reg_w
oldX = oldX * w_scale
oldY = oldY * h_scale
newX, newY = transform_vec(triangulation, oldX, oldY)
newX = newX / w_scale
newY = newY / h_scale
newX = newX.reshape(reg_h, reg_w)
newY = newY.reshape(reg_h, reg_w)
newX = newX.astype(int)
newY = newY.astype(int)
newX[newX >= reg_w] = reg_w - 1
newY[newY >= reg_h] = reg_h - 1
newX[newX < 0] = 0
newY[newY < 0] = 0
new_image = image[newY, newX]
return new_image

new_image = warp_image(image, triangulation, h, w)
plt.imshow(new_image)
plt.show()
4 changes: 3 additions & 1 deletion testOOP.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from PyNutil import PyNutil
import os

os.chdir("..")
pnt = PyNutil(settings_file=r"PyNutil/test/test8_PyNutil_fixed.json")
##Use flat can be set to True if you want to use the flat file
# instead of the visualign json (this is only useful for testing and will be removed)
pnt.get_coordinates(object_cutoff=0, use_flat=True)
pnt.get_coordinates(object_cutoff=0, use_flat=False)

pnt.quantify_coordinates()

Expand Down

0 comments on commit 8298203

Please sign in to comment.