Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

  • Adapt TVM registry for MXNet PackedFunc registration in backend

  • Adapt TVM function as MXNet PackedFunc wrapper in frontend

  • Extend TVMValue for customized and efficient MXNet argument passing.

Example

A demo op np.zeros1zeros with the new FFI interface is implemented as an example, whose behavior is identical to np. zeros with the new FFI interface. Files mentioned in the example can be found in the POC mentioned above. It will help to read both of the frontend and backend registration if one wants to register some other ops.

Firstly, in frontend, np.zeros1 is registered in python/mxnet/ndarray/numpy/_op.py as usual:

Code Block
languagepy
titleFront End
@set_module('mxnet.ndarray.numpy')
def zeros1zeros(shape, dtype=None, order='C', ctx=None):  # pylint: disable=redefined-outer-name
    if order != 'C':
        raise NotImplementedError
    # If the following code (4 lines) regarding ctx is removed
    # np.zeros((3, 4)) can be as fast as 4.96 us
    if ctx is None:
        ctx = str(current_context())
    else:
        ctx = str(ctx)
    if dtype is not None and not isinstance(dtype, str):
        dtype = _np.dtype(dtype).name
    return _npiapi_internal.zeros1zeros(shape, dtype, ctx)



Secondly, in backend, _npi.zeros1 is registered in src/api/api_npi/operator/numpy/np_init_op.cc, as follows.

Code Block
languagecpp
titleBackend
MXNET_REGISTER_API("_npi.zeros1zeros")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
  // part1Part1: populate attrs NodeAttrs
  using namespace runtime;
  const static nnvm::Op* op = Op::Get("_npi_zeros");
  nnvm::NodeAttrs attrs;
  op::InitOpParam param;
  if (args[0].type_code() == kDLInt) {
    param.shape = TShape(1, args[0].operator int64_t());
  } else {
    param.shape = TShape(args[0].operator ObjectRef());
  }
  if (args[1].type_code() == kNull) {
    param.dtype = mshadow::kFloat32;
  } else {
    param.dtype = runtime::String2MXNetTypeWithBool(args[1].operator std::string());
  }
  attrs.parsed = std::move(param);
  attrs.op = op;
  if (args[2].type_code() != kNull) {
    attrs.dict["ctx"] = args[2].operator std::string();
  }
  // part2: invoke
  int num_outputs = 0;
  auto ndoutputs = InvokeInvoke<op::InitOpParam>(op, &attrs, 0, nullptr, &num_outputs, nullptr);
  *ret = ndoutputs[0];
});



The first part of npi.zeros1 is zeros is mainly responsible for populating NodeAttrs, and the second part invokes the op in an imperative way, which is similar to the api MXImperativeInvokeImpl.

As is seen, _npi.zeros1 is zeros is not directly exposed to the dll boundary but can be invoked in the frontend. If we go deeper into operator TShape(), we will see that TShape is initialized from ADTObj, which is like an array of integers in this case. As it works without (de)serialization between string and int, some speedup is obtained here.

Also, here are a few things to note:

  • In the original np.zeros, the default value for dtype is specified in front end, while in in np.zeros1zeros with the new FFI interface as is shown above, it is specified in backend, which slightly speeds it up. Generally I think it may be better to move code to backend, as python is sometimes slow.
  • ctx is still getting (de)serialized. Joint effort with engine overhead optimization may be required, as ctx gets deserialized in engine.

...

  1. python/mxnet/ndarray/numpy/_op.py: def zeros1zeros. The python entry.
  2. python/mxnet/_ffi/function.py: class Function. The front end wrapper to hide the details about ctypes and cython.
  3. python/mxnet/_ffi/_cython/function.pxi: cdef class FunctionBase. The cython entry.
  4. python/mxnet/_ffi/_cython/function.pxi: def __call__. Here we call make_ret to convert MXNetValue into python object.
  5. python/mxnet/_ffi/_cython/function.pxi: cdef inline int FuncCall. Here we call make_arg to convert python object into MXNetValue.
  6. src/runtime/c_runtime_api.cc: int MXNetFuncCall. The cpp entry. It is an API exposed to dll. It wraps the abi-compatible c struct MXNetValue into easy-to-use cpp class MXNetArgs and MXNetRetValue.
  7. include/mxnet/runtime/packed_func.h: inline void PackedFunc::CallPacked.
  8. src/api/api_npi/operator/numpy/np_init_op.cc: MXNET_REGISTER_API(“_npi.zeros1”zeros”). Here we arrive at the user-defined backend function.

...