Skip to content

Commit

Permalink
adding ability to draw multiple ROIs in single z-plane for complex sh…
Browse files Browse the repository at this point in the history
…apes (TODO documentation)
  • Loading branch information
carsen-stringer committed Sep 13, 2023
1 parent 95f8d98 commit 3bcdd36
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 49 deletions.
100 changes: 55 additions & 45 deletions cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,8 @@ def remove_stroke(self, delete_points=True, stroke_ind=-1):
if self.outlinesOn:
self.layerz[stroke[outpix,1],stroke[outpix,2]] = np.array(self.outcolor)
if delete_points:
self.current_point_set = self.current_point_set[:-1*(stroke[:,-1]==1).sum()]
# self.current_point_set = self.current_point_set[:-1*(stroke[:,-1]==1).sum()]
del self.current_point_set[stroke_ind]
self.update_layer()

del self.strokes[stroke_ind]
Expand Down Expand Up @@ -1416,10 +1417,9 @@ def update_crosshairs(self):

def add_set(self):
if len(self.current_point_set) > 0:
self.current_point_set = np.array(self.current_point_set)
while len(self.strokes) > 0:
self.remove_stroke(delete_points=False)
if len(self.current_point_set) > 8:
if len(self.current_point_set[0]) > 8:
color = self.colormap[self.ncells,:3]
median = self.add_mask(points=self.current_point_set, color=color)
if median is not None:
Expand All @@ -1436,51 +1436,61 @@ def add_set(self):
self.current_point_set = []
self.update_layer()

def add_mask(self, points=None, color=(100,200,50)):
def add_mask(self, points=None, color=(100,200,50), dense=True):
# points is list of strokes

points_all = np.concatenate(points, axis=0)

# loop over z values
median = []
if points.shape[1] < 3:
points = np.concatenate((np.zeros((points.shape[0],1), "int32"), points), axis=1)
zdraw = np.unique(points[:,0])
zdraw = np.unique(points_all[:,0])
zrange = np.arange(zdraw.min(), zdraw.max()+1, 1, int)
zmin = zdraw.min()
pix = np.zeros((2,0), "uint16")
mall = np.zeros((len(zrange), self.Ly, self.Lx), "bool")
k=0
for z in zdraw:
iz = points[:,0] == z
vr = points[iz,1]
vc = points[iz,2]
# get points inside drawn points
mask = np.zeros((np.ptp(vr)+4, np.ptp(vc)+4), np.uint8)
pts = np.stack((vc-vc.min()+2,vr-vr.min()+2), axis=-1)[:,np.newaxis,:]
mask = cv2.fillPoly(mask, [pts], (255,0,0))
ar, ac = np.nonzero(mask)
ar, ac = ar+vr.min()-2, ac+vc.min()-2
# get dense outline
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
pvc, pvr = contours[-2][0].squeeze().T
vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
# concatenate all points
ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
# if these pixels are overlapping with another cell, reassign them
ioverlap = self.cellpix[z][ar, ac] > 0
if (~ioverlap).sum() < 8:
print('ERROR: cell too small without overlaps, not drawn')
return None
elif ioverlap.sum() > 0:
ar, ac = ar[~ioverlap], ac[~ioverlap]
# compute outline of new mask
mask = np.zeros((np.ptp(ar)+4, np.ptp(ac)+4), np.uint8)
mask[ar-ar.min()+2, ac-ac.min()+2] = 1
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
pvc, pvr = contours[-2][0].squeeze().T
vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
self.draw_mask(z, ar, ac, vr, vc, color)

