Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

Later when a CustomOp operator is bound/executed the functions from the shared library are executed. During the bind step, the attributes for the operator are analyzed by the customOp's parseAttrs function in the shared library. For type and shape inference, the respective functions are also called through the inferType and inferShape APIs. Lastly, when executing the forward pass, the Forward function is called for the operator from the shared library.


New MXNet APIs

These are new APIs that

...

are added to MXNet

...



C APIs

  • MXLoadCustomOpLib MXLoadLib - API to load operator libraries
    • Checks version number
    • Calls initialize on
    • Load the customOp library
    • Go through the operators in the library
    • Check that each operator defines required functions
      • ParseAttrs, InferType, InferShape, FCompute
    • Register each operator found

Python APIs

  • load_op_lib - API to load operator libraries
    • Takes a path to the operator library
    • checks if the path exists and if points to file
    • calls C API MXLoadCustomOpLib MXLoadLib to perform actual loading
    • mx.operator.load_op_lib('/path/to/libtest.so')

New CustomOp Operator

  • CustomOp - new operator that executes custom operators loaded from the library
    • Takes op_type to identify custom operator name
    • Takes any number of kwargs as attributes/parameters
    • Takes any number of in-order args as input arrays
    • b = mx.nd.CustomOp(a,op_type='sam',myParam='2')

APIs for implementing Custom Operators

...

  • parseAttrs - takes a set of key/value pairs for attributes and gives users an opportunity to validate the attributes passed to their custom operator.
    • int parseAttrs(std::map<std::string, std::string> attrs, 
      int* num_in,
      int* num_out);
    • Inputs: the map of attributes passed to the operator from the user
    • Outputs: num_in, num_out - the number of input/output arrays required for this operator
    • returns 1 if success, or zero if failure
  • inferType - performs type inference for this operator
    • int inferType(std::map<std::string, std::string> attrs, 
      std::vector<int> &intypes,
      std::vector<int> &outtypes);
    • Inputs: the map of attributes
    • Inputs/Outputs: intypes, outtypes - the list of input/output types that should be inferred. Values of of -1 should be defined by this operator as a specific type
    • returns 1 if success, or zero if failure
  • inferShape - performs shape inference for this operator
    • int inferShape(std::map<std::string, std::string> attrs, 
      std::vector<std::vector<unsigned int>> &inshapes,
      std::vector<std::vector<unsigned int>> &outshapes);
    • Inputs: the map of attributes
    • Inputs: inshapes - the shapes of the input arrays
    • Outputs: outshapes - the shapes of output arrays
  • fcompute forward - performs computation forward pass of this operator
    • int myFComputeforward(std::map<std::string, std::string> attrs, 
      std::vector<MXTensor> inputs,
      std::vector<MXTensor> outputs,
      OpResource res);
    • Inputs: the map of attributes
    • Input data: inputs, input tensors
    • Output data: outputs, output tensors

...

  • REGISTER_OP - registers the operator in the library
    • REGISTER_OP(sam)
      .setFCompute_cpusetForward(myFCompute)
      .setParseAttrs(parseAttrs)
      .setInferType(inferType)
      .setInferShape(inferShape);
    • REGISTER_OP - macro that defines an custom operator object with given name
    • setFCompute_cpu - sets the FCompute function for CPU context
    • setFCompute_gpu setForward - sets the FCompute function for GPU context
    • setParseAttrs - sets the parse attributes function
    • setInferType - sets the infer types function
    • setInferShape - sets the infer shapes function

Goals/Usecases

MXNet Java Inference API#Goals

Open Questions

Proposed Approach

Example Custom Operators

Examples of creating custom operators, building them into a library, and loading them at runtime to test them can be found here:

https://github.com/apache/incubator-mxnet/tree/master/example/extensions/lib_custom_op

The GEMM example contains two operators. The state-less operator shows a regular operator here: 

Initial PoC in this branch: https://github.com/samskalickyapache/incubator-mxnet/tree//blob/master/example/extensions/lib_custom_op/

MXNet Java Inference API#ProposedApproach

MXNet Java Inference API#ClassDiagram

MXNet Java Inference API#SequenceDiagram

Addition of New APIs

Backward compatibility

Performance Considerations

Test Plan

Alternative Approaches

MXNet Scala API Usability Improvement#AlternativeApproachconsidered

Technical Challenges 

MXNet Scala API Usability Improvement#TechnicalChallenges

Milestones

...

gemm_lib.cc#L169-L174

The example GEMM stateful operator is here:

https://github.com/apache/incubator-mxnet/blob/master/example/extensions/lib_custom_op/gemm_lib.cc#L220-L225