Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

There are two possible approaches of implementing the control flow operators.

Approach 1: The first approach is to stitch the keep all computation graph in the user-defined functions with the main in a single graph and insert some special operators in the graph to change the executor behavior (e.g., avoid executing a sequence of operators, repeat execution of a sequence of operators). These special operators resembles the jump instruction in CPU. TensorFlow adopts takes this approach. Its implementation The detailed design can be found in http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf

Approach 2: The second approach is to maintain multiple computation graphs. There is a main graph that contains the control flow operators. Each control flow operator contains its own computation graph(s) and is responsible for executing the computation graphs inside the operator. This approach is represents computation more like how a high-level programming language handles control flows (a piece of the code is broken into basic block for compilation and execution).

Both approaches have their pros and cons. I think the second approach will be preferred in MXNet because the execution of MXNet is more static than TensorFlow (e.g., MXNet requires static shape and data type inference and static memory planning). The second approach also allows easier graph reconstruction for optimization-level optimization and easier implementation. From here on, I'll discuss the second approach in more details.
The
Given the API (shown in the end of the proposal), the first step for approach 2 is to build the a subgraph in the user-defined functions of a control flow operator and pass them to the operator. Below I'll use ``foreach’’ as an example to describe the implementation of a symbolic control flow operator. To help Gluon hybridization, we also implement an NDArray version of these operators.

Building a subgraph: a control flow operator invokes the user-defined functions (UDFs) once to create a computation subgraph. The subgraph of ``foreach’’ has four sets of inputs: slices from the input arrays we iterate over, the states from the previous iteration or the initial states, the variables outside the UDF and the variables defined inside the UDF. The initial states and the variables defined outside the UDF reference to the symbols in the main graph and, thus, connect the subgraph in the UDF with the main graph. However, the main graph and the subgraph have to be completely disjoint (otherwise, stateful operators may have different behaviors). To build a completely separate graph, we create new variable symbols to replace the ones in ``state'' and in the closure. It's easy to find the symbols passed in as ``state''. However, it's more difficult to identify the symbols in the closure and the ones defined inside the UDF, which aren't the inputs of the control flow operator.

...

The proposed APIs of the flow control operators are listed below:

...

``cond’’

``ifelse’’ ``cond’’ invokes different computations based on a certain condition.

ifelsecond(condpred, ifthen_func, else_func, inputs)


Input arguments:

  • ``cond’’ ``pred’’ is a user-defined function that returns symbol/NDArray that contains a boolean scalar symbol to define which function to compute.
  • ``if``then_func’’ and ``else_func’’ are also user-defined functions whose signature is defined below.``inputs’’ is a list of symbols that represent NDArrays.

Return value:

  • A list of symbols/NDArrays returned from one of the user-defined functions.


The signature of ``cond’’ is

def func(inputs): boolean symbol

where ``inputs’’ is the same as ``inputs’’ of ``ifelse’’ and the return value is a boolean scalar symbol.
The signature of ``if_``then_func’’ and ``else_func’’ is

def func(inputs): outputs

where ``inputs’’ is the same as ``inputs’’ of ``ifelse’’ and ``outputs’’ is a list of symbols that represent /NDArrays.
``if
``then_func’’ and ``else_func’’ should return the same number of outputs with the same types and shapes. This is compatible with ``cond'' in TensorFlow, except the restriction in shapes.

...

``foreach’’ is a special form of loops. It's designed to have easy shape inference and other optimizations. It iterates over the first dimension of the input NDArray/Symbol, so the number of iterations is determined before entering the loop.

foreach(funcbody, input, state)


Input arguments:

  • ``input'' is a symbol that represents an NDArray/NDArray or a list of symbols/NDArrays.
  • ``func’’ ``body’’ is a user-defined function that defines computation for each iteration.
  • ``state’’ is a list of NDArray symbols/NDArrays passed to ``func’’ ``body’’ as part of the inputs for the first iteration.

Return values:

  • A tuple of NDArray (out_data, state), where ``out_data’’ is an NDArray a symbol/NDArray or a list of symbols/NDArrays that is a concatenation of all outputs from ``func’’ ``body’’ and ``state’’ is the output state in the last iteration.


