Source code for cntk.default_options

# ==============================================================================
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================

'''
Default options of CNTK functions. 


Usage: 

* ``with default_options():``, and
* ``with default_options_for():``
'''

# context manager for overriding defaults, use through default_options() or default_options_for() below
class _OptionsContextManager: # implement Python's 'with' protocol
    # this static member variable holds a linked list of default overrides, managed by _OptionsContextManager
    _current_default_overrides = None
    # constructor remembers the options-override record
    def __init__(self, scope, **kwargs):
        if '_scope' in kwargs or '_outer' in kwargs:
            raise ValueError("default_options: _scope or _outer are invalid (reserved) names.")
        self.scope = scope
        self.kwargs = kwargs
    # entering with block: link in a new default-options record at head
    def __enter__(self):
        from .variables import Record
        _OptionsContextManager._current_default_overrides = Record(_scope = self.scope, _outer = _OptionsContextManager._current_default_overrides, **self.kwargs) # insert new scope at head of link
        return self
    # exiting with block: restore previous remembered defaults
    def __exit__(self, type, value, traceback):
        _OptionsContextManager._current_default_overrides = _OptionsContextManager._current_default_overrides._outer # restore outer scope

# options scope without limit to specific functions, e.g.:
#   with default_options(activation=relu, init=he_normal(), pad=True):
#       model = Convolution((3,3), 32) >> Convolution((3,3), 64)  # will have relu activation and padding
[docs]def default_options(**kwargs): return _OptionsContextManager(None, **kwargs)
# options scope with limit to specific functions # functions = function or list of functions # with default_options(activation=relu, init=he_normal()): # with default_options_for(Convolution, pad=True): # model = Convolution((3,3), 32) >> Convolution((3,3), 64) >> MaxPooling((2,2)) # Convolution will pad, MaxPooling won't
[docs]def default_options_for(functions, **kwargs): if not isinstance(functions, list): functions = [functions] return _OptionsContextManager(set(functions), **kwargs)
# simple wrapper to hold a default value # meant to be both an indicator in a function signature that shows the default, e.g.: # def Convolution(args, init=default_override_or(glorot_uniform()), activation=default_override_or(identity), pad=default_override_or(False)):
[docs]class default_override_or: def __init__(self, value): self.value = value # This has a single member 'value' to hold the value.
# check if a parameter was given # meant to be used inside functions that use this facility # (No if it is still of type default_override_or.)
[docs]def is_default_override(value): return isinstance(value, default_override_or)
[docs]def get_default_override(function_or_class, **kwargs): ''' Looks up an option default override. Meant to be used inside functions that use this facility. Args: function: the function that calls this. For example:: def Convolution(args, init=default_override_or(glorot_uniform()), activation=default_override_or(identity), pad=default_override_or(False)): init = _get_default_override(Convolution, init=init) # pass default under the same name ''' # parameter checking and casting if len(kwargs) != 1: raise TypeError("get_default_override() takes 1 keyword argument but %s were given" % len(kwargs)) key, value = next(iter(kwargs.items())) # this is the keyword argument that the user passed in if function_or_class is not None: # first arg, unless None, must be an actual Python function... from inspect import isfunction, isclass if not (isfunction(function_or_class) or isclass(function_or_class)): raise ValueError('get_default_override() expects the first argument to be a Python function or class') # ...that has an arg with the same name as the given parameter from inspect import getargspec args, _, _, _ = getargspec(function_or_class) if isfunction(function_or_class) else getargspec(function_or_class.__init__) if key not in args: raise TypeError("{0}() has no argument named '{1}'".format(function_or_class.__name__, key)) # if the value passed in is not a default, then use that value if not is_default_override(value): return value # traverse linked list of scopes inside-out until an override was found, else fall back to default opts = _OptionsContextManager._current_default_overrides while opts is not None: if opts._scope is None or function_or_class is None or function_or_class in opts._scope: # we are in the right scope if hasattr(opts, key): return opts[key] # look up the option override and return it if present in this scope opts = opts._outer # step out one scope and try again return value.value # no override found: use the default as passed in
# In some case we couldn't use "with" expression, we need a global setting. # This is the global dictionary with key / value format _GlobalOptions = {} # set a global option with the input initial value
[docs]def set_global_option(key, value): global _GlobalOptions # overwrite previous setting _GlobalOptions[key] = value
# get the global option value, is not set yet, return the value user set in input
[docs]def get_global_option(key, default_value): if key in _GlobalOptions: return _GlobalOptions[key] else: # return the default value user passed in return default_value