Re-train a model

Learn how to retrain a machine learning model in ML.NET.

The world and the data around it change at a constant pace. As such, models need to change and update as well. ML.NET provides functionality for re-training models using learned model parameters as a starting point to continually build on previous experience rather than starting from scratch every time.

The following algorithms are re-trainable in ML.NET:

Load pre-trained model

First, load the pre-trained model into your application. To learn more about loading training pipelines and models, see Save and load a trained model.

// Create MLContext
MLContext mlContext = new MLContext();

// Define DataViewSchema of data prep pipeline and trained model
DataViewSchema dataPrepPipelineSchema, modelSchema;

// Load data preparation pipeline
ITransformer dataPrepPipeline = mlContext.Model.Load("data_preparation_pipeline.zip", out dataPrepPipelineSchema);

// Load trained model
ITransformer trainedModel = mlContext.Model.Load("ogd_model.zip", out modelSchema);

Extract pre-trained model parameters

Once the model is loaded, extract the learned model parameters by accessing the Model property of the pre-trained model. The pre-trained model was trained using the linear regression model OnlineGradientDescentTrainer which creates aRegressionPredictionTransformer that outputs LinearRegressionModelParameters. These linear regression model parameters contain the learned bias and weights or coefficients of the model. These values will be used as a starting point for the new re-trained model.

// Extract trained model parameters
LinearRegressionModelParameters originalModelParameters =
    ((ISingleFeaturePredictionTransformer<object>)trainedModel).Model as LinearRegressionModelParameters;

Re-train model

The process for retraining a model is no different than that of training a model. The only difference is, the Fit method in addition to the data also takes as input the original learned model parameters and uses them as a starting point in the re-training process.

// New Data
HousingData[] housingData = new HousingData[]
{
    new HousingData
    {
        Size = 850f,
        HistoricalPrices = new float[] { 150000f,175000f,210000f },
        CurrentPrice = 205000f
    },
    new HousingData
    {
        Size = 900f,
        HistoricalPrices = new float[] { 155000f, 190000f, 220000f },
        CurrentPrice = 210000f
    },
    new HousingData
    {
        Size = 550f,
        HistoricalPrices = new float[] { 99000f, 98000f, 130000f },
        CurrentPrice = 180000f
    }
};

//Load New Data
IDataView newData = mlContext.Data.LoadFromEnumerable<HousingData>(housingData);

// Preprocess Data
IDataView transformedNewData = dataPrepPipeline.Transform(newData);

// Retrain model
RegressionPredictionTransformer<LinearRegressionModelParameters> retrainedModel =
    mlContext.Regression.Trainers.OnlineGradientDescent()
        .Fit(transformedNewData, originalModelParameters);

Compare model parameters

How do you know if re-training actually happened? One way would be to compare whether the re-trained model's parameters are different than those of the original model. The code sample below compares the original against the re-trained model weights and outputs them to the console.

// Extract Model Parameters of re-trained model
LinearRegressionModelParameters retrainedModelParameters = retrainedModel.Model as LinearRegressionModelParameters;

// Inspect Change in Weights
var weightDiffs =
    originalModelParameters.Weights.Zip(
        retrainedModelParameters.Weights, (original, retrained) => original - retrained).ToArray();

Console.WriteLine("Original | Retrained | Difference");
for(int i=0;i < weightDiffs.Count();i++)
{
    Console.WriteLine($"{originalModelParameters.Weights[i]} | {retrainedModelParameters.Weights[i]} | {weightDiffs[i]}");
}

The table below shows what the output might look like.

Original Retrained Difference
33039.86 56293.76 -23253.9
29099.14 49586.03 -20486.89
28938.38 48609.23 -19670.85
30484.02 53745.43 -23261.41