Tutorial: Retrain a TensorFlow image classifier with transfer learning and ML.NET

Learn how to retrain an image classification TensorFlow model with transfer learning and ML.NET. The original model was trained to classify individual images. After retraining, the new model organizes the images into broad categories.

Training an Image Classification model from scratch requires setting millions of parameters, a ton of labeled training data and a vast amount of compute resources (hundreds of GPU hours). While not as effective as training a custom model from scratch, transfer learning allows you to shortcut this process by working with thousands of images vs. millions of labeled images and build a customized model fairly quickly (within an hour on a machine without a GPU).

In this tutorial, you learn how to:

  • Understand the problem
  • Reuse and tune the pre-trained model
  • Classify Images

What is transfer learning?

What if you could reuse a model that's already been pre trained to solve a similar problem and retrain either all or some of the layers of that model to make it solve your problem? This technique of reusing part of an already trained model to build a new model is known as transfer learning.

Image classification sample overview

The sample is a console application that uses ML.NET to build an image classifier by reusing a pre-trained model to classify images with a small amount of training data.

You can find the source code for this tutorial at the dotnet/samples repository. Note that by default, the .NET project configuration for this tutorial targets .NET core 2.2.

Prerequisites

Select the appropriate machine learning task

Deep learning is a subset of Machine Learning, which is revolutionizing areas like Computer Vision and Speech Recognition.

Deep learning models are trained by using large sets of labeled data and neural networks that contain multiple learning layers. Deep learning:

  • Performs better on some tasks like Computer Vision.

  • Performs well on huge data amounts.

Image Classification is a common Machine Learning task that allows us to automatically classify images into multiple categories such as:

  • Detecting a human face in an image or not.
  • Detecting Cats vs. dogs.

Or as in the following images determining if an image is a(n) food, toy, or appliance:

pizza image teddy bear image toaster image

Note

The preceding images belong to Wikimedia Commons and are attributed as follows:

Transfer learning includes a few strategies, such as retrain all layers and penultimate layer. This tutorial will explain and show how to use the penultimate layer strategy. The penultimate layer strategy reuses a model that's already been pre-trained to solve a specific problem. The strategy then retrains the final layer of that model to make it solve a new problem. Reusing the pre-trained model as part of your new model will save significant time and resources.

Your image classification model reuses the Inception model, a popular image recognition model trained on the ImageNet dataset where the TensorFlow model tries to classify entire images into a thousand classes, like “Umbrella”, “Jersey”, and “Dishwasher”.

The Inception v1 model can be classified as a deep convolutional neural network and can achieve reasonable performance on hard visual recognition tasks, matching or exceeding human performance in some domains. The model/algorithm was developed by multiple researchers and based on the original paper: "Rethinking the Inception Architecture for Computer Vision” by Szegedy, et. al.

Because the Inception model has already been pre trained on thousands of different images, it contains the image features needed for image identification. The lower image feature layers recognize simple features (such as edges) and the higher layers recognize more complex features (such as shapes). The final layer is trained against a much smaller set of data because you're starting with a pre trained model that already understands how to classify images. As your model allows you to classify more than two categories, this is an example of a multi-class classifier.

TensorFlow is a popular deep learning and machine learning toolkit that enables training deep neural networks (and general numeric computations), and is implemented as a transformer in ML.NET. For this tutorial, it's used to reuse the Inception model.

As shown in the following diagram, you add a reference to the ML.NET NuGet packages in your .NET Core or .NET Framework applications. Under the covers, ML.NET includes and references the native TensorFlow library that allows you to write code that loads an existing trained TensorFlow model file for scoring.

TensorFlow transform ML.NET Arch diagram

The Inception model is trained to classify images into a thousand categories, but you need to classify images in a smaller category set, and only those categories. Enter the transfer part of transfer learning. You can transfer the Inception model's ability to recognize and classify images to the new limited categories of your custom image classifier.

You're going to retrain the final layer of that model using a set of three categories:

  • Food
  • Toy
  • Appliance

