You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 12 Next »




Problem

There are 2 phases to applying Deep Learning to a ML problem, the first phase is where a neural network is created and trained using training data to generate pre-trained model and In the second phase, this pre-trained model is put to work by running inference(forward-pass) on new data in the customer’s application in production. Model Creation and Training is typically performed by Data Scientists who prefer using Python as a primary language which provides rich set of libraries(numpy, pandas, pillow) etc., to setup the training pipeline. MXNet already has very good support for Python to quickly prototype and develop models.

Inference on the other hand is run and managed by Software Engineers in a production eco-system which is built with tools and frameworks that use Java/Scala as a primary Language.

Inference on a trained model has two different use-cases:

  1. Real time or Online Inference - tasks that require immediate feedback, such as fraud detection
  2. Batch or Offline Inference - tasks that don't require immediate feedback, these are use-cases where you have massive amounts of data and want to run Inference or pre-compute inference results 


Batch Inference is performed on big data platforms such as Spark using Scala or Java while Real time Inference is typically performed and deployed on popular web frameworks such as Tomcat, Netty, Jetty, etc. which use Java. 

With this project, we want to build Java APIs that are easy to use for Inference and lowers the entry barrier is consuming MXNet for production use-cases.

Goals

  1. New Inference Java API
  2. Easy to use (idiomatic) - this likely means optimizing for the Java inference use case and not just copying how the python api works
  3. Limited dependencies 
  4. Full test coverage
  5. Performance - Should be at least as performant as the python API

Proposed Approach

Currently there are two approaches being considered/researched:

  1. Java wrapper around new Scala API - One approach being considered is using the new Scala API bindings from inside Java. It is already possible to call the Scala bindings from inside Java but the current process would be very painful for users to implement. To improve upon this experience, a Java wrapper would be created to call the Scala bindings. The wrapper could be designed so that it abstracts away the complexities of the Java/Scala interaction and is more idiomatic for the Java inferencing use case. 
    1. Advantages
      1. Interaction with the native code is already done. The implementation and maintenance should be simplified due to this. 
    2. Disadvantages
      1. At a minimum this would introduce a dependency upon the Scala compiler.
      2. Interaction with the Scala code could be complicated due to differences in the languages. Known issues are forming Scala collections and leveraging default values. Changes to the Scala API or the introduction of a builder function seems like the most obvious solutions here.
  2. Writing a Java API that directly calls the native code - Doing this would be designing and implementing a Java API that will interact with the native code using JNI/JNA. Similar to the other solution, the API would be designed to make Java inferencing simple and idiomatic.
    1. Advantages
      1. No added dependencies
      2. No surprises from interacting with the Scala code
    2. Disadvantages
      1. It is difficult to write and test JNI code.
      2. Possible duplication of efforts between this and the Scala API

Class Diagram

Sequence Diagram

Java API Design

The Java interface needs to be a high level interface which supports loading a model and doing forward passes. Additionally, there should be support for RNN models.

LoadModel
/*
 * loadModel
 * Loads a model from a model file into this instance of the class.
 * 
 * Input parameters
 * modelSource - this is the location of the model file. Supports local disk access and S3.
 *
 * Exceptions
 * AccessDeniedException
 * AmazonS3Exception
*/
void loadModel(String modelSource)
Predict
/*
 * predict
 * Takes input that will be fed to the model for forward passes to produce predictions.
 *
 * Input Parameters
 * modelInput - Any type of Java collection (set, queue, lists, etc). This will be fed into the model to get the predictions.
 *
*/
void predict(Collection<E> modelInput)
getPredictions
/*
 * getPredictions
 * Gets the predictions resulting from the last call to predict.
 * 
 * Response
 * Returns the results from the latest call to predict in the form of a List<List<T>>
*/
List<List<T>> getPredictions() 

Example Uses

Example
List<float> doSinglePrediction(){
    MXNetInferenceModel model = new MXNetInferenceModel();
    model.load("/local/path/to/model");
    List<float> singlePrediction = Arrays.asList(3.14159, .5772, 2.71828);
    model.predict(singlePrediction);
    return model.getPredictions().get(0);
}


List<List<float>> doBatchPrediction(){
    MXNetInferenceModel model = new MXNetInferenceModel();
    model.load("/local/path/to/model");
    List<List<float>> batchPrediction = Arrays.asList(
        Arrays.asList(3.14159, .5772, 2.71828),
        Arrays.asList(.5772, 2.71828, 3.14159),
        Arrays.asList(2.71828, 3.14159, .5772));
    model.predict(batchPrediction);
    return model.getPredictions();
}



Alternative Approaches considered

 

Technical Challenges

 

Open Questions

Milestones

 

Glossary

References

  • No labels