record_single(conv, (1, 2, 5, 5), "conv_sb_1x1_kernel")
record_single(conv, (3, 2, 5, 5), "conv_mb_1x1_kernel")
+ attention = K.layers.Attention()
+ record_single(attention, [(1, 2, 2), (1, 2, 2)],
+ "attention_golden_shared_kv", {"training": False})
+
inspect_file("conv_sb_no_overlap.nnlayergolden")
def record_single(layer, input_shape, test_name, call_args={}):
layer = attach_trans_layer(layer)
layer.build(input_shape)
- inputs = _rand_like(input_shape)
+ if isinstance(input_shape, list):
+ inputs = [_rand_like(in_shape) for in_shape in input_shape]
+ else:
+ inputs = _rand_like(input_shape)
initial_weights = [tf.Variable(i) for i in layer.weights]
layer.call(inputs, **call_args) # warm layer multiple times
with tf.GradientTape(persistent=True) as tape:
- tape.watch(inputs)
+ if isinstance(inputs, list):
+ list([tape.watch(inp) for inp in inputs])
+ else:
+ tape.watch(inputs)
outputs = layer.call(inputs, **call_args)
dy_constant = outputs * 2 # set incoming derivative to 2 instead of 1
with open(test_name + ".nnlayergolden", "wb") as f:
writer = _get_writer(f)
- def write_tensor(*tensors):
+ def write_tensor(tensors):
+ if not isinstance(tensors, list):
+ tensors = [tensors]
for tensor in tensors:
- # print(tensor)
+ print(tf.size(tensor))
writer(tf.size(tensor), tensor)
## @todo inputs outputs derivatives can be more than one
## @note please update genLayerTests.py comments when updating below
- write_tensor(*initial_weights)
+ write_tensor(initial_weights)
write_tensor(inputs)
write_tensor(outputs)
- write_tensor(*gradients)
- write_tensor(*weights)
+ write_tensor(gradients)
+ write_tensor(weights)
write_tensor(derivatives)
INSTANTIATE_TEST_CASE_P(Attention, LayerSemantics,
::testing::Values(semantic_attention));
+
+auto attention_shared_kv = LayerGoldenTestParamType(
+ nntrainer::createLayer<nntrainer::AttentionLayer>, {}, "1:1:3:10,1:1:3:10",
+ "attention_golden_shared_kv.nnlayergolden",
+ LayerGoldenTestParamOptions::DEFAULT);
+
+INSTANTIATE_TEST_CASE_P(Attention, LayerGoldenTest,
+ ::testing::Values(attention_shared_kv));