Your layer uses a multinomial logistic regression algorithm to find the correct category as quickly as possible. This algorithm classifies using probabilities to determine the answer, giving a one value to the correct category and a zero value to the others.

DataSet

There are two data sources: the .tsv file, and the image files. The tags.tsv file contains two columns: the first one is defined as ImagePath and the second one is the Label corresponding to the image. The following example file doesn't have a header row, and looks like this:

broccoli.jpg	food
pizza.jpg	food
pizza2.jpg	food
teddy2.jpg	toy
teddy3.jpg	toy
teddy4.jpg	toy
toaster.jpg	appliance
toaster2.png	appliance

The training and testing images are located in the assets folders that you'll download in a zip file. These images belong to Wikimedia Commons.

Wikimedia Commons, the free media repository. Retrieved 10:48, October 17, 2018 from:
https://commons.wikimedia.org/wiki/Pizza
https://commons.wikimedia.org/wiki/Toaster
https://commons.wikimedia.org/wiki/Teddy_bear

Create a console application

Create a project

  1. Create a .NET Core Console Application called "TransferLearningTF".

  2. Install the Microsoft.ML NuGet Package:

    In Solution Explorer, right-click on your project and select Manage NuGet Packages. Choose "nuget.org" as the Package source, select the Browse tab, search for Microsoft.ML. Click on the Version drop-down, select the 1.0.0 package in the list, and select the Install button. Select the OK button on the Preview Changes dialog and then select the I Accept button on the License Acceptance dialog if you agree with the license terms for the packages listed. Repeat these steps for Microsoft.ML.ImageAnalytics v1.0.0 and Microsoft.ML.TensorFlow v0.12.0.

