# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
from .. import cntk_py
from ..device import use_default_device
from cntk.internal import sanitize_var_map, sanitize_function, typemap
from ..io import MinibatchData
__doc__= '''\
An evaluator provides functionality to evaluate minibatches against the specified evaluation function.
'''
[docs]class Evaluator(cntk_py.Evaluator):
'''
Class for evaluation of minibatches against the specified evaluation function.
Args:
eval_function (:class:`~cntk.ops.functions.Function`): evaluation function.
progress_writers (list): optionally, list of progress writers from :mod:`cntk.utils` to track
training progress.
'''
def __init__(self, eval_function, progress_writers=None):
if eval_function is not None:
eval_function = sanitize_function(eval_function)
if progress_writers is None:
progress_writers = []
elif not isinstance(progress_writers, list):
progress_writers = [progress_writers]
evaluator = cntk_py.create_evaluator(eval_function, progress_writers)
# transplant into this class instance
self.__dict__ = evaluator.__dict__
[docs] def test_minibatch(self, arguments, device=None, distributed=False):
'''
Test the evaluation function on the specified batch of samples.
Args:
arguments: maps variables to their
input data. The interpretation depends on the input type:
* `dict`: keys are input variable or names, and values are the input data.
See :meth:`~cntk.ops.functions.Function.forward` for details on passing input data.
* any other type: if node has a unique input, ``arguments`` is mapped to this input.
For nodes with more than one input, only `dict` is allowed.
In both cases, every sample in the data will be interpreted
as a new sequence. To mark samples as continuations of the
previous sequence, specify ``arguments`` as `tuple`: the
first element will be used as ``arguments``, and the second one will
be used as a list of bools, denoting whether a sequence is a new
one (`True`) or a continuation of the previous one (`False`).
Data should be either NumPy arrays or a
:class:`~cntk.io.MinibatchData` instance.
device (:class:`~cntk.device.DeviceDescriptor`): the device descriptor that
contains the type and id of the device on which the computation is
to be performed.
distributed (`bool`, optional): flag indicating if evaluation results should
be aggregated across workers.
Note:
See :meth:`~cntk.ops.functions.Function.forward` for examples on
passing input data.
Returns:
`float`: the average evaluation criterion value per sample for the
tested minibatch.
'''
if not device:
device = use_default_device()
arguments = sanitize_var_map(tuple(self.evaluation_function.arguments), arguments)
return super(Evaluator, self).test_minibatch(arguments, device, distributed)
@property
@typemap
def evaluation_function(self):
'''
The evaluation function that the evaluator is using.
'''
return super(Evaluator, self).evaluation_function()
[docs] def summarize_test_progress(self):
'''
Updates the progress writers with the summary of test progress since start and resets the internal
accumulators.
'''
return super(Evaluator, self).summarize_test_progress()
[docs] def print_node_timing(self):
'''
Prints per-node average timing per-minibatch for each primitive function
statistics would reset after print
'''
return super(Evaluator, self).print_node_timing()