DataOperationsCatalog.CrossValidationSplit 方法

定義

將資料集分割成訓練集和測試集的交叉驗證折迭。 如果提供, samplingKeyColumnName 則為 。

public System.Collections.Generic.IReadOnlyList<Microsoft.ML.DataOperationsCatalog.TrainTestData> CrossValidationSplit (Microsoft.ML.IDataView data, int numberOfFolds = 5, string samplingKeyColumnName = default, int? seed = default);
member this.CrossValidationSplit : Microsoft.ML.IDataView * int * string * Nullable<int> -> System.Collections.Generic.IReadOnlyList<Microsoft.ML.DataOperationsCatalog.TrainTestData>
Public Function CrossValidationSplit (data As IDataView, Optional numberOfFolds As Integer = 5, Optional samplingKeyColumnName As String = Nothing, Optional seed As Nullable(Of Integer) = Nothing) As IReadOnlyList(Of DataOperationsCatalog.TrainTestData)

參數

data
IDataView

要分割的資料集。

numberOfFolds
Int32

交叉驗證折迭的數目。

samplingKeyColumnName
String

要用於分組資料列的資料行名稱。 如果兩個範例共用 相同的值 samplingKeyColumnName ,則保證它們會出現在相同的子集中, (定型或測試) 。 這可用來確保不會從定型外泄至測試集。 請注意,執行排名實驗時, samplingKeyColumnName 必須是 GroupId 資料行。 如果未 null 執行任何資料列群組。

seed
Nullable<Int32>

亂數產生器的種子,用來選取交叉驗證折迭的資料列。

傳回

範例

using System;
using System.Collections.Generic;
using Microsoft.ML;

