forked from potpov/New-Maxillo-Dataset-Segmentation
-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
195 lines (165 loc) · 7.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from torch.utils.data import DistributedSampler
from scipy.ndimage import binary_fill_holes
import pathlib
import torchio as tio
import logging
import os
import numpy as np
import yaml
import sys
import torch
from tqdm import tqdm
import SimpleITK as sitk
import json
from scipy.linalg import norm
def set_logger(log_path=None):
"""
Set the logger to log info in terminal and file `log_path`.
In general, it is useful to have a logger so that every output to the terminal is saved
in a permanent file. Here we save it to `model_dir/train.log`.
"""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if not logger.handlers:
if not log_path:
# Logging to console
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(stream_handler)
else:
# Logging to a file
file_handler = logging.FileHandler(os.path.join(log_path))
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logger.addHandler(file_handler)
def load_config_yaml(config_file):
return yaml.load(open(config_file, 'r'), yaml.FullLoader)
def resample(ctvol, is_label, original_spacing=.3, out_spacing=.4):
original_spacing = (original_spacing, original_spacing, original_spacing)
out_spacing = (out_spacing, out_spacing, out_spacing)
ctvol_itk = sitk.GetImageFromArray(ctvol)
ctvol_itk.SetSpacing(original_spacing)
original_size = ctvol_itk.GetSize()
out_shape = [int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))),
int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))),
int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))]
# Perform resampling:
resample = sitk.ResampleImageFilter()
resample.SetOutputSpacing(out_spacing)
resample.SetSize(out_shape)
resample.SetOutputDirection(ctvol_itk.GetDirection())
resample.SetOutputOrigin(ctvol_itk.GetOrigin())
resample.SetTransform(sitk.Transform())
resample.SetDefaultPixelValue(ctvol_itk.GetPixelIDValue())
if is_label:
resample.SetInterpolator(sitk.sitkNearestNeighbor)
else:
resample.SetInterpolator(sitk.sitkBSpline)
resampled_ctvol = resample.Execute(ctvol_itk)
return sitk.GetArrayFromImage(resampled_ctvol)
def create_dataset(splits_todo, is_competitor, saving_dir):
split_filepath = "path/to/splits.json"
with open(split_filepath) as f:
folder_splits = json.load(f)
# for split in ['train', 'synthetic', 'val']:
if is_competitor:
base = '/path/to/data/SPARSE'
else:
base = "/path/to/data/DENSE"
for split in splits_todo:
dirs = [os.path.join(base, p) for p in
folder_splits[split]]
dataset = {'data': [], 'gt': []}
for i, dir in tqdm(enumerate(dirs), total=len(dirs), desc=f"processing {split}"):
gt_dir = os.path.join(dir, 'synthetic.npy') if is_competitor else os.path.join(dir, 'gt_alpha.npy')
data_dir = os.path.join(dir, 'data.npy')
image = np.load(data_dir)
gt_orig = np.load(gt_dir)
# rescale
image = resample(image, is_label=False)
gt = resample(gt_orig, is_label=True)
# DICOM_MAX = 3095 if is_competitor else 2100
DICOM_MAX = 2100
DICOM_MIN = 0
image = np.clip(image, DICOM_MIN, DICOM_MAX)
image = (image.astype(float) + DICOM_MIN) / (DICOM_MAX + DICOM_MIN) # [0-1] with shifting
if split not in ["test", "val"]:
s = tio.Subject(
data=tio.ScalarImage(tensor=image[None]),
label=tio.LabelMap(tensor=gt[None]),
)
grid_sampler = tio.inference.GridSampler(
s,
patch_size=(32, 32, 32),
patch_overlap=(10, 10, 10),
)
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=1)
for a in patch_loader:
image = a['data'][tio.DATA].squeeze().numpy()
gt = a['label'][tio.DATA].squeeze().numpy()
if np.sum(gt) != 0:
dataset['data'].append(image)
dataset['gt'].append(gt)
else: # do not cut volumes for testing - we do this at runtime
dataset['data'].append(image)
dataset['gt'].append(gt_orig)
log_dir = pathlib.Path(os.path.join(saving_dir, split))
log_dir.mkdir(parents=True, exist_ok=True)
for partition in ['data', 'gt']:
a = np.empty(len(dataset[partition]), dtype=object)
for i in range(len(dataset[partition])):
a[i] = dataset[partition][i]
np.save(os.path.join(saving_dir, split, f'{partition}.npy'), a)
print(f"split {split} completed. created {len(dataset['data'])} subvolumes")
def create_synthetic():
data_dir = "path/toyour/SPARSE/npy_files"
for folder in os.listdir(data_dir):
print(f"processing {folder}")
gt = np.load(os.path.join(data_dir, folder, "gt_sparse.npy"))
example = np.zeros_like(gt)
points = np.argwhere(gt == 1)
splits = [
points[points[:, -1] < gt.shape[-1] // 2],
points[points[:, -1] > gt.shape[-1] // 2]
]
for jj in range(2):
points = splits[jj]
points = points[np.lexsort((points[:, 2], points[:, 0], points[:, 1]))]
for i in range(points.shape[0] - 2):
# axis and radius
p0 = np.array(points[i])
p1 = np.array(points[i + 1])
R = 1.6
# vector in direction of axis
v = p1 - p0
# find magnitude of vector
mag = norm(v)
# unit vector in direction of axis
v = v / mag
# make some vector not in the same direction as v
not_v = np.array([1, 0, 0])
if (v == not_v).all():
not_v = np.array([0, 1, 0])
# make vector perpendicular to v
n1 = np.cross(v, not_v)
# normalize n1
n1 /= norm(n1)
# make unit vector perpendicular to v and n1
n2 = np.cross(v, n1)
# surface ranges over t from 0 to length of axis and 0 to 2*pi
t = np.linspace(0, mag, 100)
theta = np.linspace(0, 2 * np.pi, 100)
# use meshgrid to make 2d arrays
t, theta = np.meshgrid(t, theta)
# generate coordinates for surface
Z, Y, X = [p0[i] + v[i] * t + R * np.sin(theta) * n1[i] + R * np.cos(theta) * n2[i] for i in [0, 1, 2]]
example[(Z+4).astype(int), Y.astype(int), X.astype(int)] = 1
example = binary_fill_holes(example).astype(int)
np.save(os.path.join(data_dir, folder, 'synthetic.npy'), example)
if __name__ == '__main__':
# generate cicle expanded dataset - set your paths!
create_synthetic()
print("synthetic dataset has been created!")
# generate training and synthetic datasets (32x32x32) and the test set (resampling to 0.3 voxel space)
create_dataset(['train', 'synthetic', 'val', 'test'], is_competitor=True, saving_dir="saving_dir/sparse")
create_dataset(['train', 'val', 'test'], is_competitor=False, saving_dir="saving_dir/dense")
print("subvolumes for training have been created!")