...
draw.io Diagram |
---|
border | true |
---|
viewerToolbar | true |
---|
fitWindow | false |
---|
diagramName | Java API Sequence Diagram |
---|
simpleViewer | false |
---|
diagramWidth | 701 |
---|
revision | 3 |
---|
|
Java Inference API Design for Predictor Class
The Java Inference API will be a wrapper around the high level Scala Inference interface. Here is an example of what the Java wrapper will look like for the Scala inference Predictor class.
Code Block |
---|
language | java |
---|
title | Predictor |
---|
|
/**
* Implementation of prediction routines.
*
* @param modelPathPrefix Path prefix from where to load the model artifacts.
* These include the symbol, parameters, and synset.txt
* Example: file://model-dir/resnet-152 (containing
* resnet-152-symbol.json, resnet-152-0000.params, and synset.txt).
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and type parameters
* <p>Note: If the input Descriptors is missing batchSize
* ('N' in layout), a batchSize of 1 is assumed for the model.
* @param contexts Device contexts on which you want to run inference; defaults to CPU
* @param epoch Model epoch to load; defaults to 0
*/
Predictor(String modelPathPrefix, List<DataDesc> inputDescriptors,
List<Context> Contexts, int epoch)
/**
* Predict using NDArray as input
* This method is useful when the input is a batch of data
* Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
*
* @param inputBatch List of NDArrays
* @return Output of predictions as NDArrays
*/
List <NDArray> predictWithNDArray(List <NDArray> inputBatch) |
Java Inference API usage
A primary goal of the Java Inference API is to provide a simple means for Java users to load and do inference on an existing model. Ideally, this will typically be as simple as defining the context (cpu vs gpu) to be used, defining what the input will look like, and setting up the model that will be used. After setting up the model like this, it should be simple to do input on the model.
Code Block |
---|
|
/*
* Psudeocode for how ObjectDetector Class can be used to do SSD detection
* A full working SSD example will be included in the release.
*/
// Set the context to be used
List<Context> context = new ArrayList<Context>();
context.add(Context.cpu());
// Define the shape and data type of the input
Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
// Instantiate the object detector with the model, input descriptors, context, and epoch
JavaObjectDetector objDetector = new JavaObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
// Load an image and run inference on it
BufferedImage img = JavaImageClassifier.loadImageFromFile(inputImagePath);
objDetector.imageObjectDetect(img, 3); |
Open Questions
How to deal with Option[T] field in Java when calling from Scala?
...