The signature of ``func’’ ``body’’ is

def funcbody(input, state): output, new_state

``input'' is a symbol that will contain an element /NDArray or a list of symbols/NDArrays that is a slice from the input array arrays of ``foreach’’; ``state'' is a list of symbols/NDArrays that represent data from the previous iteration; ``output'' is a symbol that contains the output data generated by /NDArray or a list of symbols/NDArrays that contains the output data generated in this iteration; ``new_state'' is a list of symbols/NDArrays that contain data passed to the next iteration. All outputs ``output'' from this function are concatenated into a single NDArray as the output of ``foreach’’. As such, the shape and type of the output ``output'' from each iteration should always be the same. ``func``body'' is invoked once to generate a symbol that represents the computation in the function.

``foreach’’ is similar to ``scan’’ in TensorFlow. The only difference is that ``func’’ ``body’’ in ``foreach’’ has two types of outputs: one is concatenated as the output of ``foreach’’; the other outputs are passed to ``func’’ ``body’’ as input for the next iteration. In contrast, ``scan’’ concatenates the outputs of ``func’’ ``body’’ as the output of ``scan’’ and also passes them to ``func’’ ``body’’ as one of the inputs for the next iteration. This difference makes the implementation of LSTM with ``foreach’’ simpler and more efficient than ``scan’’ in TensorFlow.

``while_loop’’

In addition, ``foreach'' allows ``body'' to only return empty ``output'' or empty ``state''. If ``state'' is empty, ``foreach'' becomes map. If ``output'' is empty, ``foreach'' becomes scan.

``while_loop’’

``while_loop’’ is the general ``while_loop’’ is the general form of a loop: at the beginning of each iteration, it checks a condition function to determine the termination of the loop. As such, it It is difficult to determine the number of iterations in advance and is more difficult to optimize ``while_loop’’ than ``foreach’’.

...

  • ``cond’’ is a user-defined function that takes ``loop_vars’’ as input and return a boolean scalar symbol/NDArray to determine the termination of the loop.
  • ``func’’ is a user-defined function that takes ``loop_vars’’ as input and performs computation of an iteration. There are two potential signatures as described below.
  • ``loop_``loop_vars’’ is a list of symbols that represent NDArrays.
  • ``max_iterations’’ is a python scalar or a MXNet scalar symbol that defines the maximal number of iterations. When ``max_iterations’’ is a python scalar, the maximal number of iterations is defined statically (when the computation is constructed); when ``max_iterations’’ is a MXNet scalar symbol, the maximal number of iterations is defined at runtime (when the computation graph is executed).

Return value:

  • Depending on the signature of ``func’’, there are two potential ways of returning values.

The signature of ``cond’’:

def cond(loop_vars): boolean scalar

...

Return value:

  • Depending on the signature of ``func’’, there are two potential ways of returning values.

The signature of ``cond’’:

def cond(loop_vars): boolean scalar


The signature of ``func’’

def func(loop_vars): new_loop_vars

In this option, we only require ``loop_vars'' to have the same type for each iteration and their shape can change. ``while_loop’’ returns the return values of the last invocation of ``func’’. This interface is similar to the one in TensorFlow and is very flexible.
Option 2:

def func(loop_vars): (output, new_loop_vars)

In this option, ``output'' from each iteration will be concatenated and returned as the output of ``while_loop’’. We can require ``output'' to have the same shape and data type. We probably require arrays in ``new_loop_vars'' to have the same shape and data type as ``loop_vars''. This interface is similar to the definition of ``loop'' in ONNX and is more restrictive.
For both options, it
It is difficult to inference the shape of the output of ``while_loop’’ statically because we cannot determine the number of iterations required by ``while_loop’’ in advance. For the second option, even though we can't infer the shape, we may still be able to determine the maximal memory size used by the output of the loop. As such, the second option might be preferred even though it's more restrictiveBecause MXNet symbols don't support dynamic shape, we currently use ``max_iterations'' to determine the first dimension of the output arrays.