Tolka modellförutsägelser med hjälp av funktionsvikt för permutation

Med hjälp av PFI (Permutation Feature Importance) får du lära dig hur du tolkar ML.NET förutsägelser för maskininlärningsmodellen. PFI ger det relativa bidrag som varje funktion ger till en förutsägelse.

Maskininlärningsmodeller betraktas ofta som ogenomskinliga rutor som tar indata och genererar utdata. De mellanliggande stegen eller interaktionerna mellan de funktioner som påverkar utdata förstås sällan. När maskininlärning introduceras i fler aspekter av vardagen, till exempel sjukvård, är det av yttersta vikt att förstå varför en maskininlärningsmodell fattar de beslut den gör. Om diagnoser till exempel görs av en maskininlärningsmodell behöver vårdpersonal ett sätt att undersöka de faktorer som gick till att göra diagnoserna. Att tillhandahålla rätt diagnos kan göra stor skillnad på om en patient har en snabb återhämtning eller inte. Ju högre förklaringsnivån i en modell är, desto större förtroende måste vårdpersonalen därför acceptera eller avvisa de beslut som fattas av modellen.

Olika tekniker används för att förklara modeller, varav en är PFI. PFI är en teknik som används för att förklara klassificerings- och regressionsmodeller som är inspirerade av Breimans randomskogspapper (se avsnitt 10). På en hög nivå fungerar det genom att slumpmässigt blanda data en funktion i taget för hela datamängden och beräkna hur mycket prestandamåttet av intresse minskar. Ju större förändring, desto viktigare är funktionen.

Genom att markera de viktigaste funktionerna kan modellbyggare dessutom fokusera på att använda en delmängd av mer meningsfulla funktioner som potentiellt kan minska brus och träningstid.

Läsa in data

Funktionerna i datamängden som används för det här exemplet finns i kolumnerna 1–12. Målet är att förutsäga Price.

Column Funktion Beskrivning
1 CrimeRate Brottsfrekvens per capita
2 ResidentialZones Bostadsområden i stan
3 CommercialZones Icke-bostadszoner i staden
4 NearWater Närhet till vattenmassa
5 ToxicWasteLevels Toxicitetsnivåer (PPM)
6 AverageRoomNumber Genomsnittligt antal rum i huset
7 HomeAge Hemmets ålder
8 BusinessCenterDistance Avstånd till närmaste affärsdistrikt
9 HighwayAccess Närhet till motorvägar
10 Taxrate Fastighetsskattesats
11 StudentTeacherRatio Förhållandet mellan elever och lärare
12 PercentPopulationBelowPoverty Procent av befolkningen som lever under fattigdom
13 Pris Priset på hemmet

Ett exempel på datamängden visas nedan:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

Data i det här exemplet kan modelleras av en klass som HousingPriceData och läsas in i en IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Träna modellen

Kodexemplet nedan illustrerar processen att träna en linjär regressionsmodell för att förutsäga huspriser.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);

Förklara modellen med PFI (Permutation Feature Importance)

I ML.NET använder du PermutationFeatureImportance metoden för din respektive uppgift.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

Resultatet av att använda PermutationFeatureImportance på träningsdatauppsättningen är ett ImmutableArray av RegressionMetricsStatistics objekten. RegressionMetricsStatistics innehåller sammanfattningsstatistik som medelvärde och standardavvikelse för flera observationer som RegressionMetrics är lika med det antal permutationer som anges av parametern permutationCount .

Måttet som används för att mäta funktionsvikt beror på vilken maskininlärningsuppgift som används för att lösa problemet. Regressionsaktiviteter kan till exempel använda ett vanligt utvärderingsmått, till exempel R-kvadrat för att mäta prioritet. Mer information om modellutvärderingsmått finns i utvärdera din ML.NET modell med mått.

Vikten, eller i det här fallet den absoluta genomsnittliga minskningen av R-kvadratmåttet som beräknas av PermutationFeatureImportance , kan sedan sorteras från viktigast till minst viktigt.

// Order features by importance
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

Om du skriver ut värdena för var och en av funktionerna i featureImportanceMetrics genereras utdata som liknar dem nedan. Tänk på att du bör förvänta dig att se olika resultat eftersom dessa värden varierar beroende på vilka data de ges.

Funktion Ändra till R-Kvadrat
HighwayAccess -0.042731
StudentTeacherRatio -0.012730
BusinessCenterDistance -0.010491
Taxrate -0.008545
AverageRoomNumber -0.003949
CrimeRate -0.003665
CommercialZones 0.002749
HomeAge -0.002426
ResidentialZones -0.002319
NearWater 0.000203
PercentPopulationLivingBelowPoverty 0.000031
ToxicWasteLevels -0.000019

Om du tar en titt på de fem viktigaste funktionerna för denna datamängd påverkas priset på ett hus som förutspås av denna modell av dess närhet till motorvägar, elevlärarens förhållande mellan skolor i området, närhet till stora arbetsförmedlingar, fastighetsskattesats och genomsnittligt antal rum i hemmet.

Nästa steg