Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

Code Block
languagepy
# Trained network
net = mx.gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=mx.cpu())

# Data transformations applicable during inference
inference_input_transforms = gluon.nn.HybridSequential()
inference_input_transforms.add(transforms.Resize((224, 224)))
inference_input_transforms.add(transforms.ToTensor())
inference_input_transforms.add(transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))


# Export the model. Cannot export data transformation and input/output signature
net.export(path="./my_model", epoch=0)

...

Code Block
languagepy
sym, arg_params, aux_params = mx.model.load_checkpoint('my_model', 0)
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))], 
         label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True)

mod.forward(...)

Import the model in Java Predictor API

Code Block
languagejava
Shape inputShape = new Shape(new int[] {1,3,224,224});
DataDesc inputDescriptor = new DataDesc("data", inputShape, DType.Float32(), "NCHW"); 
List<DataDesc> inputDescList = new ArrayList<DataDesc>();
inputDescList.add(inputDescriptor);
List<Context> context = new ArrayList<>();
context.add(Context.cpu()); 
String modelPathPrefix = "path-to-model";
Predictor predictor = new Predictor(modelPathPrefix, inputDescList, context);

List<NDArray> result = predictor.predictWithNDArray(inputNDArray);

After

Step 1 - Train and export the model from Gluon

Code Block
languagepy
# Trained network
net = mx.gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=mx.cpu())

# Data transformations applicable during inference
inference_input_transforms = gluon.nn.HybridSequential()
inference_input_transforms.add(transforms.Resize((224, 224)))
inference_input_transforms.add(transforms.ToTensor())
inference_input_transforms.add(transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))


# Export the model
gluon.contrib.utils.export(net, path="./my_model", 
                           epoch=0, 
                           signature={constants.INPUT_DESC:[("data", (1,3,224,224))],
                                      constants.OUTPUT_DESC:[("softmax_label", (1,10))]},
                           input_transforms={"data":inference_input_transforms},
                           output_transforms=None)

...

Code Block
languagepy
gluon.contrib.utils.import(symbol_file="my_model-symbol.json",
                           param_file="my_model-0000.params",
                           load_transforms = True,
                           ctx = 'cpu')

...

Import the Model in Module API

...

Code Block
languagepy
mod = mx.contrib.Module.import(
                symbol_file = "my_model-symbol.json",
                param_file = "my_model-0000.params",
                load_transforms = True,
                ctx = 'cpu',
                batch_size = 1)
mod.forward(...) 

...

Import the Model in Java Predictor API

Code Block
languagejava
List<Context> context = new ArrayList<>();
context.add(Context.cpu()); 
String modelPathPrefix = "my_model";
Predictor predictor = new Predictor(modelPathPrefix, context, load_transforms=True);

List<NDArray> result = predictor.predictWithNDArray(inputNDArray);

...

Before

Code Block
languagepy
SymbolBlock.imports(symbol_file="my_model-symbol.json",
                    input_names=["data"],
                    param_file="my_model-0000.params",
                    ctx='cpu')

After

Code Block
languagepy
gluon.contrib.utils.import(symbol_file="my_model-symbol.json",
                           param_file="my_model-0000.params",
                           load_transforms = True,
                           ctx = 'cpu')

Import API - Module

Before

After
Supported to create a module for inference only.

Code Block
languagepy
mod = mx.contrib.Module.import(
                symbol_file = "my_model-symbol.json",
                param_file = "my_model-0000.params",
                load_transforms = True,
                ctx = 'cpu',
                batch_size = 1)
mod.forward(...) 

Import API - Java predictor

Before

Code Block
languagejava
Shape inputShape = new Shape(new int[] {1,3,224,224});
DataDesc inputDescriptor = new DataDesc("data", inputShape, DType.Float32(), "NCHW"); 
List<DataDesc> inputDescList = new ArrayList<DataDesc>();
inputDescList.add(inputDescriptor);
List<Context> context = new ArrayList<>();
context.add(Context.cpu()); 
String modelPathPrefix = "path-to-model";
Predictor predictor = new Predictor(modelPathPrefix, inputDescList, context);

List<NDArray> result = predictor.predictWithNDArray(inputNDArray);

...

Code Block
languagejava
List<Context> context = new ArrayList<>();
context.add(Context.cpu()); 
String modelPathPrefix = "my_model";
Predictor predictor = new Predictor(modelPathPrefix, context, load_transforms=True);

List<NDArray> result = predictor.predictWithNDArray(inputNDArray);

...