Prepare your data

  1. Download The project assets directory zip file, and unzip.

  2. Copy the assets directory into your TransferLearningTF project directory. This directory and its subdirectories contain the data and support files (except for the Inception model, which you'll download and add in the next step) needed for this tutorial.

  3. Download the Inception model, and unzip.

  4. Copy the contents of the inception5h directory just unzipped into your TransferLearningTF project assets\inputs-train\inception directory. This directory contains the model and additional support files needed for this tutorial, as shown in the following image:

    Inception directory contents

  5. In Solution Explorer, right-click each of the files in the asset directory and subdirectories and select Properties. Under Advanced, change the value of Copy to Output Directory to Copy if newer.

Create classes and define paths

Add the following additional using statements to the top of the Program.cs file:

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms.Image;

Create global fields to hold the paths to the various assets, and global variables for the LabelTokey,ImageReal, and PredictedLabelValue:

  • _assetsPath has the path to the assets.
  • _trainTagsTsv has the path to the training image data tags tsv file.
  • _predictTagsTsv has the path to the prediction image data tags tsv file.
  • _trainImagesFolder has the path to the images used to train the model.
  • _predictImagesFolder has the path to the images to be classified by the trained model.
  • _inceptionPb has the path to the pre-trained Inception model to be reused to retrain your model.
  • _inputImageClassifierZip has the path where the trained model is loaded from.
  • _outputImageClassifierZip has the path where the trained model is saved.
  • LabelTokey is the Label value mapped to a key.
  • ImageReal is the column containing the predicted image value.
  • PredictedLabelValue is the column containing the predicted label value.

Add the following code to the line right above the Main method to specify those paths and the other variables:

static readonly string _assetsPath = Path.Combine(Environment.CurrentDirectory, "assets");
static readonly string _trainTagsTsv = Path.Combine(_assetsPath, "inputs-train", "data", "tags.tsv");
static readonly string _predictImageListTsv = Path.Combine(_assetsPath, "inputs-predict", "data", "image_list.tsv");
static readonly string _trainImagesFolder = Path.Combine(_assetsPath, "inputs-train", "data");
static readonly string _predictImagesFolder = Path.Combine(_assetsPath, "inputs-predict", "data");
static readonly string _predictSingleImage = Path.Combine(_assetsPath, "inputs-predict-single", "data", "toaster3.jpg");
static readonly string _inceptionPb = Path.Combine(_assetsPath, "inputs-train", "inception", "tensorflow_inception_graph.pb");
static readonly string _inputImageClassifierZip = Path.Combine(_assetsPath, "inputs-predict", "imageClassifier.zip");
static readonly string _outputImageClassifierZip = Path.Combine(_assetsPath, "outputs", "imageClassifier.zip");
private static string LabelTokey = nameof(LabelTokey);
private static string PredictedLabelValue = nameof(PredictedLabelValue);

Create some classes for your input data, and predictions. Add a new class to your project:

  1. In Solution Explorer, right-click the project, and then select Add > New Item.

  2. In the Add New Item dialog box, select Class and change the Name field to ImageData.cs. Then, select the Add button.

    The ImageData.cs file opens in the code editor. Add the following using statement to the top of ImageData.cs:

using Microsoft.ML.Data;

Remove the existing class definition and add the following code for the ImageData class to the ImageData.cs file:

public class ImageData
{
    [LoadColumn(0)]
    public string ImagePath;

    [LoadColumn(1)]
    public string Label;
}

ImageData is the input image data class and has the following String fields:

  • ImagePath contains the image file name.
  • Label contains a value for the image label.

Add a new class to your project for ImagePrediction:

  1. In Solution Explorer, right-click the project, and then select Add > New Item.

  2. In the Add New Item dialog box, select Class and change the Name field to ImagePrediction.cs. Then, select the Add button.

    The ImagePrediction.cs file opens in the code editor. Remove both the System.Collections.Generic and the System.Text using statements at the top of ImagePrediction.cs:

Remove the existing class definition and add the following code, which has the ImagePrediction class, to the ImagePrediction.cs file:

public class ImagePrediction : ImageData
{
    public float[] Score;

    public string PredictedLabelValue;
}

ImagePrediction is the image prediction class and has the following fields:

  • Score contains the confidence percentage for a given image classification.
  • PredictedLabelValue contains a value for the predicted image classification label.

ImagePrediction is the class used for prediction after the model has been trained. It has a string (ImagePath) for the image path. The Label is used to reuse and retrain the model. The PredictedLabelValue is used during prediction and evaluation. For evaluation, an input with training data, the predicted values, and the model are used.

The MLContext class is a starting point for all ML.NET operations, and initializing mlContext creates a new ML.NET environment that can be shared across the model creation workflow objects. It's similar, conceptually, to DBContext in Entity Framework.

Initialize variables in Main

Initialize the mlContext variable with a new instance of MLContext. Replace the Console.WriteLine("Hello World!") line with the following code in the Main method:

MLContext mlContext = new MLContext(seed: 1);

Create a struct for default parameters

The Inception model has several default parameters you need to pass in. Create a struct to map the default parameter values to friendly names with the following code, just after the Main() method:

private struct InceptionSettings
{
    public const int ImageHeight = 224;
    public const int ImageWidth = 224;
    public const float Mean = 117;
    public const float Scale = 1;
    public const bool ChannelsLast = true;
}

Create a display utility method

Since you'll display the image data and the related predictions more than once, create a display utility method to handle displaying the image and prediction results.

The DisplayResults() method executes the following tasks:

  • Displays the predicted results.

Create the DisplayResults() method, just after the InceptionSettings struct, using the following code:

private static void DisplayResults(IEnumerable<ImagePrediction> imagePredictionData)
{

}

The Transform() method populated ImagePath in ImagePrediction along with the predicted fields. As the ML.NET process progresses, each component adds columns, and this makes it easy to display the results:

foreach (ImagePrediction prediction in imagePredictionData)
{
    Console.WriteLine($"Image: {Path.GetFileName(prediction.ImagePath)} predicted as: {prediction.PredictedLabelValue} with score: {prediction.Score.Max()} ");
}

You'll call the DisplayResults() method in the two image classification methods.

Create a .tsv file utility method

The ReadFromTsv() method executes the following tasks:

  • Reads the image data tags.tsv file.
  • Adds the file path to the image file name.
  • Loads the file data into an IEnumerableImageData object.

Create the ReadFromTsv() method, just after the PairAndDisplayResults() method, using the following code:

public static IEnumerable<ImageData> ReadFromTsv(string file, string folder)
{

}

The following code parses through the tags.tsv file to add the file path to the image file name for the ImagePath property and load it and the Label into an ImageData object. Add it as the first line of the ReadFromTsv() method. You need the fully qualified file path to display the prediction results.

return File.ReadAllLines(file)
 .Select(line => line.Split('\t'))
 .Select(line => new ImageData()
 {
     ImagePath = Path.Combine(folder, line[0])
 });

There are three major concepts in ML.NET: Data, Transformers, and Estimators.

Reuse and tune pre-trained model

Add the following call to the ReuseAndTuneInceptionModel()method as the next line of code in the Main() method:

var model = ReuseAndTuneInceptionModel(mlContext, _trainTagsTsv, _trainImagesFolder, _inceptionPb, _outputImageClassifierZip);

The ReuseAndTuneInceptionModel() method executes the following tasks:

  • Loads the data
  • Extracts and transforms the data.
  • Scores the TensorFlow model.
  • Tunes (retrains) the model.
  • Displays model results.
  • Evaluates the model.
  • Returns the model.

Create the ReuseAndTuneInceptionModel() method, just after the InceptionSettings struct and just before the DisplayResults() method, using the following code:

public static ITransformer ReuseAndTuneInceptionModel(MLContext mlContext, string dataLocation, string imagesFolder, string inputModelLocation, string outputModelLocation)
{

}

Load the data

Data in ML.NET is represented as an IDataView class. IDataView is a flexible, efficient way of describing tabular data (numeric and text). Data can be loaded from a text file or in real time (for example, SQL database or log files) to an IDataView object.

Load the data using the MLContext.Data.LoadFromTextFile wrapper. Add the following code as the next line in the ReuseAndTuneInceptionModel() method:

var data = mlContext.Data.LoadFromTextFile<ImageData>(path: dataLocation, hasHeader: false);

Extract Features and transform the data

Pre-processing and cleaning data are important tasks that occur before a dataset is used effectively for machine learning. Using data without these modeling tasks can produce misleading results.

Machine learning algorithms understand featurized data, and when dealing with deep neural networks you must adapt the images to the format expected by the network. That format is a numeric vector.

After training and evaluation, predict with the Label column values. As you're using a pre-trained model, map fields to the new model with the MapValueToKey() method. This method transforms the Label into a numeric key type (LabelTokey) column and add it as new dataset column: Name this estimator as you'll also add the trainer to it. Add the next line of code:

var estimator = mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: LabelTokey, inputColumnName: "Label")

