2 # Copyright (C) 2017 The Android Open Source Project
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
23 input = Input("input", "TENSOR_FLOAT32", "{%d, %d}" % (batches, input_size))
24 weights = Input("weights", "TENSOR_FLOAT32", "{%d, %d}" % (units, input_size))
25 recurrent_weights = Input("recurrent_weights", "TENSOR_FLOAT32", "{%d, %d}" % (units, units))
26 bias = Input("bias", "TENSOR_FLOAT32", "{%d}" % (units))
27 hidden_state_in = Input("hidden_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (batches, units))
29 activation_param = Int32Scalar("activation_param", 1) # Relu
31 hidden_state_out = IgnoredOutput("hidden_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (batches, units))
32 output = Output("output", "TENSOR_FLOAT32", "{%d, %d}" % (batches, units))
34 model = model.Operation("RNN", input, weights, recurrent_weights, bias, hidden_state_in,
35 activation_param).To([hidden_state_out, output])
39 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
40 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
41 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
42 -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
43 -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
44 -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
45 -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
46 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
47 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
48 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
49 -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
50 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
51 -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
52 -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
53 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
54 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
55 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
56 -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
57 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
58 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
59 -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
63 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
64 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
65 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
66 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
67 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
68 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
69 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
70 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
71 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
72 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
73 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
74 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
75 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
76 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
77 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
81 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
82 -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
83 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
90 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133,
91 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471,
92 -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222,
93 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933,
94 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103,
95 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
96 -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007,
97 -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154,
98 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584,
99 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144,
100 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351,
101 -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
102 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567,
103 -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881,
104 -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032,
105 -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374,
106 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071,
107 -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
108 -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682,
109 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493,
110 -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265,
111 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539,
112 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446,
113 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
114 -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563,
115 0.93455386, -0.6324693, -0.083922029
119 0.496726, 0, 0.965996, 0, 0.0584254, 0,
120 0, 0.12315, 0, 0, 0.612266, 0.456601,
121 0, 0.52286, 1.16099, 0.0291232,
123 0, 0, 0.524901, 0, 0, 0,
124 0, 1.02116, 0, 1.35762, 0, 0.356909,
125 0.436415, 0.0355727, 0, 0,
127 0, 0, 0, 0.262335, 0, 0,
128 0, 1.33992, 0, 2.9739, 0, 0,
129 1.31914, 2.66147, 0, 0,
131 0.942568, 0, 0, 0, 0.025507, 0,
132 0, 0, 0.321429, 0.569141, 1.25274, 1.57719,
133 0.8158, 1.21805, 0.586239, 0.25427,
135 1.04436, 0, 0.630725, 0, 0.133801, 0.210693,
136 0.363026, 0, 0.533426, 0, 1.25926, 0.722707,
137 0, 1.22031, 1.30117, 0.495867,
139 0.222187, 0, 0.72725, 0, 0.767003, 0,
140 0, 0.147835, 0, 0, 0, 0.608758,
141 0.469394, 0.00720298, 0.927537, 0,
143 0.856974, 0.424257, 0, 0, 0.937329, 0,
144 0, 0, 0.476425, 0, 0.566017, 0.418462,
145 0.141911, 0.996214, 1.13063, 0,
147 0.967899, 0, 0, 0, 0.0831304, 0,
148 0, 1.00378, 0, 0, 0, 1.44818,
149 1.01768, 0.943891, 0.502745, 0,
151 0.940135, 0, 0, 0, 0, 0,
152 0, 2.13243, 0, 0.71208, 0.123918, 1.53907,
153 1.30225, 1.59644, 0.70222, 0,
155 0.804329, 0, 0.430576, 0, 0.505872, 0.509603,
156 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311,
157 0.0454298, 0.300267, 0.562784, 0.395095,
159 0.228154, 0, 0.675323, 0, 1.70536, 0.766217,
160 0, 0, 0, 0.735363, 0.0759267, 1.91017,
163 0, 0, 1.5909, 0, 0, 0,
164 0, 0.5755, 0, 0.184687, 0, 1.56296,
167 0, 0, 0.0857888, 0, 0, 0,
168 0, 0.488383, 0.252786, 0, 0, 0,
169 1.02817, 1.85665, 0, 0,
171 0.00981836, 0, 1.06371, 0, 0, 0,
172 0, 0, 0, 0.290445, 0.316406, 0,
173 0.304161, 1.25079, 0.0707152, 0,
175 0.986264, 0.309201, 0, 0, 0, 0,
176 0, 1.64896, 0.346248, 0, 0.918175, 0.78884,
177 0.524981, 1.92076, 2.07013, 0.333244,
179 0.415153, 0.210318, 0, 0, 0, 0,
180 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
181 0.628881, 3.58099, 1.49974, 0
184 input_sequence_size = int(len(test_inputs) / input_size / batches)
186 # TODO: enable the other data points after fixing reference issues
187 #for i in range(input_sequence_size):
189 input_begin = i * input_size
190 input_end = input_begin + input_size
191 input0[input] = test_inputs[input_begin:input_end]
192 input0[input].extend(input0[input])
193 input0[hidden_state_in] = [0 for x in range(batches * units)]
195 hidden_state_out: [0 for x in range(batches * units)],
197 golden_start = i * units
198 golden_end = golden_start + units
199 output0[output] = golden_outputs[golden_start:golden_end]
200 output0[output].extend(output0[output])
201 Example((input0, output0))