Тесты

Пробит-классификация на C#

Джеймс Маккафри

Исходный код можно скачать по ссылке

James McCaffreyПробит-классификация (probability unit, probit) — это метод машинного обучения (machine learning, ML), который можно использовать для прогнозирования в ситуациях, где зависимая переменная, значение которой нужно предсказать, является бинарной, т. е. она может принимать одно из двух возможных значений. Пробит-классификацию также называют пробит-регрессией и пробит-моделированием.

Пробит-классификация весьма похожа на классификацию методом логистической регрессии (logistic regression, LR). Эти два метода применяются к одним и тем же типам задач и дают сходные результаты. Выбор одного из этих двух методов обычно зависит от дисциплины, в которой вы работаете. Пробит часто применяют в экономике и финансах, а LR чаще используют в других областях.

Чтобы получить представление о том, что такое пробит-классификация, взгляните на демонстрационную программу на рис. 1.

Пробит-классификация в действии
Рис. 1. Пробит-классификация в действии

В демонстрации пробит-классификация используется для создания модели, которая предсказывает, умрет ли пациент в больнице, исходя из его возраста, пола и результатов исследования почек. Данные полностью искусственные. Первый элемент исходных данных выглядит так:

48.00   +1.00   4.40   0

Исходные данные (raw data) состоят из 30 элементов. Пол кодируется как –1 (мужчина) и +1 (женщина). Прогнозируемое значение Died находится в последнем столбце и кодируется как 0 = false (пациент выживет) и 1 = true. Таким образом, первый элемент данных определяет 48-летнюю женщину с оценкой состояния почек 4.40, которая выживет. Демонстрационная программа начинает с нормализации возраста и данных состояния почек, чтобы все значения имели примерно одинаковый порядок величин. Первый элемент данных после нормализации становится таким:

-0.74   +1.00   -0.61   0.00

Нормализованные значения меньше 0.0 (здесь: возраст и оценка состояния почек) находятся ниже средних, а значения больше 0.0 — выше средних.

Затем исходные данные случайным образом разделяются на обучающий набор с 24 элементами для создания модели и на проверочный набор с шестью элементами для оценки точности модели при применении к новым данным с неизвестными результатами.

Далее демонстрационная программа создает пробит-модель. «За кулисами» обучение осуществляется по методу симплексной оптимизации (simplex optimization) с максимальным количеством итераций, равным 100. После обучения отображаются весовые значения, которые определяют модель, — { –4.26, 2.30, –1.29, 3.45 }.

Первое весовое значение (–4.26) является универсальной константой и не применяется ни к какой конкретной предсказывающей переменной (predictor variable). Второе весовое значение (2.30) применяется к возрасту, третье (–1.29) — к полу, а четвертое (3.45) — к оценке состояния почек. Положительные весовые значения, например связанные с возрастом и оценкой состояния почек, означают, что более высокие значения предиктора указывают: зависимая переменная Died будет ближе к true.

Демонстрационная программа вычисляет точность модели на обучающем наборе из 24 элементов (100% правильных результатов) и на проверочном наборе (83,33%, или пять правильных результатов и один ошибочный). Более старшее из этих двух значений является точностью проверки (test accuracy). Это грубая оценка общей точности пробит-модели.

В этой статье предполагается, что вы умеете программировать хотя бы на среднем уровне и имеете ьазовое представление об ML-классификации, но ничего не знаете о пробит-классификации. Демонстрационная программа написана на C#, но у вас не должно возникнуть особых проблем, если вы захотите выполнить рефакторинг кода под другие .NET-языки. Демонстрационная программа слишком длинная, чтобы ее можно было представить в статье во всей ее полноте, но вы можете найти полный исходный код в сопутствующем этой статье пакете кода. Вся обработка обычных ошибок удалена, чтобы не затруднять восприятие основных идей.

Понимание пробит-классификации

Простым способом предсказать смерть из-за таких показателей, как возраст, пол и состояние почек, была бы линейная комбинация:

died = b0 + (b1)(age) + (b2)(sex) + (b3)(kidney)

