返回介绍

操作语义

发布于 2025-01-22 23:08:20 字数 76652 浏览 0 评论 0 收藏 0

本文档介绍了在 ComputationBuilder 接口中定义的操作语义。通常来说,这些操作与 xla_data.proto 中 RPC 接口所定义的操作是一一对应的。

关于术语:广义数据类型 XLA 处理的是一个 N - 维数组,其元素均为某种数据类型(如 32 位浮点数)。在本文档中, 数组 表示任意维度的数组。为方便起见,有些特例使用人们约定俗成的更具体和更熟悉的名称;比如,1 维数组称为 向量 ,2 维数组称为 矩阵

BatchNormGrad

算法详情参见 ComputationBuilder::BatchNormGradbatch normalization 原始论文

计算 batch norm 的梯度

`BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)`

类型类型语义
operandComputationDataHandle待归一化的 n 维数组 (x)
scaleComputationDataHandle1 维数组 (\(\gamma\))
meanComputationDataHandle1 维数组 (\(\mu\))
varianceComputationDataHandle1 维数组 (\(\sigma^2\))
grad_outputComputationDataHandle传入 BatchNormTraining 的梯度(\( \nabla y\))
epsilonfloatε 值 (\(\epsilon\))
feature_indexint64operand 中的特征维数索引

对于特征维数中的每一个特征( feature_indexoperand 中特征维度的索引),此操作计算 operand 的梯度、在所有其他维度上的 offsetscalefeature_index 必须是 operand 中特征维度的合法索引。

[需要翻译]:The three gradients are defined by the following formulas (Assuming a 4-dimensional tensor as operand and (l) is the index for feature dimension):

\( coefl = \frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h (\nabla y{ijkl} * (x{ijkl} - \mul) / (\sigma^2{l}+\epsilon)) \)

\( \nabla x{ijkl} = \gamma{l} (1/\sqrt{\sigma^2_{l}+\epsilon}) [\nabla y{ijkl} - mean(\nabla y) - (x{ijkl} - \mu_{l}) * coef_l] \)

\( \nabla \betal = \sum{i=1}^m\sum{j=1}^w\sum{k=1}^h \nabla y_{ijkl} \)

\( \nabla \gammal = \sum{i=1}^m\sum{j=1}^w\sum{k=1}^h \nabla y{ijkl} * ((x{ijkl} - \mul) / \sqrt{\sigma^2{l}+\epsilon}) \)

输入 meanvariance 表示在批处理和空间维度上的矩值。

输出类型是包含三个句柄的元组:

输出类型语义
grad_operandComputationDataHandle输入 operand 的梯度 (\( \nabla x\))
grad_scaleComputationDataHandle输入 scale 的梯度 (\( \nabla \gamma\))
grad_offsetComputationDataHandle输入 offset 的梯度 (\( \nabla \beta\))

BatchNormInference

算法详情参见 ComputationBuilder::BatchNormInferencebatch normalization 原始论文

在批处理和空间维度上归一化数组。

`BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)`

参数类型语义
operandComputationDataHandle待归一化的 n 维数组
scaleComputationDataHandle1 维数组
offsetComputationDataHandle1 维数组
meanComputationDataHandle1 维数组
varianceComputationDataHandle1 维数组
epsilonfloatε 值
feature_indexint64operand 中的特征维数索引

对于特征维数中的每一个特征( feature_indexoperand 中特征维度的索引),此操作计算在所有其他维度上的均值和方差,以及使用均值和方差归一化 operand 中的每个元素。 feature_index 必须是 operand 中特征维度的合法索引。

BatchNormInference 等价于在每批次中不计算 meanvariance 的情况下调用 BatchNormTraining 。它使用 meanvariance 作为估计值。此操作的目的是减少推断中的延迟,因此命名为 BatchNormInference

输出是一个 N 维的标准化数组,与输入 operand 的形状相同。

BatchNormTraining

算法详情参见 ComputationBuilder::BatchNormTrainingbatch normalization 原始论文

在批处理和空间维度上归一化数组。

`BatchNormTraining(operand, scale, offset, epsilon, feature_index)`

参数类型语义
operandComputationDataHandle待归一化的 N 维数组 normalized (x)
scaleComputationDataHandle1 维数组 (\(\gamma\))
offsetComputationDataHandle1 维数组 (\(\beta\))
epsilonfloatEpsilon 值 (\(\epsilon\))
feature_indexint64operand 中的特征维数索引

对于特征维数中的每一个特征( feature_indexoperand 中特征维度的索引),此操作计算在所有其他维度上的均值和方差,以及使用均值和方差归一化 operand 中的每个元素。 feature_index 必须是 operand 中特征维度的合法索引。

该算法对 operand \(x\) 中的每批次数据(包含 whm 元素作为空间维度的大小)按如下次序执行:

  • 在特征维度中,对每个特征 l 计算批处理均值 \(\mu_l\):
    \(\mul=\frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h x_{ijkl}\)
  • 计算批处理方差 \(\sigma^2_l\):
    \(\sigma^2l=\frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h (x_{ijkl} - \mu_l)^2\)
  • 归一化、缩放和平移:
    \(y_{ijkl}=\frac{\gammal(x{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon}}+\beta_l\)

ε 值,通常是一个很小的数字,以避免 divide-by-zero 错误

输出类型是一个包含三个 ComputationDataHandle 类型元素的元组:

输出类型语义
outputComputationDataHandle与输入 operand (y) 具有相同形状的 N 维数组
batch_meanComputationDataHandle1 维数组 (\(\mu\))
batch_varComputationDataHandle1 维数组 (\(\sigma^2\))

输入 batch_meanbatch_var 表示使用上述公式在批处理和空间维度上计算的矩值。

BitcastConvertType

同样参见
ComputationBuilder::BitcastConvertType .

类似于 TensorFlow 中的 tf.bitcast ,对输入数据的每个元素进行 bitcast 操作,从而转化为目标形状。维度必须匹配,且转换是一对一的;如 s32 元素通过 bitcast 操作转化为 f32 。Bitcast 采用底层 cast 操作,所以不同浮点数表示法的机器会产生不同的结果。

`BitcastConvertType(operand, new_element_type)`

参数类型语义
operandComputationDataHandleD 维,类型为 T 的数组
new_element_typePrimitiveType类型 U

operand 和 目标形状的维度必须匹配。源和目标元素类型的位宽必须一致。源和目标元素类型不能是元组。

广播(Broadcast)

另请参阅 ComputationBuilder::Broadcast

通过在数组中复制数据来增加其维度。

`Broadcast(operand, broadcast_sizes)`

参数类型语义
operandComputationDataHandle待复制的数组
broadcast_sizesArraySlice<int64>新维度的形状大小

新的维度被插入在操作数(operand)的左侧,即,若 broadcast_sizes 的值为 {a0, ..., aN} ,而操作数(operand)的维度形状为 {b0, ..., bM} ,则广播后输出的维度形状为 {a0, ..., aN, b0, ..., bM}

新的维度指标被插入到操作数(operand)副本中,即

output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]

比如,若 operand 为一个值为 2.0f 的标量,且 broadcast_sizes{2, 3} ,则结果形状为 f32[2, 3] 的一个数组,且它的所有元素的值都为 2.0f

调用(Call)

另请参阅 ComputationBuilder::Call

给定参数情况下,触发计算。

`Call(computation, args...)`

参数类型语义
computationComputation类型为 T_0, T_1, ..., T_N ->S 的计算,它有 N 个任意类型的参数
argsN 个 ComputationDataHandle 的序列任意类型的 N 个 参数

参数 args 的数目和类型必须与计算 computation 相匹配。当然,没有参数 args 也是允许的。

钳制(Clamp)

另请参阅 ComputationBuilder::Clamp

将一个操作数钳制在最小值和最大值之间的范围内。

`Clamp(min, operand, max)`

