// See the License for the specific language governing permissions and
// limitations under the License.
-#include "include/common.cl"
+#include "include/fetch.cl"
#include "include/data_types.cl"
{
for (uint x = 0; x < INPUT0_SIZE_X; x++)
{
+#if INPUT0_LAYOUT_BFZYX_F16
+ input_idx = GET_DATA_BFZYX_F16_INDEX(INPUT0, b, f, z, y, x);
+ mean += (float)input[input_idx];
+ }
+ }
+#else
mean += (float)input[input_idx];
input_idx += INPUT0_X_PITCH;
}
input_idx += INPUT0_Y_PITCH - INPUT0_SIZE_X*INPUT0_X_PITCH;
}
input_idx += INPUT0_Z_PITCH - INPUT0_SIZE_Y*INPUT0_Y_PITCH;
+#endif
}
mean /= INPUT0_SIZE_X * INPUT0_SIZE_Y * INPUT0_SIZE_Z;
+#if INPUT0_LAYOUT_BFZYX_F16
+ uint output_idx;
+#else
uint output_idx = OUTPUT_OFFSET + b * OUTPUT_BATCH_PITCH + f * OUTPUT_FEATURE_PITCH;
-
+#endif
#if NORMALIZE_VARIANCE == 0
//subtract mean
input_idx = input_first;
{
for (uint x = 0; x < INPUT0_SIZE_X; x++)
{
+#if INPUT0_LAYOUT_BFZYX_F16
+ input_idx = GET_DATA_BFZYX_F16_INDEX(INPUT0, b, f, z, y, x);
+ output_idx = GET_DATA_BFZYX_F16_INDEX(OUTPUT, b, f, z, y, x);
+ output[output_idx] = ACTIVATION(input[input_idx] - UNIT_CVT_FUNC(mean), ACTIVATION_PARAMS);
+ }
+ }
+#else
output[output_idx] = ACTIVATION(input[input_idx] - UNIT_CVT_FUNC(mean), ACTIVATION_PARAMS);
input_idx += INPUT0_X_PITCH;
output_idx += OUTPUT_X_PITCH;
}
input_idx += INPUT0_Z_PITCH - INPUT0_SIZE_Y*INPUT0_Y_PITCH;
output_idx += OUTPUT_Z_PITCH - INPUT0_SIZE_Y*OUTPUT_Y_PITCH;
-
+#endif
}
#else //NORMALIZE_VARIANCE
float variance = 0.f;
{
for (uint x = 0; x < INPUT0_SIZE_X; x++)
{
+#if INPUT0_LAYOUT_BFZYX_F16
+ input_idx = GET_DATA_BFZYX_F16_INDEX(INPUT0, b, f, z, y, x);
+ float res = (float)input[input_idx] - mean;
+ variance = fma(res, res, variance);
+ }
+ }
+#else
float res = (float)input[input_idx] - mean;
variance = fma(res, res, variance);
input_idx += INPUT0_X_PITCH;
input_idx += INPUT0_Y_PITCH - INPUT0_SIZE_X*INPUT0_X_PITCH;
}
input_idx += INPUT0_Z_PITCH - INPUT0_SIZE_Y*INPUT0_Y_PITCH;
+#endif
}
//normalize variance
{
for (uint x = 0; x < INPUT0_SIZE_X; x++)
{
+#if INPUT0_LAYOUT_BFZYX_F16
+ input_idx = GET_DATA_BFZYX_F16_INDEX(INPUT0, b, f, z, y, x);
+ output_idx = GET_DATA_BFZYX_F16_INDEX(OUTPUT, b, f, z, y, x);
+ output[output_idx] = ACTIVATION((input[input_idx] - UNIT_CVT_FUNC(mean)) * UNIT_CVT_FUNC(variance), ACTIVATION_PARAMS);
+ }
+ }
+#else
output[output_idx] = ACTIVATION((input[input_idx] - UNIT_CVT_FUNC(mean)) * UNIT_CVT_FUNC(variance), ACTIVATION_PARAMS);
input_idx += INPUT0_X_PITCH;
output_idx += OUTPUT_X_PITCH;
}
input_idx += INPUT0_Z_PITCH - INPUT0_SIZE_Y*INPUT0_Y_PITCH;
output_idx += OUTPUT_Z_PITCH - INPUT0_SIZE_Y*OUTPUT_Y_PITCH;
+#endif
}
#endif
}