# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
from .. import cntk_py
from ..device import use_default_device
from cntk.internal import sanitize_var_map, sanitize_function, typemap, \
_value_as_sequence_or_array
from cntk.internal.utils import _py_dict_to_cntk_dict
from ..io import MinibatchData
__doc__ = '''\
A trainer encapsulates the overall training process and employs one or more
:mod:`~cntk.learners` to tune the parameters of a specified model
using gradients of parameters w.r.t. a training objective.
'''
[docs]class Trainer(cntk_py.Trainer):
'''
Class for training the model parameters of a models' specified loss function, using the
specified set of ``parameter_learners`` for updating the model's parameters
using computed gradients.
An optional specified metric function, which can be non-differentiable,
can be used for tracking the trained model's quality.
Args:
model (:class:`~cntk.ops.functions.Function`): root node of the function to train
criterion (tuple of :class:`~cntk.ops.functions.Function` or :class:`~cntk.variables.Variable`):
Function with one or two outputs, representing loss and, if given, evaluation metric
(in this order). Alternatively, a tuple(loss Function, evaluation Function) is also
accepted.
parameter_learners (list): list of learners from :mod:`cntk.learners`
progress_writers (progress writer or list of them): optionally, list of
progress writers from :mod:`cntk.logging` to automatically track training
progress.
Todo:
Allow to skip some parameters that should not be updated.
'''
@staticmethod
def _get_loss_metric(criterion): # helper to interpret criterion parameter
if isinstance(criterion, cntk_py.Function): # input can be a tuple of Functions or a tuple-valued Function
criterion = criterion.outputs # break up tuple-valued Function into tuple of Functions
# map Variable to Function
from cntk import combine
criterion = tuple([combine([output], name=output.name) if isinstance(output, cntk_py.Variable) else output for output in criterion])
if len(criterion) == 1:
criterion = criterion + (None,) # tuple of 1 value: pad with None
elif len(criterion) != 2:
raise ValueError("criterion parameter must be a singleton or a tuple of 2 elements")
return criterion
def __init__(self, model, criterion, parameter_learners, progress_writers=None):
loss_function, eval_function = Trainer._get_loss_metric(criterion)
# TODO sanitizing should be removed once Swig's typemaps are in place
if model is not None: # None means dummy model that is, e.g., the same as a criterion
model = sanitize_function(model)
loss_function = sanitize_function(loss_function)
if eval_function is not None:
eval_function = sanitize_function(eval_function)
if not isinstance(parameter_learners, list):
parameter_learners = [parameter_learners]
if progress_writers is None:
progress_writers = []
elif not isinstance(progress_writers, list):
progress_writers = [progress_writers]
trainer = cntk_py.trainer_impl(model, loss_function, eval_function, parameter_learners, progress_writers)
# transplant into this class instance
self.__dict__ = trainer.__dict__
# TODO: bring this back once the design has been settled
def _train_test_mb_map_args(self, *args, **kwargs):
'''helper function for mimicking Python calling convention in train/test_minibatch()'''
# one argument, which is an arg map or a (map, bool) tuple
if len(args) == 1 and isinstance(args[0], (dict, tuple)):
return args[0]
# map to function arguments
args = self.loss_function.argument_map(*args, **kwargs)
# in this use case, all must have the same inputs (subsets of loss) since they are all called as a single combined function
if self.model:
for arg in self.model.arguments:
if arg not in self.loss_function.arguments:
raise ValueError("model function must share its arguments with the loss function")
if self.evaluation_function:
for arg in self.evaluation_function.arguments:
if arg not in self.loss_function.arguments:
raise ValueError("evaluation function must have the same signature and inputs as the loss function")
return args
[docs] def train_minibatch(self, arguments, outputs=None, device=None, is_sweep_end=None):
'''
Optimize model parameters using the specified 'arguments' minibatch of training samples.
Args:
arguments: maps variables to their input data. Empty map signifies
end of local training data.
The interpretation depends on the input type:
* `dict`: keys are input variable or names, and values are the input data.
* any other type: if node has a unique input, ``arguments`` is mapped to this input.
For nodes with more than one input, only `dict` is allowed.
In both cases, every sample in the data will be interpreted
as a new sequence. To mark samples as continuations of the
previous sequence, specify ``arguments`` as `tuple`: the
first element will be used as ``arguments``, and the second one will
be used as a list of bools, denoting whether a sequence is a new
one (`True`) or a continuation of the previous one (`False`).
Data should be either NumPy arrays or a
:class:`~cntk.io.MinibatchData` instance.
outputs (iterable): outputs to fetch values for.
device (:class:`~cntk.device.DeviceDescriptor`): the device descriptor that
contains the type and id of the device on which the computation is
to be performed.
is_sweep_end (bool): indicate whether this minibatch is at the end of a sweep (of an eopoch), default to None.
This is used in combination with `arguments` being fed with numpy arrays data; when the data is from
:class:`~cntk.io.MinibatchData`, `is_sweep_end` is provided by :class:`~cntk.io.MinibatchData` so there is
no need to specify it manually.
Note:
See :meth:`~cntk.ops.functions.Function.forward` for examples on
passing input data.
Returns:
`bool` or `tuple`:
If ``outputs`` have not been provided, the returned value is `True`
if updates have been performed, `False` if all parameter learners
indicate end of learning (through their update). Otherwise, the
return value is a tuple of the that `bool` and a dictionary that
maps the variables in `outputs` to their respective NumPy arrays.
'''
if not device:
device = use_default_device()
if arguments: # arguments must feed all inputs (model, loss, eval)
all_args = set(self.loss_function.arguments)
if self.model:
all_args |= set(self.model.arguments)
if self.evaluation_function:
all_args |= set(self.evaluation_function.arguments)
arguments = sanitize_var_map(tuple(all_args), arguments,
extract_values_from_minibatch_data = False, device=device)
contains_minibatch_data = False
if (len(arguments) > 0):
value = next(iter(arguments.values()))
contains_minibatch_data = isinstance(value, MinibatchData)
if contains_minibatch_data and is_sweep_end is not None:
raise ValueError("is_sweep_end is ignored by Trainer::train_minibatch when it is fed with MinibatchData!")
if not contains_minibatch_data and is_sweep_end is None:
#for legacy code when is_sweep_end is not specified.
is_sweep_end = False
if outputs:
output_map = {v: None for v in outputs}
if contains_minibatch_data:
updated = super(Trainer, self).train_minibatch_overload_for_minibatchdata(
arguments, output_map, device)
else:
updated = super(Trainer, self).train_minibatch(arguments, is_sweep_end,
output_map, device)
for k, v in output_map.items():
output_map[k] = _value_as_sequence_or_array(v, k)
return updated, output_map
else:
if contains_minibatch_data:
updated = super(Trainer, self).train_minibatch_overload_for_minibatchdata(
arguments, device)
else:
updated = super(Trainer, self).train_minibatch(arguments, is_sweep_end,
device)
return updated
[docs] def test_minibatch(self, arguments, device=None):
'''
Test the model on the specified batch of samples using the evaluation
Function specified during construction of the Trainer.
Args:
arguments: maps variables to their
input data. The interpretation depends on the input type:
* `dict`: keys are input variable or names, and values are the input data.
See :meth:`~cntk.ops.functions.Function.forward` for details on passing input data.
* any other type: if node has a unique input, ``arguments`` is mapped to this input.
For nodes with more than one input, only `dict` is allowed.
In both cases, every sample in the data will be interpreted
as a new sequence. To mark samples as continuations of the
previous sequence, specify ``arguments`` as `tuple`: the
first element will be used as ``arguments``, and the second one will
be used as a list of bools, denoting whether a sequence is a new
one (`True`) or a continuation of the previous one (`False`).
Data should be either NumPy arrays or a
:class:`~cntk.io.MinibatchData` instance.
device (:class:`~cntk.device.DeviceDescriptor`): the device descriptor that
contains the type and id of the device on which the computation is
to be performed.
Note:
See :meth:`~cntk.ops.functions.Function.forward` for examples on
passing input data.
Returns:
`float`: the average evaluation criterion value per sample for the
tested minibatch.
'''
if not device:
device = use_default_device()
# pass all args of all parts (model, loss, eval)
all_args = set(self.loss_function.arguments)
if self.model:
all_args |= set(self.model.arguments)
if self.evaluation_function:
all_args |= set(self.evaluation_function.arguments)
arguments = sanitize_var_map(tuple(all_args), arguments)
return super(Trainer, self).test_minibatch(arguments, device)
[docs] def save_checkpoint(self, filename, external_state={}):
'''
Saves a checkpoint of the model and other Trainer state at the
specified file location.
In distributed environment the checkpointing is done by
the main worker.
Args:
filename (str): filename to store the checkpoint.
external_state (dict): additional external state, default is empty.
'''
super(Trainer, self).save_checkpoint(filename, _py_dict_to_cntk_dict(external_state))
[docs] def restore_from_checkpoint(self, filename):
'''
Restores a checkpoint of the model and Trainer state from the
specified file location.
Args:
filename (str): filename to restore the checkpoint from
'''
return super(Trainer, self).restore_from_checkpoint(filename)
@property
@typemap
def model(self):
'''
The model that the trainer is training.
'''
return super(Trainer, self).model()
@property
@typemap
def loss_function(self):
'''
The loss function that the trainer is using.
'''
return super(Trainer, self).loss_function()
@property
@typemap
def evaluation_function(self):
'''
The evaluation function that the trainer is using.
'''
return super(Trainer, self).evaluation_function()
@property
@typemap
def parameter_learners(self):
'''
The parameter learners that the trainer is using.
'''
return super(Trainer, self).parameter_learners()
@property
def previous_minibatch_loss_average(self):
'''
The average training loss per sample for the last minibatch trained
'''
return super(Trainer, self).previous_minibatch_loss_average()
@property
def previous_minibatch_evaluation_average(self):
'''
The average evaluation criterion value per sample for the last minibatch trained
'''
return super(Trainer, self).previous_minibatch_evaluation_average()
@property
def previous_minibatch_sample_count(self):
'''
The number of samples in the last minibatch trained with
'''
return super(Trainer, self).previous_minibatch_sample_count()
@property
def total_number_of_samples_seen(self):
'''
The number of samples seen globally between all workers from the beginning of training.
'''
return super(Trainer, self).total_number_of_samples_seen()
[docs] def summarize_training_progress(self):
'''
Updates the progress writers with the summary of training progress since start and resets the internal
accumulators.
'''
return super(Trainer, self).summarize_training_progress()
[docs] def summarize_test_progress(self):
'''
Updates the progress writers with the summary of test progress since start and resets the internal
accumulators.
'''
return super(Trainer, self).summarize_test_progress()
[docs] def print_node_timing(self):
'''
Prints per-node average timing per-minibatch for each primitive function
statistics would reset after print
'''
return super(Trainer, self).print_node_timing()