参数类型语义
minComputationDataHandle类型为 T 的数组
operandComputationDataHandle类型为 T 的数组
maxComputationDataHandle类型为 T 的数组

给定操作数,最小和最大值,如果操作数位于最小值和最大值之间,则返回操作数,否则,如果操作数小于最小值,则返回最小值,如果操作数大于最大值,则返回最大值。即 clamp(a, x, b) = min(max(a, x), b)

输入的三个数组的维度形状必须是一样的。另外,也可以采用一种严格的 广播 形式,即 min 和/或 max 可以是类型为 T 的一个标量。

minmax 为标量的示例如下:

let operand: s32[3] = {-1, 5, 9};
let min: s32 = 0;
let max: s32 = 6;
==>
Clamp(min, operand, max) = s32[3]{0, 5, 6};

折叠(Collapse)

tf.reshape

将一个数组的多个维度折叠为一个维度。

`Collapse(operand, dimensions)`

参数类型语义
operandComputationDataHandle类型为 T 的数组
dimensionsint64 矢量T 的维度形状的依次连续子集

tf.reshape

比如,令 v 为包含 24 个元素的数组:

let v = f32[4x2x3] {{{10, 11, 12},  {15, 16, 17}},
                    {{20, 21, 22},  {25, 26, 27}},
                    {{30, 31, 32},  {35, 36, 37}},
                    {{40, 41, 42},  {45, 46, 47}}};

// 折叠至一个维度,即只留下一个维度
let v012 = Collapse(v, {0,1,2});
then v012 == f32[24] {10, 11, 12, 15, 16, 17,
                      20, 21, 22, 25, 26, 27,
                      30, 31, 32, 35, 36, 37,
                      40, 41, 42, 45, 46, 47};

// 折叠两个较低维度,剩下两个维度
let v01 = Collapse(v, {0,1});
then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17},
                      {20, 21, 22, 25, 26, 27},
                      {30, 31, 32, 35, 36, 37},
                      {40, 41, 42, 45, 46, 47}};

// 折叠两个较高维度,剩下两个维度
let v12 = Collapse(v, {1,2});
then v12 == f32[8x3] {{10, 11, 12},
                      {15, 16, 17},
                      {20, 21, 22},
                      {25, 26, 27},
                      {30, 31, 32},
                      {35, 36, 37},
                      {40, 41, 42},
                      {45, 46, 47}};

串连(Concatenate)

另请参阅 ComputationBuilder::ConcatInDim

串连操作是将多个数组操作数合并成一个数组。输出数组与输入数组的秩必须是一样的(即要求输入数组的秩也要相同),并且它按输入次序包含了输入数组的所有元素。

`Concatenate(operands..., dimension)`

参数类型语义
operandsN 个 ComputationDataHandle 的序列类型为 T 维度为 [L0, L1, ...] 的 N 个数组。要求 N>=1
dimensionint64区间 [0, N) 中的一个整数值,令那些 operands 能够串连起来的维度名

除了 dimension 之外,其它维度都必须是一样的。这是因为 XLA 不支持 "不规则" 数组。还要注意的是,0-阶的标量值是无法串连在一起的(因为无法确定串连到底发生在哪个维度)。

1-维示例:

Concat({{2, 3}, {4, 5}, {6, 7}}, 0)
>>> {2, 3, 4, 5, 6, 7}

2-维示例:

let a = {
  {1, 2},
  {3, 4},
  {5, 6},
};
let b = {
  {7, 8},
};
Concat({a, b}, 0)
>>> {
  {1, 2},
  {3, 4},
  {5, 6},
  {7, 8},
}

图表:

Conditional

另请参阅 ComputationBuilder::Conditional .

`Conditional(pred, true_operand, true_computation, false_operand, false_computation)`

参数类型语义
predComputationDataHandle类型为 PRED 的标量
true_operandComputationDataHandle类型为 T_0 的参数
true_computationComputation类型为 T_0 -> S 的计算
false_operandComputationDataHandle类型为 T_1 的参数
false_computationComputation类型为 T_0 -> S 的计算

如果 predtrue ,执行 true_computation ,如果 predfalse ,则返回结果。

true_computation 必须接受一个类型为 T_0 的单参数,并使用 true_operand 来调用,它们必须类型相同。 false_computation 必须接受一个类型为 T_1 的单参数,并使用 false_operand 来调用,它们必须类型相同。 true_computationfalse_computation 的返回值的类型必须相同。

注意,根据 pred 的值, true_computationfalse_computation 只能执行其中一个。

Conv (卷积)

另请参阅 ComputationBuilder::Conv

类似于 ConvWithGeneralPadding,但是边缘填充(padding)方式比较简单,要么是 SAME 要么是 VALID。SAME 方式将对输入( lhs )边缘填充零,使得在不考虑步长(striding)的情况下输出与输入的维度形状一致。VALID 填充方式则表示没有填充。

ConvWithGeneralPadding (卷积)

另请参阅 ComputationBuilder::ConvWithGeneralPadding

计算神经网络中使用的卷积。此处,一个卷积可被认为是一个 n-维窗口在一个 n-维底空间上移动,并对窗口的每个可能的位置执行一次计算。

参数类型语义
lhsComputationDataHandle秩为 n+2 的输入数组
rhsComputationDataHandle秩为 n+2 的内核权重数组
window_stridesArraySlice<int64>n-维内核步长数组
paddingArraySlice<pair<int64, int64>>n-维 (低,高) 填充数据
lhs_dilationArraySlice<int64>n-维左边扩张因子数组
rhs_dilationArraySlice<int64>n-维右边扩张因子数组

设 n 为空间维数。 lhs 参数是一个 n+2 阶数组,它描述底空间区域的维度。它被称为输入,其实 rhs 也是输入。在神经网络中,它们都属于输入激励。n+2 维的含义依次为:

  • batch : 此维中每个坐标表示执行卷积的一个独立输入
  • z/depth/features : 基空间区域中的每个 (y,x) 位置都指定有一个矢量,由这个维度来表示
  • spatial_dims : 描述了定义了底空间区域的那 n 个空间维度,窗口要在它上面移动

rhs 参数是一个 n+2 阶的数组,它描述了卷积过滤器/内核/窗口。这些维度的含义依次为:

  • output-z : 输出的 z 维度。
  • input-z : 此维度的大小等于 lhs 参数的 z 维度的大小。
  • spatial_dims : 描述了定义此 n-维窗口的那 n 个空间维度,此窗口用于在底空间上移动。

window_strides 参数指定了卷积窗口在空间维度上的步长。比如,如果步长为 3,则窗口只用放在第一个空间维度指标为 3 的倍数的那些位置上。

padding 参数指定了在底空间区域边缘填充多少个零。填充数目可以是负值 -- 这时数目绝对值表示执行卷积前要移除多少个元素。 padding[0] 指定维度 y 的填充对子, padding[1] 指定的是维度 x 的填充对子。每个填充对子包含两个值,第一个值指定低位填充数目,第二个值指定高位填充数目。低位填充指的是低指标方向的填充,高位填充则是高指标方向的填充。比如,如果 padding[1](2,3) ,则在第二个空间维度上,左边填充 2 个零,右边填充 3 个零。填充等价于在执行卷积前在输入 ( lhs ) 中插入这些零值。

lhs_dilationrhs_dilation 参数指定了扩张系数,分别应用于 lhs 和 rhs 的每个空间维度上。如果在一个空间维度上的扩张系数为 d,则 d-1 个洞将被插入到这个维度的每一项之间,从而增加数组的大小。这些洞被填充上 no-op 值,对于卷积来说表示零值。

tf.nn.conv2d_transpose

输出形状的维度含义依次为:

  • batch : 和输入( lhs )具有相同的 batch 大小。
  • z : 和内核( rhs )具有相同的 output-z 大小。
  • spatial_dims : 每个卷积窗口的有效位置的值。

卷积窗口的有效位置是由步长和填充后的底空间区域大小所决定的。

