You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 3 Next »

Enable MXNet to compute second (or higher) order gradients for a list of operators using autograd package. 

Link to dev List discussion

https://lists.apache.org/thread.html/464712f0136fb51916ca9f1b702b99847e108dbdbd0b6a2b73fc91f1@%3Cdev.mxnet.apache.org%3E

Feature Shepherd

Lin Yuan

Problem


Currently only a very limited number of operators (such as exp) that support second or higher order gradient calculation. For other operators, if users try to get the second order gradient of an operator, MXNet would issue an error message such as "mxnet.base.MXNetError: [23:15:34] src/pass/gradient.cc:192: Operator _backward_XXX is non-differentiable because it didn't register FGradient attribute." This is because MXNet backend does not implement the FGradient function for the backward node of the operator and therefore cannot support second (and higher) order of gradient.

However, higher order gradient calculation for certain operators have many applications such as adaptive learning rate optimization, network architecture search, W-GAN network and etc. Implementing higher order gradient can unlock these applications and improve the usability and popularity of Apache MXNet framework.

User Experience

We will support second order gradient calculation in the autograd package as shown in the example below:

import mxnet.ndarray as nd
from mxnet import autograd

x = nd.array([1, 2, 3])
x.attach_grad()
with autograd.record():
y = nd.sin(x)
# y_grad is first order gradient of y and should be cos(x)
y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0]
# this call should calculate the second order of y w.r.t x which should be -sin(x)
y_grad.backward()
print(x.grad) # Should be -sin(x)

Goals/Usecases

MXNet Java Inference API#Goals

Open Questions

Proposed Approach

MXNet Java Inference API#ProposedApproach

MXNet Java Inference API#ClassDiagram

MXNet Java Inference API#SequenceDiagram

Addition of New APIs

Backward compatibility

Performance Considerations

Test Plan

Alternative Approaches

MXNet Scala API Usability Improvement#AlternativeApproachconsidered

Technical Challenges 

MXNet Scala API Usability Improvement#TechnicalChallenges

Milestones

References

  • No labels