Machine Learning Image Popularity

Steve Spagnola
Written by Steve Spagnola
Last updated Feb. 6, 2019

What makes people click, like & share online images? Is it the colors, composition, contrast, tones or something else?

In this post, we’ll walk through developing an algorithm to predict whether or not an image is popular on GrubHub with 65% accuracy.

Part 1 - Getting Training Image Data

In this exercise, we’ll keep things simple and focus on predicting whether or not an image’s click through rate will exceed a certain percent or not.

Ideally, we’d like to train our system on images with a 50:50 split: 50% of images have a click through rate below a certain threshold and 50% have a clickthrough above that threshold.

If you already have your own data of images and associated click throughs (thousands of images are recommended), feel free to skip to part 2. Otherwise, we’ll discuss how one could collect images from GrubHub to follow this exercise.

If you don’t have thousands of images with click through rates to get started with, we can get a little creative & resourceful using public data.

Take a look at GrubHub and click on your favorite restaurant. Check out their menu and you’ll notice some items are featured as Most Popular, typically with images attached.

Let’s inspect this data more closely - we’ll find that within some sections of the menu where multiple items have images (e.g. in the appetizers section), there are cases where one item rose to “Most Popular” while other menu items were left behind. We can assume that the “Most Popular” item was ordered more than its peers and thus rose to the “Most Popular” section.

While we can’t say for sure that the most popular items were oredered more due to their photo (the title and description can make all the difference too - or the popular item may have just been around longer), but for this exercise we will assume that within a menu section, the popular item has a higher click through rate than its peers with photos.

Now we can create our training data: popular food images vs. non-popular food images.

Collecting the Data

You can just use Google Chrome’s web inspector to collect the images yourself while browsing GrubHub, or you can check out the Stevesie GrubHub Data API for more information.

Download Popular & Unpopular Images

Once you have the JSON data for each restaurant’s menu (with the links to the images), we can download the images locally to build our machine learning algorithm.

# python

RESTAURANT_JSON_FILEPATH = '~/Desktop/training_restaurants.json'
TARGET_IMAGE_DIRECTORY = '~/Desktop/training_images'

with open(os.path.expanduser(RESTAURANT_JSON_FILEPATH), 'r') as f:
    all_restaurants = json.load(f)

for item in all_restaurants['items']:
        restaurant = item['object']['restaurant']
        for menu_category in restaurant.get('menu_category_list', []):

            menu_items_with_images = \
                [menu_item for menu_item in menu_category['menu_item_list'] if 'media_image' in menu_item]

            if len(menu_items_with_images) > 0:
                category_popular_urls = []
                category_unpopular_urls = []

                for menu_item in menu_items_with_images:
                    is_popular = menu_item['popular']
                    media_image = menu_item['media_image']
                    image_url = '{}{}.{}'.format(
                        media_image['base_url'], media_image['public_id'], media_image['format'])

                    if is_popular:

                if len(category_popular_urls) > 0 and len(category_unpopular_urls) > 0:
                    all_popular_urls += category_popular_urls
                    all_unpopular_urls += category_unpopular_urls

def dedupe_list(inspect_list, check_list):
    return [url for url in inspect_list if url not in check_list]

def write_urls(image_urls, directory_path):
    if not os.path.exists(directory_path):

    for image_url in image_urls:
        url_name = image_url.split('/')[-1]
        write_image = open(os.path.join(directory_path, url_name), 'wb')
        r = requests.get(image_url, stream=True)
        for block in r.iter_content(1024):
            if not block:

write_urls(dedupe_list(all_popular_urls, all_unpopular_urls), os.path.join(TARGET_IMAGE_DIRECTORY, 'popular'))
write_urls(dedupe_list(all_unpopular_urls, all_popular_urls), os.path.join(TARGET_IMAGE_DIRECTORY, 'unpopular'))

Part 2 - Strategize

If you fail at this step, everything you doing going forward will be a waste of time and you’ll have to come back here. Take a deep breathe and really think about the problem you’re trying to solve and its context.

Don’t Throw AI at It

You may be saying to yourself - now I have 2 sets of images I want to classify.. I know what that sounds like! If you Google how to train and perform image classification, you’ll likely land on something like Simple Image Classification using Convolutional Neural Network.

If you throw your images at this algorithm, you’re going to see poor results.

We’re Learning the Invisible

The problem of predicting image engagement is not a traditional problem of image classifictaion (e.g. what is in this image), but rather will it be well-received online based on the how of the image: the colors, lighting, angles, etc…

Popular Training Images:

Popular Training Images

Unpopular Training Images:

Unpopular Training Images

Something About Those Colors

Just by looking at my training data, I can see that the popular images just seem to… pop more than the other images. They have a certain characteristic about the colors used, their distributions and the contrasts they command to make them more appealing.

I’m going to hypothesize going forward that we can get some predictive power just by analyzing the main colors of images, so we will proceed by focusing on those features and ignoring everything else (e.g. raw pixels, lines, shadows, etc…).

