Add quantized uint8 L2Normalization Kernel.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 4 Apr 2018 21:55:16 +0000 (14:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 21:57:47 +0000 (14:57 -0700)
PiperOrigin-RevId: 191652174

tensorflow/contrib/lite/kernels/l2norm.cc
tensorflow/contrib/lite/kernels/l2norm_test.cc

index ee8bfe5..e67f4e0 100644 (file)
@@ -45,10 +45,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 
   TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
 
-  // TODO(ahentz): Our current implementations only support float32.
-  TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+  TF_LITE_ENSURE(
+      context, output->type == kTfLiteFloat32 || output->type == kTfLiteUInt8);
   TF_LITE_ENSURE_EQ(context, input->type, output->type);
 
+  if (output->type == kTfLiteUInt8) {
+    TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.));
+    TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128);
+  }
+
   // TODO(ahentz): For some reason our implementations don't support
   // activations.
   TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
@@ -75,6 +80,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
       TF_LITE_L2NORM(optimized_ops);
     }
 #undef TF_LITE_L2NORM
+  } else if (output->type == kTfLiteUInt8) {
+#define TF_LITE_L2NORM(type)                                               \
+  type::L2Normalization(GetTensorData<uint8>(input), GetTensorDims(input), \
+                        input->params.zero_point,                          \
+                        GetTensorData<uint8>(output), GetTensorDims(output))
+
+    if (kernel_type == kReference) {
+      TF_LITE_L2NORM(reference_ops);
+    }
+    if (kernel_type == kGenericOptimized) {
+      TF_LITE_L2NORM(optimized_ops);
+    }
+#undef TF_LITE_L2NORM
   } else {
     context->ReportError(context, "Inputs and outputs not all float types.");
     return kTfLiteError;
index 30e103f..042314c 100644 (file)
@@ -25,10 +25,22 @@ using ::testing::ElementsAreArray;
 
 class L2NormOpModel : public SingleOpModel {
  public:
-  L2NormOpModel(std::initializer_list<int> input_shape,
-                ActivationFunctionType activation_type) {
-    input_ = AddInput(TensorType_FLOAT32);
-    output_ = AddOutput(TensorType_FLOAT32);
+  L2NormOpModel(const std::initializer_list<int> input_shape,
+                const TensorType tensor_type,
+                const ActivationFunctionType activation_type) {
+    TensorData data = TensorData{tensor_type};
+    if (tensor_type != TensorType_FLOAT32) {
+      data.min = -2.0;
+      data.max = 2.0;
+      data.scale = 2.0;
+      data.zero_point = 128;
+    }
+    input_ = AddInput(data);
+    if (tensor_type != TensorType_FLOAT32) {
+      data.min = -1.0;
+      data.max = 127.0 / 128.0;
+    }
+    output_ = AddOutput(data);
     SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
                  CreateL2NormOptions(builder_, activation_type).Union());
     BuildInterpreter({input_shape});
@@ -38,7 +50,17 @@ class L2NormOpModel : public SingleOpModel {
     PopulateTensor(input_, data);
   }
 
-  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+  template <typename T>
+  std::vector<T> GetOutput() {
+    return ExtractVector<T>(output_);
+  }
+
+  std::vector<float> GetDequantizedOutput() {
+    return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+                               GetScale(output_), GetZeroPoint(output_));
+  }
+
+  int input() const { return input_; }
 
  private:
   int input_;
@@ -46,13 +68,26 @@ class L2NormOpModel : public SingleOpModel {
 };
 
 TEST(L2NormOpTest, SimpleTest) {
-  L2NormOpModel m({1, 1, 1, 6}, ActivationFunctionType_NONE);
+  L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
+                  ActivationFunctionType_NONE);
   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
   m.Invoke();
-  EXPECT_THAT(m.GetOutput(),
+  EXPECT_THAT(m.GetOutput<float>(),
               ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
 }
 
+TEST(L2NormOpTest, SimpleUint8Test) {
+  L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
+
+  m.QuantizeAndPopulate<uint8_t>(m.input(), {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput<uint8_t>(),
+              ElementsAreArray({58, 166, 173, 205, 83, 134}));
+  EXPECT_THAT(m.GetDequantizedOutput(),
+              ElementsAreArray(
+                  ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
+}
+
 }  // namespace
 }  // namespace tflite