Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

Code Block
languagecpp
titleBackend
MXNET_REGISTER_API("_npi.zeros")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
   // Part1: populate NodeAttrs
   using namespace runtime;
   const nnvm::Op* op = Op::Get("_npi_zeros");
   nnvm::NodeAttrs attrs;
   op::InitOpParam param;
   if (args[0].type_code() == kDLInt) {
      param.shape = TShape(1, args[0].operator int64_t());
   } else {
      param.shape = TShape(args[0].operator ObjectRef());
   }
   if (args[1].type_code() == kNull) {
      param.dtype = mshadow::kFloat32;
   } else {
      param.dtype = String2MXNetTypeWithBool(args[1].operator std::string());
   }
   attrs.parsed = std::move(param);
   attrs.op = op;
  SetAttrDict<op::InitOpParam>(&attrs);
  if (args[2].type_code() != kNull) {
      attrs.dict["ctx"] = args[2].operator std::string();
   }
    // part2Part2: invoke
   int num_outputs = 0;
   auto ndoutputs = Invoke<op::InitOpParam>Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr);
   *ret = ndoutputs[0];
});



The first part of npi.zeros is mainly responsible for populating NodeAttrs, and the second part invokes the op in an imperative way, which is similar to the api MXImperativeInvokeImpl.

As is seen, _npi.zeros is not directly exposed to the dll boundary but can be invoked in the frontend. If we go deeper into operator TShape(), we will see that TShape is initialized from ADTObj, which is like an array of integers in this case. As it works without (de)serialization between string and int, some speedup is obtained here.

Also, here are a few things to note:

...