где b0, b1, b2, b3 — весовые значения, которые нужно как-то определить, чтобы вычисленные на обучающем наборе выходные значения близко соответствовали известным значениям зависимой переменной. Логистическая регрессия расширяет эту идею с помощью более сложной функции прогнозирования:

z = b0 + (b1)(age) + (b2)(sex) + (b3)(kidney)
died = 1.0 / (1.0 + e-z)

Математика здесь весьма сложна, но функция прогнозирования, называемая логистической сигмоидальной функцией (logistic sigmoid function), к нашему удобству всегда возвращает значение между 0.0 и 1.0, что можно интерпретировать как вероятность (probability). Пробит-классификация использует другую функцию прогнозирования:

z = b0 + (b1)(age) + (b2)(sex) + (b3)(kidney)
died = Phi(z)

Функция Phi(z) называется стандартной нормальной кумулятивной функцией плотности распределения вероятностей (standard normal cumulative density function) (обычно используют аббревиатуру CDF), и она всегда возвращает значение между 0.0 и 1.0. CDF весьма сложна, потому что простого уравнения для нее нет. CDF для значения z — это область под знаменитой кривой распределения Гаусса (bell-shaped curve function) (гауссовой функцией) от отрицательной бесконечности до z.

Звучит куда сложнее, чем есть на самом деле. Взгляните на график на рис. 2. Он показывает логистическую сигмоидальную функцию и функцию CDF, точки которых отложены на графике рядом. Важно, что для любого значения z, хотя нижележащие функции совершенно разные, обе функции возвращают значение между 0.0 и 1.0, которое можно интерпретировать как вероятность.

График кумулятивной функции плотности распределения вероятностей
Рис. 2. График кумулятивной функции плотности распределения вероятностей

Standard normal CDF Стандартная нормальная CDF
Cumulative Density(z) Кумулятивная плотность (z)
CDF CDF
Log sigmoid Логистическая сигмоида

Значит, с точки зрения разработчика, первая задача — написать функцию, которая вычисляет CDF для значения z. Простого уравнения для вычисления CDF нет, но есть десятки экзотично выглядящих аппроксимаций. Один из самых распространенных способов аппроксимации функции CDF — расчет функции erf (сокращение от Error Function), используя уравнение под названием «A&S 7.1.26», и применение erf для определения CDF. Код для функции CDF представлен на рис. 3.

Рис. 3. Функция CDF на C#

static double CumDensity(double z)
{
  double p = 0.3275911;
  double a1 = 0.254829592;
  double a2 = -0.284496736;
  double a3 = 1.421413741;
  double a4 = -1.453152027;
  double a5 = 1.061405429;

  int sign;
  if (z < 0.0)
    sign = -1;
  else
    sign = 1;

  double x = Math.Abs(z) / Math.Sqrt(2.0);
  double t = 1.0 / (1.0 + p * x);
  double erf = 1.0 - (((((a5 * t + a4) * t) + a3) *
    t + a2) * t + a1) * t * Math.Exp(-x * x);
  return 0.5 * (1.0 + (sign * erf));
}

Подведем итог. Пробит-классификация использует функцию CDF для расчета выходного значения. Функция CDF также называется phi. CDF — это область под кривой распределения Гаусса, и простого уравнения для нее нет. Распространенный способ аппроксимации CDF — применение формулы A&S 7.1.26 для получения erf и последующего использования erf для получения CDF.

Располагая функцией CDF, легко вычислить пробит-вывод для набора входных значений и набора весовых значений:

public double ComputeOutput(
  double[] dataItem, double[] weights)
{
  double z = 0.0;
  z += weights[0]; // константа b0
  // Данные могут включать Y
  for (int i = 0; i < weights.Length - 1; ++i)
    // Пропускаем первое весовое значение
    z += (weights[i + 1] * dataItem[i]);
  return CumDensity(z);
}

Вторая задача при написании кода для пробит-классификации — определение значений для весов, чтобы при передаче обучающих данных вычисленные выходные значения близко соответствовали известным выходным значениям. На эту задачу можно посмотреть и иначе: ее цель — свести к минимуму ошибку между вычисленными и известными выходными значениями. Это называют обучением модели с помощью числовой оптимизации.

