Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

2. In order to provide the weights from MXNet (NNVM) to the TensorRT graph converter before the symbol is fully bound (before the memory is allocated, etc.), the arg_params and aux_params need to be provided to the symbol's simple_bind method. The weights and other values (e.g. moments learned from data by batch normalization, provided via aux_params) will be provided via the shared_buffer argument to simple_bind as follows:

    executor = sym.simple_bind(ctx=ctx, data = data_shape,
    
softmax_label=sm_shape, grad_req='null', shared_buffer=all_params, force_rebind=True)

3. To collect arg_params and aux_params from the dictionaries loaded by model.load(), we need to combine them into one dictionary:

...

def

...

merge_dicts(*dict_args):

...


    result = {}
    for dictionary in dict_args:
        result.update(dictionary)
        return result


        return result

    sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch)
    all_params = merge_dicts(arg_params, aux_params)

This all_params dictionary can be seen in use in the simple_bind call in #2. 4. Once the symbol is bound, we need to feed the data and run the forward() method. Let's say we're using a test set data iterator called test_iter. We can run inference as follows:

...

for

...

idx,

...

dbatch

...

in

...

enumerate(test_iter):

...


    data = dbatch.data[0]

...


    executor.arg_dict["data"][:]

...

= data
    executor.forward(is_train=False)

...


    preds = executor.outputs[0].asnumpy()

...

 
    top1 = np.argmax(preds,

...

axis=1)

5. Note: One can choose between running inference with and without TensorRT. This can be selected by changing the state of the MXNET_USE_TENSORRT environment variable. Let's first write a convenience function to change the state of this environment variable:

...

Now, assuming that the logic to bind a symbol and run inference in batches of batch_size on dataset dataset is wrapped in the run_inference function, we can do the following:

 print("Running inference in MXNet")
    set_use_tensorrt(False)
    mx_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size=batch_size)

 print("Running inference in MXNet-TensorRT")
    set_use_tensorrt(True)
    trt_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels,  batch_size=batch_size)

...