为描述卷积到底干了什么,考虑一个二维卷积,为输出选择某个固定的 batchzyx 坐标。则 (y,x) 是底空间区域中的某个窗口的一个角的位置(比如左上角,具体是哪个要看你如何编码其空间维度)。现在,我们从底空间区域中得到了一个二维窗口,其中每一个二维点都指定有一个一维矢量,所以,我们得到一个三维盒子。对于卷积过滤器而言,因为我们固定了输出坐标 z ,我们也有一个三维盒子。这两个盒子具有相同的维度,所以我们可以让它们逐个元素地相乘并相加(类似于点乘)。最后得到输出值。

注意,如果 output-z 等于一个数,比如 5,则此窗口的每个位置都在输出的 z 维 上产生 5 个值。这些值对应于卷积过滤器的不同部分,即每个 output-z 坐标,都由一个独立的三维盒子生成。所以,你可以将其想象成 5 个分立的卷积,每个都用了不同的过滤器。

下面是一个考虑了填充和步长的二维卷积伪代码:

for (b, oz, oy, ox) {  // 输出坐标
  value = 0;
  for (iz, ky, kx) {  // 内核坐标和输入 z
    iy = oy*stride_y + ky - pad_low_y;
    ix = ox*stride_x + kx - pad_low_x;
    if (底空间区域内的(iy, ix) 是不在填充位置上的) {
      value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
    }
  }
  output(b, oz, oy, ox) = value;
}

ConvertElementType

另请参阅
ComputationBuilder::ConvertElementType .

与 C++ 中逐元素的 static_cast 类似,对输入数据的每个元素进行转换操作,从而转化为目标形状。维度必须匹配,且转换是一对一的;如 s32 元素通过 s32 -to- f32 转换过程转换为 f32

`ConvertElementType(operand, new_element_type)`

参数类型语义
operandComputationDataHandleD 维类型为 T 的数组
new_element_typePrimitiveType类型 U

操作数和目标形状的维度必须匹配。源和目标元素类型不能是元组。

一个 T=s32U=f32 的转换将执行标准化的 int-to-float 转化过程,如 round-to-nearest-even。

注意:目前没有指定精确的 float-to-int 和 visa-versa 转换,但是将来可能作为转换操作的附加参数。不是所有的目标都实现了所有可能的转换。

let a: s32[3] = {0, 1, 2};
let b: f32[3] = convert(a, f32);
then b == f32[3]{0.0, 1.0, 2.0}

CrossReplicaSum

另请参阅 ComputationBuilder::CrossReplicaSum

跨多个副本(replica)的求和。

`CrossReplicaSum(operand)`

参数类型语义
operandComputationDataHandle跨多个副本待求和的数组。

输出的维度形状与输入形状一样。比如,如果有两个副本,而操作数在这两个副本上的值分别为 (1.0, 2.5)(3.0, 5.25) ,则此操作在两个副本上的输出值都是 (4.0, 7.75)

计算 CrossReplicaSum 的结果需要从每个副本中获得一个输入,所以,如果一个副本执行一个 CrossReplicaSum 结点的次数多于其它副本,则前一个副本将永久等待。因此这些副本都运行的是同一个程序,这种情况发生的机会并不多,其中一种可能的情况是,一个 while 循环的条件依赖于输入的数据,而被输入的数据导致此循环在一个副本上执行的次数多于其它副本。

CustomCall

另请参阅 ComputationBuilder::CustomCall

在计算中调用由用户提供的函数。

`CustomCall(target_name, args..., shape)`

参数类型语义
target_namestring函数名称。一个指向这个符号名称的调用指令会被发出
argsN 个 ComputationDataHandle 的序列传递给此函数的 N 个任意类型的参数
shapeShape此函数的输出维度形状

不管参数的数目和类型,此函数的签名(signature)都是一样的。

extern "C" void target_name(void* out, void** in);

比如,如果使用 CustomCall 如下:

let x = f32[2] {1,2};
let y = f32[2x3] {{10, 20, 30}, {40, 50, 60}};

CustomCall("myfunc", {x, y}, f32[3x3])

myfunc 实现的一个示例如下:

extern "C" void myfunc(void* out, void** in) {
  float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
  float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
  EXPECT_EQ(1, x[0]);
  EXPECT_EQ(2, x[1]);
  EXPECT_EQ(10, y[0][0]);
  EXPECT_EQ(20, y[0][1]);
  EXPECT_EQ(30, y[0][2]);
  EXPECT_EQ(40, y[1][0]);
  EXPECT_EQ(50, y[1][1]);
  EXPECT_EQ(60, y[1][2]);
  float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
  z[0][0] = x[1] + y[1][0];
  // ...
}

这个用户提供的函数不能有副作用,而且它的执行结果必须是确定的(即两次同样的调用不能有不同结果)。

注:用户函数的黑箱特点限制了编译器的优化潜力。所以,尽量使用原生的 XLA 操作来表示你的计算;只有在迫不得已的情况下才使用 CustomCall。

点乘(Dot)

另请参阅 ComputationBuilder::Dot

`Dot(lhs, rhs)`

参数类型语义
lhsComputationDataHandle类型为 T 的数组
rhsComputationDataHandle类型为 T 的数组

此操作的具体语义由它的两个操作数的秩来决定:

输入输出语义
矢量 [n] dot 矢量 [n]标量矢量点乘
矩阵 [m x k] dot 矢量 [k]矢量 [m]矩阵矢量乘法
矩阵 [m x k] dot 矩阵 [k x n]矩阵 [m x n]矩阵矩阵乘法

此操作执行的是 lhs 的最后一维与 rhs 的倒数第二维之间的乘法结果的求和。因而计算结果会导致维度的 "缩减"。 lhsrhs 缩减的维度必须具有相同的大小。在实际中,我们会用到矢量之间的点乘,矢量/矩阵点乘,以及矩阵间的乘法。

DotGeneral

另请参阅
ComputationBuilder::DotGeneral .

`DotGeneral(lhs, rhs, dimension_numbers)`

参数类型语义
lhsComputationDataHandle类型为 T 的数组
rhsComputationDataHandle类型为 T 的数组
dimension_numbersDotDimensionNumbers类型为 T 的数组

和点乘一样,但是对于 'lhs' 和 'rhs' 允许收缩和指定批处理维数。

DotDimensionNumbers 成员类型语义
'lhs_contracting_dimensions'repeated int64'lhs' 转换维数
'rhs_contracting_dimensions'repeated int64'rhs' 转换维数
'lhs_batch_dimensions'repeated int64'lhs' 批处理维数
'rhs_batch_dimensions'repeated int64'rhs' 批处理维数

DotGeneral 根据 'dimension_numbers' 指定的维数进行转换操作,然后计算点积和。

与 'lhs' 和 'rhs' 有关的转换维数不需要相同,但是在 'lhs_contracting_dimensions' 和 'rhs_contracting_dimensions' 数组必须按照相同的顺序列出,同时具有相同的维数大小。且需要同时与 'lhs' 和 'rhs' 在同一个维度上。

以转换维数为例:

lhs = { {1.0, 2.0, 3.0},
        {4.0, 5.0, 6.0} }

rhs = { {1.0, 1.0, 1.0},
        {2.0, 2.0, 2.0} }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(1);

DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
                                 {15.0, 30.0} }

'lhs' 和 'rhs' 的批处理维数必须相同,在两个数组中必须以相同的顺序列出,同时维数大小必须相同。[需要翻译]and must be ordered before contracting and non-contracting/non-batch dimension numbers。

批处理维数的例子(批处理大小为 2,2x2 矩阵):

lhs = { { {1.0, 2.0},
          {3.0, 4.0} },
        { {5.0, 6.0},
          {7.0, 8.0} } }

rhs = { { {1.0, 0.0},
          {0.0, 1.0} },
        { {1.0, 0.0},
          {0.0, 1.0} } }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);

DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
                                   {3.0, 4.0} },
                                 { {5.0, 6.0},
                                   {7.0, 8.0} } }
InputOutputSemantics
[b0, m, k] dot [b0, k, n][b0, m, n]batch matmul
[b0, b1, m, k] dot [b0, b1, k, n][b0, b1, m, n]batch matmul

[需要翻译]It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension.

DynamicSlice

另请参阅 ComputationBuilder::DynamicSlice .

DynamicSlice 从动态 start_indices 输入数组中提取子数组。 size_indices 为每个维度的切片大小,它在每个维度上指定了切片范围:[start, start + size)。 start_indices 的秩必须为 1,且维数大小等于 operand 的秩。

注意:当前实现未定义切片索引越界(错误的运行时生成的'start_indices')的情况。

`DynamicSlice(operand, start_indices, size_indices)`

参数类型语义
operandComputationDataHandle类型为 T 的 N 维数组
start_indicesComputationDataHandleN 个整数组成的秩为 1 的数组,其中包含每个维度的起始切片索引。值必须大于等于 0
size_indicesArraySlice<int64>N 个整数组成的列表,其中包含每个维度的切片大小。值必须大于 0,且 start + size 必须小于等于维度大小,从而避免封装维数大小的模运算

1 维示例如下:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let s = {2}

DynamicSlice(a, s, {2}) produces:
  {2.0, 3.0}

2 维示例如下:

let b =
 { {0.0,  1.0,  2.0},
   {3.0,  4.0,  5.0},
   {6.0,  7.0,  8.0},
   {9.0, 10.0, 11.0} }
let s = {2, 1}

DynamicSlice(b, s, {2, 2}) produces:
  { { 7.0,  8.0},
    {10.0, 11.0} }

DynamicUpdateSlice

另请参见
ComputationBuilder::DynamicUpdateSlice .

DynamicUpdateSlice 是在输入数组 operand 上,通过切片 update 操作覆盖 start_indices 后生成的结果。 update 的形状决定了更新后结果的子数组的形状。 start_indices 的秩必须为 1,且维数大小等于 operand 的秩。

注意:当前实现未定义切片索引越界(错误的运行时生成的'start_indices')的情况。

`DynamicUpdateSlice(operand, update, start_indices)`

参数类型语义
operandComputationDataHandle类型为 T 的 N 维数组
updateComputationDataHandle类型为 T 的包含切片更新的 N 维数组,每个维度的更新形状必须大于 0 ,且 start + update 必须小于维度大小,从而避免越界更新索引
start_indicesComputationDataHandleN 个整数组成的秩为 1 的数组,其中包含每个维度的起始切片索引。值必须大于等于 0

1 维示例如下:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let u = {5.0, 6.0}
let s = {2}

DynamicUpdateSlice(a, u, s) produces:
  {0.0, 1.0, 5.0, 6.0, 4.0}

2 维示例如下:

let b =
 { {0.0,  1.0,  2.0},
   {3.0,  4.0,  5.0},
   {6.0,  7.0,  8.0},
   {9.0, 10.0, 11.0} }
let u =
 { {12.0,  13.0},
   {14.0,  15.0},
   {16.0,  17.0} }

let s = {1, 1}

DynamicUpdateSlice(b, u, s) produces:
 { {0.0,  1.0,  2.0},
   {3.0, 12.0, 13.0},
   {6.0, 14.0, 15.0},
   {9.0, 16.0, 17.0} }

逐个元素的二元算术操作

另请参阅 ComputationBuilder::Add

XLA 支持多个逐个元素的二元算术操作。

`Op(lhs, rhs)`

其中 Op 可以是如下操作之一: Add (加法), Sub (减法), Mul (乘法), Div (除法), Rem (余数), Max (最大值), Min (最小值), LogicalAnd (逻辑且), 或 LogicalOr (逻辑或)。

参数类型语义
lhsComputationDataHandle左操作数:类型为 T 的数组
rhsComputationDataHandle右操作数:类型为 T 的数组

广播语义

OpRem 时,结果的符号与被除数一致,而结果的绝对值总是小于除数的绝对值。

不过,还是可以用如下接口来支持不同秩操作数的广播:

`Op(lhs, rhs, broadcast_dimensions)`

其中 Op 的含义同上。这种接口用于具有不同秩的数组之间的算术操作(比如将一个矩阵与一个矢量相加)。

广播语义

逐个元素的比较操作

另请参阅 ComputationBuilder::Eq

XLA 还支持标准的逐个元素的二元比较操作。注意:当比较浮点类型时,遵循的是标准的 IEEE 754 浮点数语义。

`Op(lhs, rhs)`

其中 Op 可以是如下操作之一: Eq (相等), Ne (不等), Ge (大于或等于), Gt (大于), Le (小于或等于), Lt (小于)。

参数类型语义
lhsComputationDataHandle左操作数:类型为 T 的数组
rhsComputationDataHandle右操作数:类型为 T 的数组

广播语义

要想用广播来比较不同秩的数组,需要用到如下接口:

`Op(lhs, rhs, broadcast_dimensions)`

其中 Op 含义同上。这种接口应该用于不同阶的数组之间的比较操作(比如将一个矩阵加到一个矢量上)。

广播语义

逐个元素的一元函数

ComputationBuilder 支持下列逐个元素的一元函数:

`Abs(operand)` 逐个元素的绝对值 x -> |x|

`Ceil(operand)` 逐个元素的整数上界 x -> ⌈x⌉

`Cos(operand)` 逐个元素的余弦 x -> cos(x)

`Exp(operand)` 逐个元素的自然幂指数 x -> e^x

`Floor(operand)` 逐个元素的整数下界 x -> ⌊x⌋

`IsFinite(operand)` 测试 operand 的每个元素是否是有限的,即不是正无穷或负无穷,也不是 NaN 。该操作返回一个 PRED 值的数组,维度形状与输入一致,数组中的元素当且仅当相应的输入是有限时为 true ,否则为 false

`Log(operand)` 逐个元素的自然对数 x -> ln(x)

`LogicalNot(operand)` 逐个元素的逻辑非 x -> !(x)

`Neg(operand)` 逐个元素取负值 x -> -x

`Sign(operand)` 逐个元素求符号 x -> sgn(x) ,其中

$$\text{sgn}(x) = \begin{cases} -1 & x < 0\ 0 & x = 0\ 1 & x > 0 \end{cases}$$

它使用的是 operand 的元素类型的比较运算符。

`Tanh(operand)` 逐个元素的双曲正切 x -> tanh(x)

参数类型语义
operandComputationDataHandle函数的操作数

该函数应用于 operand 数组的每个元素,从而形成具有相同形状的数组。它允许操作数为标量(秩 0 )

Gather[需要翻译]

The XLA gather operation stitches together several slices (each slice at a potentially different runtime offset) of an input tensor into an output tensor.

General Semantics

See also ComputationBuilder::Gather . For a more intuitive description, see the "Informal Description" section below.

`gather(operand, gather_indices, output_window_dims, elided_window_dims, window_bounds, gather_dims_to_operand_dims)`

ArgumentsTypeSemantics
operandComputationDataHandleThe tensor we’re gathering from.
gather_indicesComputationDataHandleTensor containing the starting indices of the slices we're we're stitching together into the output tensor.
index_vector_dimint64The dimension in gather_indices that contains the starting indices.
output_window_dimsArraySlice<int64>The set of dimensions in the output shape that are window dimensions (defined below). Not all window dimensions may be present in the output shape.
elided_window_dimsArraySlice<int64>The set of window dimensions that are not present in the output shape. window_bounds[i] must be 1 for all i in elided_window_dims .
window_boundsArraySlice<int64>window_bounds[i] is the bounds for window dimension i . This includes both the window dimensions that are explicitly part of the output shape (via output_window_dims ) and the window dimensions that are elided (via elided_window_dims ).
gather_dims_to_operand_dimsArraySlice<int64>A dimension map (the array is interpreted as mapping i to gather_dims_to_operand_dims[i] ) from the gather indices in gather_indices to the operand index space. It has to be one-to-one and total.

