Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

APPENDIX A - Gluon Training Loop Example

Use case survey : Gluon Fit Method Use Cases

current user experience

##Current training loop in gluon
###########################
# Only one epoch
###########################
num_epochs = 1

trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
L = gluon.loss.SoftmaxCrossEntropyLoss()
best_val_score = 1

for epoch in range(num_epochs):
tic = time.time()
train_metric.reset()
btic = time.time()

for i, batch in enumerate(train_data):
data, label = batch_fn(batch, ctx)


with ag.record():
outputs = [net(X) for X in data]
loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
for l in loss:
l.backward()
lr_scheduler.update(i, epoch)
trainer.step(batch_size)

train_metric.update(label, outputs)

if log_interval and not (i+1)%log_interval:
train_metric_name, train_metric_score = train_metric.get()
logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
epoch, i, batch_size*log_interval/(time.time()-btic),
train_metric_name, train_metric_score, trainer.learning_rate))
btic = time.time()

train_metric_name, train_metric_score = train_metric.get()
throughput = int(batch_size * i /(time.time() - tic))

err_top1_val, err_top5_val = test(ctx, val_data)

logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))
logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))

if err_top1_val < best_val_score:
best_val_score = err_top1_val
net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

if save_frequency and save_dir:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))


...