From c713edd046c4eac1455b3f5ae50ce79a56e98341 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 30 Nov 2018 17:19:59 +0900 Subject: [PATCH] [enco/tfl] Validate stride (#2455) This commit introduces basic tfl model validation infrastructure to tfl frontend. The current implementation supports 'stride' validation for the following operations: - Conv2D - DepthwiseConv2D - MaxPool2D - AveragePool2D Signed-off-by: Jonghyun Park --- contrib/enco/frontend/tflite/src/Frontend.cpp | 5 +++++ contrib/enco/frontend/tflite/src/GraphBuilder.h | 7 +++++++ contrib/enco/frontend/tflite/src/Op/AveragePool2D.cpp | 12 ++++++++++++ contrib/enco/frontend/tflite/src/Op/AveragePool2D.h | 1 + contrib/enco/frontend/tflite/src/Op/Conv2D.cpp | 12 ++++++++++++ contrib/enco/frontend/tflite/src/Op/Conv2D.h | 1 + contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.cpp | 12 ++++++++++++ contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.h | 1 + contrib/enco/frontend/tflite/src/Op/MaxPool2D.cpp | 12 ++++++++++++ contrib/enco/frontend/tflite/src/Op/MaxPool2D.h | 1 + contrib/enco/frontend/tflite/src/Op/Padding.cpp | 2 ++ 11 files changed, 66 insertions(+) diff --git a/contrib/enco/frontend/tflite/src/Frontend.cpp b/contrib/enco/frontend/tflite/src/Frontend.cpp index 90620b3..5fe345d 100644 --- a/contrib/enco/frontend/tflite/src/Frontend.cpp +++ b/contrib/enco/frontend/tflite/src/Frontend.cpp @@ -127,6 +127,11 @@ enco::Bundle Frontend::load(void) const if (const auto *graph_builder = tflimport::GraphBuilderRegistry::get().lookup(builtincode)) { + if (!graph_builder->validate(op)) + { + throw std::runtime_error{"Invalid operator"}; + } + graph_builder->build(op, &opbuilder_context); } else diff --git a/contrib/enco/frontend/tflite/src/GraphBuilder.h b/contrib/enco/frontend/tflite/src/GraphBuilder.h index 613c5ff..f2cb578 100644 --- a/contrib/enco/frontend/tflite/src/GraphBuilder.h +++ b/contrib/enco/frontend/tflite/src/GraphBuilder.h @@ -30,6 +30,13 @@ namespace tflimport class GraphBuilder { public: + /** + * TODO Declare "validate" method as a pure virtual method + * + * Q: Is it possible to validate T/F Lite model only with this interface? + */ + virtual bool validate(const tflite::Operator *) const { return true; } + virtual void build(const tflite::Operator *op, GraphBuilderContext *context) const = 0; virtual ~GraphBuilder() {} }; diff --git a/contrib/enco/frontend/tflite/src/Op/AveragePool2D.cpp b/contrib/enco/frontend/tflite/src/Op/AveragePool2D.cpp index fbc334a..380237e 100644 --- a/contrib/enco/frontend/tflite/src/Op/AveragePool2D.cpp +++ b/contrib/enco/frontend/tflite/src/Op/AveragePool2D.cpp @@ -35,6 +35,18 @@ using namespace morph::tflite; namespace tflimport { +bool AvgPool2DGraphBuilder::validate(const tflite::Operator *op) const +{ + auto const options = op->builtin_options_as_Pool2DOptions(); + + if ((options->stride_h() == 0) || (options->stride_w() == 0)) + { + return false; + } + + return true; +} + void AvgPool2DGraphBuilder::build(const tflite::Operator *op, GraphBuilderContext *context) const { assert(context != nullptr); // check if init(..) is called diff --git a/contrib/enco/frontend/tflite/src/Op/AveragePool2D.h b/contrib/enco/frontend/tflite/src/Op/AveragePool2D.h index 1cd2b79..3e37e3c 100644 --- a/contrib/enco/frontend/tflite/src/Op/AveragePool2D.h +++ b/contrib/enco/frontend/tflite/src/Op/AveragePool2D.h @@ -30,6 +30,7 @@ namespace tflimport class AvgPool2DGraphBuilder : public GraphBuilder { public: + bool validate(const tflite::Operator *op) const override; void build(const tflite::Operator *op, GraphBuilderContext *) const override; }; diff --git a/contrib/enco/frontend/tflite/src/Op/Conv2D.cpp b/contrib/enco/frontend/tflite/src/Op/Conv2D.cpp index 874c9ec..784906f 100644 --- a/contrib/enco/frontend/tflite/src/Op/Conv2D.cpp +++ b/contrib/enco/frontend/tflite/src/Op/Conv2D.cpp @@ -36,6 +36,18 @@ using namespace morph::tflite; namespace tflimport { +bool Conv2DGraphBuilder::validate(const tflite::Operator *op) const +{ + auto const options = op->builtin_options_as_Conv2DOptions(); + + if ((options->stride_h() == 0) || (options->stride_w() == 0)) + { + return false; + } + + return true; +} + void Conv2DGraphBuilder::build(const tflite::Operator *op, GraphBuilderContext *context) const { assert(context != nullptr); diff --git a/contrib/enco/frontend/tflite/src/Op/Conv2D.h b/contrib/enco/frontend/tflite/src/Op/Conv2D.h index c2f5eab..018815b 100644 --- a/contrib/enco/frontend/tflite/src/Op/Conv2D.h +++ b/contrib/enco/frontend/tflite/src/Op/Conv2D.h @@ -30,6 +30,7 @@ namespace tflimport class Conv2DGraphBuilder : public GraphBuilder { public: + bool validate(const tflite::Operator *op) const override; void build(const tflite::Operator *op, GraphBuilderContext *context) const override; }; diff --git a/contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.cpp b/contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.cpp index 5f2a7f6..3b0319e 100644 --- a/contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.cpp +++ b/contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.cpp @@ -37,6 +37,18 @@ using namespace morph::tflite; namespace tflimport { +bool DepthwiseConv2DGraphBuilder::validate(const tflite::Operator *op) const +{ + auto const options = op->builtin_options_as_DepthwiseConv2DOptions(); + + if ((options->stride_h() == 0) || (options->stride_w() == 0)) + { + return false; + } + + return true; +} + void DepthwiseConv2DGraphBuilder::build(const tflite::Operator *op, GraphBuilderContext *context) const { diff --git a/contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.h b/contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.h index c9e65aa..b36b36b 100644 --- a/contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.h +++ b/contrib/enco/frontend/tflite/src/Op/DepthwiseConv2D.h @@ -30,6 +30,7 @@ namespace tflimport class DepthwiseConv2DGraphBuilder : public GraphBuilder { public: + bool validate(const tflite::Operator *op) const override; void build(const tflite::Operator *op, GraphBuilderContext *context) const override; }; diff --git a/contrib/enco/frontend/tflite/src/Op/MaxPool2D.cpp b/contrib/enco/frontend/tflite/src/Op/MaxPool2D.cpp index 176a177..ce6af2e 100644 --- a/contrib/enco/frontend/tflite/src/Op/MaxPool2D.cpp +++ b/contrib/enco/frontend/tflite/src/Op/MaxPool2D.cpp @@ -35,6 +35,18 @@ using namespace morph::tflite; namespace tflimport { +bool MaxPool2DGraphBuilder::validate(const tflite::Operator *op) const +{ + auto const options = op->builtin_options_as_Pool2DOptions(); + + if ((options->stride_h() == 0) || (options->stride_w() == 0)) + { + return false; + } + + return true; +} + void MaxPool2DGraphBuilder::build(const tflite::Operator *op, GraphBuilderContext *context) const { assert(context != nullptr); // check if init(..) is called diff --git a/contrib/enco/frontend/tflite/src/Op/MaxPool2D.h b/contrib/enco/frontend/tflite/src/Op/MaxPool2D.h index 095085f..06a8285 100644 --- a/contrib/enco/frontend/tflite/src/Op/MaxPool2D.h +++ b/contrib/enco/frontend/tflite/src/Op/MaxPool2D.h @@ -30,6 +30,7 @@ namespace tflimport class MaxPool2DGraphBuilder : public GraphBuilder { public: + bool validate(const tflite::Operator *op) const override; void build(const tflite::Operator *op, GraphBuilderContext *) const override; }; diff --git a/contrib/enco/frontend/tflite/src/Op/Padding.cpp b/contrib/enco/frontend/tflite/src/Op/Padding.cpp index f382490..3bd1f10 100644 --- a/contrib/enco/frontend/tflite/src/Op/Padding.cpp +++ b/contrib/enco/frontend/tflite/src/Op/Padding.cpp @@ -38,6 +38,8 @@ coco::Padding2D get_padding(const tensor::Shape &ifm_shape, const int kernel_w, tflite::Padding padding, int stride_w, int stride_h, int dilation_w_factor, int dilation_h_factor) { + assert(stride_w != 0); + assert(stride_h != 0); assert(ifm_shape.rank() == 4); /** -- 2.7.4