2019 年 8 月

第 34 卷,第 8 期

[测试运行]

针对多臂老虎机问题的 UCB1 算法

作者 James McCaffrey

James McCaffrey假设你在赌场,面前有三台老虎机。有 30 个筹码。每台机器会依据不同的概率分布显示获胜,而这些分布对你来说是未知的。你的目标是快速找到最好的机器,这样你就可以最大限度地赢钱。这是一个多臂老虎机问题的示例,之所以这样命名是因为老虎机被通俗地称为单臂老虎机。

在日常工作环境中,不太可能要和赌场的老虎机打交道。但是多臂老虎机场景对应许多现实生活中的问题。例如,一家制药公司,有三种用于治疗某种疾病的新药,该公司必须通过最少的人体临床试验来找出哪种药物是最有效的。而一个有几种新方案的在线广告营销活动需要尽快找到哪种方案能使收益最大化。

有许多不同的算法可以用于多臂老虎机问题。UCB1(置信上限,版本 1)算法是从数学角度而言最复杂的算法之一,但令人惊讶的是,它是最容易实现的算法之一。要了解什么是 UCB1 算法以及了解本文要讨论的问题,较好的方法是查看图 1 中的演示运行。

UCB1 算法演示运行
图 1 UCB1 算法演示运行

该演示设置了三个基于 0 的索引机器,每个机器具有不同的获胜概率。这三个概率是 (0.3, 0.7, 0.5),所以机器 [1] 是最佳机器。每拉一次,如果机器显示获胜,它支付 1 美元,如果显示失败,它支付 0 美元。从每台机器各玩一次开始进行 UBC1 算法。在演示中,机器 [0] 和 [1] 显示获胜,但机器 [2] 显示失败。

UCB1 算法是迭代的。演示指定了初始化下拉后的六次测试。在第一次测试中,算法计算每台机器的平均奖励。因为机器 [0] 和 [1] 在初始化阶段显示获胜,它们当前的平均奖励 = 1.00 美元/ 1 次下拉 = 1.00 美元。因为机器 [2] 显示失败了,它当前的平均奖励 = 0.00 美元/ 1 次下拉 = 0.00 美元。

使用当前的平均奖励和当前的测试次数巧妙地计算每台机器的决策值。对于第一次测试,决策值与平均奖励相同。要玩的臂/机器具有最大的决策值。此时,机器 [0] 和 [1] 与最大决策值相等。相对于机器 [1],机器 [0] 是任意选择的。然后,玩机器 [0],但显示失败。

在第 2 次测试中,机器 [0] 的更新平均奖励为 1.00 美元/ 2 次下拉 = 0.50 美元。机器 [1] 和 [2] 的平均奖励仍分别为 1.00 美元和 0.00 美元,因为没有玩这两台机器。决策值计算为 (1.33, 2.18, 1.18),因此,选择机器 [1] 并且它会显示获胜。

此过程继续进行至第 6 次测试。此时,UCB1 算法似乎是成功的,因为最佳老虎机机器 [1],已经玩了最多次数 (4) 并且具有最高的平均奖励 ($0.75)。

UCB1 算法相当聪明。看看图 1 中的第 5 次测试。累积奖励为 ($1.00, $3.00, $0.00),机器玩的次数为 (2, 4, 1)。因此,机器 [2] 自初始化阶段的初始失败以来就没有再尝试过。通过选择机器 [0] 或 [1] 可以继续使用简单的算法,但 UCB1 通过利用找到的最佳机器来平衡对机器特征的探索并选择机器 [2]。

UCB1 算法专门用于付款值为 0 或 1 的老虎机问题。这称为伯努利过程。UCB1 可以应用于其他类型的问题,比如付款遵循高斯分布的问题。但与伯努利不同的付款分布越多,UCB1 算法的性能就越差。我不推荐将 UCB1 用于非伯努利问题,但我的一些同事认为,如果保守地使用,UCB1 还是能够成功的。

本文中的信息基于 P. Auer、N. Cesa-Bianchi 和 P. Fischer 于 2002 年合著的研究论文《Finite-Time Analysis of the Multiarmed Bandit Problem》(多臂老虎机问题的有限时间分析)。除 UCB1 之外,此论文提出了一种适用于高斯分布多臂老虎机问题的 UCB-Normal 算法。

本文假设你具有使用 C# 或 C 系列语言(如 Python 或 Java)的中级或中级以上编程技能,但并不假定你对 UCB1 算法有任何了解。此演示使用 C# 进行编码,但如果愿意,将代码重构为其他语言应该不会存在任何问题。本文将提供完整的演示代码。源代码也可以在随附的下载文件中找到。

了解 UCB1 算法

UCB1 算法的关键是将测试 t 中的一组平均奖励转换为一组决策值的函数,然后将决策值用于确定要玩的机器。该方程如图 2 所示。换句话说,在测试 t 时,从所有臂中选择具有最大平均奖励 (r-hat) 加上置信上限(即平方根项)的臂 a。此处,n(a) 是下拉臂 a 的次数。

