Trénování modelu strojového učení pomocí křížového ověřování

Naučte se používat křížové ověřování k trénování robustnějších modelů strojového učení v ML.NET.

Křížové ověření je technika trénování a vyhodnocení modelu, která rozdělí data do několika oddílů a trénuje více algoritmů v těchto oddílech. Tato technika zlepšuje odolnost modelu tím, že z trénovacího procesu vydrží data. Kromě zlepšení výkonu u nezoznaných pozorování může být v prostředích s omezenými daty efektivním nástrojem pro trénování modelů s menší datovou sadou.

Data a datový model

Data ze souboru, který má následující formát:

Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41

Data lze modelovat podle třídy jako HousingData a načíst do objektu IDataView.

public class HousingData
{
    [LoadColumn(0)]
    public float Size { get; set; }

    [LoadColumn(1, 3)]
    [VectorType(3)]
    public float[] HistoricalPrices { get; set; }

    [LoadColumn(4)]
    [ColumnName("Label")]
    public float CurrentPrice { get; set; }
}

Příprava dat

Před použitím dat před sestavením modelu strojového učení je předzpracujte. V této ukázce SizeHistoricalPrices se sloupce zkombinují do jednoho vektoru funkce, který je výstupem do nového sloupce volaného Features metodou Concatenate . Kromě získání dat do formátu očekávaného algoritmy ML.NET zřetězení sloupců optimalizuje následné operace v kanálu použitím operace jednou pro zřetězený sloupec místo jednotlivých samostatných sloupců.

Jakmile se sloupce zkombinují do jednoho vektoru, použije se u sloupce, NormalizeMinMax který získá Size a HistoricalPrices ve stejném rozsahu mezi 0–Features1.

// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
    mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
        .Append(mlContext.Transforms.NormalizeMinMax("Features"));

// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);

// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);

Trénování modelu s křížovým ověřováním

Po předběžném zpracování dat je čas model vytrénovat. Nejprve vyberte algoritmus, který je nejvíce v souladu s úlohou strojového učení, který se má provést. Vzhledem k tomu, že predikovaná hodnota je číselně souvislá hodnota, je úkol regresní. Jedním z regresních algoritmů implementovaných ML.NET je StochasticDualCoordinateAscentCoordinator algoritmus. K trénování modelu pomocí křížového ověření použijte metodu CrossValidate .

Poznámka:

I když tato ukázka používá lineární regresní model, crossValidate se vztahuje na všechny ostatní úlohy strojového učení v ML.NET s výjimkou detekce anomálií.

// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();

// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);

CrossValidate provádí následující operace:

  1. Rozdělí data do několika oddílů, které se rovnají hodnotě zadané v parametru numberOfFolds . Výsledkem každého oddílu TrainTestData je objekt.
  2. Model se vytrénuje na každém oddílu pomocí zadaného estimátoru algoritmů strojového učení v trénovací sadě dat.
  3. Výkon každého modelu se vyhodnocuje pomocí Evaluate metody testovací datové sady.
  4. Model spolu s jeho metrikami se vrátí pro každý z těchto modelů.

Výsledek uložený v cvResults kolekci CrossValidationResult objektů. Tento objekt zahrnuje vytrénovaný model i metriky, které jsou přístupné jak ve formě, tak ModelMetrics i vlastnosti. V této ukázce Model je vlastnost typu ITransformer a Metrics vlastnost je typu RegressionMetrics.

Vyhodnocení modelu

K metrikám pro různé vytrénované modely je možné přistupovat prostřednictvím Metrics vlastnosti jednotlivého CrossValidationResult objektu. V tomto případě je metrika R-Squared přístupná a uložená v proměnné rSquared.

IEnumerable<double> rSquared =
    cvResults
        .Select(fold => fold.Metrics.RSquared);

Pokud zkontrolujete obsah rSquared proměnné, měl by mít výstup pět hodnot v rozsahu od 0 do 1, kde je to nejlepší. Pomocí metrik, jako je R-Squared, vyberte modely od nejlepších po nejhorší výkon. Pak výběrem horního modelu proveďte předpovědi nebo proveďte další operace.

// Select all models
ITransformer[] models =
    cvResults
        .OrderByDescending(fold => fold.Metrics.RSquared)
        .Select(fold => fold.Model)
        .ToArray();

// Get Top Model
ITransformer topModel = models[0];