DML_JOIN_OPERATOR_DESC 结构 (directml.h)

沿指定轴连接输入张量数组。

仅当输入张量在所有维度中的大小相同时,输入张量才可联接,但联接轴可能包含任何非零大小。 输出大小等于输入大小(联接轴除外),联接轴是所有输入联接轴大小的总和。 下面的伪代码演示了这些约束。

joinSize = 0;

for (i = 0; i < InputCount; i++) {
    assert(inputTensors[i]->DimensionCount == outputTensor->DimensionCount);
    for (dim = 0; dim < outputTensor->DimensionCount; dim++) {
        if (dim == Axis) { joinSize += inputTensors[i]->Sizes[dim]; }
        else { assert(inputTensors[i]->Sizes[dim] == outputTensor->Sizes[dim]); }
    }
}

assert(joinSize == outputTensor->Sizes[Axis]);

联接单个输入张量只会生成输入张量的副本。

此运算符是 DML_SPLIT_OPERATOR_DESC的反函数。

语法

struct DML_JOIN_OPERATOR_DESC {
  UINT                  InputCount;
  const DML_TENSOR_DESC *InputTensors;
  const DML_TENSOR_DESC *OutputTensor;
  UINT                  Axis;
};

成员

InputCount

类型: UINT

此字段确定 InputTensors 数组的大小。 此值必须大于 0。

InputTensors

类型:_Field_size_ (InputCount) const DML_TENSOR_DESC*

一个数组,其中包含要联接到单个输出张量中的张量的说明。 此数组中的所有输入张量必须具有相同的大小,但联接轴可能具有任何非零值。

OutputTensor

类型: const DML_TENSOR_DESC*

要向其写入联接输入张量的张量。 输出大小必须与所有输入张量具有相同的大小,联接轴除外,联接轴必须等于所有输入的联接轴大小之和。

Axis

类型: UINT

要联接的输入张量维度的索引。 除此轴之外,所有输入和输出张量在所有维度中都必须具有相同的大小。 此值必须位于范围 [0, OutputTensor.DimensionCount - 1]中。

示例

示例 1。 联接只有一个可能的轴的张量

在此示例中,张量只能沿着第四个维度 (轴 3) 联接。 无法联接任何其他轴,因为第四维中的张量大小不匹配。

InputCount: 2
Axis: 3

InputTensors[0]: (Sizes:{1, 1, 2, 3}, DataType:FLOAT32)
[[[[ 1,  2,  3],
   [ 4,  5,  6]]]]

InputTensors[1]: (Sizes:{1, 1, 2, 4}, DataType:FLOAT32)
[[[[ 7,  8,  9, 10],
   [11, 12, 13, 14]]]]

OutputTensor: (Sizes:{1, 1, 2, 7}, DataType:FLOAT32)
[[[[ 1,  2,  3,  7,  8,  9, 10],
   [ 4,  5,  6, 11, 12, 13, 14]]]]

示例 2。 联接具有多个可能轴的张量:

以下示例使用相同的输入张量。 由于所有输入在所有维度中具有相同的大小,因此可以沿任何维度联接它们。

InputCount: 3

InputTensors[0]: (Sizes:{1, 1, 2, 2}, DataType:FLOAT32)
[[[[1, 2],
   [3, 4]]]]

InputTensors[1]: (Sizes:{1, 1, 2, 2}, DataType:FLOAT32)
[[[[5, 6],
   [7, 8]]]]

InputTensors[2]: (Sizes:{1, 1, 2, 2}, DataType:FLOAT32)
[[[[9, 10],
   [11, 12]]]]

联接轴 1:

Axis: 1

OutputTensor: (Sizes:{1, 3, 2, 2}, DataType:FLOAT32)
[[[[1, 2],
   [3, 4]],

  [[5, 6],
   [7, 8]],

  [[9, 10],
   [11, 12]]]]

联接轴 2:

Axis: 2

OutputTensor: (Sizes:{1, 1, 6, 2}, DataType:FLOAT32)
[[[[1, 2],
   [3, 4],
   [5, 6],
   [7, 8],
   [9, 10],
   [11, 12]]]]

联接轴 3:

Axis: 3

OutputTensor: (Sizes:{1, 1, 2, 6}, DataType:FLOAT32)
[[[[1, 2, 5, 6, 9, 10],
   [3, 4, 7, 8, 11, 12]]]]

可用性

此运算符是在 中 DML_FEATURE_LEVEL_1_0引入的。

张量约束

InputTensorsOutputTensor 必须具有相同的 DataTypeDimensionCount

张量支持

DML_FEATURE_LEVEL_4_1 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensors 输入数组 1 到 8 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8
OutputTensor 输出 1 到 8 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_3_0 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensors 输入数组 4 到 5 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputTensor 输出 4 到 5 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_2_1 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensors 输入数组 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputTensor 输出 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_1_0 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensors 输入数组 4 FLOAT32、FLOAT16、INT32、INT16、UINT32、UINT16
OutputTensor 输出 4 FLOAT32、FLOAT16、INT32、INT16、UINT32、UINT16

要求

   
标头 directml.h