Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

With this project, we want to build a new set of APIs which are Java friendly, compatible with Java 87+, are easy to use for inference, and lowers the entry barrier of consuming MXNet for production use-cases.

...

As a user, I’d like to have a Java Inference API that allows me to use deep learning models from my existing Java application.

As a user, I'd like for the new Java Inference API to be thread safe.

As a user, I’d like for the new Java inference API to be idiomatic and easy to use so that I can quickly learn to deploy models.

...

As a user already familiar with MXNet, I’d like for the new API to be similar to existing implementations so that it’s easy for me to use.

As a user, I'd like to have examples and tutorials available to help learn how to use the new Java Inference API.

Proposed Approach

The proposed implementation for the new Java API is to create a Java friendly wrapper around the existing Scala API. The Scala API is already fully implemented and is undergoing significant improvements (most notably simplifying the memory management of off-heap memory). By utilizing the existing Scala API, the development effort require for the new Java API is greatly decreased. Additionally, the Java API would automatically (or with minimal work) benefit from new features and code improvements allowing for development efforts to remain focused. This is a very similar approach to how Apache Spark developed their Java API.

Since both Java and Scala are JVM languages, it is already possible for the Scala bindings to be called from Java code by loading the jar into the classpath. Due to differences in the languages, this process is currently very painful for users to implement. Most notably, the difficulty comes from the liberal use of default values in the Scala code being unsupported by Java and converting between Java/Scala collections. To improve upon this experience, a Java wrapper would be created which will call the Scala bindings. The wrapper would be designed so that it abstracts away the complexities of the Java/Scala interaction and is by automating the conversions, simplifying the method calls, and making the API more idiomatic for the Java inferencing use case.

  1. Advantages
    • Fastest time to market requiring the least amount of engineering effort.
    • Interaction with the native code is already done.
    • The Scala API is already designed and decided. Implementing a wrapper limits design decisions which needs to be made and keeps the APIs consistent. 
    • Allows for development continue to be focused on a single JVM implementation which can be utilized by other JVM languages.
    • The implementation, adding new features, maintenance would be greatly simplified.
    • Implementation is not one way. In the future we maintain the ability to walk this decision back and go with another implementation.
  2. Disadvantages
    • Interaction with the Scala code could be complicated due to differences in the languages. Known issues are forming Scala collections and leveraging default values. Changes to the Scala API or the introduction of a builder function seems like the most obvious solutions here.
    • Some overhead in converting collections should be expected.
    • The JAR files will be larger than they would be without Scala in the middle. Theoretically, this could be an issue for some memory constrained edge devices.

...

Limited by existing Scala Inference API - The current Scala Inference API is lacking support for some models such as RNNs. Since this API will be utilized by the new Java Inference API, it will be necessary to improve and expand the Scala Inference API. This work can be done in parallel and should undergo it’s own design process. On the plus side this will serve as a forcing function to improve the Scala API.

Performance

Performance should be very similar to Scala. Since both are JVM languages doing inference will be calling the same byte code from Java as it is in Scala. The only known issue which will cause a performance difference is converting the Java collections into Scala collections. Preliminary testing with simple models shows negligible to nonexistent impact to performance. Java performance should be measured via Benchmark Scripts in a manner similar to how it's measured in Scala. More details on Scala benchmarks are available here.

...

The new Java inference API can be distributed alongside the existing Scala API. Currently, the Scala API is distributed via a jar file using a Maven repository. There is ongoing work to automate this process and ideally this work will include the new Java API as well. The design for the Automated Scala Release is available here. Releases for the Java Inference API will be aligned with the MXNet release schedule and follow the same versioning.

Improving Scala Inference API

...

Known improvements to could made to the Scala API include:

  • Support for RNNs (Scala 
  • Adding domain specific use cases
  • Improving interface of existing APIs (for example, it should be possible to do batch inference using just an NDArray)

Existing Scala Infer API Class Diagram

draw.io Diagram
bordertrue
viewerToolbartrue
fitWindowfalse
diagramNameJavaAPI Class Diagram
simpleViewerfalse
diagramWidth811
revision5

...

Code Block
languagejava
titlePredictor
/**
 * Implementation of prediction routines.
 *
 * @param modelPathPrefix     Path prefix from where to load the model artifacts.
 *                            These include the symbol, parameters, and synset.txt
 *                            Example: file://model-dir/resnet-152 (containing
 *                            resnet-152-symbol.json, resnet-152-0000.params, and synset.txt).
 * @param inputDescriptors    Descriptors defining the input node names, shape,
 *                            layout and type parameters
 *                            <p>Note: If the input Descriptors is missing batchSize
 *                            ('N' in layout), a batchSize of 1 is assumed for the model.
 * @param contexts            Device contexts on which you want to run inference; defaults to CPU
 * @param epoch               Model epoch to load; defaults to 0

 */
Predictor(String modelPathPrefix, List<DataDesc> inputDescriptors,
                List<Context> Contexts, int epoch)

/**
 * Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference
 * The array will be reshaped based on the input descriptors.
 *
 * @param input:            A List of a one-dimensional array.
                            A List is needed when the model has more than one input.
 * @return                  Indexed sequence array of outputs
 */
List <List <Float>> predict(List <List <Float>> input)

/**
 * Predict using NDArray as input
 * This method is useful when the input is a batch of data
 * Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
 *
 * @param inputBatch        List of NDArrays
 * @return                  Output of predictions as NDArrays
 */
List <NDArray> predictWithNDArray(List <NDArray> inputBatch)

...