Простого способа обучить большинство ML-классификаторов, в том числе пробит-классификаторов, нет. Существует примерно десяток основных методов, которые вы можете использовать, и у каждого метода есть десятки вариаций. К распространенным методам обучения относятся простой градиентный спуск (simple gradient descent), обратное распространение (back-propagation), алгоритм Ньютона-Рафсона, оптимизация роя частиц, эволюционная оптимизация (evolutionary optimization) и L-BFGS. Демонстрационная программа использует один из старейших и простейших методов обучения — симплексную оптимизацию (simplex optimization).

Понимание симплексной оптимизации

Грубо говоря, симплекс — это треугольник. Идея симплексной оптимизации заключается в том, чтобы начать с тремя возможными решениями (отсюда и название — симплекс). Одно из решений будет «лучшим» (имеющим наименьшую ошибку), другое — «худшим» (наибольшая ошибка), а третье — «другим». Затем алгоритм симплексной оптимизации создает три новых потенциальных решения: «расширенное» (expanded), «отраженное» (reflected) и «сжатое» (contracted). Каждое из них сравнивается с текущим худшим решением, и, если любой из новых кандидатов оказывается лучше (с меньшей ошибкой), он заменяет худшее решение.

Симплексная оптимизация показана на рис. 4. В простом случае, где решение состоит из двух значений, например (1.23, 4.56), решение можно рассматривать как точку на плоскости (x, y). В левой части рис. 4 видно, что три новых решения-кандидата генерируются из решений «лучшее», «худшее» и «другое».

Симплексная оптимизация
Рис. 4. Симплексная оптимизация

other другое
worst худшее
contracted сжатое
centroid центроид
reflected отраженное
expanded расширенное
best лучшее
Three candidates to replace worst Три кандидата на замену худшего решения
other' другое'
worst' худшее'
Shrinking Сокращение

Сначала вычисляется центроид. Центроид — это среднее лучшего и другого решения. В двух измерениях это точка посередине между точками лучшего и другого решений. Затем проводится воображаемая линия, которая начинается с точки худшего решения и проходит через центроид. Сжатый кандидат находится между точками худшего решения и центроида. Отраженный кандидат располагается на воображаемой линии за центроидом, а расширенный — за точкой отраженного решения.

На каждой итерации симплексной оптимизации, если один из кандидатов (расширенный, отраженный или сжатый) лучше, чем текущее худшее решение, последнее заменяется этим кандидатом. Если ни один из трех сгенерированных кандидатов не лучше худшего решения, текущее худшее и другое решение смещаются в направлении к лучшему решению где-то между своими текущими позициями и позицией лучшего решения, как показано в правой части рис. 4.

После каждой итерации формируется новый виртуальный треугольник «лучший-другой-худший», который все ближе и ближе к оптимальному решению. Если делать снимок каждого треугольника при последовательном подходе, то смещение треугольников напоминает движение заостренной капли по плоскости в стиле одноклеточной амебы. По этой причине симплексная оптимизация иногда называется оптимизацией методом амебы (amoeba method optimization).

Существует много вариаций симплексной оптимизации, которые отличаются в том, насколько далеки сжатое, отраженное и расширенное решения-кандидаты от текущего центроида, и в порядке, в котором проверяются решения-кандидаты, чтобы узнать, лучше ли они текущего худшего решения. Самая распространенная форма симплексной оптимизации — алгоритм Нелдера-Мида (Nelder-Mead algorithm). Демонстрационная программа использует более простую вариацию, у которой нет конкретного названия.

В случае пробит-классификации каждое потенциальное решение является набором весовых значений. В псевдокоде на рис. 5 показана вариация симплексной оптимизации, применяемая в демонстрационной программе.

Рис. 5. Псевдокод для симплексной оптимизации, используемой демонстрационной программой

