Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

There are two possible approaches of implementing symbol control flow operators.

Approach 1: The approach is to keep all computation in a single graph and insert some special operators in the graph to change the executor behavior (e.g., avoid executing a sequence of operators, repeat execution of a sequence of operators). These special operators resembles the jump instruction in CPU. TensorFlow takes this approach. The detailed design can be found in http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf

Approach 2: The approach is to maintain multiple computation graphs. There is a main graph that contains control flow operators. Each control flow operator contains its own computation graph(s) and is responsible for executing the computation graphs inside the operator. This approach represents computation more like how a high-level programming language handles control flows (the code is broken into basic block for compilation and execution).

Both approaches have their pros and cons. I think the second approach will be preferred in MXNet because the execution of MXNet is more static than TensorFlow (e.g., MXNet requires static shape and data type inference and static memory planning). The second approach also allows easier graph-level optimization and easier implementation. From here on, I'll discuss the second approach in more details.

Given the API (shown in the end of the proposal), the first step for approach 2 is to build a subgraph in the user-defined Python functions of a control flow operator and pass them to the operator. I'll use ``foreach’’ as an example to describe the implementation of a symbolic control flow operator.

Building a subgraph: a control flow operator invokes the user-defined functions (UDFs) Python functions once to create a computation subgraph. The Python functions have arguments as input and can also access any variables in its scope. Therefore, the subgraph of ``foreach’’ has four sets of inputs: slices from the input arrays we iterate over, the states from the previous iteration or the initial states, the variables outside the UDF Python functions and the variables defined inside the UDFPython functions. The initial states and the variables defined outside the UDF Python functions reference to the symbols in the main graph and, thus, connect the subgraph in the UDF the Python functions with the main graph. However, the main graph and the subgraph have to be completely disjoint (otherwise, stateful operators may have different behaviors). To build a completely separate graph, we create new variable symbols to replace the ones in ``state'' and in the closure. It's easy to find the symbols passed in as ``state''. However, it's more difficult to identify the symbols in the closure and the ones defined inside the UDFPython functions, which aren't the inputs of the control flow operator.

  • For symbols in the closure, we perform a graph transformation. When we create a computation graph from the user-defined Python function, we mark the nodes with a special label. Then we traverse the nodes in the computation graph to find the unmarked nodes, which are the ones not created in the UDFPython functions. We remove the nodes from the computation graph and create new symbols to reference them. Later on, we pass the output of the original nodes to the control flow operator as its inputs.
  • For symbols defined inside the UDFPython functions, we make a copy of the original symbols and pass them to the control flow operator as inputs.


Pass the subgraph to a control flow operator: The previous process creates a symbol that represent the computation graph in the user-defined Python function. Now we need pass it to the operator as a graph. By default, NNVM interprets input symbols of an operator as input data (see Symbol::Compose in NNVM) and use them to build connections between nodes, which results in a computation graph. To pass a symbol as a subgraph to an operator, we need to distinguish the graph symbols from the data symbols in Symbol::Compose. We create a new operator attribute (FInputGraph) to identify the graph symbols in the inputs of an operator. Once a graph symbol is spotted, nnvm stores it in NodeAttrs of the created node.

Shape and type inference:

  • ``foreach'':

...

  • shape and type inference on ``foreach'' this operator is easy. Because the input shapes and types in each iteration is the same, we only need to infer the subgraph once. After having the shape and the type of the output of the subgraph, we can easily calculate the shape and the type of the output of ``foreach''.
  • ``cond'': If we require the outputs of if``then_fun'' and else``else_func'' have the same shape and data type, the shape and type inference can also be simple.
  • ``while_loop'': We can't perform static shape inference on ``while_loop'' because the shape of its output arrays depends on the number of iterations governed by the condition subgraph. As such, we need to extend MXNet to support dynamic shape inference.

...

  • We can execute the subgraph outside the execution engine in the main thread like a normal function call. In this way, the data in the output arrays of a control flow operator is only marked valid by the operators in the subgraph.
  • In the case of ``ifelse’’ ``cond’’ and ``while_loop’’, we need to wait for the condition subgraph to complete before we can proceed with ``if``then_func'' or ``else_func'' or the loop body. Potentially, we can still wait in the main thread for the condition subgraph to complete. However, blocking the main thread prevents multiple-GPU parallelism (MXNet has a single execution thread and uses asynchronous execution for multiple-GPU parallelism). As such, we have to allow the subgraph execution in the threaded engine.


Dynamic shape inference and execution: To execute a computation graph with operators that don't support static shape inference, the first step is to perform graph partitioning. Operators before ``while_loop’’ and the ones after ``while_loop’’ are grouped into separate subgraphs. As such, the original graph is split into multiple subgraphs and we can execute these subgraphs imperatively (just like Gluon does in Python). In this way, we can perform shape inference right before a subgraph is executed. Even though the shape of a subgraph may change in each mini-batch, we can still avoid memory allocation by reusing memory as long as the memory used in the previous execution is large enough.

Memory management: Inference and training tasks require very different memory management strategies. For inference, we only need to allocate memory for an iteration and the final output NDArrays. The memory used by the previous iteration can be reused in the next iteration. However, it's completely a different story for the training. Backward propagation for an iteration requires the intermediate computation results in the forward path. One approach is saving all intermediate computation results in the forward path. If a model requires many iterations (e.g., long sequences) or uses the attention mechanism, we need to allocate astonishing amount of memory, which can be a problem for GPU. To enable training on these workloads, we may have to use small batch sizes. Another possibility is to drop some memory in the forward path, e.g., only saving the outputs of each iteration and recompute the intermediate results in the forward path whenever it's needed. This can save significant amount of memory at the cost of computing the forward path twice. In the case of attention, we should develop a mechanism that systematically identifies the memory generated by cheap computation and drop it in the forward path (Pytorch allows developers to mark these memory manually).

...

  • ``pred’’ is a symbol/NDArray that contains a boolean scalar to define which function to compute.
  • ``then_func’’ and ``else_func’’ are also user-defined Python functions whose signature is defined below.

...

  • A list of symbols/NDArrays returned from one of the user-defined Python functions.


The signature of ``then_func’’ and ``else_func’’ is

...

  • ``input'' is a symbol/NDArray or a list of symbols/NDArrays.
  • ``body’’ is a user-defined Python function that defines computation for each iteration.
  • ``state’’ is a list of symbols/NDArrays passed to ``body’’ as part of the inputs for the first iteration.

...


Input arguments:

  • ``cond’’ is a user-defined Python function that takes ``loop_vars’’ as input and return a boolean scalar symbol/NDArray to determine the termination of the loop.
  • ``func’’ is a user-defined Python function that takes ``loop_vars’’ as input and performs computation of an iteration.
  • ``loop_vars’’ is a list of symbols that represent NDArrays.
  • ``max_iterations’’ is a python scalar or that defines the maximal number of iterations.

...