UCB1 的关键方程
图 2 UCB1 的关键方程

该方程看起来复杂,实际比较简单,最好用示例来解释。假设,在演示中,算法在测试 t = 5,累积奖励为 (1.00, 3.00, 0.00),且臂数为 (2, 4, 1)。第一步是计算每个臂的平均奖励:(1.00 / 2, 3.00 / 4, 0.00 / 1) = (0.50, 0.75, 0.00)。然后,臂 [0] 的决策值计算如下:

decision[0] = 0.50 + sqrt( 2 * ln(5) / 2 )
                  = 0.50 + sqrt(1.61)
                  = 0.50 + 1.27
                  = 1.77

同样,臂 [1] 和 [2] 的决策值为:

decision[1] = 0.75 + sqrt( 2 * ln(5) / 4 )
                  = 0.75 + sqrt(0.80)
                  = 0.75 + 0.90
                  = 1.65
decision[2] = 0.00 + sqrt( 2 * ln(5) / 1 )
                  = 0.00 + sqrt(3.22)
                  = 0.00 + 1.79
                  = 1.79

因为已经玩过的臂的次数在置信上限项的小数部分的分母中,所以较小的值会增加决策值。这允许很少下拉的臂最终有机会与具有高平均奖励的臂对阵。非常棒。

演示程序

图 3 展示了完整的演示程序(为节省空间,进行了少量小幅改动)。为了创建程序,我启动了 Visual Studio 并创建了一个名为 BanditUCB 的新控制台应用程序。我使用的是 Visual Studio 2017,但该演示没有重要的 .NET Framework 依赖项。

图 3 UCB1 算法演示程序

using System;
namespace BanditUCB
{
  class BanditProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("Begin UCB1 bandit demo \n");
      Console.WriteLine("Three arms with true means u1 = " +
       "0.3, u2 = 0.7, u3 = 0.5");
      Random rnd = new Random(20);
      int N = 3;
      int trials = 6;
      double p = 0.0;
      double[] means = new double[] { 0.3, 0.7, 0.5 };
      double[] cumReward = new double[N];
      int[] armCounts = new int[N];
      double[] avgReward = new double[N];
      double[] decValues = new double[N];
      // Play each arm once to get started
      Console.WriteLine("Playing each arm once: ");
      for (int i = 0; i < N; ++i) {
        Console.Write("[" + i + "]: ");
        p = rnd.NextDouble();
        if (p < means[i]) {
          Console.WriteLine("win");
          cumReward[i] += 1.0;
        }
        else {
          Console.WriteLine("lose");
          cumReward[i] += 0.0;
        }
        ++armCounts[i];
      }
      Console.WriteLine("-------------");
      for (int t = 1; t <= trials; ++t) {
        Console.WriteLine("trial #" + t);
        Console.Write("curr cum reward: ");
        for (int i = 0; i < N; ++i)
          Console.Write(cumReward[i].ToString("F2") + " ");
        Console.Write("\ncurr arm counts: ");
        for (int i = 0; i < N; ++i)
          Console.Write(armCounts[i] + " ");
        Console.Write("\ncurr avg reward: ");
        for (int i = 0; i < N; ++i) {
          avgReward[i] = (cumReward[i] * 1.0) / armCounts[i];
          Console.Write(avgReward[i].ToString("F2") + " ");
        }
        Console.Write("\ndecision values: ");
        for (int i = 0; i < N; ++i) {
          decValues[i] = avgReward[i] +
            Math.Sqrt( (2.0 * Math.Log(t) / armCounts[i]) );
          Console.Write(decValues[i].ToString("F2") + " ");
        }
        int selected = ArgMax(decValues);
        Console.WriteLine("\nSelected machine = [" +
          selected + "]");
        p = rnd.NextDouble();
        if (p < means[selected]) {
          cumReward[selected] += 1.0;
          Console.WriteLine("result: a WIN");
        }
        else {
          cumReward[selected] += 0.0;
          Console.WriteLine("result: a LOSS");
        }
        ++armCounts[selected];
        Console.WriteLine("-------------");
      } // t
      Console.WriteLine("End bandit UCB1 demo ");
      Console.ReadLine();
    } // Main
    static int ArgMax(double[] vector)
    {
      double maxVal = vector[0];
      int result = 0;
      for (int i = 0; i < vector.Length; ++i) {
        if (vector[i] > maxVal) {
          maxVal = vector[i];
          result = i;
        }
      }
      return result;
    }
  } // Program
} // ns

加载模板代码后,在编辑器窗口的顶部,我删除了所有不需要的命名空间引用,只留下了对顶级系统命名空间的引用。在“解决方案资源管理器”窗口中,右键单击 Program.cs 文件,将其重命名为更具描述性的 BanditProgram.cs,并允许 Visual Studio 自动重命名类 Program。

