Source code for cntk.debugging.debug

# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See file in the project root
# for full license information.
# ==============================================================================

import sys
from collections import defaultdict

from cntk import cntk_py, user_function, output_variable, CloneMethod

from cntk.ops.functions import UserFunction
from cntk.internal import map_if_possible

        n - execute the next node
        n <number> - execute the next <number> nodes

        u f - execute until forward pass (like 'n' when already in forward pass)
        u b - execute until backward pass (like 'n' when already in backward pass)
        u name - execute until a node with that name is hit
        u <lambda> - execute until the lambda expression is True. Examples:
                     Until a Times node is hit:
                         lambda arg, node: node.op_name == 'Times'
                     Until a node is hit that has 3 dimensions:
                         lambda arg, node: len(node.shape) == 3
                     Until the variance of the input exceeds 1 (np = numpy):
                         lambda arg, node: np.var(arg) > 1

        c - execute until end
        p - print input (forward) or root gradients (backward)
        d - drop into a pdb shell
        q - quit\

__doc__ = '''\
In order to debug a graph one simply needs to wrap the root node as follows::

    # ... setting up the model in z
    from cntk.debugging import debug_model
    z = debug_model(z)

Then, when ``z`` is evaluated or trained (i.e. when either
:meth:`~cntk.ops.functions.Function.forward` or
:meth:`~cntk.ops.functions.Function.backward` is called, you will see the
following command-line interface::

    =================================== forward  ===================================
    Parameter node with uid='Parameter28' shape=[](2,)
    [CNTK forward] >>> help

    [CNTK backward] >>> n

    Times node with uid='Times29' shape=[*,*](2,)
    [CNTK forward] >>> n
    =================================== backward ===================================
    Times node with uid='Times29' shape=[*,*](2,)
    [CNTK backward] >>> p
    State: None
    Root gradients:
    [[[-0.79412955  0.79412955]]
     [[-0.79412955  0.79412955]]
     [[ 0.20587046 -0.20587045]]
     [[ 0.20587046 -0.20587045]]
     [[ 0.20587046 -0.20587045]]
     [[ 0.20587046 -0.20587045]]
     [[-0.79412955  0.79412955]]
     [[ 0.20587046 -0.20587045]]
     [[ 0.20587039 -0.20587039]]
     [[-0.79412961  0.79412961]]]

At every stop the following information is given:
 * Forward or backward pass
 * Node type (e.g. 'Times')
 * Name if given, otherwise it is omitted
 * uid, which is a unique reference within the graph
 * shape having the format [dynamic axis](static axes). E.g. ``[*,*](2,)``
   means that the node's output has two dynamic axes (batch and sequence) and
   one static axis (2 dimensions)

[docs]def save_as_legacy_model(root_op, filename): ''' Save the network of ``root_op`` in ``filename``. For debugging purposes only, very likely to be deprecated in the future. Args: root_op (:class:`~cntk.ops.functions.Function`): op of the graph to save filename (str): filename to store the model in. ''' cntk_py.save_as_legacy_model(root_op, filename)
class _DebugState(object): def __init__(self, all_nodes): self.commands = [] self.last_pass = '<start>' self.all_nodes = all_nodes self.name_to_node = defaultdict(lambda: []) for n in self.all_nodes: self.name_to_node[].append(n)
[docs]def set_checked_mode(enable): ''' Checked mode enables additional runtime verification such as: - Tracking NaN occurrences in sequence gaps. - Function graph verification after binding of free static axes to actual values at runtime Enabling checked mode incurs additional runtime costs and is meant to be used as a debugging aid. Args: enable (bool): whether to enable checked mode (with performance impact) ''' cntk_py.set_checked_mode(enable)
[docs]def set_computation_network_trace_level(level): ''' Set trace level to the computation network. Currently supported values: 0 turn off trace 1 output nodes' dimensions and some other static info 1000 output each node's abs sum of elements in its value matrix for every forward/backward 1000000 output each node's full matrix for every forward/backward Args: level (int): trace level ''' cntk_py.set_computation_network_trace_level(level)
[docs]def set_node_timing(enable): ''' Node-timing records per-node average execution time per-minibatch. Enabling checked mode incurs a little runtime costs and is meant to be used as a debugging aid. Args: enable (bool): whether to enable per-node timing ''' cntk_py.enable_node_timing() if enable else cntk_py.disable_node_timing()
class _DebugNode(UserFunction): ''' A user function node that exposes a command line interface. With that one can step through the graph and investigate data, shapes, etc. In order to use it, call :func:`debug_model` on the model. Args: arg (graph node): the node in the graph after which this Debug Node is to be inserted debug_state (:class:`_DebugState`): state that is shared among all debug nodes in_stream (object behaving like sys.stdin): `readline()` will be called on it to obtain user input out_stream (object behaving like sys.stdout): `write()` and `flush()` will be called on it to output debug info to the user exit_func (callable): callable that takes an exit code and is called, when the user exits the debugging process name (str): name of the node ''' _commands = [] PROMPT_FORWARD = '[CNTK forward] >>> ' PROMPT_BACKWARD = '[CNTK backward] >>> ' def __init__(self, arg, debug_state, in_stream=sys.stdin, out_stream=sys.stdout, exit_func=sys.exit, name='D'): if hasattr(arg, 'is_composite') and arg.is_composite: arg = arg.root_function # Shorten the name a bit arg_uid_parts = arg.uid.split('_') if len(arg_uid_parts)>2 and arg_uid_parts[-2] == 'Output': del arg_uid_parts[-2] name += '_%s' % '_'.join(arg_uid_parts) super(_DebugNode, self).__init__([arg], as_numpy=True, name=name) self.after = arg self.debug_state = debug_state self._in, self._out = in_stream, out_stream self._exit = exit_func def clone(self, cloned_inputs): arg = cloned_inputs[0] map_if_possible(arg) return _DebugNode(arg, self.debug_state, self._in,self._out, self._exit) # TODO: # Breakopint handling # u h - until here def _wait_for_input(self, prompt): understood = False while not understood: self._out.write(prompt) self._out.flush() new_input = self._in.readline().strip() if not new_input: continue if len(new_input) == 1 and new_input in 'bcdfp': understood = [new_input] elif new_input[0] == 'n': if len(new_input) > 1: remainder = new_input[1:] try: number = int(remainder) understood = ['n'] * number except ValueError: pass else: understood = ['n'] elif new_input[0] == 'u': try: what = new_input[1:].strip() if what.startswith('lambda'): code = eval(what) understood = [code] else: if what in self.debug_state.name_to_node: def code(arg, n): return == what understood = [code] elif not understood: if "backward".startswith(what): understood = ['ub'] elif "forward".startswith(what): understood = ['uf'] else: self._out.write('Your model does not contain a ' 'node with name "%s"\n' % what) self._out.flush() except SyntaxError: understood = False elif new_input == 'q': self._exit(0) if not understood: self._out.write(DEBUG_USAGE + '\n') self._out.flush() return understood def _print_status(self, current_pass): if current_pass != self.debug_state.last_pass: if current_pass == 'f': self._out.write('\n') self._out.write('=' * 35 + ' forward ' + '=' * 35 + '\n') else: self._out.write('=' * 35 + ' backward ' + '=' * 35 + '\n') self._out.flush() after = self.after.owner if self.after.is_output else self.after self._out.write("\n%s with uid '%s'\n" % (str(after), after.uid)) self._out.flush() def forward(self, argument, device=None, outputs_to_retain=None): self._print_status('f') done = False while not done: if not self.debug_state.commands: self.debug_state.commands = self._wait_for_input( _DebugNode.PROMPT_FORWARD) commands = self.debug_state.commands next_command = commands[-1] if next_command == 'c': done = True elif isinstance(next_command, str) and next_command.startswith('n'): if len(next_command) == 1: commands.pop() done = True elif isinstance(next_command, str) and next_command.startswith('u'): if next_command == "uf": commands.pop() if self.debug_state.last_pass == 'b': self.debug_state.commands = self._wait_for_input( _DebugNode.PROMPT_FORWARD) done = False else: done = True elif next_command == "ub": done = True elif next_command == 'p': self._out.write('Input with shape %s: \n' % str(argument.shape)) self._out.write(str(argument)) self._out.write('\n') self._out.flush() commands.pop() elif next_command == 'd': commands.pop() import pdb pdb.set_trace() done = True elif callable(next_command): if next_command(argument, self.after): commands.pop() else: done = True self.debug_state.last_pass = 'f' return None, argument def backward(self, state, root_gradients): self._print_status('b') done = False while not done: if not self.debug_state.commands: self.debug_state.commands = self._wait_for_input( _DebugNode.PROMPT_BACKWARD) commands = self.debug_state.commands next_command = commands[-1] if next_command == 'c': done = True elif isinstance(next_command, str) and next_command.startswith('n'): if len(next_command) == 1: commands.pop() done = True elif isinstance(next_command, str) and next_command.startswith('u'): if next_command == "uf": done = True elif next_command == "ub": commands.pop() if self.debug_state.last_pass == 'f': self.debug_state.commands = self._wait_for_input( _DebugNode.PROMPT_FORWARD) done = False else: done = True elif next_command == 'p': if state is not None: self._out.write('State: %s\n' % str(state)) self._out.write('Root gradients with shape %s: \n' % str(root_gradients.shape)) self._out.write(str(root_gradients)) self._out.write('\n') self._out.flush() commands.pop() elif next_command == 'd': import pdb pdb.set_trace() done = True elif callable(next_command): if next_command(root_gradients, self.after): commands.pop() else: done = True self.debug_state.last_pass = 'b' return root_gradients def infer_outputs(self): return [output_variable(self.inputs[0].shape, self.inputs[0].dtype, self.inputs[0].dynamic_axes)] def __str__(self): return "_DebugNode(after=%s)" % str(self.after) def _nodes_to_debug(model): from cntk.logging.graph import depth_first_search def node_filter(x): if hasattr(x, 'op_name') and x.op_name in ['NoOp']: return False else: return True nodes = set(depth_first_search(model, lambda x: True)) uf_nodes = [n for n in nodes if hasattr(n, 'op_name') and n.op_name == 'UserFunction'] already_covered = [n.inputs[0].owner if n.inputs[0].is_output else n.inputs[0] for n in uf_nodes] to_remove = [n.uid for n in (already_covered + uf_nodes)] return [n for n in nodes if n.uid not in to_remove]
[docs]def debug_model(model, in_stream=sys.stdin, out_stream=sys.stdout, exit_func=sys.exit): ''' debug_model(model, in_stream=sys.stdin, out_stream=sys.stdout, exit_func=sys.exit) Returns a cloned model that has debug nodes inserted everywhere. When the graph is evaluated or trained, those nodes will allow to inspect the graph. Args: model (root node): root node until which the nodes are to be debugged in_stream (object behaving like sys.stdin, default stdin): `readline()` will be called on it to obtain user input out_stream (object behaving like sys.stdout, default stdout): `write()` and `flush()` will be called on it to output debug info to the user exit_func (callable, default sys.exit): callable that takes an exit code and is called, when the user exits the debugging process Returns: a clone of the model that has debugging enabled ''' nodes = _nodes_to_debug(model) dbg_state = _DebugState(nodes) orig_node_count = len(nodes) mod_counter = 0 # We cannot add the DebugNodes in one clone because the replacements will # hide parent nodes. while len(nodes) > 0: modifications = {n: user_function(_DebugNode(n, dbg_state, in_stream, out_stream, exit_func)) for n in nodes} model = model.clone(CloneMethod.share, modifications) mod_counter += 1 if mod_counter > orig_node_count: raise ValueError('cannot debug this graph') nodes = _nodes_to_debug(model) return model