Instance Segmentation with Detectron2 and Remo

In this tutorial, we do transfer learning on a MaskRCNN model from Detectron2. We use Remo to facilitate exploring, accessing and managing the dataset.

In particular, we will:

  • Browse through our images and annotations
  • Quickly visualize the main properties of the dataset and make some initial observations
  • Create a train, test, valid split without moving data around, using Remo image tags.
  • Fine tune a pre-trained MaskRCNN model from Detectron2 and do some inference
  • Visually compare Mask predictions with the ground truth, and draw possible conclusions on how to improve performance

Along the way, we will see how browsing images, annotations and predictions helps to gather insights on the dataset and on the model.

Before proceeding, we need to install the required dependencies.

This can be done by executing the next cell. Once complete, restart your runtime to ensure that the installed packages can be detected.

This tutorial is supported to run only on a CUDA enabled GPU locally or on Google Colab.

!pip install imantics
!pip install git+https://github.com/facebookresearch/fvcore.git
!git clone https://github.com/facebookresearch/detectron2 detectron2_repo
!pip install -e detectron2_repo

Let us then import the required packages.

import remo
remo.set_viewer('jupyter')

import numpy as np
import os
from PIL import Image
import glob
import random
random.seed(600)

from imantics import Polygons, Mask

import torch, torchvision

# Detectron 2 files
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.engine import DefaultTrainer
from detectron2.data.datasets import register_coco_instances
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader

Adding Data to Remo

  • The Dataset used is a subset of the MS COCO Dataset.
  • The directory structure of the dataset is:
    ├── coco_instance_segmentation_dataset
        ├── images
            ├── image_1.jpg
            ├── image_2.jpg
            ├── ...
        ├── annotations
            ├── coco_instance_segmentation.json
    
if not os.path.exists('coco_instance_segmentation_dataset.zip'):
  !wget https://s-3.s3-eu-west-1.amazonaws.com/coco_instance_segmentation_dataset.zip
  !unzip coco_instance_segmentation_dataset.zip
else:
  print('Files already downloaded')
# The path to the folders

path_to_annotations = 'coco_instance_segmentation_dataset/annotations/'
path_to_images = 'coco_instance_segmentation_dataset/images/'

Train / test split

In Remo, we can use tags to organise our images. Among other things, this allows us to generate train / test splits without the need to move image files around.

To do this, we just need to pass a dictionary (mapping tags to the relevant images paths) to the function remo.generate_image_tags().

im_list = [i for i in glob.glob(path_to_images + '/**/*.jpg', recursive=True)]
im_list = random.sample(im_list, len(im_list))

train_idx = round(len(im_list) * 0.8)
test_idx  = train_idx + round(len(im_list) * 0.2)

tags_dict =  {'train' : im_list[0:train_idx], 
              'test' : im_list[train_idx:test_idx]}

train_test_split_file_path = os.path.join(path_to_annotations, 'images_tags.csv') 
remo.generate_image_tags(tags_dictionary  = tags_dict, 
                         output_file_path = train_test_split_file_path, 
                         append_path = False)

Create a dataset

To create a dataset we can use remo.create_dataset(), specifying the path to data and annotations.

For a complete list of formats supported, you can refer to the docs.

coco_instance_segmentation_dataset = remo.create_dataset(name = 'coco_instance_segmentation_dataset', local_files = [path_to_annotations, path_to_images], annotation_task='Instance Segmentation')

Visualizing the dataset

To view and explore images and labels, we can use Remo directly from the notebook. We just need to call dataset.view().

coco_instance_segmentation_dataset.view()

instance_segmentation_view

Looking at the dataset, we notice some interesting points:

  • Pictures of the animals can be taken from different angles

  • In some cases, there is an overlap of classes such that it represents an occlusion i.e Zebra right beside a Giraffe.

  • The actual pose of the same object varies across instances.

Dataset Statistics

Using Remo, we can quickly visualize some key Dataset properties that can help us with our modelling, without needing to write extra boilerplate code.

This can be done either from code, or using the visual interface.

coco_instance_segmentation_dataset.get_annotation_statistics()

[{'AnnotationSet ID': 430, 'AnnotationSet name': 'Instance segmentation', 'n_images': 6, 'n_classes': 2, 'n_objects': 22, 'top_3_classes': [{'name': 'Zebra', 'count': 13}, {'name': 'Giraffe', 'count': 9}], 'creation_date': None, 'last_modified_date': '2020-11-04T11:35:11.415574Z'}]

coco_instance_segmentation_dataset.view_annotation_stats()