Your image processing estimator uses pre-trained Deep Neural Network(DNN) featurizers for feature extraction. When dealing with deep neural networks, you adapt the images to the expected network format. This is the reason you use several image transforms to get the image data into the model's expected form:

  1. The LoadImagestransform images are loaded in memory as a Bitmap type.
  2. The ResizeImages transform resizes the images as the pre-trained model has a defined input image width and height.
  3. The ExtractPixels transform extracts the pixels from the input images and converts them into a numeric vector.

Add these image transforms as the next lines of code:

.Append(mlContext.Transforms.LoadImages(outputColumnName: "input", imageFolder: _trainImagesFolder, inputColumnName: nameof(ImageData.ImagePath)))
.Append(mlContext.Transforms.ResizeImages(outputColumnName: "input", imageWidth: InceptionSettings.ImageWidth, imageHeight: InceptionSettings.ImageHeight, inputColumnName: "input"))
.Append(mlContext.Transforms.ExtractPixels(outputColumnName: "input", interleavePixelColors: InceptionSettings.ChannelsLast, offsetImage: InceptionSettings.Mean))

The LoadTensorFlowModel is a convenience method that allows the TensorFlow model to be loaded once and then creates the TensorFlowEstimator using ScoreTensorFlowModel. The ScoreTensorFlowModel extracts specified outputs (the Inception model's image features softmax2_pre_activation), and scores a dataset using the pre-trained TensorFlow model.

softmax2_pre_activation assists the model with determining which class the images belongs to. softmax2_pre_activation returns a probability for each of the categories for an image, and all of those probabilities must add up to 1. It assumes that an image will belong to only one category, as shown in the following example:

Class Probability
Food 0.001
Toy 0.95
Appliance 0.06

Append the TensorFlowTransform to the estimator with the following line of code:

.Append(mlContext.Model.LoadTensorFlowModel(inputModelLocation).
    ScoreTensorFlowModel(outputColumnNames: new[] { "softmax2_pre_activation" }, inputColumnNames: new[] { "input" }, addBatchDimensionInput: true))

Choose a training algorithm

To add the training algorithm, call the mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy() wrapper method. The LbfgsMaximumEntropy is appended to the estimator and accepts the Inception image features (softmax2_pre_activation) and the Label input parameters to learn from the historic data. Add the trainer with the following code:

.Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: LabelTokey, featureColumnName: "softmax2_pre_activation"))

