2 # Copyright (C) 2019 The Android Open Source Project
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
17 # Unidirectional Sequence LSTM Test:
18 # 1 Time Step, Layer Normalization, No Cifg, Peephole, Projection, and No Clipping.
26 # n_cell and n_output have the same size when there is no projection.
30 input = Input("input", "TENSOR_FLOAT32", "{%d, %d, %d}" % (max_time, n_batch, n_input))
32 input_to_input_weights = Input("input_to_input_weights", "TENSOR_FLOAT32",
33 "{%d, %d}" % (n_cell, n_input))
34 input_to_forget_weights = Input("input_to_forget_weights", "TENSOR_FLOAT32",
35 "{%d, %d}" % (n_cell, n_input))
36 input_to_cell_weights = Input("input_to_cell_weights", "TENSOR_FLOAT32",
37 "{%d, %d}" % (n_cell, n_input))
38 input_to_output_weights = Input("input_to_output_weights", "TENSOR_FLOAT32",
39 "{%d, %d}" % (n_cell, n_input))
41 recurrent_to_input_weights = Input("recurrent_to_intput_weights",
43 "{%d, %d}" % (n_cell, n_output))
44 recurrent_to_forget_weights = Input("recurrent_to_forget_weights",
46 "{%d, %d}" % (n_cell, n_output))
47 recurrent_to_cell_weights = Input("recurrent_to_cell_weights", "TENSOR_FLOAT32",
48 "{%d, %d}" % (n_cell, n_output))
49 recurrent_to_output_weights = Input("recurrent_to_output_weights",
51 "{%d, %d}" % (n_cell, n_output))
53 cell_to_input_weights = Input("cell_to_input_weights", "TENSOR_FLOAT32",
55 cell_to_forget_weights = Input("cell_to_forget_weights", "TENSOR_FLOAT32",
57 cell_to_output_weights = Input("cell_to_output_weights", "TENSOR_FLOAT32",
60 input_gate_bias = Input("input_gate_bias", "TENSOR_FLOAT32", "{%d}" % (n_cell))
61 forget_gate_bias = Input("forget_gate_bias", "TENSOR_FLOAT32",
63 cell_gate_bias = Input("cell_gate_bias", "TENSOR_FLOAT32", "{%d}" % (n_cell))
64 output_gate_bias = Input("output_gate_bias", "TENSOR_FLOAT32",
67 projection_weights = Input("projection_weights", "TENSOR_FLOAT32",
68 "{%d,%d}" % (n_output, n_cell))
69 projection_bias = Input("projection_bias", "TENSOR_FLOAT32", "{0}")
71 output_state_in = Input("output_state_in", "TENSOR_FLOAT32",
72 "{%d, %d}" % (n_batch, n_output))
73 cell_state_in = Input("cell_state_in", "TENSOR_FLOAT32",
74 "{%d, %d}" % (n_batch, n_cell))
76 activation_param = Int32Scalar("activation_param", 4) # Tanh
77 cell_clip_param = Float32Scalar("cell_clip_param", 0.)
78 proj_clip_param = Float32Scalar("proj_clip_param", 0.)
79 time_major_param = BoolScalar("time_major_param", True)
81 input_layer_norm_weights = Input("input_layer_norm_weights", "TENSOR_FLOAT32",
83 forget_layer_norm_weights = Input("forget_layer_norm_weights", "TENSOR_FLOAT32",
85 cell_layer_norm_weights = Input("cell_layer_norm_weights", "TENSOR_FLOAT32",
87 output_layer_norm_weights = Input("output_layer_norm_weights", "TENSOR_FLOAT32",
90 output = Output("output", "TENSOR_FLOAT32", "{%d, %d, %d}" % (max_time, n_batch, n_output))
92 model = model.Operation(
93 "UNIDIRECTIONAL_SEQUENCE_LSTM", input, input_to_input_weights, input_to_forget_weights,
94 input_to_cell_weights, input_to_output_weights, recurrent_to_input_weights,
95 recurrent_to_forget_weights, recurrent_to_cell_weights,
96 recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights,
97 cell_to_output_weights, input_gate_bias, forget_gate_bias, cell_gate_bias,
98 output_gate_bias, projection_weights, projection_bias, output_state_in,
99 cell_state_in, activation_param, cell_clip_param, proj_clip_param, time_major_param,
100 input_layer_norm_weights, forget_layer_norm_weights,
101 cell_layer_norm_weights, output_layer_norm_weights).To([output])
103 # Example 1. Input in operand 0,
105 input_to_input_weights: [
106 0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5, -0.8, 0.7, -0.6,
107 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1
109 input_to_forget_weights: [
110 -0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8, -0.4, 0.3, -0.5,
111 -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5
113 input_to_cell_weights: [
114 -0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6, 0.6, -0.1,
115 -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6
117 input_to_output_weights: [
118 -0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2, 0.6, -0.2,
119 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4
121 input_gate_bias: [0.03, 0.15, 0.22, 0.38],
122 forget_gate_bias: [0.1, -0.3, -0.2, 0.1],
123 cell_gate_bias: [-0.05, 0.72, 0.25, 0.08],
124 output_gate_bias: [0.05, -0.01, 0.2, 0.1],
125 recurrent_to_input_weights: [
126 -0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6
128 recurrent_to_cell_weights: [
129 -0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2
131 recurrent_to_forget_weights: [
132 -0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2
134 recurrent_to_output_weights: [
135 0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2
137 cell_to_input_weights: [0.05, 0.1, 0.25, 0.15],
138 cell_to_forget_weights: [-0.02, -0.15, -0.25, -0.03],
139 cell_to_output_weights: [0.1, -0.1, -0.5, 0.05],
140 projection_weights: [
141 -0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2
144 input_layer_norm_weights: [0.1, 0.2, 0.3, 0.5],
145 forget_layer_norm_weights: [0.2, 0.2, 0.4, 0.3],
146 cell_layer_norm_weights: [0.7, 0.2, 0.3, 0.8],
147 output_layer_norm_weights: [0.6, 0.2, 0.2, 0.5]
150 test_input = [0.7, 0.8, 0.1, 0.2, 0.3, 0.3, 0.2, 0.9, 0.8, 0.1]
153 0.024407668039203, 0.128027379512787, -0.001709178090096,
154 -0.006924282759428, 0.084874063730240, 0.063444979488850
158 output: golden_output,
161 input0[input] = test_input
162 input0[output_state_in] = [ 0 for _ in range(n_batch * n_output) ]
163 input0[cell_state_in] = [ 0 for _ in range(n_batch * n_cell) ]
165 Example((input0, output0))