instance_view_annotations

Looking at the statistics we can gain some useful insights like:

  • The highest number of instances per image is that of Zebra. This means we have a somewhat unbalanced dataset, and we might expect to see the model perfom better on Zebras

  • The data distribution looks similar in both train and test dataset. This is good!

Exporting the dataset

To export annotations according to the train, test split in a format accepted by the model, we use the dataset.export_annotations_to_file() method, and filter by the desired tag.

For a complete list of formats supported, you can refer to the docs.

path_to_train = path_to_annotations + 'coco_instance_segmentation_train.json'
path_to_test = path_to_annotations + 'coco_instance_segmentation_test.json'
coco_instance_segmentation_dataset.export_annotations_to_file(path_to_train, annotation_format='coco', filter_by_tags=['train'], export_tags=False, append_path=False)
coco_instance_segmentation_dataset.export_annotations_to_file(path_to_test, annotation_format='coco', filter_by_tags=['test'], export_tags=False, append_path=False)

Detectron2

Here we will start working with the Detectron2 framework written in PyTorch.

Feeding Data into Detectron2

To use Detectron2, you are required to register your dataset.

The register_coco_instances method takes in the following parameters:

  • path_to_annotations: Path to annotation files. Format: COCO JSON.

  • path_to_images: Path to the folder containing the images.

This then allows to store the metadata for future operations.

register_coco_instances('coco_instance_segmentation_train', {}, path_to_train, path_to_images)
register_coco_instances('coco_instance_segmentation_test', {}, path_to_test, path_to_images)

train_metadata = MetadataCatalog.get('coco_instance_segmentation_train')

Training the Model

For the sake of the tutorial, our Mask RCNN architecture will have a ResNet-50 Backbone, pre-trained on on COCO train2017. This can be loaded directly from Detectron2.

To train the model, we specify the following details:

  • model_yaml_path: Configuration file for the Mask RCNN model.

  • model_weights_path: Symbolic link to the desired Mask RCNN architecture.

The parameters can be tweaked by overriding the correspodning variable in the cfg.

model_yaml_path = './detectron2_repo/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml'
model_weights_path = 'detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl'

cfg = get_cfg()
cfg.merge_from_file(model_yaml_path)
cfg.DATASETS.TRAIN = ('coco_instance_segmentation_train',)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_weights_path # initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.02
cfg.SOLVER.MAX_ITER = 150    # 300 iterations seems good enough, but you can certainly train longer
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 10

Instantiating the Trainer

We instatiate the trainer with the required configuration, and finally kick-off the training.

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

Performance Metrics

In order to evaluate the model performance on metrics such as Average Precision (AP) and Mean Average Precision (mAP), we will use the COCOEvaluator in the detectron2 package.

evaluator = COCOEvaluator("coco_instance_segmentation_test", cfg, False, output_dir="./output/")
val_loader = build_detection_test_loader(cfg, "coco_instance_segmentation_test")
print(inference_on_dataset(trainer.model, val_loader, evaluator))

