Skip to content

Commit

Permalink
updating stitching to improve speed (#845)
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Feb 12, 2024
1 parent 4f56619 commit 6ce3bee
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
20 changes: 15 additions & 5 deletions cellpose/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False):
igood &= imfile[-len(imf):]==imf
if igood:
imn.append(im)

image_names = imn

# remove duplicates
Expand Down Expand Up @@ -240,29 +241,38 @@ def get_label_files(image_names, mask_filter, imf=None):
#elif os.path.exists(label_names[0] + '_seg.npy'):
# io_logger.info('labels found as _seg.npy files, converting to tif')
else:
raise ValueError('labels not provided with correct --mask_filter')
if not flow_names:
raise ValueError('labels not provided with correct --mask_filter')
else:
label_names = None
if not all([os.path.exists(label) for label in label_names]):
raise ValueError('labels not provided for all images in train and/or test set')
if not flow_names:
raise ValueError('labels not provided for all images in train and/or test set')
else:
label_names = None

return label_names, flow_names


def load_images_labels(tdir, mask_filter='_masks', image_filter=None, look_one_level_down=False, unet=False):
image_names = get_image_files(tdir, mask_filter, image_filter, look_one_level_down)
nimg = len(image_names)

# training data
label_names, flow_names = get_label_files(image_names, mask_filter, imf=image_filter)

images = []
labels = []
k = 0
for n in range(nimg):
if os.path.isfile(label_names[n]):
if os.path.isfile(label_names[n]) or os.path.isfile(flow_names[0]):
print(image_names[n])
image = imread(image_names[n])
label = imread(label_names[n])
if label_names is not None:
label = imread(label_names[n])
if not unet:
if flow_names is not None and not unet:
print(flow_names[n])
flow = imread(flow_names[n])
if flow.shape[0]<4:
label = np.concatenate((label[np.newaxis,:,:], flow), axis=0)
Expand Down
4 changes: 2 additions & 2 deletions cellpose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import logging
import os, warnings, time, tempfile, datetime, pathlib, shutil, subprocess
from tqdm import tqdm
from tqdm import tqdm, trange
from urllib.request import urlopen
from urllib.parse import urlparse
import cv2
Expand Down Expand Up @@ -403,7 +403,7 @@ def stitch3D(masks, stitch_threshold=0.25):
mmax = masks[0].max()
empty = 0

for i in range(len(masks)-1):
for i in trange(len(masks)-1):
iou = metrics._intersection_over_union(masks[i+1], masks[i])[1:,1:]
if not iou.size and empty == 0:
masks[i+1] = masks[i+1]
Expand Down

0 comments on commit 6ce3bee

Please sign in to comment.