Fix CBLAS Conv reference implementation in TFLite.
authorYu-Cheng Ling <ycling@google.com>
Mon, 5 Feb 2018 23:05:43 +0000 (15:05 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 23:09:51 +0000 (15:09 -0800)
PiperOrigin-RevId: 184592951

tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h

index 6578915..6acc513 100644 (file)
@@ -49,9 +49,12 @@ void cblas_sgemm(const enum CBLAS_ORDER order,
   TFLITE_DCHECK(order == CblasRowMajor);
   TFLITE_DCHECK(trans_a == CblasNoTrans);
   TFLITE_DCHECK(trans_b == CblasTrans);
+  TFLITE_DCHECK(beta == 0.0f);
   for (int row = 0; row < m; ++row) {
     for (int col = 0; col < n; ++col) {
-      float value = beta * c[stride_c * row + col];
+      // If `beta` non-zero, multiple it with the original values in output.
+      // Otherwise, ignore the original value in output completely.
+      float value = 0.0f;
       for (int idx = 0; idx < k; ++idx) {
         value += alpha * a[stride_a * row + idx] * b[stride_b * col + idx];
       }