This document was originally written by Yizhi Liu

Forward

Set auto_broadcast to True as is shown here. TVM will map buffer[i][j][k] to buffer[i][0][k] if dimension j’s shape equals 1.

Backward

Overview and the ideas behind the implementation

TVM op kernels are typically generated in compile time, at that point of time the input shapes remain unknown. In this scenario, we have no idea about which axes are to broadcast. In contrast with forward computation, the broadcasting axes will be reduced by summation in backward, and the axes to be reduced (and thus the axes to broadcast) must be known at compile time.

A natural solution is to enumerate all cases. For example, if A.shape=(m, n), we have 4 cases:

  • m != 1 && n != 1
  • m == 1 && n != 1
  • m != 1 && n==1
  • m == 1 && n == 1

We can therefore generate 4 TVM kernels to cover all the cases.

Generally, suppose that we are considering two input operands of n dim, and we label each dim as 1 if it needs broadcasting, and 0 otherwise. With the setting we can generate bit strings correspondingly (e.g., in the above example, we have bit strings “00”, “10”, “01” and “11”). But the problem is that we may have as many as 2^n bit strings, and thus 2^n TVM kernels, which is incredibly large.

The optimized version is to merge consecutive 1s and 0s within the bit string. It’s easy to verify that consecutive broadcasting axes can jointly form a combined axis, and so can the consecutive axes that do not broadcast.

For example, if A.shape=(m, n, k), originally we have 2^3=8 cases, but after merging, only 2*3=6 cases (which are 0, 1, 01, 10, 010, and 101) are left:

originalmerged
0000
00101
010010
01101
10010
101101
11010
1111

Note that after merging, two consecutive bits in the bit string must be different, which indicates that the bit string is uniquely determined by its leading bit, so the number of possible bit strings gets reduced from 2^n to a 2*n.

The following two sections elaborate how to implement this algorithm for simple operators (whose input gradients are ONLY related with output gradients, that is, USE_NONE)

Runtime: Before Invoking TVM Kernels

Firstly, the shapes of the input operands are padded with 1 so that they share the same number of dimensions.

Then, we identify which axis is to broadcast and which is not to, label them with 1 and 0 in a bit string, and merge the consecutive 1s and 0s.

Finally, invoke the TVM kernel with the reshaped output gradient.

For example, two inputs x and y with shapes [2, 2, 1, 2, 2] and [1, 1, 2, 2, 1], and we are calculating the gradient of x. The output gradient is of shape [2, 2, 2, 2, 2]. We will have bit string “00100”, and the reshaped output gradient is of shape [4, 2, 4] (because we merge first two dimensions as well as last two dimensions).

Similarly, suppose we are calculating the gradient of y. The bit string will be “11001” and the reshaped output gradient is of shape [4, 4, 2].

Compile-time: Generating TVM Kernels

For a specific input dimension, two kernels are to be generated. One for the case where the first axis is to broadcast, and the other for the case where the first axis does not broadcast. Just sum over the broadcasting axes. A helper function reduce_axes is available for this.

Take the add operator as an example:

def compute_backward_vadd(dtype, ndim, reduce1st, req):
axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim]
X = tvm.placeholder([tvm.var() for _ in range(ndim)], name='X', dtype=dtype)
reducer = tvm.comm_reducer(lambda x, y: x + y,
lambda t: tvm.const(0, dtype=t), name="sum")
ret = reduce_axes(X, axes, reducer)
in_grad_a, in_grad = assign_by_req(ret, req)
s = tvm.create_schedule(in_grad.op)
return s, X, in_grad_a, in_grad, [ret, in_grad]


@defop(name="backward_vadd", target="cpu", dtype=AllTypes,
ndim=[5], reduce1st=[0, 1],
req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"])
def backward_vadd(dtype, ndim, reduce1st, req):
s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [X, in_grad_a, in_grad]

Here we are generating kernels for two cases: 01010 and 10101. The ndim variable determines the length of the bit string, and the reduce1st variable determines its leading bit.

The backward computation takes place in the function compute_backward_vadd. Firstly, in the function, axes is either assigned as 01010 or 10101, as is determined by reduct1st. Then, we set the reducer as summation. Actually, all backward computation for any operator reduces the axes by summation. Finally, the reduction is accomplished by reduce_axes, which reduces the dims labeled by axes as 1.

Some detailed explanation about req are omitted, which controls the operation request type and is generally unrelated with broadcasting.

Complicated operators

For operands like multiply, input gradients are related not only with the output gradient, but also input data. Suppose the inputs data for multiplication are x and y, and the output gradient is z. And we are to compute the gradient of x.

First, compute temp = y * z with auto broadcast.

Then, merge the axes of temp as is stated above and sum over the interleaving axes.

  • No labels