[11/18 21:18:15 d2.data.datasets.coco]: Loaded 1 images in COCO format from coco_instance_segmentation_dataset/annotations/coco_instance_segmentation_test.json [11/18 21:18:15 d2.data.common]: Serializing 1 elements to byte tensors and concatenating them all ... [11/18 21:18:15 d2.data.common]: Serialized dataset takes 0.00 MiB [11/18 21:18:15 d2.data.dataset_mapper]: Augmentations used in training: [ResizeShortestEdge(short_edge_length=(800, 800), max_size=1333, sample_style='choice')] [11/18 21:18:15 d2.evaluation.evaluator]: Start inference on 1 images [11/18 21:18:15 d2.evaluation.evaluator]: Inference done 1/1. 0.4105 s / img. ETA=0:00:00 [11/18 21:18:15 d2.evaluation.evaluator]: Total inference time: 0:00:00.440677 (0.440677 s / img per device, on 1 devices) [11/18 21:18:15 d2.evaluation.evaluator]: Total inference pure compute time: 0:00:00 (0.410518 s / img per device, on 1 devices) [11/18 21:18:15 d2.evaluation.coco_evaluation]: Preparing results for COCO format ... [11/18 21:18:15 d2.evaluation.coco_evaluation]: Saving results to ./output/coco_instances_results.json [11/18 21:18:15 d2.evaluation.coco_evaluation]: Evaluating predictions ... Loading and preparing results... DONE (t=0.00s) creating index... index created! Running per image evaluation... Evaluate annotation type bbox COCOeval_opt.evaluate() finished in 0.00 seconds. Accumulating evaluation results... COCOeval_opt.accumulate() finished in 0.00 seconds. Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 1.000 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 1.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.800 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.400 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.800 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.800 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000 [11/18 21:18:15 d2.evaluation.coco_evaluation]: Evaluation results for bbox: | AP | AP50 | AP75 | APs | APm | APl | |:------:|:-------:|:-------:|:------:|:-----:|:-----:| | 80.000 | 100.000 | 100.000 | 80.000 | nan | nan | [11/18 21:18:15 d2.evaluation.coco_evaluation]: Some metrics cannot be computed and is shown as NaN. Loading and preparing results... DONE (t=0.00s) creating index... index created! Running per image evaluation... Evaluate annotation type segm COCOeval_opt.evaluate() finished in 0.00 seconds. Accumulating evaluation results... COCOeval_opt.accumulate() finished in 0.00 seconds. Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 1.000 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 1.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.800 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.400 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.800 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.800 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000 [11/18 21:18:15 d2.evaluation.coco_evaluation]: Evaluation results for segm: | AP | AP50 | AP75 | APs | APm | APl | |:------:|:-------:|:-------:|:------:|:-----:|:-----:| | 80.000 | 100.000 | 100.000 | 80.000 | nan | nan | [11/18 21:18:15 d2.evaluation.coco_evaluation]: Some metrics cannot be computed and is shown as NaN. OrderedDict([('bbox', {'AP': 80.0, 'AP50': 100.0, 'AP75': 100.0, 'APs': 80.0, 'APm': nan, 'APl': nan}), ('segm', {'AP': 80.0, 'AP50': 100.0, 'AP75': 100.0, 'APs': 80.0, 'APm': nan, 'APl': nan})])

Visualizing Predictions

Using Remo, we can easily browse our predictions and compare them with the ground-truth.

We will do this by uploading the model predictions to a new AnnotationSet, which we call model_predictions

cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, 'model_final.pth')
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5   # set the testing threshold for this model
cfg.DATASETS.TEST = ('coco_instance_segmentation_test', )
predictor = DefaultPredictor(cfg)
test_dataset_dicts = DatasetCatalog.get('coco_instance_segmentation_test')

To visualize the labels as strings rather than IDs, we can use a dictionary mapping the two of them.

mapping = {k: v for k, v in enumerate(train_metadata.thing_classes)}
for d in test_dataset_dicts:    
    im = np.array(Image.open(d['file_name']))
    outputs = predictor(im)
    pred_classes = outputs['instances'].get('pred_classes').cpu().numpy()
    masks = outputs['instances'].get('pred_masks').cpu().permute(1, 2, 0).numpy()
    image_name = d['file_name']
    annotations = []

    if masks.shape[2] != 0:
        for i in range(masks.shape[2]):
            polygons = Mask(masks[:, :, i]).polygons()
            annotation = remo.Annotation()
            annotation.img_filename = image_name
            annotation.classes = mapping[pred_classes[i]]
            annotation.segment = polygons.segmentation[0]
            annotations.append(annotation)
    else:
        polygons = Mask(masks[:, :, 0]).polygons()
        annotation = remo.Annotation()
        annotation.img_filename = image_name
        annotation.classes = mapping[pred_classes[0]]
        annotation.segment = polygons.segmentation[0]
        annotations.append(annotation)
model_predictions = coco_instance_segmentation_dataset.create_annotation_set(annotation_task = 'Instance Segmentation', name = 'model_predictions')

coco_instance_segmentation_dataset.add_annotations(annotations, annotation_set_id=model_predictions.id)
coco_instance_segmentation_dataset.view()

instance_model_predictions

By visualizing the predicted masks against the ground truth, we can go past summary performance metrics, and visually inspect model biases and iterate to improve it.

Looking at once picture, we notice:

  • The giraffe legs are not picked up by the model, this might be due to the tree occlusion

  • A part of the gazelle's body (not present in the annotation) is mistaken as part of a zebra, possibly due to feature similarity in pose.

  • The model is able to distinguish the zebra in the background, which was quite occluded. This is good!

In reality, we would look at all the pictures and at the model performance by class before drawing conclusions. Based on one picture, we can already come up with the following:

Potential improvements

  • Add trees to the model training data, and give it more occluded examples. The occluded examples could be labelled as a different class initially, so we can see the count. And then experiments would say whether it's better to have it as unique class or not

  • Annotate Gazelles as a separate class.

  • Obvious one: train for more epochs and with more data

References