THIS IS A TEST INSTANCE. ALL YOUR CHANGES WILL BE LOST!!!!
...
Since we need both arg_params (weights) and aux_params (e.g. BatchNorm moments), we need to merge arg_params and aux_params into one dictionary. Here’s a Python example:
Code Block | ||
---|---|---|
| ||
def merge_dicts(*dict_args): |
...
"""Merge arg_params and aux_params to populate shared_buffer""" |
...
result = {}
...
result = {} for dictionary in dict_args: |
...
result.update(dictionary) |
...
return result
Now let’s see a use example:
...
return result Now let’s see a use example: device = mx.gpu(0) |
...
sym, arg_params, aux_params |
...
= mx.model.load_checkpoint(model_name, num_epochs) |
...
executor = sym.simple_bind(ctx=device, |
...
data=data_shape, |
...
softmax_label=(batch_size,), |
...
shared_buffer=merge_dicts(arg_params, aux_params),, |
...
grad_req='null', |
...
force_rebind=True) |
...
Now we can simply update data in the executor’s arg dict and run the forward pass: |
...
executor.arg_dict["data"][:] = my_data_batch |
...
executor.forward(is_train=False) |
...
predictions = executor.outputs[0].asnumpy() |
Limitations and future work:
...