为了使主要概念尽可能清晰,省略所有常规的错误检查。所有控制逻辑都包含在 Main 方法中。有一个名为 ArgMax 的帮助程序函数,它返回数值阵列中最大值的索引。例如,如果阵列包含值 (5.0, 7.0, 2.0, 9.0),则 ArgMax 返回 3。

设置 UCB1

演示程序使用以下语句开头:

Random rnd = new Random(20);
int N = 3;
int trials = 6;
double p = 0.0;
double[] means = new double[] { 0.3, 0.7, 0.5 };

Random 对象用于确定所选机器是显示获胜还是失败。之所以使用种子值 20,仅仅是因为它提供具有代表性的演示。名为 means 的阵列可能已被重新命名为 probsWin。但由于每台机器付款 1 美元或 0 美元,则每台机器的平均值(平均数)与获胜概率相同。例如,如果一台机器的获胜概率为 0.6 并且你玩了 1,000 次,那么赢得 1 美元的次数应该约 600 次。每次下拉的平均值为 600 美元/1000 = 0.60。

计算决策值

演示程序使用直接映射到图 2**** 中的 UCB1 方程来计算决策值:

for (int i = 0; i < N; ++i) {
  decValues[i] = avgReward[i] +
    Math.Sqrt( (2.0 * Math.Log(t) / armCounts[i]) );
  Console.Write(decValues[i] + " ");
}

如果 t = 0(0 的对数为负无穷大)或者任何臂计数为 0(除以 0),则计算将引发异常。但是,在每台机器均尝试一次的 UCB1 初始化阶段会阻止发生任何异常情况。

在计算出决策值之后,由以下语句确定要玩的机器:

int selected = ArgMax(decValues);
Console.WriteLine("Selected machine = [" + selected + "]");

ArgMax 函数返回最大决策值的索引。如果两个或多个决策值均为最大值,则 ArgMax 返回遇到的第一个索引。这引入了对小型索引机器的轻微偏差。消除这种偏差的一种方法是重构 ArgMax,以便在出现平局时,随机选择其中一个相同的索引。

Epsilon 贪婪算法

UCB1 算法与另一种称为 epsilon-greedy 的多臂老虎机算法密切相关。epsilon-greedy 算法首先为 epsilon 指定一个较小的值。然后在每次测试中,生成 0.0 和 1.0 之间的随机概率值。如果生成的概率小于 (1 - epsilon),则选择具有当前最大平均奖励的臂。否则,随机选择一个臂。基于演示程序结构的 epsilon-greedy 实现可能如下所示:

//int selected = ArgMax(decValues);  // UCB1
double epsilon = 0.05;  // Epsilon-greedy
int selected;
double pr = rnd.NextDouble();  // [0.0, 1.0)
if (pr < (1 - epsilon))
  selected = ArgMax(avgReward);  // Usually
else
  selected = rnd.Next(0, N);  // 5% of the time

基本的 epsilon-greedy 算法的几种变体之一是随着时间的推移缓慢减小 epsilon 的值。这样便于在运行初期专注于探索,然后强调在后续运行中发现最佳臂。epsilon-greedy 最大的问题是没有简单的方法来确定 epsilon 的合适值。

在我看来,比较 UCB1 和 epsilon-greedy 以及许多其他多臂老虎机算法的研究结果尚无定论。根据我的经验,没有单个始终如一的最佳算法可供使用,如果可能的话,使用模拟真实问题的方法,通过不同算法来运行一些试验是一个非常好的实践。

比较不同的多臂老虎机算法的标准方法是计算遗憾指标。遗憾是系统预期值(假设你知道最佳臂)以及试验中系统实际值之间的差值。例如,假设玩了演示系统的三台机器 10 次,你赢了 6 次,输了 4 次。总奖励为 6.00 美元。但如果假设全部下拉最好的臂(获胜概率等于 0.7)10 次,那么平均总奖励为 7.00 美元。因此,遗憾为 7.00 美元 - 6.00 美元 = 1.00 美元。

总结

我心里把机器学习分成三类:监督式学习,即已经具备包含已知正确答案的训练数据;非监督式学习,即具备没有正确答案的数据;强化学习 (RL),其中正确或不正确的结果被称为奖励(可能是负数),并且结果来自函数而不是数据。多臂老虎机问题通常被视作是 RL 的一部分,但我的一些研究同事认为多臂老虎机是一个独特的问题类型。

有多种算法可用于多臂老虎机场景。根据我的经验,除了本文中描述的 UCB1 和 epsilon-greedy 算法之外,实践中最长使用的算法称为 Thompson Sampling。可以在《MSDN 杂志》2018 年 2 月刊 (msdn.com/magazine/mt829274) 中了解该算法。


Dr.James McCaffrey 供职于华盛顿地区雷蒙德市沃什湾的 Microsoft Research。他参与开发过多个重要 Microsoft 产品(包括 Azure 和必应)。Dr.可通过 jamccaff@microsoft.com 与 McCaffrey 取得联系。

衷心感谢以下 Microsoft 技术专家对本文的审阅:Chris Lee 和 Ricky Loynd


在 MSDN 杂志论坛讨论这篇文章