Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

This serialization/deserialization module should have following methods supporting different formats:

 

Import model into mxnet.

sym, params = mx.serde.import(input_file, ininput_format=‘onnx’, outoutput_format=’gluon’)

  • input_file : input model file (e.g., protobuf model file for onnx)
  • ininput_format : (optional) onnx, coreml
  • outoutput_format : (optional) “gluon/symbolic”. by default, gluon will be used.

...

Note: Currently gluon does not provide an easy way to import a pre-trained model. (there is a workaround using which this can be done). 

 

Export mxnet model to specified input format.

1) mx.serde.export(sym, params, input_format='symbolic', output_format=‘onnx’, filename_prefix=”model_name”)

2) mx.serde.export(module, input_format='symbolic', output_format=‘onnx’, filename_prefix=”model_name”)

3) mx.serde.export(gluon_model, input_format='gluon', output_format=‘onnx’, filename_prefix=”model_name”)

  • sym : model definition
  • module : mxnet module object
  • gluon_model : model definition (HybridBlock)
  • params : weights
  • input_format : symbolic/gluon 
  • output_format : onnx, coreml
  • filename_prefix: a filename prefix to be used to model save files. E.g., for onnx, a binary protobuf will be written to output file with “.onnx” extension.

...

For example,

  •  Gluon

mx.gluon.import_from(input_file), input_format=‘onnx’)

mx.gluon.export(output_to(format=’onnx’, filename_prefix=”model_name”)

If any pre-processing/post-processing logic is required for gluon specific models, it can go under this gluon wrapper. These functions will internally call `serde` module APIs.

e.g. `mx.gluon.import(input_file, input_format=‘onnx’)` will internally call `mx.serde.import(input_file, input_format=‘onnx’, output_format='gluon')`

 

  •  For Symbolic interface

sym, params = mx.mod.Module.import_from(input_file, ininput_format=‘onnx’)

mx.mod.Module.export(output_to(format=‘onnx’, filename_prefix=”model_name”)

 

This function will directly save the file called “model_name.onnx” on disk.

Implementation Implementation approaches

There are two approaches which can be taken to import/export onnx model.

...

The whole implementation will go under MXNet repo.

 

def import(..., input_format='onnx'):
    # returns mxnet graph with parameters
    ...
    return sym, params

def export(..., output_format='onnx'):
  # returns ONNX protobuf object
  ...
  return onnx_proto

...

import nnvm.frontend

def import(..., input_format='onnx'):
  # convert from onnx to nnvm graph
  nnvm_graph, params = nnvm.frontend.from_onnx(...) # Exists
  # convert fron nnvm graph to mxnet graph
  mxnet_graph, params = nnvm.frontend.to_mxnet(...) # Need to implement
  return mxnet_graph, params

def export(..., output_format='onnx'):
  # convert from mxnet to nnvm graph
  nnvm_graph, params = nnvm.frontend.from_mxnet(...) # Exists
  # convert fron nnvm graph to onnx proto format
  onnx_proto = nnvm.frontend.to_onnx(...) # Need to implement
  return onnx_proto

 

 

 

Suggested approach:

As a middle ground for both of the above implementation choices, I propose to take the first approach and implement MXNet->ONNX conversion for export functionality and if someone wants to take advantage of NNVM/TVM optimized engine for their usage, they can do it by leveraging import functionality provided in NNVM/TVM package.

...

data_iter = data_process(mnist)
sym = construct_model_def(...)
# create module for training
mod = mx.mod.Module(symbol=sym, ...)
# train model
mod.fit(data_iter, optimizer='...', optimizer_params='...', ...)
# get parameters of trained model
params = mod.get_params()

# save into different format using sym/params

# This will internally call serde package API `serde.export`
mx.mod.serdeModule.export(sym, params, output_format=‘onnx’, filename=”model.onnx”)

# OR save into different format using module directly

mx.serdemod.Module.export(mod, output_format=‘onnx’, filename=”model.onnx”)

 

...

The general structure of onnx import would look something like this:

# Import model

# This will internally call serde package API `serde.import`

sym, params = mxnet.mod.serdeModule.import(input_file, ininput_format=‘onnx’)

# create module for inference
mod = mx.mod.Module(symbol=sym, data_names=..., context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=..., label_shapes=None)

# set parameters
mod.set_params(arg_params=params, aux_params=None, allow_missing=True)

# forward on the provided data batch
mod.forward(Batch([data_batch]))

...