You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 16 Next »




Problem

There are 2 phases to applying Deep Learning to a ML problem, the first phase is where a neural network is created and trained using training data to generate pre-trained model and In the second phase, this pre-trained model is put to work by running inference(forward-pass) on new data in the customer’s application in production. Model Creation and Training is typically performed by Data Scientists who prefer using Python as a primary language which provides rich set of libraries(numpy, pandas, pillow) etc., to setup the training pipeline. MXNet already has very good support for Python to quickly prototype and develop models.

Inference on the other hand is run and managed by Software Engineers in a production eco-system which is built with tools and frameworks that use Java/Scala as a primary Language.

Inference on a trained model has two different use-cases:

  1. Real time or Online Inference - tasks that require immediate feedback, such as fraud detection
  2. Batch or Offline Inference - tasks that don't require immediate feedback, these are use-cases where you have massive amounts of data and want to run Inference or pre-compute inference results 


Batch Inference is performed on big data platforms such as Spark using Scala or Java while Real time Inference is typically performed and deployed on popular web frameworks such as Tomcat, Netty, Jetty, etc. which use Java. 

With this project, we want to build Java APIs that are easy to use for Inference and lowers the entry barrier is consuming MXNet for production use-cases.

Goals

  1. New Inference Java API
  2. Easy to use (idiomatic) - this likely means optimizing for the Java inference use case and not just copying how the python api works
  3. Limited dependencies 
  4. Full test coverage
  5. Performance - Should be at least as performant as the python API
  6. RNN support
  7. New API should be similar to existing implementations

Proposed Approach

Java wrapper around new Scala API - One approach being considered is using the new Scala API bindings from inside Java. It is already possible to call the Scala bindings from inside Java but the current process is very painful for users to implement. To improve upon this experience, a Java wrapper would be created to call the Scala bindings. The wrapper could be designed so that it abstracts away the complexities of the Java/Scala interaction and is more idiomatic for the Java inferencing use case.

  1. Advantages
    1. Interaction with the native code is already done.
    2. The Scala API is already designed and decided. Implementing a wrapper limits design decisions which needs to be made and keeps the APIs consistent. 
    3. Allows for development continue to be focused on a single JVM implementation which can be utilized by other JVM languages.
    4. The implementation, adding new features, maintenance would be greatly simplified.
  2. Disadvantages
    1. Interaction with the Scala code could be complicated due to differences in the languages. Known issues are forming Scala collections and leveraging default values. Changes to the Scala API or the introduction of a builder function seems like the most obvious solutions here.
    2. Some overhead in converting collections should be expected.

Possible Alternative Approach

Writing a Java API that directly calls the native code - Doing this would be designing and implementing a Java API that will interact with the native code using JNI/JNA. Similar to the other solution, the API would be designed to make Java inferencing simple and idiomatic.

  1. Advantages
    1. No overhead from converting collections.
    2. No surprises from interacting with the Scala code

2. Disadvantages

    1. Duplication of efforts between this and the Scala API (this means reimplementing executor, ndarray, module, etc which is a significant effort).
    2. Will have to reimplement off-heap memory management.
    3. Added design effort to decide the Java API

Known Difficulties

Converting Java collections into Scala collections - Scala and Java use different collections. Generally, these can be converted through the scala.collection.JavaConverters library. Ideally, this will be done automatically on behalf of the user. The Java methods should take Java collections, do the necessary conversion, then call the corresponding Scala method. 

Java doesn’t support methods with default arguments - The current Scala implementation makes liberal use of default arguments. For class instantiation, a simple builder pattern will work. Class methods with default values will likely need to be overloaded.

Performance

Performance should be very similar to Scala. Since both are JVM languages doing inference will be calling the same byte code from Java as it is in Scala. The only known issue which will cause a performance difference is converting the Java collections into Scala collections. Preliminary testing with simple models shows negligible to nonexistent impact to performance. Java performance should be measured via Benchmark AI in a manner similar to how it's measured in Scala.

Class Diagram

Sequence Diagram

Java API Design

The Java interface needs to be a high level interface which supports loading a model and doing forward passes. Additionally, there should be support for RNN models.

LoadModel
/*
 * loadModel
 * Loads a model from a model file into this instance of the class.
 * 
 * Input parameters
 * modelSource - this is the location of the model file. Supports local disk access and S3.
 *
 * Exceptions
 * AccessDeniedException
 * AmazonS3Exception
*/
void loadModel(String modelSource)
Predict
/*
 * predict
 * Takes input that will be fed to the model for forward passes to produce predictions.
 *
 * Input Parameters
 * modelInput - Any type of Java collection (set, queue, lists, etc). This will be fed into the model to get the predictions.
 *
*/
void predict(Collection<E> modelInput)
getPredictions
/*
 * getPredictions
 * Gets the predictions resulting from the last call to predict.
 * 
 * Response
 * Returns the results from the latest call to predict in the form of a List<List<T>>
*/
List<List<T>> getPredictions() 

Example Uses

Example
List<float> doSinglePrediction(){
    MXNetInferenceModel model = new MXNetInferenceModel();
    model.load("/local/path/to/model");
    List<float> singlePrediction = Arrays.asList(3.14159, .5772, 2.71828);
    model.predict(singlePrediction);
    return model.getPredictions().get(0);
}


List<List<float>> doBatchPrediction(){
    MXNetInferenceModel model = new MXNetInferenceModel();
    model.load("/local/path/to/model");
    List<List<float>> batchPrediction = Arrays.asList(
        Arrays.asList(3.14159, .5772, 2.71828),
        Arrays.asList(.5772, 2.71828, 3.14159),
        Arrays.asList(2.71828, 3.14159, .5772));
    model.predict(batchPrediction);
    return model.getPredictions();
}


Open Questions

How to deal with Option[T] field in Java when calling from Scala?

On Java side:

  • Option 1: Create a wrapper class in Java that allow users to use Scala Option field smoothly, something like this.

On Scala side:

  • Option 1: Use null to replace some field defined in Scala to match Java's need
  • Option 2: Build override method in Scala to allow Java user to use them
  • Option 3: Create builder on Scala side that allow Java user by pass the field for optional.


SCALA/JAVA INTEGRATION TIP

 Construct interfaces in Java that define all types that will be passed between Java and Scala. Place these interfaces into a project that can be shared between the Java portions of code and the Scala portions of code. By limiting the features used in the integration points, there won’t be any feature mismatch issues. (referred from "scala-in-depth" page 242)

Glossary

References

  • No labels