使用经过训练的模型进行预测Make predictions with a trained model
了解如何使用经过训练的模型进行预测Learn how to use a trained model to make predictions
创建数据模型Create data models
输入数据Input data
public class HousingData
{
[LoadColumn(0)]
public float Size { get; set; }
[LoadColumn(1, 3)]
[VectorType(3)]
public float[] HistoricalPrices { get; set; }
[LoadColumn(4)]
[ColumnName("Label")]
public float CurrentPrice { get; set; }
}
输出数据Output data
与 Features
和 Label
输入列名一样,ML.NET 为模型生成的预测值列提供默认名称。Like the Features
and Label
input column names, ML.NET has default names for the predicted value columns produced by a model. 名称可能因任务而异。Depending on the task the name may differ.
由于此示例中使用的算法是线性回归算法,输出列的默认名称为 Score
,它由 PredictedPrice
属性上的 ColumnName
特性定义。Because the algorithm used in this sample is a linear regression algorithm, the default name of the output column is Score
which is defined by the ColumnName
attribute on the PredictedPrice
property.
class HousingPrediction
{
[ColumnName("Score")]
public float PredictedPrice { get; set; }
}
设置预测管道Set up a prediction pipeline
无论是进行单一预测还是批量预测,都需要将预测管道加载到应用程序中。Whether making a single or batch prediction, the prediction pipeline needs to be loaded into the application. 此管道包含数据预处理转换以及经过训练的模型。This pipeline contains both the data pre-processing transformations as well as the trained model. 下面的代码片段从名为 model.zip
的文件中加载预测管道。The code snippet below loads the prediction pipeline from a file named model.zip
.
//Create MLContext
MLContext mlContext = new MLContext();
// Load Trained Model
DataViewSchema predictionPipelineSchema;
ITransformer predictionPipeline = mlContext.Model.Load("model.zip", out predictionPipelineSchema);
单一预测Single prediction
若要进行单一预测,请使用加载的预测管道创建 PredictionEngine
。To make a single prediction, create a PredictionEngine
using the loaded prediction pipeline.
// Create PredictionEngines
PredictionEngine<HousingData, HousingPrediction> predictionEngine = mlContext.Model.CreatePredictionEngine<HousingData, HousingPrediction>(predictionPipeline);
然后,使用 Predict
方法并将输入数据作为参数传入。Then, use the Predict
method and pass in your input data as a parameter. 请注意,使用 Predict
方法不要求输入为 IDataView
。Notice that using the Predict
method does not require the input to be an IDataView
). 这是因为它可以方便地内在化输入数据类型操作,以便能够传入输入数据类型的对象。This is because it conveniently internalizes the input data type manipulation so you can pass in an object of the input data type. 此外,由于 CurrentPrice
是尝试使用新数据进行预测的目标或标签,假设此时没有用于它的值。Additionally, since CurrentPrice
is the target or label you're trying to predict using new data, it's assumed there is no value for it at the moment.
// Input Data
HousingData inputData = new HousingData
{
Size = 900f,
HistoricalPrices = new float[] { 155000f, 190000f, 220000f }
};
// Get Prediction
HousingPrediction prediction = predictionEngine.Predict(inputData);
如果访问 prediction
对象的 Score
属性,则应获得类似于 150079
的值。If you access the Score
property of the prediction
object, you should get a value similar to 150079
.
多个预测Multiple predictions
给定以下数据,将其加载到 IDataView
中。Given the following data, load it into an IDataView
. 在这种情况下,IDataView
的名称可能是 inputData
。In this case, the name of the IDataView
is inputData
. 因为 CurrentPrice
是尝试使用新数据进行预测的目标或标签,所以假设此时没有用于它的值。Because CurrentPrice
is the target or label you're trying to predict using new data, it's assumed there is no value for it at the moment.
// Actual data
HousingData[] housingData = new HousingData[]
{
new HousingData
{
Size = 850f,
HistoricalPrices = new float[] { 150000f, 175000f, 210000f }
},
new HousingData
{
Size = 900f,
HistoricalPrices = new float[] { 155000f, 190000f, 220000f }
},
new HousingData
{
Size = 550f,
HistoricalPrices = new float[] { 99000f, 98000f, 130000f }
}
};
然后,使用 Transform
方法应用数据转换并生成预测。Then, use the Transform
method to apply the data transformations and generate predictions.
// Predicted Data
IDataView predictions = predictionPipeline.Transform(inputData);
使用 GetColumn
方法检查预测值。Inspect the predicted values by using the GetColumn
method.
// Get Predictions
float[] scoreColumn = predictions.GetColumn<float>("Score").ToArray();
分数列中的预测值应如下所示:The predicted values in the score column should look like the following:
观测Observation | 预测Prediction |
---|---|
11 | 144638.2144638.2 |
22 | 150079.4150079.4 |
33 | 107789.8107789.8 |