Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

Let's start with an example to discuss the ways to implement callbacks. Let's implement a Stopping Criterial to stop at certain number of batches or epochs. It tells the for loop whether to stop training.

stop = StopTrainingHandlerV2(max_batch=100, max_epoch=10)

for epoch in range(20):
for batch in range(25):
print('epoch: ', epoch)
print('batch: ', batch)
batch_result = stop.batch_end()
if batch_result:
break
epoch_result = stop.epoch_end()
if epoch_result:
break

Multi-inheritance

One base class for one type of event, base class keep states that's specific to that event only.
Rule of thumb: common states for both Base classes should be managed by sub class inherit those base classes.

class BatchEnd(object):
def __init__(self, max_batch=None):
self.batch_idx = 0
self.total_batch = 0
self.max_batch = max_batch

def batch_end(self, batch_result={}):
self.batch_idx += 1
self.total_batch += 1


class EpochEnd(object):
def __init__(self, max_epoch=None):
self.epoch = 0
self.max_epoch = max_epoch

def epoch_end(self, epoch_result={}):
self.epoch += 1

class StopTrainingHandler(BatchEnd, EpochEnd):
def __init__(self, max_batch, max_epoch):
super().__init__(max_batch)
super(BatchEnd, self).__init__(max_epoch)
self.stop_training = False

def batch_end(self, batch_result={}):
super(StopTrainingHandler, self).batch_end(batch_result)
if self.total_batch == self.max_batch:
self.stop_training = True
return self.stop_training

def epoch_end(self, epoch_result={}):
super(StopTrainingHandler, self).epoch_end(epoch_result)
# reset batch index at end
self.batch_idx = 0
if self.epoch == self.max_epoch:
self.stop_training = True
return self.stop_training

Method override

One base class with all event methods

class EventHandler(object):
"""Basic for event handlers

:py:class:`EventHandler` can perform user defined functions at
different stages of training: train begin, epoch begin, batch begin,
batch end, epoch end, train end.

Parameters
----------
estimator : Estimator
The :py:class:`Estimator` to get training statistics
"""

def __init__(self):
self._estimator = None

def train_begin(self, *args, **kwargs):
pass

def epoch_begin(self, *args, **kwargs):
pass

def batch_begin(self, *args, **kwargs):
pass

def batch_end(self, batch_id, batch_results=None, *args, **kwargs):
return False

def epoch_end(self, epoch, epoch_results=None, *args, **kwargs):
return False

def train_end(self, *args, **kwargs):
pass


class StopTrainingHandlerV2(object):
def __init__(self, max_batch, max_epoch):
self.batch_idx = 0
self.epoch = 0
self.total_batch = 0
self.max_epoch = max_epoch
self.max_batch = max_batch
self.stop_training = False

def batch_end(self, batch_result={}, *args, **kwargs):
self.batch_idx += 1
self.total_batch += 1
if self.total_batch == self.max_batch:
self.stop_training = True
return self.stop_training

def epoch_end(self, epoch_result={}, *args, **kwargs):
self.epoch += 1
# reset batch index at end
self.batch_idx = 0
if self.epoch == self.max_epoch:
self.stop_training = True
return self.stop_training

Conclusion:

  1. Any base class should not keep any states, each specific child class maintain all states. Because it's very common that some states initiated in batch begin will be used in batch end or epoch end. Then it will be managed by a child class inherit BatchBegin, BatchEnd, EpochEnd
  2. There is no difference in efficiency on avoiding empty method calls, both approach can do that. The key idea is to categorize all callbacks into 6 lists, depending on whether it override it's parent class.
    1. Multi-inheritance can be categorized by using isinstance(TrainBegin) etc
    2. Method override can be categorized by: handler.__class__.train_begin == EventHandler.train_begin
  3. Only difference is method can be changed and modified during run time.
  4. We will go with multi-inheritance, but keep all states in each specific class of event handlers. For example:
class TrainBegin(object):
def train_begin(self, estimator, *args, **kwargs):
pass


class TrainEnd(object):
def train_end(self, estimator, *args, **kwargs):
pass


class EpochBegin(object):
def epoch_begin(self, estimator, *args, **kwargs):
pass


class EpochEnd(object):
def epoch_end(self, estimator, *args, **kwargs):
return False


class BatchBegin(object):
def batch_begin(self, estimator, *args, **kwargs):
pass


class BatchEnd(object):
def batch_end(self, estimator, *args, **kwargs):
return False


class MetricHandler(EpochBegin, BatchEnd):
def __init__(self, train_metrics):
self.train_metrics = train_metrics
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
self.rank = 1

def epoch_begin(self, estimator, *args, **kwargs):
for metric in self.train_metrics:
metric.reset()

def batch_end(self, estimator, *args, **kwargs):
pred = kwargs['pred']
label = kwargs['label']
loss = kwargs['loss']
for metric in self.train_metrics:
if isinstance(metric, Loss):
# metric wrapper for loss values
metric.update(0, loss)
else:
metric.update(label, pred)

Book keeping training states

...