# Novel class segmentation demo with Deep-MAC

Welcome to the Novel class segmentation (with Deep-MAC) demo --- this colab loads a Deep-MAC model and tests it interactively with user-specified boxes. Deep-MAC was only trained to detect and segment COCO classes, but generalizes well when segmenting within user-specified boxes of unseen classes.

Estimated time to run through this colab (with GPU): 10-15 minutes.
Note that the bulk of this time is in installing Tensorflow and downloading
the checkpoint then running inference for the first time.  Once you've done
all that, running on new images is very fast.

# Prerequisites

Please change runtime to GPU.

# Installation and Imports

This takes 3-4 minutes.

In [None]:
import os
import pathlib

# Clone the tensorflow models repository if it doesn't already exist
if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models


In [None]:
# Install the Object Detection API
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .

# The latest tf-models-official installs tensorflow 2.9 which has an
# incompatible CuDNN dependency. Here we restrict ourselves to versions 2.8 and
# below.
python -m pip install "tf-models-official<=2.8" .

In [None]:
import glob
import io
import logging
import os
import random
import warnings

import imageio
from IPython.display import display, Javascript
from IPython.display import Image as IPyImage
import matplotlib
from matplotlib import patches
import matplotlib.pyplot as plt
import numpy as np
from object_detection.utils import colab_utils
from object_detection.utils import ops
from object_detection.utils import visualization_utils as viz_utils
from PIL import Image, ImageDraw, ImageFont
import scipy.misc
from six import BytesIO
from skimage import color
from skimage import transform
from skimage import util
from skimage.color import rgb_colors
import tensorflow as tf

%matplotlib inline

COLORS = ([rgb_colors.cyan, rgb_colors.orange, rgb_colors.pink,
           rgb_colors.purple, rgb_colors.limegreen , rgb_colors.crimson] +
          [(color) for (name, color) in color.color_dict.items()])
random.shuffle(COLORS)

logging.disable(logging.WARNING)


def read_image(path):
  """Read an image and optionally resize it for better plotting."""
  with tf.io.gfile.GFile(path, 'rb') as f:
    img = Image.open(f)
    return np.array(img, dtype=np.uint8)


def resize_for_display(image, max_height=600):
  height, width, _ = image.shape
  width = int(width * max_height / height)
  with warnings.catch_warnings():
    warnings.simplefilter("ignore", UserWarning)
    return util.img_as_ubyte(transform.resize(image, (height, width)))


def get_mask_prediction_function(model):
  """Get single image mask prediction function using a model."""

  @tf.function
  def predict_masks(image, boxes):
    height, width, _ = image.shape.as_list()
    batch = image[tf.newaxis]
    boxes = boxes[tf.newaxis]

    detections = model(batch, boxes)
    masks = detections['detection_masks']

    return ops.reframe_box_masks_to_image_masks(masks[0], boxes[0],
                                                height, width)

  return predict_masks


def plot_image_annotations(image, boxes, masks, darken_image=0.5):
  fig, ax = plt.subplots(figsize=(16, 12))
  ax.set_axis_off()
  image = (image * darken_image).astype(np.uint8)
  ax.imshow(image)

  height, width, _ = image.shape

  num_colors = len(COLORS)
  color_index = 0

  for box, mask in zip(boxes, masks):
    ymin, xmin, ymax, xmax = box
    ymin *= height
    ymax *= height
    xmin *= width
    xmax *= width

    color = COLORS[color_index]
    color = np.array(color)
    rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                             linewidth=2.5, edgecolor=color, facecolor='none')
    ax.add_patch(rect)
    mask = (mask > 0.5).astype(np.float32)
    color_image = np.ones_like(image) * color[np.newaxis, np.newaxis, :]
    color_and_mask = np.concatenate(
        [color_image, mask[:, :, np.newaxis]], axis=2)

    ax.imshow(color_and_mask, alpha=0.5)

    color_index = (color_index + 1) % num_colors

  return ax

# Load Deep-MAC Model

This can take up to 5 minutes.

In [None]:
print('Downloading and untarring model')
!wget http://download.tensorflow.org/models/object_detection/tf2/20210329/deepmac_1024x1024_coco17.tar.gz
!cp deepmac_1024x1024_coco17.tar.gz models/research/object_detection/test_data/
!tar -xzf models/research/object_detection/test_data/deepmac_1024x1024_coco17.tar.gz
!mv deepmac_1024x1024_coco17 models/research/object_detection/test_data/
model_path = 'models/research/object_detection/test_data/deepmac_1024x1024_coco17/saved_model'

print('Loading SavedModel')
model = tf.keras.models.load_model(model_path)
prediction_function = get_mask_prediction_function(model)

# Load image

In [None]:
image_path = 'models/research/object_detection/test_images/image3.jpg'
image = read_image(image_path)

# Annotate an image with one or more boxes

This model is trained on COCO categories, but we encourage you to try segmenting
anything you want!

Don't forget to hit **submit** when done.

In [None]:
display_image = resize_for_display(image)

boxes_list = []
colab_utils.annotate([display_image], boxes_list)

# In case you didn't want to label...

Run this cell only if you didn't annotate anything above and would prefer to just use our preannotated boxes. Don't forget to uncomment.



In [None]:
# boxes_list = [np.array([[0.000, 0.160, 0.362, 0.812],
#                         [0.340, 0.286, 0.472, 0.619],
#                         [0.437, 0.008, 0.650, 0.263],
#                         [0.382, 0.003, 0.538, 0.594],
#                         [0.518, 0.444, 0.625,0.554]], dtype=np.float32)]

# Visualize mask predictions

In [None]:

%matplotlib inline

boxes = boxes_list[0]
masks = prediction_function(tf.convert_to_tensor(image),
                            tf.convert_to_tensor(boxes, dtype=tf.float32))
plot_image_annotations(image, boxes, masks.numpy())
plt.show()