Fix validate check in Conv2d (#7937)
authorAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Mon, 7 Oct 2019 01:55:20 +0000 (04:55 +0300)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 7 Oct 2019 01:55:20 +0000 (10:55 +0900)
Do not accept conv2d operation with non one dilations

Signed-off-by: Efimov Alexander <a.efimov@samsung.com>
compiler/moco-tf/src/Op/Conv2D.cpp

index 7e011a7..323cf71 100644 (file)
@@ -37,6 +37,7 @@
 
 #include <cassert>
 #include <stdexcept>
+#include <algorithm>
 
 namespace
 {
@@ -131,7 +132,16 @@ bool Conv2DGraphBuilderBase::validate(const tensorflow::NodeDef &node) const
 
   // note: even though "data_format" is not entered when a model is written,
   //       TF seems to generate "data_format" field into a pb file
-  return plier::tf::has_attrs(node, {"T", "data_format", "dilations", "padding", "strides"});
+  bool has_mandatory_attrs = plier::tf::has_attrs(node, {"T", "data_format", "padding", "strides"});
+  // dilation attribute is not fully supported
+  bool supported_dilations = true;
+  if (plier::tf::has_attr(node, "dilations"))
+  {
+    auto dilation = plier::tf::get_list_attr(node, "dilations").i();
+    supported_dilations =
+        std::all_of(dilation.begin(), dilation.end(), [](std::int64_t dil) { return dil == 1; });
+  }
+  return has_mandatory_attrs && supported_dilations;
 }
 
 void Conv2DGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const