重新训练模型Re-train a model
了解如何在 ML.NET 中重新训练机器学习模型。Learn how to retrain a machine learning model in ML.NET.
这个世界和它周围的数据在不断变化。The world and the data around it change at a constant pace. 因此,模型也需要更改和更新。As such, models need to change and update as well. 借助 ML.NET 提供的功能,可以将已学习的模型参数作为起点并不断汲取以往经验来重新训练模型,而不必每次都从头开始。ML.NET provides functionality for re-training models using learned model parameters as a starting point to continually build on previous experience rather than starting from scratch every time.
以下算法可在 ML.NET 中重新训练:The following algorithms are re-trainable in ML.NET:
- AveragedPerceptronTrainerAveragedPerceptronTrainer
- FieldAwareFactorizationMachineTrainerFieldAwareFactorizationMachineTrainer
- LbfgsLogisticRegressionBinaryTrainerLbfgsLogisticRegressionBinaryTrainer
- LbfgsMaximumEntropyMulticlassTrainerLbfgsMaximumEntropyMulticlassTrainer
- LbfgsPoissonRegressionTrainerLbfgsPoissonRegressionTrainer
- LinearSvmTrainerLinearSvmTrainer
- OnlineGradientDescentTrainerOnlineGradientDescentTrainer
- SgdCalibratedTrainerSgdCalibratedTrainer
- SgdNonCalibratedTrainerSgdNonCalibratedTrainer
- SymbolicSgdLogisticRegressionBinaryTrainerSymbolicSgdLogisticRegressionBinaryTrainer
加载预先训练的模型Load pre-trained model
首先,将预先训练的模型加载到应用程序中。First, load the pre-trained model into your application. 若要了解有关加载训练管道和模型的详细信息,请参阅保存和加载已定型模型。To learn more about loading training pipelines and models, see Save and load a trained model.
// Create MLContext
MLContext mlContext = new MLContext();
// Define DataViewSchema of data prep pipeline and trained model
DataViewSchema dataPrepPipelineSchema, modelSchema;
// Load data preparation pipeline
ITransformer dataPrepPipeline = mlContext.Model.Load("data_preparation_pipeline.zip", out dataPrepPipelineSchema);
// Load trained model
ITransformer trainedModel = mlContext.Model.Load("ogd_model.zip", out modelSchema);
提取预先训练的模型参数Extract pre-trained model parameters
加载模型后,通过访问预先训练模型的 Model
属性来提取已学习的模型参数。Once the model is loaded, extract the learned model parameters by accessing the Model
property of the pre-trained model. 使用线性回归模型 OnlineGradientDescentTrainer
训练了预先训练的模型,该线性回归模型可创建输出 LinearRegressionModelParameters
的 RegressionPredictionTransformer
。The pre-trained model was trained using the linear regression model OnlineGradientDescentTrainer
which creates a RegressionPredictionTransformer
that outputs LinearRegressionModelParameters
. 这些线性回归模型参数包含模型已学习的偏差和权重或系数。These linear regression model parameters contain the learned bias and weights or coefficients of the model. 这些值将用作新的重新训练模型的起点。These values will be used as a starting point for the new re-trained model.
// Extract trained model parameters
LinearRegressionModelParameters originalModelParameters =
((ISingleFeaturePredictionTransformer<object>)trainedModel).Model as LinearRegressionModelParameters;
重新训练模型Re-train model
重新训练模型的过程与训练模型的过程没有什么不同。The process for retraining a model is no different than that of training a model. 唯一的区别是,除了数据之外,Fit
方法还将原始学习模型参数作为输入,并将它们用作重新训练过程的起点。The only difference is, the Fit
method in addition to the data also takes as input the original learned model parameters and uses them as a starting point in the re-training process.
// New Data
HousingData[] housingData = new HousingData[]
{
new HousingData
{
Size = 850f,
HistoricalPrices = new float[] { 150000f,175000f,210000f },
CurrentPrice = 205000f
},
new HousingData
{
Size = 900f,
HistoricalPrices = new float[] { 155000f, 190000f, 220000f },
CurrentPrice = 210000f
},
new HousingData
{
Size = 550f,
HistoricalPrices = new float[] { 99000f, 98000f, 130000f },
CurrentPrice = 180000f
}
};
//Load New Data
IDataView newData = mlContext.Data.LoadFromEnumerable<HousingData>(housingData);
// Preprocess Data
IDataView transformedNewData = dataPrepPipeline.Transform(newData);
// Retrain model
RegressionPredictionTransformer<LinearRegressionModelParameters> retrainedModel =
mlContext.Regression.Trainers.OnlineGradientDescent()
.Fit(transformedNewData, originalModelParameters);
比较模型参数Compare model parameters
如何知道是否真的进行了重新训练?How do you know if re-training actually happened? 一种方法是比较重新训练模型的参数是否与原始模型的参数不同。One way would be to compare whether the re-trained model's parameters are different than those of the original model. 下面的代码示例将原始模型与重新训练模型的权重进行比较,并将它们输出到控制台。The code sample below compares the original against the re-trained model weights and outputs them to the console.
// Extract Model Parameters of re-trained model
LinearRegressionModelParameters retrainedModelParameters = retrainedModel.Model as LinearRegressionModelParameters;
// Inspect Change in Weights
var weightDiffs =
originalModelParameters.Weights.Zip(
retrainedModelParameters.Weights, (original, retrained) => original - retrained).ToArray();
Console.WriteLine("Original | Retrained | Difference");
for(int i=0;i < weightDiffs.Count();i++)
{
Console.WriteLine($"{originalModelParameters.Weights[i]} | {retrainedModelParameters.Weights[i]} | {weightDiffs[i]}");
}
下表显示了可能的输出。The table below shows what the output might look like.
原始Original | 重新训练后Retrained | 差值Difference |
---|---|---|
33039.8633039.86 | 56293.7656293.76 | -23253.9-23253.9 |
29099.1429099.14 | 49586.0349586.03 | -20486.89-20486.89 |
28938.3828938.38 | 48609.2348609.23 | -19670.85-19670.85 |
30484.0230484.02 | 53745.4353745.43 | -23261.41-23261.41 |