median.append(np.array([np.median(ar), np.median(ac)]))
ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(0, "int")
for stroke in points:
stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
iz = stroke[:,0] == z
vr = stroke[iz,1]
vc = stroke[iz,2]
if iz.sum() > 0:
# get points inside drawn points
mask = np.zeros((np.ptp(vr)+4, np.ptp(vc)+4), np.uint8)
pts = np.stack((vc-vc.min()+2,vr-vr.min()+2), axis=-1)[:,np.newaxis,:]
mask = cv2.fillPoly(mask, [pts], (255,0,0))
ar, ac = np.nonzero(mask)
ar, ac = ar+vr.min()-2, ac+vc.min()-2
# get dense outline
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
pvc, pvr = contours[-2][0].squeeze().T
vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
# concatenate all points
ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
# if these pixels are overlapping with another cell, reassign them
ioverlap = self.cellpix[z][ar, ac] > 0
if (~ioverlap).sum() < 8:
print('ERROR: cell too small without overlaps, not drawn')
return None
elif ioverlap.sum() > 0:
ar, ac = ar[~ioverlap], ac[~ioverlap]
# compute outline of new mask
mask = np.zeros((np.ptp(ar)+4, np.ptp(ac)+4), np.uint8)
mask[ar-ar.min()+2, ac-ac.min()+2] = 1
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
pvc, pvr = contours[-2][0].squeeze().T
vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
ars = np.concatenate((ars, ar), axis=0)
acs = np.concatenate((acs, ac), axis=0)
vrs = np.concatenate((vrs, vr), axis=0)
vcs = np.concatenate((vcs, vc), axis=0)
self.draw_mask(z, ars, acs, vrs, vcs, color)

median.append(np.array([np.median(ars), np.median(acs)]))
mall[z-zmin, ar, ac] = True
pix = np.append(pix, np.vstack((ar, ac)), axis=-1)
pix = np.append(pix, np.vstack((ars, acs)), axis=-1)

mall = mall[:, pix[0].min():pix[0].max()+1, pix[1].min():pix[1].max()+1].astype(np.float32)
ymin, xmin = pix[0].min(), pix[1].min()
Expand Down Expand Up @@ -1514,12 +1524,12 @@ def draw_mask(self, z, ar, ac, vr, vc, color, idx=None):
self.cellpix[z, vr, vc] = idx
self.cellpix[z, ar, ac] = idx
self.outpix[z, vr, vc] = idx
self.layerz[ar, ac, :3] = color
if self.masksOn:
self.layerz[ar, ac, -1] = self.opacity
if self.outlinesOn:
self.layerz[vr, vc] = np.array(self.outcolor)

if z==self.currentZ:
self.layerz[ar, ac, :3] = color
if self.masksOn:
self.layerz[ar, ac, -1] = self.opacity
if self.outlinesOn:
self.layerz[vr, vc] = np.array(self.outcolor)

def compute_scale(self):
self.diameter = float(self.Diameter.text())
Expand Down
6 changes: 2 additions & 4 deletions cellpose/gui/guiparts.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ def hoverEvent(self, ev):
self.drawAt(ev.pos())
if self.is_at_start(ev.pos()):
self.end_stroke()

else:
ev.acceptClicks(QtCore.Qt.RightButton)
#ev.acceptClicks(QtCore.Qt.LeftButton)
Expand Down Expand Up @@ -607,11 +606,11 @@ def end_stroke(self):
self.parent.stroke_appended = True
self.parent.current_stroke = np.array(self.parent.current_stroke)
ioutline = self.parent.current_stroke[:,3]==1
self.parent.current_point_set.extend(list(self.parent.current_stroke[ioutline]))
self.parent.current_point_set.append(list(self.parent.current_stroke[ioutline]))
self.parent.current_stroke = []
if self.parent.autosave:
self.parent.add_set()
if len(self.parent.current_point_set) > 0 and self.parent.autosave:
if len(self.parent.current_point_set) and len(self.parent.current_point_set[0]) > 0 and self.parent.autosave:
self.parent.add_set()
self.parent.in_stroke = False

Expand All @@ -623,7 +622,6 @@ def tabletEvent(self, ev):

def drawAt(self, pos, ev=None):
mask = self.strokemask
set = self.parent.current_point_set
stroke = self.parent.current_stroke
pos = [int(pos.y()), int(pos.x())]
dk = self.drawKernel
Expand Down

0 comments on commit 3bcdd36

Please sign in to comment.