Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

Add a NNVM pass for the backend. This would use the amp lists based on the original_dtype and target_dtype.
This pass will perform graph traversal and add amp_cast and amp_multicast layers for FP16 and FP32 ops based on the op whitelists and excluded_sym_names. Some of the ideas have been borrowed from quantization pass added as part of quantization support [2].

...

Outline of algorithm:


1. Three additional data structures used:

  1. map from a node to a copy node in the casted graph (mirror_map)
  2. map from an input entry to the corresponding fp32 casted entry (mirror_entry_fp32_map)
  3. map from an input entry to the target dtype (e.g. fp16) casted entry (mirror_entry_target_map) (please see below fig for why b and c are needed)

Consider the below script:

...

Without the mirror_entry_target_map there would have been 3 cast nodes instead of 2: 1 for amp_cast float16 and two others going as amp_casted fp32 input to exp0 and sqrt0. Thus, the two additional data structures help optimize and reuse such commonly used cast operators.


2. Visit nodes of the graph in a topologically sorted order and when each node is visited do the following:

  1. Create a copy node
  2. Clear inputs of the copy node
  3.  
    1. If node is a variable:
      1. Add mapping of the node to the copy node.
    2. Else if node is not in any whitelist:
      1. Find all inputs of original node and find corresponding mirror entry and add as inputs to the copy node.
    3. Else if node's op_name is in the fp32 whitelist:
      1. Iterate through inputs: If input is already in mirror_entry_fp32_map, add it to the copy node inputs.                                                                         If not, insert a cast node between copy_node and mirror node of previous node.                                                                 Add mapping from input node to fp32 cast node in mirror_entry_fp32_map.
    4. Else if node's op_name is in the target dtype whitelist:
      1. Iterate through inputs:
        1. If input is already in mirror_entry_target_map, add it to the copy node inputs.
        2. If not, insert a cast node between copy_node and mirror_node of previous node. Add mapping from input node to target dtype cast node in mirror_entry_target_map.
    5. Else if node's op_name is in the target dtype whitelist:
      1. Iterate through inputs:
        1. If input is already in mirror_entry_target_map, add it to the copy node inputs.
        2. If not, insert a cast node between copy_node and mirror_node of previous node. Add mapping from input node to target dtype cast node in mirror_entry_target_map.
    6. Else if node's op_name is in the widest_precision_ops whitelist:
      1. Add amp_multicast between mirror of node's inputs and copy of current node.
  4. Add mapping from node to copy node.

3. Create a new graph using the copy nodes with the node to copy node mapping.

4. Return the newly created graph. 

Example Usage


Code Block
import mxnet as mx

# Simple demo model
data = mx.sym.var("data")
data2 = mx.sym.var("data2")
data3 = mx.sym.var("data3")
x = mx.sym.exp(data)
x2 = mx.sym.sin(data)
x3 = mx.sym.cos(data)
sym = x + x2 + x3
result = mx.sym.add_n(sym, data2, data3)
casted_result = mx.contrib.amp._convert_symbol(result, target_dtype="float16",
target_dtype_ops=["sin", "cos", "exp"], fp32_ops=["elemwise_add"],
widest_dtype_ops=["add_n"], conditional_fp32_ops=None)

Before/After Conversion:

                                                                                                                                                                                                                                                                     

View file
namemodel_before.pdf
height250
   Image Added

As you can see above the converted graph has amp_cast, amp_multicast nodes which allow for appropriate casting of inputs.

Frontend Bindings

Need to add amp convert model support for different bindings like C++, Scala etc. 

Gluon User Experience Improvement

TO BE ADDED

References

  1. https://github.com/apache/incubator-mxnet/pull/14173
  2. https://github.com/apache/incubator-mxnet/pull/9552