Случайно инициализируем лучшее, худшее и другое решения
Цикл maxEpochs раз
  Создаем центроид из худшего и другого
  Создаем расширенное
  if расширенное лучше худшего, заменяем худшее расширенным
    и продолжаем цикл
  Создаем отраженное
  if отраженное лучше худшего, заменяем худшее отраженным
    и продолжаем цикл
  Создаем сжатое
  if сжатое лучше худшего, заменяем худшее сжатым
    и продолжаем цикл
  Создаем случайное решение
  if случайное решение лучше худшего, заменяем худшее
    и продолжаем цикл
  Смещаем худшее и другое в направлении лучшего
Конец цикла
return лучшее из найденных решений

Симплексная оптимизация, как и любые другие ML-алгоритмы оптимизации, имеет свои плюсы и минусы. Однако она (сравнительно) проста в реализации и обычно (хоть и не всегда) хорошо работает на практике.

Демонстрационная программа

Чтобы создать демонстрационную программу, я запустил Visual Studio, выбрал шаблон консольного приложения на C# и назвал программу ProbitClassification. В этой демонстрационной программе нет значимых зависимостей от конкретной версии Microsoft .NET Framework, поэтому подойдет любая сравнительно недавняя версия Visual Studio. После загрузки кода шаблона я переименовал в окне Solution Explorer файл Program.cs в ProbitProgram.cs, и Visual Studio автоматически переименовал класс Program.

Начало кода демонстрационной программы показано на рис. 6. Вымышленные данные «зашиты» в саму программу. В реальном сценарии ваши данные должны храниться в каком-то текстовом файле, и вам потребуется написать вспомогательный метод для загрузки данных в память. Исходные данные отображаются в консоли с использованием определенного в программе вспомогательного метода ShowData:

Console.WriteLine("\nRaw data: \n");
Console.WriteLine("       Age       Sex      Kidney   Died");
Console.WriteLine("=======================================");
ShowData(data, 5, 2, true);

Затем нормализуются столбцы 0 и 2 исходных данных:

Console.WriteLine("Normalizing age and kidney data");
int[] columns = new int[] { 0, 2 };
double[][] means = Normalize(data, columns);
Console.WriteLine("Done");
Console.WriteLine("\nNormalized data: \n");
ShowData(data, 5, 2, true);

Рис. 6. Начало кода демонстрационной программы

using System;
namespace ProbitClassification
{
  class ProbitProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine(
        "\nBegin Probit Binary Classification demo");
      Console.WriteLine(
        "Goal is to predict death (0 = false, 1 = true)");
      double[][] data = new double[30][];
      data[0] = new double[] { 48, +1, 4.40, 0 };
      data[1] = new double[] { 60, -1, 7.89, 1 };
      // И т. д.
      data[29] = new double[] { 68, -1, 8.38, 1 };
...

Метод Normalize сохраняет и возвращает средние и среднеквадратичные отклонения всех столбцов, чтобы при поступлении новых данных их можно было нормализовать, используя те же параметры, что и при обучении модели. Затем нормализованные данные разделяются на обучающий набор (80%) и проверочный (20%):

Console.WriteLine(
  "Creating train (80%) and test (20%) matrices");
double[][] trainData;
double[][] testData;
MakeTrainTest(data, 0, out trainData, out testData);
Console.WriteLine("Done");
Console.WriteLine("\nNormalized training data: \n");
ShowData(trainData, 3, 2, true);

Возможно, вы захотите параметризовать метод MakeTrainTest, чтобы он принимал процент элементов, помещаемых в обучающий набор. Далее создается экземпляр объекта пробит-классификатора, определенного в программе:

int numFeatures = 3; // возраст, пол, почки
Console.WriteLine("Creating probit binary classifier");
ProbitClassifier pc = new ProbitClassifier(numFeatures);

После этого происходит обучение пробит-классификатора, используя симплексную оптимизацию для поиска таких весовых значений, при которых вычисленные выходные значения близко соответствуют известным выходным значениям:

int maxEpochs = 100; // 100 дает репрезентативную демонстрацию
Console.WriteLine("Setting maxEpochs = " + maxEpochs);
Console.WriteLine("Starting training");
double[] bestWeights = pc.Train(trainData, maxEpochs, 0);
Console.WriteLine("Training complete");
Console.WriteLine("\nBest weights found:");
ShowVector(bestWeights, 4, true);

Демонстрационная программа завершается вычислением точности классификации модели на обучающих и проверочных данных:

...
  double testAccuracy = pc.Accuracy(testData, bestWeights);
  Console.WriteLine("Prediction accuracy on test data = " +
    testAccuracy.ToString("F4"));
  Console.WriteLine(
    "\nEnd probit binary classification demo\n");
  Console.ReadLine();
} // Main

