...
Add a NNVM pass for the backend. This would use the amp lists based on the original_dtype and target_dtype.
This pass will perform graph traversal and add amp_cast and amp_multicast layers for FP16 and FP32 ops based on the op whitelists and excluded_sym_names. Some of the ideas have been borrowed from quantization pass added as part of quantization support [2].
...
Outline of algorithm:
1. Three additional data structures used:
- map from a node to a copy node in the casted graph (mirror_map)
- map from an input entry to the corresponding fp32 casted entry (mirror_entry_fp32_map)
- map from an input entry to the target dtype (e.g. fp16) casted entry (mirror_entry_target_map) (please see below fig for why b and c are needed)
Consider the below script:
...
Without the mirror_entry_target_map there would have been 3 cast nodes instead of 2: 1 for amp_cast float16 and two others going as amp_casted fp32 input to exp0 and sqrt0. Thus, the two additional data structures help optimize and reuse such commonly used cast operators.
2. Visit nodes of the graph in a topologically sorted order and when each node is visited do the following:
- Create a copy node
- Clear inputs of the copy node
-
- If node is a variable:
- Add mapping of the node to the copy node.
- Else if node is not in any whitelist:
- Find all inputs of original node and find corresponding mirror entry and add as inputs to the copy node.
- Else if node's op_name is in the fp32 whitelist:
- Iterate through inputs: If input is already in mirror_entry_fp32_map, add it to the copy node inputs. If not, insert a cast node between copy_node and mirror node of previous node. Add mapping from input node to fp32 cast node in mirror_entry_fp32_map.
- Else if node's op_name is in the target dtype whitelist:
- Iterate through inputs:
- If input is already in mirror_entry_target_map, add it to the copy node inputs.
- If not, insert a cast node between copy_node and mirror_node of previous node. Add mapping from input node to target dtype cast node in mirror_entry_target_map.
- Iterate through inputs:
- Else if node's op_name is in the target dtype whitelist:
- Iterate through inputs:
- If input is already in mirror_entry_target_map, add it to the copy node inputs.
- If not, insert a cast node between copy_node and mirror_node of previous node. Add mapping from input node to target dtype cast node in mirror_entry_target_map.
- Iterate through inputs:
- Else if node's op_name is in the widest_precision_ops whitelist:
- Add amp_multicast between mirror of node's inputs and copy of current node.
- If node is a variable:
- Add mapping from node to copy node.
3. Create a new graph using the copy nodes with the node to copy node mapping.
4. Return the newly created graph.
Example Usage
Code Block |
---|
import mxnet as mx
# Simple demo model
data = mx.sym.var("data")
data2 = mx.sym.var("data2")
data3 = mx.sym.var("data3")
x = mx.sym.exp(data)
x2 = mx.sym.sin(data)
x3 = mx.sym.cos(data)
sym = x + x2 + x3
result = mx.sym.add_n(sym, data2, data3)
casted_result = mx.contrib.amp._convert_symbol(result, target_dtype="float16",
target_dtype_ops=["sin", "cos", "exp"], fp32_ops=["elemwise_add"],
widest_dtype_ops=["add_n"], conditional_fp32_ops=None) |
Before/After Conversion:
View file | ||||
---|---|---|---|---|
|
As you can see above the converted graph has amp_cast, amp_multicast nodes which allow for appropriate casting of inputs.
Frontend Bindings
Need to add amp convert model support for different bindings like C++, Scala etc.
Gluon User Experience Improvement
TO BE ADDED