You also need to map the predictedlabel to the predictedlabelvalue:

.Append(mlContext.Transforms.Conversion.MapKeyToValue(PredictedLabelValue, "PredictedLabel"))
.AppendCacheCheckpoint(mlContext);

The Fit() method trains your model by transforming the dataset and applying the training. Fit the model to the training dataset and return the trained model by adding the following as the next line of code in the ReuseAndTuneInceptionModel() method:

ITransformer model = estimator.Fit(data);

The Transform() method makes predictions for multiple provided input rows of a test dataset. Transform the Training data by adding the following code to ReuseAndTuneInceptionModel():

var predictions = model.Transform(data);

Convert your image data and prediction DataViews into strongly-typed IEnumerables to pair for easier display. Use the MLContext.CreateEnumerable() method to do that, using the following code:

var imageData = mlContext.Data.CreateEnumerable<ImageData>(data, false, true);
var imagePredictionData = mlContext.Data.CreateEnumerable<ImagePrediction>(predictions, false, true);

Call the DisplayResults() method to display your data and predictions as the next line in the ReuseAndTuneInceptionModel() method:

DisplayResults(imagePredictionData);

Once you have the prediction set, the Evaluate() method:

  • Assesses the model (compares the predicted values with the actual dataset Labels).

  • Returns the model performance metrics.

Add the following code to the ReuseAndTuneInceptionModel() method as the next line:

var multiclassContext = mlContext.MulticlassClassification;
var metrics = multiclassContext.Evaluate(predictions, labelColumnName: LabelTokey, predictedLabelColumnName: "PredictedLabel");

The following metrics are evaluated for image classification:

  • Log-loss - see Log Loss. You want Log-loss to be as close to zero as possible.

  • Per class Log-loss. You want per class Log-loss to be as close to zero as possible.

Use the following code to display the metrics, share the results, and then act on them:

Console.WriteLine($"LogLoss is: {metrics.LogLoss}");
Console.WriteLine($"PerClassLogLoss is: {String.Join(" , ", metrics.PerClassLogLoss.Select(c => c.ToString()))}");

Add the following code to return the trained model as the next line:

return model;

Classify images with a loaded model

Add the following call to the ClassifyImages() method as the next line of code in the Main method:

ClassifyImages(mlContext, _predictImageListTsv, _predictImagesFolder, _outputImageClassifierZip, model);

The ClassifyImages() method executes the following tasks:

  • Reads .TSV file into IEnumerable.
  • Predicts image classifications based on test data.

Create the ClassifyImages() method, just after the ReuseAndTuneInceptionModel() method and just before the PairAndDisplayResults() method, using the following code:

public static void ClassifyImages(MLContext mlContext, string dataLocation, string imagesFolder, string outputModelLocation, ITransformer model)
{

}

First, call the ReadFromTsv() method to create an IEnumerable<ImageData> class that contains the fully qualified path for each ImagePath. You need that file path to pair your data and prediction results. You also need to convert the IEnumerable<ImageData> class to an IDataView that you will use to predict. Add the following code as the next two lines in the ClassifyImages() method:

