Image Extraction via Python and Deep Learning

github repo


Would you like an easy way to extract interesting regions from photographs? For example, given a photograph such as this:



It would be a cool to have an AI tool that automatically identifies the parrot and the primate, and then extracts the individual images as follows:



And it would be even cooler if this tool could do this extraction in bulk and at a reasonable rate of speed. With such a technology, it would be a simple matter to go through large directories of images, automatically categorizing their contents and extracting the matching image segments.

In this short tutorial I'll demonstrate how to build this tool. It will find up to 80 different classes of objects in an image and then optionally extract them.

A few years ago this would have been science fiction. But, as you'll see, building it now is dead simple. Given the capabilities of modern AI, all we'll need is a suitable deep neural net and 100 lines of python code.

Let's get started!




For this project we'll apply the Mask R-CNN neural network. As of early 2020, this network was state-of-the-art for image identification and classification. Our particular network comes pretrained against the COCO dataset, which is a standard reference database for AI image research.

In the technical parlance, segments refers to the regions of interest inside an image. Mask R-CNN is trained to label all segments it discovers in a given picture. For example, in our case it returns "bird" and "person". However, we can also turn the network around, and extract the pixel bitmap that went into this label. In this way we retrieve the coordinates of all the pixels that composed "bird" and "person" in the original image.

When we do that, this is what we actually pull from the network:


These images are regions masks. In this case, they define the regions of a person and a bird. Given that data, it's then a simple matter to extract the actual image segments themselves.

Let's see how this works. There are two files in this project: segment.py and extract.py. Segment.py uses Mask R-CNN to generate the mask bitmaps. Extract.py then uses those bitmaps to actually extract the segments. These two programs can be combined to automatically extract any number of objects from any number of images.

Segment.py stores the mask bitmaps into a temporary conversions directory. There is one file per bitmap. Each file name is formatted to display the root file name, the confidence score, the label and the bounding rectangle.

For example, here is the mask file that encodes our parrot:

testimage.99.bird.144_28_172_258.png

Which means this is a high-confidence (99%) region for a bird, with a bounding rectangle with origin 144,28 and dimensions of 172x258.

Now lets take a look at the details. First, here's the main code section for segment.py:


PATH_INPUT_IMAGE = sys.argv[1]
PATH_CONVERSION_DIR = sys.argv[2]

file_name =  os.path.basename(PATH_INPUT_IMAGE).split('.')[0]

# create an inference instance of DNN config object
class InferenceConfig(coco.CocoConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
config = InferenceConfig()

# Create model object in inference mode, fold in weights
model=modellib.MaskRCNN(mode="inference", model_dir=MODEL_PATH, config=config)
model.load_weights(MODEL_WEIGHTS_PATH, by_name=True)

# run the model on the input image
image_input = skimage.io.imread(PATH_INPUT_IMAGE)
results = model.detect([image_input], verbose=1)

# unpack all results
result = results[0]
class_ids = result['class_ids']
masks = result['masks'].astype(np.uint8)
scores = result['scores']
rois = result['rois']

#
# for each region identified, 
# get the score, label, dimensions and mask
# modify the mask so that all active pixels are 
# white, with background black
# save the bitmap to the conversion directory
#
regionFileList = []
for index, class_id in enumerate(class_ids):

    region_label=COCO_CLASS_NAMES[class_id].replace(' ', '_')
    score=int(scores[index] * 100)
    (y1, x1, y2, x2)=rois[index] # bounding box for the max
    width=x2 - x1
    height=y2 - y1

    # slice off the bitmap for this object
    bitmap=masks[:,:,index]   

    # make positive mask pixels white
    bitmap[bitmap > 0]=255    

    path_output_image=f'{PATH_CONVERSION_DIR}/m{file_name}.{score}.{region_label}.{x1}_{y1}_{width}_{height}.png'
    image_region=Image.fromarray(bitmap, 'L')
    image_region.save(path_output_image, 'PNG')
    print(path_output_image)

    regionFileList.append(path_output_image)

# lastly, compute the background region 
# (negative of all other regions)
computeBackgroundRegion(file_name, regionFileList)



The algorithm is straightforward. First, we initialize our neural network with the COCO weights. We then take our input image bitmap and feed that into the network in inference mode. The result is an array of labels and their corresponding bitmaps. These identify every image segment that the neural network saw. From this data, we generate the mask images and store them off as files into the conversion directory.

Once this process completes, extract.py has all the information necessary to do segment extraction.

Here's the full code listing for extract.py:


import os
import sys
import numpy as np
import cv2

# return origin coordinates and dimensions of image 
# (these are encoded in image name)
def getRegionAttributes(image_region):

    image_region = os.path.basename(image_region)
    (x, y, w, h)  = image_region.split('.')[3].split('_');
    return (int(x), int(y), int(w), int(h))


# extract specified region within image
def extract_region(image, region):

    extracted_image = np.copy(region)
    rows = region.shape[0]
    cols = region.shape[1]
    for row in range(rows):
        for col in range(cols):
            if region[row, col][0] == 255:
                extracted_image[row, col] = image[row, col]
            else:
                extracted_image[row, col] = (255, 255, 255)

    return extracted_image


if __name__ == '__main__':

    if len(sys.argv) != 4:
        print('usage: python extract.py path_input_image path_region path_output_image')
        exit()

    PATH_INPUT_IMAGE = sys.argv[1]
    PATH_REGION = sys.argv[2]
    PATH_OUTPUT_IMAGE = sys.argv[3]

    image_input = cv2.imread(PATH_INPUT_IMAGE)
    region_input = cv2.imread(PATH_REGION)

    # extract region, crop it match the region mask 
    (x, y, w, h)=getRegionAttributes(PATH_REGION)
    extracted_image=extract_region(image_input, region_input)
    cropped_extracted_image=extracted_image[y:y+h, x:x+w]
    result_image= v2.resize(cropped_extracted_image, (w, h))

    cv2.imwrite(PATH_OUTPUT_IMAGE, result_image);
    print(PATH_OUTPUT_IMAGE)


We take as inputs the path to the image, the mask file for the segment we wish to extract, and the path for the output. Next we gather all the attributes for the mask (encoded in the mask file name itself, as described earlier). We then extract the segment region by effectively ANDing it's bitmap to the original image. Finally, we use array slicing to crop out the result and resize it to the original dimensions.

And there it is: an extracted image.

View the full github listing here.

This code runs pretty well even on a slow computer. For instance, on a down-market t2.large ec2 instance (no GPU), it takes about 20 seconds to identify and extract all segments from a typical image.

Christopher Minson


© 2024 Christopher Minson LLC