Source code for cntk.logging.graph

# Copyright (c) Microsoft. All rights reserved.

# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================

import os
import sys
from cntk.variables import Variable



[docs]def find_all_with_name(node, node_name, depth=0): ''' Finds functions in the graph starting from ``node`` and doing a depth-first search. Args: node (:class:`~cntk.ops.functions.Function` or :class:`~cntk.variables.Variable`): the node to start the journey from node_name (`str`): name for which we are search nodes depth (int, default 0): how deep into the block hierarchy the DFS algorithm should go into. Set to -1 for infinite depth. Returns: List of primitive (or block) functions having the specified name See also: :func:`~cntk.ops.functions.Function.find_all_with_name` in class :class:`~cntk.ops.functions.Function`. ''' return depth_first_search(node, lambda x: x.name == node_name, depth)
[docs]def find_by_name(node, node_name, depth=0): ''' Finds a function in the graph starting from ``node`` and doing a depth-first search. It assumes that the name occurs only once. Args: node (:class:`~cntk.ops.functions.Function` or :class:`~cntk.variables.Variable`): the node to start the journey from node_name (`str`): name for which we are search nodes depth (int, default 0): how deep into the block hierarchy the DFS algorithm should go into. Set to -1 for infinite depth. Returns: Primitive (or block) function having the specified name See also: :func:`~cntk.ops.functions.Function.find_by_name` in class :class:`~cntk.ops.functions.Function`. ''' if not isinstance(node_name, str): raise ValueError('node name has to be a string. You gave ' 'a %s' % type(node_name)) result = depth_first_search(node, lambda x: x.name == node_name, depth) if len(result) > 1: raise ValueError('found multiple functions matching "%s". ' 'If that was expected call find_all_with_name' % node_name) if not result: return None return result[0]
[docs]def find_by_uid(node, node_uid, depth=0): ''' Finds a function in the graph based on its UID starting from ``node`` and doing a depth-first search. It assumes that the name occurs only once. Args: node (:class:`~cntk.ops.functions.Function` or :class:`~cntk.variables.Variable`): the node to start the journey from node_uid (`str` or `unicode` (in Python 2)): uid for which we are search nodes. depth (int, default 0): how deep into the block hierarchy the DFS algorithm should go into. Set to -1 for infinite depth. Returns: Primitive (or block) function having the specified name See also: :func:`~cntk.ops.functions.Function.find_by_uid` in class :class:`~cntk.ops.functions.Function`. ''' # The try-except block below is in place to allow working in Python 2, where # the input argument node_uid could be of type 'unicode' instead of 'str'. But # Python 3 does not have type 'unicode', hence the check. try: uid_is_type_unicode = isinstance(node_uid, unicode) except NameError: uid_is_type_unicode = False if not (isinstance(node_uid, str) or uid_is_type_unicode): raise ValueError('node_uid must be string of type str or unicode (Python 2.7). You gave ' 'a %s' % type(node_uid)) if uid_is_type_unicode: node_uid = node_uid.encode('ascii') result = depth_first_search(node, lambda x: x.uid == node_uid, depth) if len(result) > 1: raise ValueError('found multiple functions matching "%s". ' 'This should not happen as UIDs are unique.' % node_uid) if not result: return None return result[0]
[docs]def plot(root, filename=None): ''' Walks through every node of the graph starting at ``root``, creates a network graph, and returns a network description. If ``filename`` is specified, it outputs a DOT, PNG, PDF, or SVG file depending on the file name's suffix. Requirements: * for DOT output: `pydot_ng <https://pypi.python.org/pypi/pydot-ng>`__ * for PNG, PDF, and SVG output: `pydot_ng <https://pypi.python.org/pypi/pydot-ng>`__ and `graphviz <http://graphviz.org>`__ (GraphViz executable has to be in the system's PATH). Args: node (graph node): the node to start the journey from filename (`str`, default None): file with extension '.dot', 'png', 'pdf', or 'svg' to denote what format should be written. If `None` then nothing will be plotted, and the returned string can be used to debug the graph. Returns: `str` describing the graph ''' if filename: suffix = os.path.splitext(filename)[1].lower() if suffix not in ('.svg', '.pdf', '.png', '.dot'): raise ValueError('only file extensions ".svg", ".pdf", ".png", and ".dot" are supported') else: suffix = None if filename: try: import pydot_ng as pydot except ImportError: raise ImportError("Unable to import pydot_ng, which is required to output SVG, PDF, PNG, and DOT format.") # initialize a dot object to store vertices and edges dot_object = pydot.Dot(graph_name="network_graph", rankdir='TB') dot_object.set_node_defaults(shape='rectangle', fixedsize='false', style='filled', fillcolor='lightgray', height=.85, width=.85, fontsize=12) dot_object.set_edge_defaults(fontsize=10) # string to store model model = [] root = root.root_function root_uid = root.uid stack = [root] visited = set() # [uid] instead of node object itself, as this gives us duplicate entries for nodes with multiple outputs primitive_op_map = { 'Plus': '+', 'Minus': '-', 'ElementTimes': '*', 'Times': '@', } function_nodes = {} # [uid] -> dot node def node_desc(node): name = "<font point-size=\"10\" face=\"sans\">'%s'</font> <br/>"%node.name try: name += "<b><font point-size=\"14\" face=\"sans\">%s</font></b> <br/>"%node.op_name except AttributeError: pass name += "<font point-size=\"8\" face=\"sans\">%s</font>"%node.uid return '<' + name + '>' def shape_desc(node): dyn_axes = node.dynamic_axes dyn = '[#' + ',*' * (len(dyn_axes) - 1) + ']' if len(dyn_axes) > 0 else '' # the '#' indicates the batch axis, while * indicate dynamic axes (which can be sequences) return dyn + str(node.shape) static_shape = str(node.shape) return '"#dyn: %i\nstatic: %s"'%(num_dyn_axes, static_shape) while stack: node = stack.pop(0) if node.uid in visited: continue try: # Function node node = node.root_function stack = list(node.root_function.inputs) + stack # add current Function node def lazy_create_node(node): if node.uid in function_nodes: # dot node already exists return function_nodes[node.uid] if node.is_primitive and not node.is_block and len(node.outputs) == 1 and node.output.name == node.name: # skip the node name if redundant op_name = primitive_op_map.get(node.op_name, node.op_name) render_as_primitive = len(op_name) <= 4 size = 0.4 if render_as_primitive else 0.6 cur_node = pydot.Node(node.uid, label='"' + op_name + '"', shape='ellipse' if render_as_primitive else 'box', fixedsize='true' if render_as_primitive else 'false', height=size, width=size, fontsize=20 if render_as_primitive and len(op_name) == 1 else 12 , penwidth=4 if node.op_name != 'Pass' and node.op_name != 'ParameterOrder' else 1) # TODO: Would be cool, if the user could pass a dictionary with overrides. But maybe for a later version. else: f_name = '\n' + node.name + '()' if node.name else '' cur_node = pydot.Node(node.uid, label='"' + node.op_name + f_name + '"', fixedsize='true', height=1, width=1.3, penwidth=4 if node.op_name != 'Pass' and node.op_name != 'ParameterOrder' else 1) dot_object.add_node(cur_node) function_nodes[node.uid] = cur_node return cur_node # add current node line = [node.op_name] line.append('(') if filename: cur_node = lazy_create_node(node) dot_object.add_node(cur_node) # add node's inputs for i, input in enumerate(node.inputs): # Suppress Constants inside BlockFunctions, since those are really private to the BlockFunction. # Still show Parameters, so users know what parameters it learns, e.g. a layer. from cntk import cntk_py if node.is_block and isinstance (input, cntk_py.Variable) and input.is_constant: continue line.append(input.uid) if i != len(node.inputs) - 1: line.append(', ') if filename: if input.is_input: shape = 'invhouse' color = 'yellow' elif input.is_placeholder: shape = 'invhouse' color = 'grey' elif input.is_parameter: shape = 'diamond' color = 'green' elif input.is_constant: shape = 'rectangle' color = 'lightblue' else: # is_output shape = 'invhouse' color = 'grey' if isinstance (input, cntk_py.Variable) and not input.is_output: name = 'Parameter' if input.is_parameter else 'Constant' if input.is_constant else 'Input' if input.is_input else 'Placeholder' if input.name: if name == 'Parameter': # don't say 'Parameter' for named parameters, it's already indicated by being a box name = input.name else: name = name + '\n' + input.name name += '\n' + shape_desc(input) if input.is_input or input.is_placeholder: # graph inputs are eggs (since dot has no oval) input_node = pydot.Node(input.uid, shape='egg', label=name, fixedsize='true', height=1, width=1.3, penwidth=4) # wish it had an oval elif not input.name and input.is_constant and (input.shape == () or input.shape == (1,)): # unnamed scalar constants are just shown as values input_node = pydot.Node(input.uid, shape='box', label=str(input.as_constant().value), color='white', fillcolor='white', height=0.3, width=0.4) else: # parameters and constants are boxes input_node = pydot.Node(input.uid, shape='box', label=name, height=0.6, width=1) else: # output variables never get drawn except the final output assert(isinstance (input, cntk_py.Variable)) input_node = lazy_create_node(input.owner) # connect to where the output comes from directly, no need to draw it dot_object.add_node(input_node) label = input.name if input.name else input.uid # the Output variables have no name if the function has none label += '\n' + shape_desc(input) dot_object.add_edge(pydot.Edge(input_node, cur_node, label=label)) # add node's output line.append(') -> ') line = ''.join(line) for n in node.outputs: model.append(line + n.uid + ';\n') if (filename): if node.uid == root_uid: # only final network outputs are drawn for output in node.outputs: final_node = pydot.Node(output.uid, shape='egg', label=output.name + '\n' + shape_desc(output), fixedsize='true', height=1, width=1.3, penwidth=4) dot_object.add_node(final_node) dot_object.add_edge(pydot.Edge(cur_node, final_node, label=shape_desc(output))) except AttributeError: # OutputVariable node try: if node.is_output: stack.insert(0, node.owner) except AttributeError: pass visited.add(node.uid) if filename: if suffix == '.svg': dot_object.write_svg(filename, prog='dot') elif suffix == '.pdf': dot_object.write_pdf(filename, prog='dot') elif suffix == '.png': dot_object.write_png(filename, prog='dot') else: dot_object.write_raw(filename) model = "\n".join(reversed(model)) return model
[docs]def get_node_outputs(node, depth=0): ''' Walks through every node of the graph starting at ``node`` and returns a list of all node outputs. Args: node (graph node): the node to start the journey from Returns: A list of all node outputs ''' node_list = depth_first_search(node, lambda x: True, depth) node_outputs = [] for node in node_list: try: for out in node.outputs: node_outputs.append(out) except AttributeError: pass return node_outputs