[enco/tfl] Validate stride (#2455)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 30 Nov 2018 08:19:59 +0000 (17:19 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 30 Nov 2018 08:19:59 +0000 (17:19 +0900)
This commit introduces basic tfl model validation infrastructure to tfl
frontend.

The current implementation supports 'stride' validation for the following
operations:
 - Conv2D
 - DepthwiseConv2D
 - MaxPool2D
 - AveragePool2D

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/enco/frontend/tflite/src/Frontend.cpp
contrib/enco/frontend/tflite/src/GraphBuilder.h
contrib/enco/frontend/tflite/src/Op/AveragePool2D.cpp
contrib/enco/frontend/tflite/src/Op/AveragePool2D.h
contrib/enco/frontend/tflite/src/Op/Conv2D.cpp
contrib/enco/frontend/tflite/src/Op/Conv2D.h
contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.cpp
contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.h
contrib/enco/frontend/tflite/src/Op/MaxPool2D.cpp
contrib/enco/frontend/tflite/src/Op/MaxPool2D.h
contrib/enco/frontend/tflite/src/Op/Padding.cpp

index 90620b3..5fe345d 100644 (file)
@@ -127,6 +127,11 @@ enco::Bundle Frontend::load(void) const
 
     if (const auto *graph_builder = tflimport::GraphBuilderRegistry::get().lookup(builtincode))
     {
+      if (!graph_builder->validate(op))
+      {
+        throw std::runtime_error{"Invalid operator"};
+      }
+
       graph_builder->build(op, &opbuilder_context);
     }
     else
index 613c5ff..f2cb578 100644 (file)
@@ -30,6 +30,13 @@ namespace tflimport
 class GraphBuilder
 {
 public:
+  /**
+   * TODO Declare "validate" method as a pure virtual method
+   *
+   * Q: Is it possible to validate T/F Lite model only with this interface?
+   */
+  virtual bool validate(const tflite::Operator *) const { return true; }
+
   virtual void build(const tflite::Operator *op, GraphBuilderContext *context) const = 0;
   virtual ~GraphBuilder() {}
 };
index fbc334a..380237e 100644 (file)
@@ -35,6 +35,18 @@ using namespace morph::tflite;
 namespace tflimport
 {
 
+bool AvgPool2DGraphBuilder::validate(const tflite::Operator *op) const
+{
+  auto const options = op->builtin_options_as_Pool2DOptions();
+
+  if ((options->stride_h() == 0) || (options->stride_w() == 0))
+  {
+    return false;
+  }
+
+  return true;
+}
+
 void AvgPool2DGraphBuilder::build(const tflite::Operator *op, GraphBuilderContext *context) const
 {
   assert(context != nullptr); // check if init(..) is called
index 1cd2b79..3e37e3c 100644 (file)
@@ -30,6 +30,7 @@ namespace tflimport
 class AvgPool2DGraphBuilder : public GraphBuilder
 {
 public:
+  bool validate(const tflite::Operator *op) const override;
   void build(const tflite::Operator *op, GraphBuilderContext *) const override;
 };
 
index 874c9ec..784906f 100644 (file)
@@ -36,6 +36,18 @@ using namespace morph::tflite;
 namespace tflimport
 {
 
+bool Conv2DGraphBuilder::validate(const tflite::Operator *op) const
+{
+  auto const options = op->builtin_options_as_Conv2DOptions();
+
+  if ((options->stride_h() == 0) || (options->stride_w() == 0))
+  {
+    return false;
+  }
+
+  return true;
+}
+
 void Conv2DGraphBuilder::build(const tflite::Operator *op, GraphBuilderContext *context) const
 {
   assert(context != nullptr);
index c2f5eab..018815b 100644 (file)
@@ -30,6 +30,7 @@ namespace tflimport
 class Conv2DGraphBuilder : public GraphBuilder
 {
 public:
+  bool validate(const tflite::Operator *op) const override;
   void build(const tflite::Operator *op, GraphBuilderContext *context) const override;
 };
 
index 5f2a7f6..3b0319e 100644 (file)
@@ -37,6 +37,18 @@ using namespace morph::tflite;
 namespace tflimport
 {
 
+bool DepthwiseConv2DGraphBuilder::validate(const tflite::Operator *op) const
+{
+  auto const options = op->builtin_options_as_DepthwiseConv2DOptions();
+
+  if ((options->stride_h() == 0) || (options->stride_w() == 0))
+  {
+    return false;
+  }
+
+  return true;
+}
+
 void DepthwiseConv2DGraphBuilder::build(const tflite::Operator *op,
                                         GraphBuilderContext *context) const
 {
index c9e65aa..b36b36b 100644 (file)
@@ -30,6 +30,7 @@ namespace tflimport
 class DepthwiseConv2DGraphBuilder : public GraphBuilder
 {
 public:
+  bool validate(const tflite::Operator *op) const override;
   void build(const tflite::Operator *op, GraphBuilderContext *context) const override;
 };
 
index 176a177..ce6af2e 100644 (file)
@@ -35,6 +35,18 @@ using namespace morph::tflite;
 namespace tflimport
 {
 
+bool MaxPool2DGraphBuilder::validate(const tflite::Operator *op) const
+{
+  auto const options = op->builtin_options_as_Pool2DOptions();
+
+  if ((options->stride_h() == 0) || (options->stride_w() == 0))
+  {
+    return false;
+  }
+
+  return true;
+}
+
 void MaxPool2DGraphBuilder::build(const tflite::Operator *op, GraphBuilderContext *context) const
 {
   assert(context != nullptr); // check if init(..) is called
index 095085f..06a8285 100644 (file)
@@ -30,6 +30,7 @@ namespace tflimport
 class MaxPool2DGraphBuilder : public GraphBuilder
 {
 public:
+  bool validate(const tflite::Operator *op) const override;
   void build(const tflite::Operator *op, GraphBuilderContext *) const override;
 };
 
index f382490..3bd1f10 100644 (file)
@@ -38,6 +38,8 @@ coco::Padding2D get_padding(const tensor::Shape &ifm_shape, const int kernel_w,
                             tflite::Padding padding, int stride_w, int stride_h,
                             int dilation_w_factor, int dilation_h_factor)
 {
+  assert(stride_w != 0);
+  assert(stride_h != 0);
   assert(ifm_shape.rank() == 4);
 
   /**