Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

4. Return the newly created graph. 

Please take a look at the PoC[3] for more details on the NNVM pass.

Example Usage


Code Block
import mxnet as mx

# Simple demo model
data = mx.sym.var("data")
data2 = mx.sym.var("data2")
data3 = mx.sym.var("data3")
x = mx.sym.exp(data)
x2 = mx.sym.sin(data)
x3 = mx.sym.cos(data)
sym = x + x2 + x3
result = mx.sym.add_n(sym, data2, data3)
casted_result = mx.contrib.amp._convert_symbol(result, target_dtype="float16",
target_dtype_ops=["sin", "cos", "exp"], fp32_ops=["elemwise_add"],
widest_dtype_ops=["add_n"], conditional_fp32_ops=None)

...

For Gluon code, we need to add an internal API to retrieve sym, arg_params and aux_params from a hybrid_block. Following this, convert_model can be used to convert a symbol json, model params and auxiliary params. After conversion, the symbolic model (json, arg_params, aux_params) can be imported back into gluon with SymbolBlock.imports. The returned symbolblock is ready to use for inference.

Frontend Bindings

Need to add amp convert_model API support for different bindings like C++, Scala etc. 

...

Casting inputs to FP16 and params to FP16 for gluon ensures that you are able to execute the model in FP16 precision. Generally, there may be some ops which may need to run in FP16 while other in FP32 for accuracy and performance considerations. This is where the AMP APIs will be useful. 

Will the dtype attribute in the serialized model change after convert_model is called ?

Yes dtype attribute in the serialized model can change after convert_model is called. This depends on how the whitelist affects the model in question and if the type inference decides that certain params needs to be in flaot16float16.

Is there a need for hybridizing and running a forward pass for the AMP converted gluon model ?

...

  1. https://github.com/apache/incubator-mxnet/pull/14173
  2. https://github.com/apache/incubator-mxnet/pull/9552
  3. https://github.com/apache/incubator-mxnet/pull/14702