Tolka modellförutsägelser med hjälp av funktionsvikt för permutation
Med hjälp av PFI (Permutation Feature Importance) får du lära dig hur du tolkar ML.NET förutsägelser för maskininlärningsmodellen. PFI ger det relativa bidrag som varje funktion ger till en förutsägelse.
Maskininlärningsmodeller betraktas ofta som ogenomskinliga rutor som tar indata och genererar utdata. De mellanliggande stegen eller interaktionerna mellan de funktioner som påverkar utdata förstås sällan. När maskininlärning introduceras i fler aspekter av vardagen, till exempel sjukvård, är det av yttersta vikt att förstå varför en maskininlärningsmodell fattar de beslut den gör. Om diagnoser till exempel görs av en maskininlärningsmodell behöver vårdpersonal ett sätt att undersöka de faktorer som gick till att göra diagnoserna. Att tillhandahålla rätt diagnos kan göra stor skillnad på om en patient har en snabb återhämtning eller inte. Ju högre förklaringsnivån i en modell är, desto större förtroende måste vårdpersonalen därför acceptera eller avvisa de beslut som fattas av modellen.
Olika tekniker används för att förklara modeller, varav en är PFI. PFI är en teknik som används för att förklara klassificerings- och regressionsmodeller som är inspirerade av Breimans randomskogspapper (se avsnitt 10). På en hög nivå fungerar det genom att slumpmässigt blanda data en funktion i taget för hela datamängden och beräkna hur mycket prestandamåttet av intresse minskar. Ju större förändring, desto viktigare är funktionen.
Genom att markera de viktigaste funktionerna kan modellbyggare dessutom fokusera på att använda en delmängd av mer meningsfulla funktioner som potentiellt kan minska brus och träningstid.
Läsa in data
Funktionerna i datamängden som används för det här exemplet finns i kolumnerna 1–12. Målet är att förutsäga Price
.
Column | Funktion | Beskrivning |
---|---|---|
1 | CrimeRate | Brottsfrekvens per capita |
2 | ResidentialZones | Bostadsområden i stan |
3 | CommercialZones | Icke-bostadszoner i staden |
4 | NearWater | Närhet till vattenmassa |
5 | ToxicWasteLevels | Toxicitetsnivåer (PPM) |
6 | AverageRoomNumber | Genomsnittligt antal rum i huset |
7 | HomeAge | Hemmets ålder |
8 | BusinessCenterDistance | Avstånd till närmaste affärsdistrikt |
9 | HighwayAccess | Närhet till motorvägar |
10 | Taxrate | Fastighetsskattesats |
11 | StudentTeacherRatio | Förhållandet mellan elever och lärare |
12 | PercentPopulationBelowPoverty | Procent av befolkningen som lever under fattigdom |
13 | Pris | Priset på hemmet |
Ett exempel på datamängden visas nedan:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
Data i det här exemplet kan modelleras av en klass som HousingPriceData
och läsas in i en IDataView
.
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Träna modellen
Kodexemplet nedan illustrerar processen att träna en linjär regressionsmodell för att förutsäga huspriser.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);
Förklara modellen med PFI (Permutation Feature Importance)
I ML.NET använder du PermutationFeatureImportance
metoden för din respektive uppgift.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
Resultatet av att använda PermutationFeatureImportance
på träningsdatauppsättningen är ett ImmutableArray
av RegressionMetricsStatistics
objekten. RegressionMetricsStatistics
innehåller sammanfattningsstatistik som medelvärde och standardavvikelse för flera observationer som RegressionMetrics
är lika med det antal permutationer som anges av parametern permutationCount
.
Måttet som används för att mäta funktionsvikt beror på vilken maskininlärningsuppgift som används för att lösa problemet. Regressionsaktiviteter kan till exempel använda ett vanligt utvärderingsmått, till exempel R-kvadrat för att mäta prioritet. Mer information om modellutvärderingsmått finns i utvärdera din ML.NET modell med mått.
Vikten, eller i det här fallet den absoluta genomsnittliga minskningen av R-kvadratmåttet som beräknas av PermutationFeatureImportance
, kan sedan sorteras från viktigast till minst viktigt.
// Order features by importance
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
Om du skriver ut värdena för var och en av funktionerna i featureImportanceMetrics
genereras utdata som liknar dem nedan. Tänk på att du bör förvänta dig att se olika resultat eftersom dessa värden varierar beroende på vilka data de ges.
Funktion | Ändra till R-Kvadrat |
---|---|
HighwayAccess | -0.042731 |
StudentTeacherRatio | -0.012730 |
BusinessCenterDistance | -0.010491 |
Taxrate | -0.008545 |
AverageRoomNumber | -0.003949 |
CrimeRate | -0.003665 |
CommercialZones | 0.002749 |
HomeAge | -0.002426 |
ResidentialZones | -0.002319 |
NearWater | 0.000203 |
PercentPopulationLivingBelowPoverty | 0.000031 |
ToxicWasteLevels | -0.000019 |
Om du tar en titt på de fem viktigaste funktionerna för denna datamängd påverkas priset på ett hus som förutspås av denna modell av dess närhet till motorvägar, elevlärarens förhållande mellan skolor i området, närhet till stora arbetsförmedlingar, fastighetsskattesats och genomsnittligt antal rum i hemmet.
Nästa steg
Feedback
https://aka.ms/ContentUserFeedback.
Kommer snart: Under hela 2024 kommer vi att fasa ut GitHub-problem som feedbackmekanism för innehåll och ersätta det med ett nytt feedbacksystem. Mer information finns i:Skicka och visa feedback för