Part 3 - Extract Features

Feature extraction is arguably one of the most important steps in machine learning. A learning algorithm is only as good as the data that it’s fed - feed your algorithm the wrong data (or irrelevant data), and you’re going to see poor results.

Dominant Colors

I want to get the dominant color from each image. After Googling around a little bit, Finding Dominant Image Colours Using Python proved extremely helpful in documenting the approach using K-Means clustering.

Build a Training CSV File

We now want to transform our raw data (the images in each folder) into a CSV file with the color summary for each image. We’ll write a quick Python script to accomplish this:

# python

import os

SOURCE_IMAGES_FOLDER = '~/Desktop/training_images'
TARGET_FILEPATH = '~/Desktop/training_features.csv'


def features_from_image(image_filepath):
    img = cv2.imread(filepath)

    #convert to bgr
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # shape to pixels
    img = img.reshape((img.shape[0] * img.shape[1], 3))

    kmeans = MiniBatchKMeans(n_clusters=NUM_COLORS)

    colors = kmeans.cluster_centers_
    labels = kmeans.labels_

    # todo - get weights and SORT before putting in CSV!

# write this so we can use it in the predictor
def image_features(directory_path, is_popular):
    all_features = []
    for filename in os.listdir(directory_path):
        filepath = os.path.join(directory_path, filename)

    return all_features