For every index Out in the output tensor, we compute two things (more precisely described later):

  • An index into gather_indices.rank - 1 dimensions of gather_indices , which gives us a starting index of a slice, operand slice, in the operand tensor. These gather_indices.rank - 1 dimensions are all the dimensions in gather_indices except index_vector_dim .
  • A window index that has the same rank as the operand. This index is composed of the values in Out at dimensions output_window_dims , embedded with zeroes according to elided_window_dims .

The window index is the relative index of the element in operand slice that should be present in the output at index Out .

The output is a tensor of rank output_window_dims.size + gather_indices.rank - 1 . Additionally, as a shorthand, we define output_gather_dims of type ArraySlice<int64> as the set of dimensions in the output shape but not in output_window_dims , in ascending order. E.g. if the output tensor has rank 5 , output_window_dims is { 2 , 4 } then output_gather_dims is { 0 , 1 , 3 }

If index_vector_dim is equal to gather_indices.rank we implicitly consider gather_indices to have a trailing 1 dimension (i.e. if gather_indices was of shape [6,7] and index_vector_dim is 2 then we implicitly consider the shape of gather_indices to be [6,7,1] ).

The bounds for the output tensor along dimension i is computed as follows:

  1. If i is present in output_gather_dims (i.e. is equal to output_gather_dims[k] for some k ) then we pick the corresponding dimension bounds out of gather_indices.shape , skipping index_vector_dim (i.e. pick gather_indices.shape.dims [ k ] if k < index_vector_dim and gather_indices.shape.dims [ k + 1 ] otherwise).
  2. If i is present in output_window_dims (i.e. equal to output_window_dims [ k ] for some k ) then we pick the corresponding bound out of window_bounds after accounting for elided_window_dims (i.e. we pick adjusted_window_bounds [ k ] where adjusted_window_bounds is window_bounds with the bounds at indices elided_window_dims removed).

The operand index In corresponding to an output index Out is computed as follows:

  1. Let G = { Out [ k ] for k in output_gather_dims }. Use G to slice out vector S such that S [ i ] = gather_indices [Combine( G , i )] where Combine(A, b) inserts b at position index_vector_dim into A. Note that this is well defined even if G is empty -- if G is empty then S = gather_indices .
  2. Create an index, S`in`, into operand using S by scattering S using the gather_dims_to_operand_dims map ( S`in` is the starting indices for operand slice mentioned above). More precisely:
    1. S`in`[ gather_dims_to_operand_dims [ k ]] = S [ k ] if k < gather_dims_to_operand_dims.size .
    2. S`in`[ _ ] = 0 otherwise.
  3. Create an index W`in` into operand by scattering the indices at the output window dimensions in Out according to the elided_window_dims set ( W`in` is the window index mentioned above). More precisely:
    1. W`in`[ window_dims_to_operand_dims ( k )] = Out [ k ] if k < output_window_dims.size ( window_dims_to_operand_dims is defined below).
    2. W`in`[ _ ] = 0 otherwise.
  4. In is W`in` + S`in` where + is element-wise addition.

window_dims_to_operand_dims is the monotonic function with domain [ 0 , output_window_dims.size ) and range [ 0 , operand.rank ) \ elided_window_dims . So if, e.g., output_window_dims.size is 4 , operand.rank is 6 and elided_window_dims is { 0 , 2 } then window_dims_to_operand_dims is { 01 , 13 , 24 , 35 }.

Informal Description and Examples

index_vector_dim is set to gather_indices.rank - 1 in all of the examples that follow. More interesting values for index_vector_dim does not change the operation fundamentally, but makes the visual representation more cumbersome.

To get an intuition on how all of the above fits together, let's look at an example that gathers 5 slices of shape [8,6] from a [16,11] tensor. The position of a slice into the [16,11] tensor can be represented as an index vector of shape S64[2] , so the set of 5 positions can be represented as a S64[5,2] tensor.

The behavior of the gather operation can then be depicted as an index transformation that takes [ G , W`0`, W`1`], an index in the output shape, and maps it to an element in the input tensor in the following way:

We first select an ( X , Y ) vector from the gather indices tensor using G . The element in the output tensor at index [ G , W`0`, W`1`] is then the element in the input tensor at index [ X + W`0`, Y + W`1`].

window_bounds is [8,6] , which decides the range of W`0` and W`1`, and this in turn decides the bounds of the slice.

This gather operation acts as a batch dynamic slice with G as the batch dimension.

The gather indices may be multidimensional. For instance, a more general version of the example above using a "gather indices" tensor of shape [4,5,2]
would translate indices like this:

Again, this acts as a batch dynamic slice G`0` and G`1` as the batch dimensions. The window bounds are still [8,6] .

The gather operation in XLA generalizes the informal semantics outlined above in the following ways:

  1. We can configure which dimensions in the output shape are the window dimensions (dimensions containing W`0`, W`1` in the last example). The output gather dimensions (dimensions containing G`0`, G`1` in the last example) are defined to be the output dimensions that are not window dimensions.
  2. The number of output window dimensions explicitly present in the output shape may be smaller than the input rank. These "missing" dimensions, which are listed explicitly as elided_window_dims , must have a window bound of 1 . Since they have a window bound of 1 the only valid index for them is 0 and eliding them does not introduce ambiguity.
  3. The slice extracted from the "Gather Indices" tensor (( X , Y ) in the last example) may have fewer elements than the input tensor rank, and an explicit mapping dictates how the index should be expanded to have the same rank as the input.

As a final example, we use (2) and (3) to implement tf.gather_nd :

G`0` and G`1` are used to slice out a starting index from the gather indices tensor as usual, except the starting index has only one element, X . Similarly, there is only one output window index with the value W`0`. However, before being used as indices into the input tensor, these are expanded in accordance to "Gather Index Mapping" ( gather_dims_to_operand_dims in the formal description) and "Window Mapping" ( window_dims_to_operand_dims in the formal description) into [ 0 , W`0`] and [ X , 0 ] respectively, adding up to [ X , W`0`]. In other words, the output index [ G`0`, G`1`, W`0`] maps to the input index [ GatherIndices [ G`0`, G`1`, 0 ], X ] which gives us the semantics for tf.gather_nd .

window_bounds for this case is [1,11] . Intuitively this means that every index X in the gather indices tensor picks an entire row and the result is the concatenation of all these rows.

GetTupleElement

另请参阅 ComputationBuilder::GetTupleElement

将索引添加到编译时常量的元组中。

该值必须是编译时常量,这样才可以通过形状推断确定结果值的类型。

概念上,这类似于 C++ 中的 std::get<int N>(t)

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
let element_1: s32 = gettupleelement(t, 1);  // 推断出的形状匹配 s32.

tf.tuple

Infeed

另请参阅 ComputationBuilder::Infeed .

`Infeed(shape)`

参数类型语义
shapeShape从 Infeed 接口读取数据的维度形状。此形状的数据布局必须与发送到设备上的数据相匹配;否则行为是未定义的

从设备的隐式 Infeed 流接口读取单个数据项,根据给定的形状和布局来进行解析,并返回一个此数据的 ComputationDataHandle 。在一个计算中允许有多个 Infeed 操作,但这些 Infeed 操作之间必须是全序的。比如,下面代码中两个 Infeed 是全序的,因为在不同 while 循环之间有依赖关系。

result1 = while (condition, init = init_value) {
  Infeed(shape)
}

result2 = while (condition, init = result1) {
  Infeed(shape)
}

不支持嵌套的元组形状。对于一个空的元组形状,Infeed 操作通常是一个 no-op,因而不会从设备的 Infeed 中读取任何数据。

注意:我们计划允许支持没有全序的多个 Infeed 操作,在这种情况下,编译器将提供信息,确定这些 Infeed 操作在编译后的程序中如何串行化。

