Image Segmentation

Hopefully you won’t often be faced with the task of manually segmenting images. However, for the times when you must, it’s nice to not need to leave the comfort of python for some other program. Thus we arrive at the image_segmenter class.

(Credit where it’s due: This tool was developed as part of a final project in Pavlos Protopapas’ class AC295 and you can read more about it in the project’s final write up on towards data science.)

import matplotlib.pyplot as plt
import matplotlib.cbook as cbook
import numpy as np
from mpl_interactions import image_segmenter

# load a sample image
with cbook.get_sample_data('ada.png') as image_file:
    image = plt.imread(image_file)

# If you don't keep a reference to the object the call backs will fail
segmenter = image_segmenter(image, nclasses = 3, mask_colors='red', mask_alpha=.76, figsize=(7,7))

# If working in a Jupyter Notebook
display(segmenter)

This will create an image in a Matplotlib figure. It will automatically apply zoom_factory() and panhandler(), so now you can:

  • Scroll to zoom

  • Use middle click to pan

  • Left click and drag to start creating the mask over the image

../_images/segment1.gif

You can switch which class you are marking by directly modifying the segmenter’s current_class variable:

segmenter.current_class = 2

And you can always direct the 2D mask with:

plt.figure()
plt.imshow(segmenter.mask)

Loading existing masks

You can also load an existing mask. You only need to ensure that (1) it does not contain values greater than nclasses, and (2) that it has the same shape as the image. There are currently no safegaurds for this, and when there are exceptions in a Matplotlib callback they can be hard to see in the notebook - so be careful!

# load the mask
import requests
import io

response = requests.get('https://github.com/ianhi/mpl-interactions/raw/41ebd90674e2136e87240ba81ae509dee15a63a7/examples/ada-mask.npy')
response.raise_for_status()
mask = np.load(io.BytesIO(response.content))  # Works!

preloaded = image_segmenter(image, nclasses=3, mask=mask)
display(preloaded)
../_images/segment-preload-mask.png