Support 1x1x1xN bias sizes in TFLite's convolution and FC layers.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 17 May 2018 18:47:16 +0000 (11:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 18:50:17 +0000 (11:50 -0700)
PiperOrigin-RevId: 197027135

tensorflow/contrib/lite/kernels/conv.cc
tensorflow/contrib/lite/kernels/fully_connected.cc

index 3b467b3..2b7e455 100644 (file)
@@ -212,8 +212,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     } else {
       TF_LITE_ENSURE_EQ(context, bias->type, data_type);
     }
-    TF_LITE_ENSURE_EQ(context, bias->dims->size, 1);
-    TF_LITE_ENSURE_EQ(context, bias->dims->data[0], filter->dims->data[0]);
+    TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
   }
 
   int channels_out = filter->dims->data[0];
index 1ba3064..a486b81 100644 (file)
@@ -106,11 +106,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 
   TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]);
   if (bias) {
-    TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units);
+    TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
   }
 
   TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
-  TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1);
 
   // Note that quantized inference requires that all tensors have their
   // parameters set. This is usually done during quantized training.