映射(Map)

另请参阅 ComputationBuilder::Map

`Map(operands..., computation)`

参数类型语义
operandsN 个 ComputationDataHandle 的序列类型为 T0..T{N-1} 的 N 个数组
computationComputation类型为 T_0, T_1, ..., T_{N + M -1} -> S 的计算,有 N 个类型为 T 的参数,和 M 个任意类型的参数
dimensionsint64 array映射维度的数组
static_operandsM 个 ComputationDataHandle 的序列任意类型的 M 个数组

将一个标量函数作用于给定的 operands 数组,可产生相同维度的数组,其中每个元素都是映射函数(mapped function)作用于相应输入数组中相应元素的结果,而 static_operandscomputation 的附加输入。

此映射函数可以是任意计算过程,只不过它必须有 N 个类型为 T 的标量参数,和单个类型为 S 的输出。输出的维度与输入 operands 相同,只不过元素类型 T 换成了 S。

比如, Map(op1, op2, op3, computation, par1)elem_out <- computation(elem1, elem2, elem3, par1) 将输入数组中的每个(多维)指标映射产生输出数组。

填充(Pad)

另请参阅 ComputationBuilder::Pad

`Pad(operand, padding_value, padding_config)`

参数类型语义
operandComputationDataHandle类型为 T 的数组
padding_valueComputationDataHandle类型为 T 的标量,用于填充
padding_configPaddingConfig每个维度的两端的填充量 (low, high)

通过在数组周围和数组之间进行填充,可以将给定的 operand 数组扩大,其中 padding_valuepadding_config 用于配置每个维度的边缘填充和内部填充的数目。

PaddingConfigPaddingConfigDimension 的一个重复字段,它对于每个维度都包含有三个字段: edge_padding_low , edge_padding_highinterior_paddingedge_padding_lowedge_padding_high 分别指定了该维度上低端(指标为 0 那端)和高端(最高指标那端)上的填充数目。边缘填充数目可以是负值 — 负的填充数目的绝对值表示从指定维度移除元素的数目。 interior_padding 指定了在每个维度的任意两个相邻元素之间的填充数目。逻辑上,内部填充应发生在边缘填充之前,所有在负边缘填充时,会从经过内部填充的操作数之上再移除边缘元素。如果边缘填充配置为 (0, 0),且内部填充值都是 0,则此操作是一个 no-op。下图展示的是二维数组上不同 edge_paddinginterior_padding 值的示例。

Recv

另请参阅
ComputationBuilder::Recv .

`Recv(shape, channel_handle)`

参数类型语义
shapeShape要接收的数据的形状
channel_handleChannelHandle发送/接收对的唯一标识

从另一台共享相同通道句柄的计算机的 Send 指令接收指定形状的数据,返回一个接收数据的 ComputationDataHandle。

客户端 Recv 操作的客户端 API 是同步通信。但是,指令内分解成 2 个 HLO 指令( RecvRecvDone )用于异步数据传输。请参考 HloInstruction::CreateRecvHloInstruction::CreateRecvDone

`Recv(const Shape& shape, int64 channel_id)`

分配资源从具有相同 channel_id 的 Send 指令接收数据。返回已分配资源的上下文,该上下文随后通过 RecvDone 指令等待数据传输完成。上下文是 {接收缓冲区 (形状), 请求标识符(U32)} 的元组,且只能用于 RecvDone 指令。

`RecvDone(HloInstruction context)`

给定一个由 Recv 指令创建的上下文,等待数据传输完成并返回接收的数据。

Reduce

另请参阅 ComputationBuilder::Reduce

将一个归约函数作用于一个数组。

`Reduce(operand, init_value, computation, dimensions)`

参数类型语义
operandComputationDataHandle类型为 T 的数组
init_valueComputationDataHandle类型为 T 的标量
computationComputation类型为 T, T -> T 的计算
dimensionsint64 数组待归约的未排序的维度数组

从概念上看,归约(Reduce)操作将输入数组中的一个或多个数组归约为标量。结果数组的秩为 rank(operand) - len(dimensions)init_value 是每次归约的初值,如果后端有需求也可以在计算中插入到任何地方。所以,在大多数情况下, init_value 应该为归约函数的一个单位元(比如,加法中的 0)。

归约函数的执行顺序是任意的,即可能是非确定的。因而,归约函数不应对运算的结合性敏感。

有些归约函数,比如加法,对于浮点数并没有严格遵守结合率。不过,如果数据的范围是有限的,则在大多数实际情况中,浮点加法已经足够满足结合率。当然,我们也可以构造出完全不遵守结合率的归约函数,这时,XLA 归约就会产生不正确或不可预测的结果。

下面是一个示例,对 1D 数组 [10, 11, 12, 13] 进行归约,归约函数为 f (即参数 computation ),则计算结果为:

f(10, f(11, f(12, f(init_value, 13)))

但它还有其它很多种可能性,比如:

f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(13, init_value))))

下面是一段实现归约的伪代码,归约计算为求和,初值为 0。

result_shape <- 从 operand_shape 的维度中移除所有待归约的维度

# 遍历 result_shape 中的所有元素,这里,r 的数目等于 result 的秩
for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
  # 初始化 result 的元素
  result[r0, r1...] <- 0

  # 遍历所有的归约维度
  for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
    # 用 operand 的元素的值来增加 result 中的元素的值
    # operand 的元素的索引由所有的 ri 和 di 按正确的顺序构造而来
    # (构造得到的索引用来访问 operand 的整个形状)
    result[r0, r1...] += operand[ri... di]

下面是一个对 2D 数组(矩阵)进行归约的示例。其形状的秩为 2,0 维大小为 2,1 维大小为 3:

对 0 维或 1 维进行求和归约:

注意,两个归约结果都是一维数组。图中将一个显示为行,另一个显示为列,但这只是为了可视化效果。

下面是一个更复杂的 3D 数组的例子。它的秩为 3 ,形状为 (4,2,3)。为简单起见,我们让 1 到 6 这几个数字沿 0 维复制 4 份。

类似于二维的情况,我们可以只归约一个维度。如果我们归约第 0 维,我们得到一个二阶数组,它沿第 0 维的所有值会合并为一个标量:

|  4   8  12 |
| 16  20  24 |

如果我们归约第 2 维,结果仍然是一个二阶数组,沿第 2 维的所有值合并为一个标量:

| 6  15 |
| 6  15 |
| 6  15 |
| 6  15 |

注意,输出中剩下的维度的顺序与它们在输入中的相对顺序保持一致,只不过维度的名称(数字)会发生变化,因为数组的秩发生了变化。

我们也可以归约多个维度。对 0 维和 1 维进行求和归约,将得到一个一维数组 | 20 28 36 |

对这个三维数组的所有元素进行求和归约,得到一个标量 84

ReducePrecision

另请参阅 ComputationBuilder::ReducePrecision

当浮点数转换为低精度格式(比如 IEEE-FP16)然后转换回原格式时,值可能会发生变化,ReducePrecision 对这种变化进行建模。低精度格式中的指数(exponent)和尾数(mantissa)的位数目是可以任意指定的,不过不是所有硬件实现都支持所有的位大小。

`ReducePrecision(operand, mantissa_bits, exponent_bits)`

参数类型语义
operandComputationDataHandle浮点类型 T 的数组
exponent_bitsint32低精度格式中的指数位数
mantissa_bitsint32低精度格式中的尾数位数

结果为类型为 T 的数组。输入值被舍入至与给定尾数位的数字最接近的那个值(采用的是"偶数优先"原则)。而超过指数位所允许的值域时,输入值会被视为正无穷或负无穷。 NaN 值会保留,不过它可能会被转换为规范化的 NaN 值。

低精度格式必须至少有一个指数位(为了区分零和无穷,因为两者的尾数位都为零),且尾数位必须是非负的。指数或尾数位可能会超过类型 T ;这种情况下,相应部分的转换就仅仅是一个 no-op 了。