Демонстрационная программа не делает прогнозы для ранее не встречавшихся данных. Формирование прогноза могло бы выглядеть так:

// Более пожилой мужчина с более низкой оценкой состояния почек
double[] unknownNormalized = new double[] { 0.25, -1.0, 0.50 };
int died = pc.ComputeDependent(unknownNormalized, bestWeights);
if (died == 0)
  Console.WriteLine("Predict survive");
else if (died == 1)
  Console.WriteLine("Predict die");

Этот код предполагает, что независимые x-переменные (возраст, пол и оценка состояния почек) были нормализованы с помощью средних и среднеквадратичных отклонений, полученных в процессе нормализации обучающих данных.

Класс ProbitClassifier

Общая структура класса ProbitClassifier представлена на рис. 7. Определение ProbitClassifier содержит вложенный класс Solution. Этот подкласс наследует от интерфейса IComparable, чтобы массив из трех объектов Solution можно было автоматически сортировать для получения лучшего, другого и худшего решений. Обычно я не люблю извилистые приемы кодирования, но в этой ситуации выигрыш несколько перевешивает дополнительную сложность.

Рис. 7. Класс ProbitClassifier

public class ProbitClassifier
{
  private int numFeatures; // число независимых переменных
  private double[] weights; // b0 = константа
  private Random rnd;

  public ProbitClassifier(int numFeatures) { . . }
  public double[] Train(double[][] trainData, int maxEpochs,
    int seed) { . . }
  private double[] Expanded(double[] centroid,
    double[] worst) { . . }
  private double[] Contracted(double[] centroid,
    double[] worst) { . . }
  private double[] RandomSolution() { . . }
  private double Error(double[][] trainData,
    double[] weights) { . . }
  public void SetWeights(double[] weights) { . . }
  public double[] GetWeights() { . . }
  public double ComputeOutput(double[] dataItem,
    double[] weights) { . . }
  private static double CumDensity(double z) { . . }
  public int ComputeDependent(double[] dataItem,
    double[] weights) { . . }
  public double Accuracy(double[][] trainData,
    double[] weights) { . . }

  private class Solution : IComparable<Solution>
  {
    // Здесь должно быть определение класса
  }
}

В ProbitClassifier два метода вывода. Метод ComputeOutput возвращает значение между 0.0 и 1.0 и применяется при обучении для вычисления значения ошибки. Метод ComputeDependent — это оболочка ComputeOutput, которая возвращает 0, если выходное значение меньше или равно 0.5, либо 1, если выходное значение больше 0.5. Эти возвращаемые значения используются при вычислении точности.

Заключение

Пробит-классификация — один из старейших методов ML. Поскольку пробит-классификация так похожа на классификацию по логистической регрессии (LR), здравый смысл подсказывает, что можно использовать либо один метод, либо другой. Так как LR немного проще в реализации, чем пробит, она используется чаще, чем пробит-классификация, которая со временем стала чем-то вроде гражданина второго сорта в ML. Однако пробит-классификация зачастую очень эффективна и может быть ценным пополнением в вашем арсенале алгоритмов ML.


Джеймс Маккафри (Dr. James McCaffrey) работает на Microsoft Research в Редмонде (штат Вашингтон). Принимал участие в создании нескольких продуктов Microsoft, в том числе Internet Explorer и Bing. С ним можно связаться по адресу jammc@microsoft.com.

Выражаю благодарность за рецензирование статьи экспертам Microsoft Research Натану Брауну (Nathan Brown) и Кирку Олинику (Kirk Olynyk).