var imageData = ReadFromTsv(dataLocation, imagesFolder);
var imageDataView = mlContext.Data.LoadFromEnumerable<ImageData>(imageData);

As you did previously with the training image data, predict the category of the test image data using the Transform() method of the model passed in. Add the following code to the ClassifyImages() method for the predictions and to convert the predictions IDataView into an IEnumerable for pairing and display:

var predictions = model.Transform(imageDataView);
var imagePredictionData = mlContext.Data.CreateEnumerable<ImagePrediction>(predictions, false, true);

To pair and display your test image data and predictions, add the following code to call the DisplayResults() method previously created as the next line in the ClassifyImages() method:

DisplayResults(imagePredictionData);

Classify a single image with a loaded model

Add the following call to the ClassifySingleImage() method as the next line of code in the Main method:

ClassifySingleImage(mlContext, _predictSingleImage, _outputImageClassifierZip, model);

The ClassifySingleImage() method executes the following tasks:

  • Loads an ImageData instance.
  • Predicts image classification based on test data.

Create the ClassifySingleImage() method, just after the ClassifyImages() method and just before the PairAndDisplayResults() method, using the following code:

public static void ClassifySingleImage(MLContext mlContext, string imagePath, string outputModelLocation, ITransformer model)
{

}

First, create an ImageData class that contains the fully qualified path and image file name for the single ImagePath. Add the following code as the next lines in the ClassifySingleImage() method:

var imageData = new ImageData()
{
    ImagePath = imagePath
};

The PredictionEngine class is a convenience API that performs a prediction on a single instance of data. The Predict() function makes a prediction on a single column of data. Pass imageData to the PredictionEngine to predict the image category by adding the following code to ClassifySingleImage():

// Make prediction function (input = ImageData, output = ImagePrediction)
var predictor = mlContext.Model.CreatePredictionEngine<ImageData, ImagePrediction>(model);
var prediction = predictor.Predict(imageData);

Display the prediction result as the next line of code in the ClassifySingleImage() method:

Console.WriteLine($"Image: {Path.GetFileName(imageData.ImagePath)} predicted as: {prediction.PredictedLabelValue} with score: {prediction.Score.Max()} ");

Results

After following the previous steps, run your console app (Ctrl + F5). Your results should be similar to the following output. You may see warnings or processing messages, but these messages have been removed from the following results for clarity.

=============== Training classification model ===============
Image: broccoli.jpg predicted as: food with score: 0.976743
Image: pizza.jpg predicted as: food with score: 0.9751652
Image: pizza2.jpg predicted as: food with score: 0.9660203
Image: teddy2.jpg predicted as: toy with score: 0.9748783
Image: teddy3.jpg predicted as: toy with score: 0.9829691
Image: teddy4.jpg predicted as: toy with score: 0.9868168
Image: toaster.jpg predicted as: appliance with score: 0.9769174
Image: toaster2.png predicted as: appliance with score: 0.9800823
=============== Classification metrics ===============
LogLoss is: 0.0228266745633507
PerClassLogLoss is: 0.0277501705149937 , 0.0186303530571291 , 0.0217359128952187
=============== Making classifications ===============
Image: broccoli.png predicted as: food with score: 0.905548
Image: pizza3.jpg predicted as: food with score: 0.9709008
Image: teddy6.jpg predicted as: toy with score: 0.9750155
=============== Making single image classification ===============
Image: toaster3.jpg predicted as: appliance with score: 0.9625379

C:\Program Files\dotnet\dotnet.exe (process 4304) exited with code 0.
Press any key to close this window . . .

Congratulations! You've now successfully built a machine learning model for image classification by reusing a pre-trained TensorFlow model in ML.NET.

You can find the source code for this tutorial at the dotnet/samples repository.

In this tutorial, you learned how to:

  • Understand the problem
  • Reuse and tune the pre-trained model
  • Classify images with a loaded model

Check out the Machine Learning samples GitHub repository to explore an expanded image classification sample.