| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import os |
|
|
| import cv2 |
| import fastremap |
| import numpy as np |
| import PIL |
| import tifffile |
| import torch |
| import torch.nn.functional as F |
| from cellpose.dynamics import compute_masks, masks_to_flows |
| from cellpose.metrics import _intersection_over_union, _true_positive |
| from monai.apps import get_logger |
| from monai.data import MetaTensor |
| from monai.transforms import MapTransform |
| from monai.utils import ImageMetaKey, convert_to_dst_type |
|
|
| logger = get_logger("VistaCell") |
|
|
|
|
| class LoadTiffd(MapTransform): |
| def __call__(self, data): |
| d = dict(data) |
| for key in self.key_iterator(d): |
| filename = d[key] |
|
|
| extension = os.path.splitext(filename)[1][1:] |
| image_size = None |
|
|
| if extension in ["tif", "tiff"]: |
| img_array = tifffile.imread(filename) |
| image_size = img_array.shape |
| if len(img_array.shape) == 3 and img_array.shape[-1] <= 3: |
| img_array = np.transpose(img_array, (2, 0, 1)) |
| else: |
| img_array = np.array(PIL.Image.open(filename)) |
| image_size = img_array.shape |
| if len(img_array.shape) == 3: |
| img_array = np.transpose(img_array, (2, 0, 1)) |
|
|
| if len(img_array.shape) not in [2, 3]: |
| raise ValueError( |
| "Unsupported image dimensions, filename " + str(filename) + " shape " + str(img_array.shape) |
| ) |
|
|
| if len(img_array.shape) == 2: |
| img_array = img_array[np.newaxis] |
|
|
| if key == "label": |
| if img_array.shape[0] > 1: |
| print( |
| f"Strange case, label with several channels {filename} shape {img_array.shape}, keeping only first" |
| ) |
| img_array = img_array[[0]] |
|
|
| elif key == "image": |
| if img_array.shape[0] == 1: |
| img_array = np.repeat(img_array, repeats=3, axis=0) |
| elif img_array.shape[0] == 2: |
| print( |
| f"Strange case, image with 2 channels {filename} shape {img_array.shape}, appending first channel to make 3" |
| ) |
| img_array = np.stack( |
| (img_array[0], img_array[1], img_array[0]), axis=0 |
| ) |
| elif img_array.shape[0] > 3: |
| print(f"Strange case, image with >3 channels, {filename} shape {img_array.shape}, keeping first 3") |
| img_array = img_array[:3] |
|
|
| meta_data = {ImageMetaKey.FILENAME_OR_OBJ: filename, ImageMetaKey.SPATIAL_SHAPE: image_size} |
| d[key] = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data) |
|
|
| return d |
|
|
|
|
| class SaveTiffd(MapTransform): |
| def __init__(self, output_dir, data_root_dir="/", nested_folder=False, *args, **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
|
|
| self.output_dir = output_dir |
| self.data_root_dir = data_root_dir |
| self.nested_folder = nested_folder |
|
|
| def set_data_root_dir(self, data_root_dir): |
| self.data_root_dir = data_root_dir |
|
|
| def __call__(self, data): |
| d = dict(data) |
| os.makedirs(self.output_dir, exist_ok=True) |
|
|
| for key in self.key_iterator(d): |
| seg = d[key] |
| filename = seg.meta[ImageMetaKey.FILENAME_OR_OBJ] |
|
|
| basename = os.path.splitext(os.path.basename(filename))[0] |
|
|
| if self.nested_folder: |
| reldir = os.path.relpath(os.path.dirname(filename), self.data_root_dir) |
| outdir = os.path.join(self.output_dir, reldir) |
| os.makedirs(outdir, exist_ok=True) |
| else: |
| outdir = self.output_dir |
|
|
| outname = os.path.join(outdir, basename + ".tif") |
|
|
| label = seg.cpu().numpy() |
| lm = label.max() |
| if lm <= 255: |
| label = label.astype(np.uint8) |
| elif lm <= 65535: |
| label = label.astype(np.uint16) |
| else: |
| label = label.astype(np.uint32) |
|
|
| tifffile.imwrite(outname, label) |
|
|
| print(f"Saving {outname} shape {label.shape} max {label.max()} dtype {label.dtype}") |
|
|
| return d |
|
|
|
|
| class LabelsToFlows(MapTransform): |
| |
| |
|
|
| def __init__(self, flow_key, *args, **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
| self.flow_key = flow_key |
|
|
| def __call__(self, data): |
| d = dict(data) |
| for key in self.key_iterator(d): |
| label = d[key].int().numpy() |
|
|
| label = fastremap.renumber(label, in_place=True)[0] |
| veci = masks_to_flows(label[0], device=None) |
|
|
| flows = np.concatenate((label > 0.5, veci), axis=0).astype(np.float32) |
| flows = convert_to_dst_type(flows, d[key], dtype=torch.float, device=d[key].device)[0] |
| d[self.flow_key] = flows |
| |
| |
| return d |
|
|
|
|
| class LogitsToLabels: |
| def __call__(self, logits, filename=None): |
| device = logits.device |
| logits = logits.float().cpu().numpy() |
| dp = logits[1:] |
| cellprob = logits[0] |
|
|
| try: |
| pred_mask, p = compute_masks( |
| dp, cellprob, niter=200, cellprob_threshold=0.4, flow_threshold=0.4, interp=True, device=device |
| ) |
| except RuntimeError as e: |
| logger.warning(f"compute_masks failed on GPU retrying on CPU {logits.shape} file {filename} {e}") |
| pred_mask, p = compute_masks( |
| dp, cellprob, niter=200, cellprob_threshold=0.4, flow_threshold=0.4, interp=True, device=None |
| ) |
|
|
| return pred_mask, p |
|
|
|
|
| class LogitsToLabelsd(MapTransform): |
| def __call__(self, data): |
| d = dict(data) |
| f = LogitsToLabels() |
| for key in self.key_iterator(d): |
| pred_mask, p = f(d[key]) |
| d[key] = pred_mask |
| d[f"{key}_centroids"] = p |
| return d |
|
|
|
|
| class SaveTiffExd(MapTransform): |
| def __init__(self, output_dir, output_ext=".png", output_postfix="seg", image_key="image", *args, **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
|
|
| self.output_dir = output_dir |
| self.output_ext = output_ext |
| self.output_postfix = output_postfix |
| self.image_key = image_key |
|
|
| def to_polygons(self, contours): |
| polygons = [] |
| for contour in contours: |
| if len(contour) < 3: |
| continue |
| polygons.append(np.squeeze(contour).astype(int).tolist()) |
| return polygons |
|
|
| def __call__(self, data): |
| d = dict(data) |
|
|
| output_dir = d.get("output_dir", self.output_dir) |
| output_ext = d.get("output_ext", self.output_ext) |
| overlayed_masks = d.get("overlayed_masks", False) |
| output_contours = d.get("output_contours", False) |
|
|
| os.makedirs(self.output_dir, exist_ok=True) |
|
|
| img = d.get(self.image_key, None) |
| filename = img.meta.get(ImageMetaKey.FILENAME_OR_OBJ) if img is not None else None |
| image_size = img.meta.get(ImageMetaKey.SPATIAL_SHAPE) if img is not None else None |
| basename = os.path.splitext(os.path.basename(filename))[0] if filename else "mask" |
| logger.info(f"File: {filename}; Base: {basename}") |
|
|
| for key in self.key_iterator(d): |
| label = d[key] |
| output_filename = f"{basename}{'_' + self.output_postfix if self.output_postfix else ''}{output_ext}" |
| output_filepath = os.path.join(output_dir, output_filename) |
| lm = label.max() |
| logger.info(f"Mask Shape: {label.shape}; Instances: {lm}") |
|
|
| if lm <= 255: |
| label = label.astype(np.uint8) |
| elif lm <= 65535: |
| label = label.astype(np.uint16) |
| else: |
| label = label.astype(np.uint32) |
|
|
| tifffile.imwrite(output_filepath, label) |
| logger.info(f"Saving {output_filepath}") |
|
|
| polygons = [] |
| if overlayed_masks: |
| logger.info(f"Overlay Masks: Reading original Image: {filename}") |
| image = cv2.imread(filename) |
| mask = cv2.imread(output_filepath, 0) |
|
|
| for i in range(1, np.max(mask)): |
| m = np.zeros_like(mask) |
| m[mask == i] = 1 |
| color = np.random.choice(range(256), size=3).tolist() |
| contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
| polygons.extend(self.to_polygons(contours)) |
| cv2.drawContours(image, contours, -1, color, 1) |
| cv2.imwrite(output_filepath, image) |
| logger.info(f"Overlay Masks: Saving {output_filepath}") |
| else: |
| label = cv2.convertScaleAbs(label, alpha=255.0 / label.max()) |
| contours, _ = cv2.findContours(label, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) |
| polygons.extend(self.to_polygons(contours)) |
|
|
| meta_json = {"image_size": image_size, "contours": len(polygons)} |
| with open(os.path.join(output_dir, "meta.json"), "w") as fp: |
| json.dump(meta_json, fp, indent=2) |
|
|
| if output_contours: |
| logger.info(f"Total Polygons: {len(polygons)}") |
| with open(os.path.join(output_dir, "contours.json"), "w") as fp: |
| json.dump({"count": len(polygons), "contours": polygons}, fp, indent=2) |
|
|
| return d |
|
|
|
|
| |
| class CellLoss: |
| def __call__(self, y_pred, y): |
| loss = 0.5 * F.mse_loss(y_pred[:, 1:], 5 * y[:, 1:]) + F.binary_cross_entropy_with_logits( |
| y_pred[:, [0]], y[:, [0]] |
| ) |
| return loss |
|
|
|
|
| |
| class CellAcc: |
| def __call__(self, mask_pred, mask_true): |
| if isinstance(mask_true, torch.Tensor): |
| mask_true = mask_true.cpu().numpy() |
|
|
| if isinstance(mask_pred, torch.Tensor): |
| mask_pred = mask_pred.cpu().numpy() |
|
|
| |
| |
|
|
| iou = _intersection_over_union(mask_true, mask_pred)[1:, 1:] |
| tp = _true_positive(iou, th=0.5) |
|
|
| fp = np.max(mask_pred) - tp |
| fn = np.max(mask_true) - tp |
| ap = tp / (tp + fp + fn) |
|
|
| |
| return ap |
|
|