Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

Table of Contents


Note: extended design for callbcaks here

Problem and Goals

Background

Training a model in Gluon requires users to write the training loop, this is useful because of its imperative nature, however repeating the same code across multiple models can become tedious and repetitive with boilerplate code. The training loop can also be overwhelming to some users new to deep learning.
Users have asked for a simple Fit API, similar to APIs available in SKLearn and Keras (example forum ask) as a way to simplify model training and reduce boilerplate code and complexity.

...

  • Introduce a new Gluon “Fit” API that eliminates the need to code a training loop for simple model use cases, thus reduces manual errors and friction.
  • Support Fit API handlers that enable to customize the training loop for things like checkpointing, logging, early stopping and metrics inspired by Keras Callbacks.
  • Maintain backwards compatibility: the existing Gluon way to train a model will be supported and maintained - it is needed for complex models and full imperative control by the user.
  • The new Fit API will cover beginners use-cases including canonical CV and NLP models, full list is in appendix. For advanced users and complex models, the recommended path is to use the existing training loop.
  • Test coverage: 100% unit test coverage and 100% integration test coverage for the example models in Appendix .
  • Educate Gluon users via: (1) Blog post (2) Example (3) Tutorial

...

Currently in Gluon because of its imperative style of programming, users write the entire training loop which requires multiple steps. To see a code example, see appendix A.
Writing the custom loop involves:

...

The new API described below reduces the number of lines of code to be written by the user. In cases where the user implements logging and checkpointing, the number of lines is reduced from ~40 to ~6

Below is an example for the Fit API implementing similar functionality to the one using the existing training loop in Appendix A.

import mx.gluon.estimator as est
net = get_model() ## get the network
loss = gluon.loss.CrossEntropy()
e = est(net, lossfn = loss)
## training
trainers = [gluon.Trainer('sgd',{'learning_rate':0.001})]
e.fit(train_data, val_data, epochs, trainers, context)


...

class EventHandler:
def __init__(self,estimator):
self._train_stats= estimator.train_stats


def train_begin(self):
pass
def train_end(self):
pass
def batch_begin(self):
pass
def batch_end(self):
pass
def epoch_begin(self):
pass
def epoch_end(self):
pass
class LoggingHandler(EventHandler):
def __init__(self, estimator, log_loc = './'):
# setup logging
def epoch_end:
## log the train stats to log location

class CheckpointHandler(EventHandler):
def __init__(self, estimator, checkpoint_interval=5 , ckpt_loc='./', monitor= "val_loss"):
super.__init__()
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## save the model params to the checkpointing location

class MetricHandler(EventHandler):
def __init__(self, estimator):
super.__init__()
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## calculate and update metrics for thr training dataset
## update_metrics(pred, labels)- default implementation can be overriden in case of multi-output cases
## update validation metrics for validation dataset

class EarlyStopping(EventHandler):
def __init__(self, monitor= "val_loss", min_delta=0, patience=0, mode="auto", baseline=None, restore_best_params=False):
# setup early stopping rules based on the metric/loss monitor and the mode
# e.g. if "acc" use greater mode else use lesser
def on_epoch_end:
# if metric improved, record the best value
# else wait n epochs(n=patience) and stop trainning
# restore net parameters from the best epoch accordingly
def on_train_end:
# let user know if early stopping is triggered

...

We will add integration tests covering all of the models in the release goals (Appendix). See comment section for integration test plan.


Technical Challenges / Open Questions

...

APPENDIX A - Gluon Training Loop Example

...

current user experience

##Current training loop in gluon
###########################
# Only one epoch
###########################
num_epochs = 1

trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
L = gluon.loss.SoftmaxCrossEntropyLoss()
best_val_score = 1

for epoch in range(num_epochs):
tic = time.time()
train_metric.reset()
btic = time.time()

for i, batch in enumerate(train_data):
data, label = batch_fn(batch, ctx)


with ag.record():
outputs = [net(X) for X in data]
loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
for l in loss:
l.backward()
lr_scheduler.update(i, epoch)
trainer.step(batch_size)

train_metric.update(label, outputs)

if log_interval and not (i+1)%log_interval:
train_metric_name, train_metric_score = train_metric.get()
logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
epoch, i, batch_size*log_interval/(time.time()-btic),
train_metric_name, train_metric_score, trainer.learning_rate))
btic = time.time()

train_metric_name, train_metric_score = train_metric.get()
throughput = int(batch_size * i /(time.time() - tic))

err_top1_val, err_top5_val = test(ctx, val_data)

logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))
logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))

if err_top1_val < best_val_score:
best_val_score = err_top1_val
net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

if save_frequency and save_dir:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))


...

By supporting the following models, we believe we can cover most basic use cases for Gluon users

DomainCategoryModelReferenceFeature RequiredNote
CVImage ClassificationAlexNetGluon Booknet, dataloader, batch_size, trainer, ctx, num_epochsmlp, lenet, vgg are similar, example: train_ch5()
CVImage Augmentation + ClassificationResNet18Gluon Booknet, dataloader, batch_size, trainer, ctx, num_epochsexample: train_ch5()
CVSemantic SegmentationFCNGluon Bookmore data_transformation, multi-gpuexample: train()
CVObject DetectionSSDGluon Bookmultiple lables, losses, and metricstraining script from Gluon CV
NLPText Sentiment ClassificationBiRNNGluon Booksame as 1 &2example: train()
NLPText Sentiment classificationTextCNNGluon Booksame as 1 &2example: train()
NLPNeural Machine Translationencoder-decoder and attention mechanism.Gluon Bookmultiple trainer, different inputs for loss
VariousVariousLRKaggle BlogLR and XGBoost is most used besides CV and NLP modelsXGBoost is not in scope and not supported


APPENDIX C - Tensorflow estimators

...


For customizing a specific part of pre-defined estimator, we need to re-create a new estimator with the customized module and then use it.
https://towardsdatascience.com/how-to-extend-a-canned-tensorflow-estimator-to-add-more-evaluation-metrics-and-to-pass-through-ddf66cd3047d

Custom-estimators examples:

...

Keras fit API is implemented using Callback (custom object) which exposes methods to be called at
(i) beginning of training
(ii) end of training
(iii) beginning of epoch
(iv) end of epoch
(v) batch_begin
(vi) batch_end
It has a list of default implementation of callbacks like History, BaseLogger, CSVLogger, ModelCheckpoint, EarlyStopping, LRScheduler, Tensorboard. It also has an option of customizable callback which user may define as required.