# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import cntk as C
import numpy as np
from cntk.contrib import crosstalk as cstk
DictParameterType = 'DictParameter'
[docs]def find_func_param(func, name=None, shape=None, allow_not_found=False):
'''
Find a single parameter in a function by name or by shape when the function has multiple parameters.
If the function only has one parameter it's directly returned.
Args:
func (:class:`~cntk.ops.functions.Function`): The function to search parameter for
name (string) : The name of the parameter
shape (tuple): The shape of the parameter
allow_not_found (bool): Set to True to avoid raise exception when not found
Returns:
The :class:`~cntk.variables.Parameter` that is found
'''
if len(func.parameters) == 1:
return func.parameters[0]
found = [p for p in func.parameters if (shape and p.shape == shape) or name == p.name]
if not found:
if allow_not_found:
return None
else:
raise Exception('param ({} {}) not found'.format(name, shape))
if len(found) > 1:
raise Exception('more than one found')
return found[0]
def _parameter_setter(p, raw_value, attr=None):
if p.shape != raw_value.shape:
raise Exception('different shape, expected {} actual {}'.format(p.shape, raw_value.shape))
p.value = raw_value
def _parameter_getter(p, attr=None):
return p.value
def _dict_parameter_setter(pd, raw_value, attr=None):
if len(pd) != len(raw_value):
raise Exception('mismatch len')
if pd.keys() != raw_value.keys():
raise Exception('mismatch keys')
for k in pd.keys():
_parameter_setter(pd[k], raw_value[k])
def _dict_parameter_getter(pd, attr=None):
return {k:pd[k].value for k in pd.keys()}
def _function_getter(data):
def _get(f, attr=None):
return f.eval(data)
return _get
def _variable_getter(data):
def _get(f, attr=None):
return C.as_composite(f.owner).eval(data)[f]
return _get
def _conv2d_getter(f, attr):
W = _parameter_getter(find_func_param(f, shape=(attr.num_filters, 1,) + attr.filter_shape))
bias_param = find_func_param(f, shape=(attr.num_filters, 1, 1,), allow_not_found=True)
if bias_param:
b = _parameter_getter(bias_param)
else:
b = None
return cstk.Conv2DArgs(W=W[:,0,:,:], b=None if b is None else b.reshape(-1))
def _conv2d_setter(f, raw_value, attr):
W = find_func_param(f, shape=(attr.num_filters, 1,) + attr.filter_shape)
_parameter_setter(W, raw_value.W.reshape(W.shape))
if raw_value.b is not None:
b = find_func_param(f, shape=(attr.num_filters, 1, 1,))
_parameter_setter(b, raw_value.b.reshape(b.shape))
def _get_rnn_gates(op_type):
num_gates = 1
if op_type == 'lstm':
num_gates = 4
elif op_type == 'gru':
# NOTE that cudnn GRU implementation is different from standard one
# that cell got projection/bias as well before element_times
# from CUDNN doc watch out for the difference in h't calculation:
#
# it = sigmoid(Wixt + Riht-1 + bWi + bRu)
# rt = sigmoid(Wrxt + Rrht-1 + bWr + bRr)
# h't = tanh(Whxt + rt.*(Rhht-1 + bRh) + bWh)
# ht = (1 - it) .* h't + it .* ht-1
#
# so to convert cudnn to CPU we need a different GRU
num_gates = 3
else:
raise NotImplementedError()
return num_gates
# return splitter for cudnn param of shape (_inferred, hidden_dim) along _inferred
def _get_cudnn_rnn_splitter(attr):
in_dim = attr.input_dim
h_dim = attr.hidden_dim
gates = _get_rnn_gates(attr.op_type)
# for unidirectional, W, H, b1, b2
# for bidirectional, fw_W, fw_H, bw_W, bw_H, fw_b1, fw_b2, bw_b1, bw_b2
multiplier = 2 if attr.bidirectional else 1
splitter = [in_dim*h_dim*gates, h_dim*h_dim*gates] * multiplier + [h_dim*gates, h_dim*gates] * multiplier
splitter = splitter[0:-1]
return np.cumsum(splitter)
def _get_birnn_param(f):
if f.root_function.op_name != 'Splice':
raise NotImplementedError()
# assuming forward/backward cell first/second input to Splice
fw = f.root_function.inputs[0].owner
bw = f.root_function.inputs[1].owner
return cstk.RnnArgs(fw_W=find_func_param(fw, name='W'),
fw_H=find_func_param(fw, name='H'),
fw_b=find_func_param(fw, name='b'),
bw_W=find_func_param(bw, name='W'),
bw_H=find_func_param(bw, name='H'),
bw_b=find_func_param(bw, name='b'))
'''
cudnn lstm gate is in order of input/forget/mem/output,
while both CNTK and tensorflow is input/mem/forget/output
the saved model uses CNTK/tensorflow order so cudnn weights needs ajust
NOTE this function is identical to its reverse
'''
def _adjust_lstm_gate_order(W):
if len(W.shape) == 2:
i,f,m,o = np.hsplit(W, 4)
return np.concatenate((i,m,f,o), axis=1)
elif len(W.shape) == 1:
i,f,m,o = np.split(W, 4)
return np.concatenate((i,m,f,o))
else:
raise Exception('invalid input')
def _rnn_getter(f, attr):
if not attr.bidirectional:
raise NotImplementedError()
use_cudnn = (len(f.parameters) == 1) # CNTK has only 1 big fat parameter when using cudnn
if use_cudnn:
gates = _get_rnn_gates(attr.op_type)
fw_Wt, fw_Ht, bw_Wt, bw_Ht, fw_b1, fw_b2, bw_b1, bw_b2 = np.split(f.parameters[0].value.reshape(-1), _get_cudnn_rnn_splitter(attr))
return cstk.RnnArgs(fw_W=_adjust_lstm_gate_order(fw_Wt.reshape(gates*attr.hidden_dim, -1).transpose()),
fw_H=_adjust_lstm_gate_order(fw_Ht.reshape(gates*attr.hidden_dim, -1).transpose()),
fw_b=_adjust_lstm_gate_order(fw_b1 + fw_b2),
bw_W=_adjust_lstm_gate_order(bw_Wt.reshape(gates*attr.hidden_dim, -1).transpose()),
bw_H=_adjust_lstm_gate_order(bw_Ht.reshape(gates*attr.hidden_dim, -1).transpose()),
bw_b=_adjust_lstm_gate_order(bw_b1 + bw_b2))
else:
param = _get_birnn_param(f)
return cstk.RnnArgs(fw_W=_parameter_getter(param.fw_W),
fw_H=_parameter_getter(param.fw_H),
fw_b=_parameter_getter(param.fw_b),
bw_W=_parameter_getter(param.bw_W),
bw_H=_parameter_getter(param.bw_H),
bw_b=_parameter_getter(param.bw_b))
def _rnn_setter(f, raw_value, attr):
if not attr.bidirectional:
raise NotImplementedError()
use_cudnn = (len(f.parameters) == 1)
if use_cudnn:
gates = _get_rnn_gates(attr.op_type)
_parameter_setter(f.parameters[0],
np.concatenate((_adjust_lstm_gate_order(raw_value.fw_W).transpose().reshape(-1),
_adjust_lstm_gate_order(raw_value.fw_H).transpose().reshape(-1),
_adjust_lstm_gate_order(raw_value.bw_W).transpose().reshape(-1),
_adjust_lstm_gate_order(raw_value.bw_H).transpose().reshape(-1),
_adjust_lstm_gate_order(raw_value.fw_b).reshape(-1),
np.zeros_like(raw_value.fw_b).reshape(-1),
_adjust_lstm_gate_order(raw_value.bw_b).reshape(-1),
np.zeros_like(raw_value.bw_b).reshape(-1)
)).reshape(f.parameters[0].shape))
else:
param = _get_birnn_param(f)
_parameter_setter(param.fw_W, raw_value.fw_W)
_parameter_setter(param.fw_H, raw_value.fw_H)
_parameter_setter(param.fw_b, raw_value.fw_b)
_parameter_setter(param.bw_W, raw_value.bw_W)
_parameter_setter(param.bw_H, raw_value.bw_H)
_parameter_setter(param.bw_b, raw_value.bw_b)
def _embed_getter(p, attr):
map = {}
value = _parameter_getter(p)
for i in range(attr.input_dim):
map[attr.dict[i]] = value[i,:]
return map
def _embed_setter(p, raw_value, attr):
out = [None]*attr.input_dim
for w in raw_value.keys():
out[attr.dict.index(w)] = raw_value[w]
_parameter_setter(p, np.asarray(out))
[docs]class CNTKCrosstalk(cstk.Crosstalk):
'''
CNTK implementation for crosstalk
'''
def __init__(self):
super(CNTKCrosstalk, self).__init__()
super(CNTKCrosstalk, self).register_funcs(C.variables.Parameter, setter=_parameter_setter, getter= _parameter_getter)
super(CNTKCrosstalk, self).register_funcs(DictParameterType, setter=_dict_parameter_setter, getter=_dict_parameter_getter)
super(CNTKCrosstalk, self).register_funcs(cstk.Conv2DAttr, setter=_conv2d_setter, getter=_conv2d_getter)
super(CNTKCrosstalk, self).register_funcs(cstk.RnnAttr, setter=_rnn_setter, getter=_rnn_getter)
super(CNTKCrosstalk, self).register_funcs(cstk.EmbedAttr, setter=_embed_setter, getter=_embed_getter)
[docs] def set_data(self, data):
'''
Set mapped data for variable evaluation
Args:
data: The input data as arguments parameter in :func:`~cntk.ops.functions.Function.eval`
'''
super(CNTKCrosstalk, self).register_funcs(C.ops.functions.Function, getter=_function_getter(data))
super(CNTKCrosstalk, self).register_funcs(C.variables.Variable, getter=_variable_getter(data))
[docs] def is_param(self, name):
'''
Check if var with name is a parameter
Args:
name (`str`): Variable name to check
'''
var_type = self.vars[name].type
return var_type not in [C.ops.functions.Function, C.variables.Variable]
[docs] def load_all_params(self):
'''
Load all parameters from files in working directory
'''
super(CNTKCrosstalk, self).load([n for n in self.vars.keys() if self.is_param(n)])
[docs] def save_all_params(self):
'''
Save all parameters to files in working directory
'''
super(CNTKCrosstalk, self).save([n for n in self.vars.keys() if self.is_param(n)])
instance = CNTKCrosstalk()