Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update predict.py #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask

from scipy import ndimage
import cv2
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -61,7 +61,10 @@ def get_args():
parser.add_argument('--input_file_path', '-if', metavar='INPUT_FILE_PATH', help='Path to a single input file')
parser.add_argument('--output_file_path', '-of', metavar='OUTPUT_FILE_PATH', help='Path to a single output file')
parser.add_argument('--filter_mask', '-f_m', metavar='FILTER_MASK', default = 0, help='set 1 if you want to filter predicted mask for only biggest object')

parser.add_argument('--fill_in', '-f_i', metavar='FILL_IN', default = 0, help='set 1 if you want to fill in the predicted mask')
parser.add_argument('--inner_contour_to_fill', '-f_p', type=float, default = 1000.0, help='number of pixels to fill in the mask')
parser.add_argument('--max_contour_size', '-max_s', metavar='MAX_SIZE', type=int, default = 200000, help='maximum size of the mask')
parser.add_argument('--min_contour_size', '-min_s', metavar='MIN_SIZE', type=int, default = 50000, help='minimum size of the mask')
return parser.parse_args()

# Function to convert an image to RGB
Expand Down Expand Up @@ -138,6 +141,48 @@ def filter_output_image(mask):

return filtered_image


def process_contours(mask, inner_contour_area_to_fill):
# Find contours
cnts, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

# Create an empty image for contours
# img_contours = np.zeros(mask.shape, dtype=np.uint8)

for cnt_idx, cnt in enumerate(cnts):
cnt_area = cv2.contourArea(cnt)
#print(f"Contour index: {cnt_idx}, Area: {cnt_area}")

# If the contour area is smaller than the specified inner contour area, draw it
if cnt_area < inner_contour_area_to_fill:
cv2.drawContours(mask, [cnt], -1, color=255, thickness=-1)
#print(f"Filling inner contour index: {cnt_idx} with area: {cnt_area}")

return mask


def check_max_size(mask, min_size, max_size):
"""
Check if the mask size is within the allowed maximum size and return the mask or an empty mask.

Args:
mask (np.array): Input mask as a binary numpy array.
max_size (int): Maximum allowed size of the mask.

Returns:
np.array: Original mask if within the allowed size, otherwise an empty mask.
"""
# Ensure mask is binary
mask = mask > 0

# Count the number of pixels in the mask
num_pixels = np.sum(mask)

if min_size <= num_pixels <= max_size:
return mask.astype(np.uint8) * 255 # Ensure the mask is white (255)
else:
return np.zeros_like(mask, dtype=np.uint8) # Return an empty mask

def main(arg_list=None):
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
Expand Down Expand Up @@ -198,6 +243,12 @@ def main(arg_list=None):
if str(args.filter_mask) == "1":
mask_final = filter_output_image(mask_final)

if str(args.fill_in) == "1":
mask_final=process_contours(mask_final, args.inner_contour_to_fill)

# Check if the mask is within the allowed maximum size and return appropriate mask
mask_final = check_max_size(mask_final, args.min_contour_size, args.max_contour_size)

# Write the mask to the TIFF writer
tif_writer.write(mask_final, contiguous=True)

Expand Down