ReduceWindow

另请参阅 ComputationBuilder::ReduceWindow

将一个归约函数应用于输入多维数组的每个窗口内的所有元素上,输出一个多维数组,其元素个数等于合法窗口的元素数目。一个池化层可以表示为一个 ReduceWindow

`ReduceWindow(operand, init_value, computation, window_dimensions, window_strides, padding)`

参数类型语义
operandComputationDataHandle类型为 T 的 N 维数组。这是窗口放置的底空间区域
init_valueComputationDataHandle归约的初始值。细节请参见 规约
computationComputation类型为 T, T -> T 的归约函数,应用于每个窗口内的所有元素
window_dimensionsArraySlice<int64>表示窗口维度值的整数数组
window_stridesArraySlice<int64>表示窗口步长值的整数数组
paddingPadding窗口的边缘填充类型(Padding\:\:kSame 或 Padding\:\:kValid)

下列代码和图为一个使用 ReduceWindow 的示例。输入是一个大小为 [4x6] 的矩阵,window_dimensions 和 window_stride_dimensions 都是 [2x3]。

// 创建一个归约计算(求最大值)
Computation max;
{
  ComputationBuilder builder(client_, "max");
  auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
  auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
  builder.Max(y, x);
  max = builder.Build().ConsumeValueOrDie();
}

// 用最大值归约计算来创建一个 ReduceWindow 计算
ComputationBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
    input, *max,
    /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
    /*window_dimensions=*/{2, 3},
    /*window_stride_dimensions=*/{2, 3},
    Padding::kValid);

在维度中,步长为 1 表示在此维度上两个相邻窗口间隔一个元素,为了让窗口互相不重叠,window_stride_dimensions 和 window_dimensions 应该要相等。下图给出了两种不同步长设置的效果。边缘填充应用于输入的每个维度,计算过程实际发生在填充之后的数组上。

归约函数的执行顺序是任意的,因而结果可能是非确定性的。所以,归约函数应该不能对计算的结合性太过敏感。更多细节,参见 Reduce 关于结合性的讨论。

Reshape

另请参阅 ComputationBuilder::ReshapeCollapse 操作。

变形操作(reshape)是将一个数组的维度变成另外一种维度设置。

`Reshape(operand, new_sizes)`
`Reshape(operand, dimensions, new_sizes)`

参数类型语义
operandComputationDataHandle类型为 T 的数组
dimensionsint64 vector维度折叠的顺序
new_sizesint64 vector新维度大小的矢量

从概念上看,变形操作首先将一个数组拉平为一个一维矢量,然后将此矢量展开为一个新的形状。输入参数是一个类型为 T 的任意数组,一个编译时常量的维度指标数组,以及表示结果维度大小的一个编译时常量的数组。如果给出了 dimensions 参数,这个矢量中的值必须是 T 的所有维度的一个置换,其默认值为 {0, ..., rank - 1}dimensions 中的维度的顺序是从最慢变化维(最主序)到最快变化维(最次序),按照这个顺序依次将所有元素折叠到一个维度上。 new_sizes 矢量决定了输出数组的维度大小。 new_sizes[0] 表示第 0 维的大小, new_sizes[1] 表示的是第 1 维的大小,依此类推。 new_sizes 中的维度值的乘积必须等于 operand 的维度值的乘积。将折叠的一维数组展开为由 new_sizes 定义的多维数组时, new_sizes 中的维度的顺序也是最慢变化维(最主序)到最快变化维(最次序)。

比如,令 v 为包含 24 个元素的数组:

let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}},
                    {{20, 21, 22}, {25, 26, 27}},
                    {{30, 31, 32}, {35, 36, 37}},
                    {{40, 41, 42}, {45, 46, 47}}};

依次折叠:
let v012_24 = Reshape(v, {0,1,2}, {24});
then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
                         30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};

let v012_83 = Reshape(v, {0,1,2}, {8,3});
then v012_83 == f32[8x3] {{10, 11, 12}, {15, 16, 17},
                          {20, 21, 22}, {25, 26, 27},
                          {30, 31, 32}, {35, 36, 37},
                          {40, 41, 42}, {45, 46, 47}};

乱序折叠:
let v021_24 = Reshape(v, {1,2,0}, {24});
then v012_24 == f32[24]  {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
                          15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};

let v021_83 = Reshape(v, {1,2,0}, {8,3});
then v021_83 == f32[8x3] {{10, 20, 30}, {40, 11, 21},
                          {31, 41, 12}, {22, 32, 42},
                          {15, 25, 35}, {45, 16, 26},
                          {36, 46, 17}, {27, 37, 47}};


let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
then v021_262 == f32[2x6x2] {{{10, 20}, {30, 40},
                              {11, 21}, {31, 41},
                              {12, 22}, {32, 42}},
                             {{15, 25}, {35, 45},
                              {16, 26}, {36, 46},
                              {17, 27}, {37, 47}}};

作为特例,单元素数组和标量之间可以用变形操作相互转化。比如:

Reshape(f32[1x1] {{5}}, {0,1}, {}) == 5;
Reshape(5, {}, {1,1}) == f32[1x1] {{5}};

Rev (反转)

另请参阅 ComputationBuilder::Rev

`Rev(operand, dimensions)`

参数类型语义
operandComputationDataHandle类型为 T 的数组
dimensionsArraySlice<int64>待反转的维度

反转操作是将 operand 数组沿指定的维度 dimensions 对元素的顺序反转,产生一个形状相同的数组。operand 数组的每个元素被存储在输出数组的变换后的位置上。元素的原索引位置在每个待倒置维度上都被反转了,得到其在输出数组中的索引位置(即,如果一个大小为 N 的维度是待倒置的,则索引 i 被变换为 N-i-i)。

Rev 操作的一个用途是在神经网络的梯度计算时沿两个窗口维度对卷积权重值进行倒置。

RngNormal

另请参阅 ComputationBuilder::RngNormal

RngNormal 构造一个符合 $$(\mu, \sigma)$$ 正态随机分布的指定形状的随机数组。参数 musigma 为 F32 类型的标量值,而输出形状为 F32 的数组。

`RngNormal(mean, sigma, shape)`

参数类型语义
muComputationDataHandle类型为 F32 的标量,指定生成的数的均值
sigmaComputationDataHandle类型为 F32 的标量,指定生成的数的标准差
shapeShape类型为 F32 的输出的形状

RngUniform

另请参阅 ComputationBuilder::RngUniform

RngNormal 构造一个符合区间 $$[a,b)$$ 上的均匀分布的指定形状的随机数组。参数和输出形状可以是 F32、S32 或 U32,但是类型必须是一致的。此外,参数必须是标量值。如果 $$b <= a$$,输出结果与具体的实现有关。

`RngUniform(a, b, shape)`

参数类型语义
aComputationDataHandle类型为 T 的标量,指定区间的下界
bComputationDataHandle类型为 T 的标量,指定区间的上界
shapeShape类型为 T 的输出的形状

Select

另请参阅
ComputationBuilder::Select .

基于 predicate 数组的值,从两个输入数组构造输出数组。

`Select(pred, on_true, on_false)`

参数类型语义
predComputationDataHandle类型为 PRED 的数组
on_trueComputationDataHandle类型为 T 的数组
on_falseComputationDataHandle类型为 T 的数组

数组 on_trueon_false 的形状必须相同。这也是输出数组的形状。数组 pred 必须与 on_trueon_false 具有相同的维度,且值为 PRED 类型。

对于 pred 的每个元素 P ,当 P 值为 true 时,相应的输出值从 on_true 中获取,否则从 on_false 中获取。由于 broadcasting 限制, pred 可以是类型为 PRED 的标量。此时,当 pred 值为 true 时,输出数组为 on_true ,否则为 on_false

非标量 pred 的示例如下:

let pred: PRED[4] = {true, false, false, true};
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};