with open(os.path.expanduser(TARGET_FILEPATH), 'w') as f:
    writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_ALL)
    write_image_features(os.path.join(SOURCE_IMAGES_FOLDER, 'popular', True, writer)
    write_image_features(os.path.join(SOURCE_IMAGES_FOLDER, 'unpopular', False, writer)   

We’ll also be able to use this script to generate our testing features, which we’ll get to later.

Part 4 - Explore Features

Once we have our features from the raw data, let’s examine them a bit to make sure they line up with our assumptions.

# python

import csv

FEATURE_FILEPATH = '~/Desktop/training_features.csv'

fig = plt.figure()
ax_unpop = fig.add_subplot(1, 2, 1, projection='3d', title='Unpopular')
ax_pop = fig.add_subplot(1, 2, 2, projection='3d', title='Popular')

def feature_tuple_from_row(row):
    # (is_popular, colors)
    return (bool(row[1]), [
        [row[2], row[3], row[4], row[5]], # r, g, b, weight
        [row[6], row[7], row[8], row[9]],
        [row[10], row[11], row[12], row[13]],
        [row[14], row[15], row[16], row[17]],
        [row[18], row[19], row[20], row[21]]

def rgb_to_hex(r, g, b):
    return '#%02x%02x%02x' % (r, g, b)

with open(os.path.expanduser(FEATURE_FILEPATH), 'r') as f:
  for row in csv_reader:
    is_popular, colors = feature_tuple_from_row(row)

    to_plot = ax_pop if is_popular else ax_unpop
    points = []

    for color in colors:
        rgb = [color[0], color[1], color[2]]
        to_plot.scatter(*rgb, s=100*color[3], color=rgb_to_hex(*rgb))

    poly = geoms.Polygon(np.array(points))
    verts = [list(zip(x, y, z))]
    pc = Poly3DCollection(verts, linewidths=1, alpha=max_weight)

    avg_r = np.array([color[0] for color in colors]).mean()
    avg_g = np.array([color[1] for color in colors]).mean()
    avg_b = np.array([color[2] for color in colors]).mean()

    face_color = [avg_r / 255, avg_g / 255, avg_b / 255]
    to_plot.add_collection3d(pc, zs='z')

Part 5 - Build Your Model

Once you’re gotten to know your features, consider what would be a good machine learning algorithm. In our color data, there are a lot of different colors for each datapoint as well as differences in colors we hypothesize to be important. We feel that the relationship between these raw colors and their interactions with each other will help determine if an image reaches popularity or not.

This sounds like a good case for using a Support Vector Machine (or SVM).

# python

import joblib

from sklearn.model_selection import cross_val_score

FEATURE_FILEPATH = '~/Desktop/training_features.csv'
TARGET_MODEL_FILEPATH = '~/Desktop/svm_model.joblib'

def extract_label_and_features(row):
    is_popular, colors = feature_tuple_from_row(row)

    features = [
        unweighted_r - avg_r,
        unweighted_g - avg_g,
        unweighted_b - avg_b,






        sorted_colors[0][0] - sorted_colors[1][0],
        sorted_colors[0][1] - sorted_colors[1][1],
        sorted_colors[0][2] - sorted_colors[1][2],
        sorted_colors[0][3] - sorted_colors[1][3],

        sorted_colors[0][0] - sorted_colors[2][0],
        sorted_colors[0][1] - sorted_colors[2][1],
        sorted_colors[0][2] - sorted_colors[2][2],
        sorted_colors[0][3] - sorted_colors[2][3],

        sorted_colors[0][0] - sorted_colors[3][0],
        sorted_colors[0][1] - sorted_colors[3][1],
        sorted_colors[0][2] - sorted_colors[3][2],
        sorted_colors[0][3] - sorted_colors[3][3],

        sorted_colors[0][0] - sorted_colors[4][0],
        sorted_colors[0][1] - sorted_colors[4][1],
        sorted_colors[0][2] - sorted_colors[4][2],
        sorted_colors[0][3] - sorted_colors[4][3],

        sorted_colors[1][0] - sorted_colors[2][0],
        sorted_colors[1][1] - sorted_colors[2][1],
        sorted_colors[1][2] - sorted_colors[2][2],
        sorted_colors[1][3] - sorted_colors[2][3],

        sorted_colors[1][0] - sorted_colors[3][0],
        sorted_colors[1][1] - sorted_colors[3][1],
        sorted_colors[1][2] - sorted_colors[3][2],
        sorted_colors[1][3] - sorted_colors[3][3],

        sorted_colors[1][0] - sorted_colors[4][0],
        sorted_colors[1][1] - sorted_colors[4][1],
        sorted_colors[1][2] - sorted_colors[4][2],
        sorted_colors[1][3] - sorted_colors[4][3],

        sorted_colors[2][0] - sorted_colors[3][0],
        sorted_colors[2][1] - sorted_colors[3][1],
        sorted_colors[2][2] - sorted_colors[3][2],
        sorted_colors[2][3] - sorted_colors[3][3],

        sorted_colors[2][0] - sorted_colors[4][0],
        sorted_colors[2][1] - sorted_colors[4][1],
        sorted_colors[2][2] - sorted_colors[4][2],
        sorted_colors[2][3] - sorted_colors[4][3],

        sorted_colors[3][0] - sorted_colors[4][0],
        sorted_colors[3][1] - sorted_colors[4][1],
        sorted_colors[3][2] - sorted_colors[4][2],
        sorted_colors[3][3] - sorted_colors[4][3],


    return (is_popular, features)

X_train = []
Y_train = []

with open(os.path.expanduser(FEATURE_FILEPATH), 'r') as f:
  for row in csv_reader:
    y, x = extract_label_and_features(row)


clf = svm.SVC(kernel='poly', max_iter=1000000, degree=3, C=100000.0)

scores = cross_val_score(clf, X, training_labels, cv=5)
print('Accuracy: %0.2f (+/- %0.2f)' % (scores.mean(), scores.std() * 2)), Y_train)
joblub.dump(clf, os.path.expanduser(TARGET_MODEL_FILEPATH))

Part 6 - Test & Evaluate Your Model


FEATURE_DATA_FILE_PATH_TEST = '~/Desktop/test_features.csv'
SVM_MODEL_FILEPATH = '~/Desktop/svm_model.joblib'

from sklearn.metrics import confusion_matrix

# TODO - cite source
def plot_confusion_matrix(cm, classes,
                          title='Confusion matrix',
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
        print('Confusion matrix, without normalization')


    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')

X_test = []
Y_test = []

with open(os.path.expanduser(FEATURE_DATA_FILE_PATH_TEST), 'r') as f:
    csv_reader = csv.reader(f)

    for row in csv_reader:
        y, x = extract_label_and_features(row)


clf = joblib.load(SVM_MODEL_FILEPATH)

scores = clf.score(X_test, Y_test)

Y_predict = clf.predict(X_test)

# Compute confusion matrix

class_names = ['Unpopular', 'Popular']

# Plot non-normalized confusion matrix
plot_confusion_matrix(cnf_matrix, classes=class_names,
                      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                      title='Normalized confusion matrix')

total_popular_predictions = cnf_matrix[0][1] + cnf_matrix[1][1]
correct_popular_predictions = cnf_matrix[1][1]

print('Popular Distribution')
print( (cnf_matrix[1][0] + cnf_matrix[1][1]) / (cnf_matrix[1][0] + cnf_matrix[1][1] + cnf_matrix[0][0] + cnf_matrix[0][1]) )
print('Relevant Precision')
print(correct_popular_predictions / total_popular_predictions)

Evaluate the Confusion Matrix

Confustion Matrix

We can see that of the popular predictions, the algorithm is correct 29% of the time (compared to a baseline of 20% that are popular in our test set). This means that if the algorithm agrees that a photo will become popular, it has a 29% chance of making it.

What’s more impressive however is how good the algorithm is at filtering out what will be unpopular images.

Part 7 - Build Your Predictor

Now it’s the fun part. When we’re happy with our model, we can now use it to predict if a photo will do well online or not. Remember, we know the accuracy is only 65% and to take it with a grain of salt, but when you’re on the fence between which of two images to post, this algorithm may make a decent tie breaker.

# python

import os

IMAGE_FILEPATH '~/Desktop/predict_me.jpg'
SVM_MODEL_FILEPATH = '~/Desktop/svm_model.joblib'

import joblib

clf = joblib.load(SVM_MODEL_FILEPATH)

image_features = features_from_image(os.expanduser(IMAGE_FILEPATH))

predictions = clf.predict([image_features])