Interpretace predikcí modelu pomocí důležitosti funkce Permutation

Pomocí funkce Permutation Feature Importance (PFI) se naučíte interpretovat ML.NET predikce modelu strojového učení. PFI dává relativní příspěvek každé funkci k predikci.

Modely strojového učení se často považují za neprůžné rámečky, které přebírají vstupy a generují výstup. Přechodné kroky nebo interakce mezi funkcemi, které ovlivňují výstup, jsou zřídka srozumitelné. Vzhledem k tomu, že strojové učení je zavedeno do více aspektů každodenního života, jako je zdravotní péče, je nanejvýš důležité pochopit, proč model strojového učení dělá rozhodnutí. Pokud jsou například diagnostiky vytvořené modelem strojového učení, potřebují odborníci na zdravotnictví způsob, jak se podívat na faktory, které se dostaly do provádění diagnostiky. Poskytnutí správné diagnózy může mít velký rozdíl na tom, zda má pacient rychlé obnovení, nebo ne. Proto čím vyšší je úroveň vysvětlitelnosti modelu, tím větší spolehlivost zdravotnických odborníků musí přijmout nebo odmítnout rozhodnutí, která model udělal.

K vysvětlení modelů se používají různé techniky, z nichž jedna je PFI. PFI je technika používaná k vysvětlení klasifikačních a regresních modelů inspirovaných dokumentem Random Forests společnosti Breiman (viz část 10). Na vysoké úrovni je způsob, jakým funguje, náhodným náhodném prohazování jedné funkce pro celou datovou sadu a výpočtu toho, kolik metrik výkonu zájmu klesá. Čím větší je změna, tím důležitější je tato funkce.

Kromě toho se tvůrci modelů můžou zaměřit na používání podmnožina smysluplnějších funkcí, které můžou potenciálně snížit šum a dobu trénování.

Načtení dat

Funkce v datové sadě používané pro tuto ukázku jsou ve sloupcích 1–12. Cílem je předpovědět Price.

Column Funkce Popis
1 CrimeRate Míra trestné činnosti na obyvatele
2 ResidentialZones Obytné zóny ve městě
3 CommercialZones Nebytové zóny ve městě
4 NearWater Blízkost vodního těla
5 ToxicWasteLevels Úrovně toxicity (PPM)
6 AverageRoomNumber Průměrný počet místností v domě
7 HomeAge Věk domu
8 BusinessCenterDistance Vzdálenost k nejbližší obchodní čtvrti
9 HighwayAccess Blízkost dálnic
10 TaxRate Sazba daně z nemovitostí
11 StudentTeacherRatio Poměr studentů k učitelům
12 PercentPopulationBelowPoverty Procento populace žijící pod chudobou
13 Cena Cena domu

Ukázka datové sady je znázorněná níže:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

Data v této ukázce mohou být modelována třídou jako HousingPriceData a načtena do .IDataView

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Trénování modelu

Následující ukázka kódu znázorňuje proces trénování lineárního regresního modelu, který předpovídá ceny domů.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);

Vysvětlení modelu s důležitostí funkcí permutace (PFI)

V ML.NET použijte metodu PermutationFeatureImportance pro příslušný úkol.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

Výsledkem použití PermutationFeatureImportance na trénovací datové sadě je ImmutableArrayRegressionMetricsStatistics objekt. RegressionMetricsStatistics poskytuje souhrnné statistiky, jako jsou střední a směrodatná odchylka pro více pozorování RegressionMetrics rovnajících se počtu permutací určených parametrem permutationCount .

Metrika použitá k měření důležitosti funkcí závisí na úloze strojového učení použité k vyřešení vašeho problému. Například regresní úlohy můžou k měření důležitosti použít běžnou metriku vyhodnocení, například R-squared. Další informace o metrikách vyhodnocení modelu najdete v tématu vyhodnocení modelu ML.NET metrikami.

Důležitost nebo v tomto případě je možné absolutní průměrné snížení metriky R na druhou mocninu vypočítané podle toho, co PermutationFeatureImportance je poté seřazeno od nejdůležitějších po nejméně důležité.

// Order features by importance
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

Při tisku hodnot jednotlivých funkcí by featureImportanceMetrics se výstup podobný následujícímu vygeneroval. Mějte na paměti, že byste měli očekávat různé výsledky, protože tyto hodnoty se liší v závislosti na zadaných datech.

Funkce Změna na R-Squared
HighwayAccess -0.042731
StudentTeacherRatio -0.012730
BusinessCenterDistance -0.010491
TaxRate -0.008545
AverageRoomNumber -0.003949
CrimeRate -0.003665
CommercialZones 0.002749
HomeAge -0.002426
ResidentialZones -0.002319
NearWater 0.000203
PercentPopulationLivingBelowPoverty 0.000031
ToxicWasteLevels -0.000019

Když se podíváme na pět nejdůležitějších funkcí pro tuto datovou sadu, cena domu předpovězená tímto modelem je ovlivněna jeho blízkostí k dálnicím, poměrem učitelů studentů v dané oblasti, blízkost hlavních pracovních center, sazbou daně z nemovitostí a průměrným počtem místností v domácnosti.

Další kroky