标量 pred 的示例如下:

let pred: PRED = true;
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};

支持元组之间的 Selections 操作。因此元组认为是标量类型。如果 on_trueon_false 为元组(必须形状相同),则 pred 必须是类型为 PRED 的标量。

SelectAndScatter

另请参阅 ComputationBuilder::SelectAndScatter

这个操作可视为一个复合操作,它先在 operand 数组上计算 ReduceWindow ,以便从每个窗口中选择一个数,然后将 source 数组散布到选定元素的指标位置上,从而构造出一个与 operand 数组形状一样的输出数组。二元函数 select 用于从每个窗口中选出一个元素,当调用此函数时,第一个参数的指标矢量的字典序小于第二个参数的指标矢量。如果第一个参数被选中,则 select 返回 true ,如果第二个参数被选中,则返回 false 。而且该函数必须满足传递性,即如果 select(a, b)select(b, c) 都为 true ,则 select(a, c) 也为 true 。这样,被选中的元素不依赖于指定窗口中元素访问的顺序。

scatter 函数作用在输出数组的每个选中的指标上。它有两个标量参数:

  1. 输出数组中选中指标处的值
  2. source 中被放置到选中指标处的值

它根据这两个参数返回一个标量值,用于更新输出数组中选中指标处的值。最开始的时候,输出数组所有指标处的值都被设为 init_value

输出数组与 operand 数组的形状相同,而 source 数组必须与 operand 上应用 ReduceWindow 之后的形状相同。 SelectAndScatter 可用于神经网络池化层中梯度值的反向传播。

`SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, init_value, scatter)`

参数类型语义
operandComputationDataHandle类型为 T 的数组,窗口在它上面滑动
selectComputation类型为 T, T -> PRED 的二元计算,它被应用到每个窗口中的所有元素上;如果选中第一个元素返回 true ,如果选中第二个元素返回 false
window_dimensionsArraySlice<int64>表示窗口维度值的整数数组
window_stridesArraySlice<int64>表示窗口步长值的整数数组
paddingPadding窗口边缘填充类型(Padding\:\:kSame 或 Padding\:\:kValid)
sourceComputationDataHandle类型为 T 的数组,它的值用于散布
init_valueComputationDataHandle类型为 T 的标量值,用于输出数组的初值
scatterComputation类型为 T, T -> T 的二元计算,应用于 source 的每个元素和它的目标元素

下图为 SelectAndScatter 的示例,其中 select 函数计算它的参数中的最大值。注意,当窗口重叠时,如图 (2) 所示, operand 的一个指标可能会被不同窗口多次选中。在此图中,值为 9 的元素被顶部的两个窗口(蓝色和红色)选中,从而二元加法函数 scatter 产生值为 8 的输出值(2+6)。

scatter 函数的执行顺序是任意的,因而可能会出现不确定的结果。所以, scatter 函数不应该对计算的结合性过于敏感。更多细节,参见 Reduce 一节中关于结合性的讨论。

Send

另请参阅 ComputationBuilder::Send

`Send(operand, channel_handle)`

参数类型语义
operandComputationDataHandle待发送的数据(类型为 T 的数组)
channel_handleChannelHandle发送/接收 对的唯一标识符

将给定的 operand 数据发送到另一台计算机上共享相同通道句柄的 Recv 中。不返回任何数据。

Recv 操纵类似, Send 操作的客户端 API 为同步通信,并在内部分解为 2 个 HLO 指令( SendSendDone )以使用异步数据传输。另请参阅 HloInstruction::CreateSendHloInstruction::CreateSendDone

`Send(HloInstruction operand, int64 channel_id)`

发起 operand 的异步传输过程,将数据传输到具有相同通道 id 的 Recv 指令分配的资源中。返回一个上下文,随后使用 SendDone 指令等待数据传输完成。上下文是 {operand (shape), request identifier
(U32)} 的二元组,且只能用于 SendDone 指令。

`SendDone(HloInstruction context)`

根据 Send 指令创建的上下文,等待数据传输完成。指令不返回任何数据。

Scheduling of channel instructions

每个通道的 4 个指令 ( Recv , RecvDone , Send , SendDone ) 的执行顺序如下。

  • Recv happens before Send
  • Send happens before RecvDone
  • Recv happens before RecvDone
  • Send happens before SendDone

当后端编译器为通过通道指令进行通信的每一个计算生成一个线性调度时,在计算过程中不能有循环。例如,下面的调度会产生死循环。

Slice

另请参阅 ComputationBuilder::Slice

Slice 用于从输入数组中提取出一个子数组。子数组与输入数组的秩相同,它的值在输入数组的包围盒中,此包围盒的维度和指标作为 slice 操作的参数而给出。

`Slice(operand, start_indices, limit_indices)`

参数类型语义
operandComputationDataHandle类型为 T 的 N 维数组
start_indicesArraySlice<int64>N 个整数的数组,包含每个维度的切片的起始指标。值必须大于等于零
limit_indicesArraySlice<int64>N 个整数的数组,包含每个维度的切片的结束指标(不包含)。每个维度的结束指标必须严格大于其起始指标,且小于等于维度大小

1-维示例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
Slice(a, {2}, {4}) produces:
  {2.0, 3.0}

2-维示例:

let b =
 { {0.0,  1.0,  2.0},
   {3.0,  4.0,  5.0},
   {6.0,  7.0,  8.0},
   {9.0, 10.0, 11.0} }

Slice(b, {2, 1}, {4, 3}) produces:
  { { 7.0,  8.0},
    {10.0, 11.0} }

Sort

另请参阅 ComputationBuilder::Sort

Sort 用于对输入数组中的元素进行排序。

`Sort(operand)`

参数类型语义
operandComputationDataHandle待排序数组

Transpose

tf.reshape

`Transpose(operand)`

参数类型语义
operandComputationDataHandle待转置的数组
permutationArraySlice<int64>指定维度重排列的方式

Transpose 将 operand 数组的维度重排列,所以
∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]

这等价于 Reshape(operand, permutation, Permute(permutation, operand.shape.dimensions))。

Tuple

另请参阅 ComputationBuilder::Tuple

一个元组(tuple)包含一些数据句柄,它们各自都有自己的形状。

概念上看,它类似于 C++ 中的 std::tuple

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);

元组可通过 GetTupleElement 操作来解析(访问)。

While

另请参阅 ComputationBuilder::While

`While(condition, body, init)`

参数类型语义
conditionComputation类型为 T -> PRED 的计算,它定义了循环终止的条件
bodyComputation类型为 T -> T 的计算,它定义了循环体
initTconditionbody 的参数的初始值

While 顺序执行循环体 body ,直到 condition 失败。这类似于很多语言中的 while 循环,不过,它有如下的区别和限制:

  • 一个 While 结点有一个类型为 T 的返回值,它是最后一次执行 body 的结果。
  • 类型为 T 的形状是由统计确定的,在整个迭代过程中,它都是保持不变的。
  • While 结点之间不允许嵌套。这个限制可能会在未来某些目标平台上取消。

该计算的类型为 T 的那些参数使用 init 作为迭代的第一次计算的初值,并在接下来的迭代中由 body 来更新。

While 结点的一个主要使用安例是实现神经网络中的训练的重复执行。下面是一个简化版的伪代码,和一个表示计算过程的图。实际代码可以在 while_test.cc 中找到。此例中的 T 类型为一个 Tuple ,它包含一个 int32 值,表示迭代次数,还有一个 vector[10] ,用于累加结果。它有 1000 次迭代,每一次都会将一个常数矢量累加到 result(1) 上。

// Pseudocode for the computation.
init = {0, zero_vector[10]} // Tuple of int32 and float[10].
result = init;
while (result(0) < 1000) {
  iteration = result(0) + 1;
  new_vector = result(1) + constant_vector[10];
  result = {iteration, new_vector};
}

如果您发现本页面存在错误或可以改进,请 点击此处 帮助我们改进。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文