namespace Samples.Dynamic
{
    /// <summary>
    /// Sample class showing how to use CrossValidationSplit.
    /// </summary>
    public static class CrossValidationSplit
    {
        public static void Example()
        {
            // Creating the ML.Net IHostEnvironment object, needed for the pipeline.
            var mlContext = new MLContext();

            // Generate some data points.
            var examples = GenerateRandomDataPoints(10);

            // Convert the examples list to an IDataView object, which is consumable
            // by ML.NET API.
            var dataview = mlContext.Data.LoadFromEnumerable(examples);

            // Cross validation splits your data randomly into set of "folds", and
            // creates groups of Train and Test sets, where for each group, one fold
            // is the Test and the rest of the folds the Train. So below, we specify
            // Group column as the column containing the sampling keys. If we pass
            // that column to cross validation it would be used to break data into
            // certain chunks.
            var folds = mlContext.Data
                .CrossValidationSplit(dataview, numberOfFolds: 3,
                samplingKeyColumnName: "Group");

            var trainSet = mlContext.Data
                .CreateEnumerable<DataPoint>(folds[0].TrainSet,
                reuseRowObject: false);

            var testSet = mlContext.Data
                .CreateEnumerable<DataPoint>(folds[0].TestSet,
                reuseRowObject: false);

            PrintPreviewRows(trainSet, testSet);

            // The data in the Train split.
            // [Group, 1], [Features, 0.8173254]
            // [Group, 2], [Features, 0.7680227]
            // [Group, 1], [Features, 0.2060332]
            // [Group, 2], [Features, 0.5588848]
            // [Group, 1], [Features, 0.4421779]
            // [Group, 2], [Features, 0.9775497]
            // 
            // The data in the Test split.
            // [Group, 0], [Features, 0.7262433]
            // [Group, 0], [Features, 0.5581612]
            // [Group, 0], [Features, 0.9060271]
            // [Group, 0], [Features, 0.2737045]

            trainSet = mlContext.Data
                .CreateEnumerable<DataPoint>(folds[1].TrainSet,
                reuseRowObject: false);

            testSet = mlContext.Data
                .CreateEnumerable<DataPoint>(folds[1].TestSet,
                reuseRowObject: false);

            PrintPreviewRows(trainSet, testSet);
            // The data in the Train split.
            // [Group, 0], [Features, 0.7262433]
            // [Group, 2], [Features, 0.7680227]
            // [Group, 0], [Features, 0.5581612]
            // [Group, 2], [Features, 0.5588848]
            // [Group, 0], [Features, 0.9060271]
            // [Group, 2], [Features, 0.9775497]
            // [Group, 0], [Features, 0.2737045]
            // 
            // The data in the Test split.
            // [Group, 1], [Features, 0.8173254]
            // [Group, 1], [Features, 0.2060332]
            // [Group, 1], [Features, 0.4421779]

            trainSet = mlContext.Data
                .CreateEnumerable<DataPoint>(folds[2].TrainSet,
                reuseRowObject: false);

            testSet = mlContext.Data
                .CreateEnumerable<DataPoint>(folds[2].TestSet,
                reuseRowObject: false);

            PrintPreviewRows(trainSet, testSet);
            // The data in the Train split.
            // [Group, 0], [Features, 0.7262433]
            // [Group, 1], [Features, 0.8173254]
            // [Group, 0], [Features, 0.5581612]
            // [Group, 1], [Features, 0.2060332]
            // [Group, 0], [Features, 0.9060271]
            // [Group, 1], [Features, 0.4421779]
            // [Group, 0], [Features, 0.2737045]
            // 
            // The data in the Test split.
            // [Group, 2], [Features, 0.7680227]
            // [Group, 2], [Features, 0.5588848]
            // [Group, 2], [Features, 0.9775497]

            // Example of a split without specifying a sampling key column.
            folds = mlContext.Data.CrossValidationSplit(dataview, numberOfFolds: 3);
            trainSet = mlContext.Data
                .CreateEnumerable<DataPoint>(folds[0].TrainSet,
                reuseRowObject: false);

            testSet = mlContext.Data
                .CreateEnumerable<DataPoint>(folds[0].TestSet,
                reuseRowObject: false);

            PrintPreviewRows(trainSet, testSet);
            // The data in the Train split.
            // [Group, 0], [Features, 0.7262433]
            // [Group, 1], [Features, 0.8173254]
            // [Group, 2], [Features, 0.7680227]
            // [Group, 0], [Features, 0.5581612]
            // [Group, 1], [Features, 0.2060332]
            // [Group, 1], [Features, 0.4421779]
            // [Group, 2], [Features, 0.9775497]
            // [Group, 0], [Features, 0.2737045]
            // 
            // The data in the Test split.
            // [Group, 2], [Features, 0.5588848]
            // [Group, 0], [Features, 0.9060271]

            trainSet = mlContext.Data
                .CreateEnumerable<DataPoint>(folds[1].TrainSet,
                reuseRowObject: false);

            testSet = mlContext.Data
                .CreateEnumerable<DataPoint>(folds[1].TestSet,
                reuseRowObject: false);

            PrintPreviewRows(trainSet, testSet);
            // The data in the Train split.
            // [Group, 2], [Features, 0.7680227]
            // [Group, 0], [Features, 0.5581612]
            // [Group, 1], [Features, 0.2060332]
            // [Group, 2], [Features, 0.5588848]
            // [Group, 0], [Features, 0.9060271]
            // [Group, 1], [Features, 0.4421779]
            // 
            // The data in the Test split.
            // [Group, 0], [Features, 0.7262433]
            // [Group, 1], [Features, 0.8173254]
            // [Group, 2], [Features, 0.9775497]
            // [Group, 0], [Features, 0.2737045]

            trainSet = mlContext.Data
                .CreateEnumerable<DataPoint>(folds[2].TrainSet,
                reuseRowObject: false);

            testSet = mlContext.Data.CreateEnumerable<DataPoint>(folds[2].TestSet,
                reuseRowObject: false);

            PrintPreviewRows(trainSet, testSet);
            // The data in the Train split.
            // [Group, 0], [Features, 0.7262433]
            // [Group, 1], [Features, 0.8173254]
            // [Group, 2], [Features, 0.5588848]
            // [Group, 0], [Features, 0.9060271]
            // [Group, 2], [Features, 0.9775497]
            // [Group, 0], [Features, 0.2737045]
            // 
            // The data in the Test split.
            // [Group, 2], [Features, 0.7680227]
            // [Group, 0], [Features, 0.5581612]
            // [Group, 1], [Features, 0.2060332]
            // [Group, 1], [Features, 0.4421779]
        }

        private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
            int seed = 0)

        {
            var random = new Random(seed);
            for (int i = 0; i < count; i++)
            {
                yield return new DataPoint
                {
                    Group = i % 3,

                    // Create random features that are correlated with label.
                    Features = (float)random.NextDouble()
                };
            }
        }

        // Example with features and group column. A data set is a collection of
        // such examples.
        private class DataPoint
        {
            public float Group { get; set; }

            public float Features { get; set; }
        }

        // print helper
        private static void PrintPreviewRows(IEnumerable<DataPoint> trainSet,
            IEnumerable<DataPoint> testSet)

        {

            Console.WriteLine($"The data in the Train split.");
            foreach (var row in trainSet)
                Console.WriteLine($"{row.Group}, {row.Features}");

            Console.WriteLine($"\nThe data in the Test split.");
            foreach (var row in testSet)
                Console.WriteLine($"{row.Group}, {row.Features}");
        }
    }
}

適用於