[moco/tf] Introduce TFMul IR (#4061)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 3 Jul 2019 01:40:15 +0000 (10:40 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 3 Jul 2019 01:40:15 +0000 (10:40 +0900)
This will add TFMul IR and related codes to support TensorFlow Mul node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco-tf/src/Dialect/TFNodes.h
contrib/moco-tf/src/Dialect/TFNodes.lst
contrib/moco-tf/src/IR/TFMul.h [new file with mode: 0644]
contrib/moco-tf/src/IR/TFMul.test.cpp [new file with mode: 0644]
contrib/moco-tf/src/Transforms/FixPaddingTransform.cpp
contrib/moco-tf/src/Transforms/FixShapeTransform.cpp

index 3366cb8..7a7fd62 100644 (file)
@@ -19,5 +19,6 @@
 
 #include "IR/TFAdd.h"
 #include "IR/TFFusedBatchNorm.h"
+#include "IR/TFMul.h"
 
 #endif // __MOCO_TF_DIALECT_TFNODES_H__
index eb52619..fcd00bc 100644 (file)
@@ -9,3 +9,4 @@
 // TENSORFLOW_NODE(OPCODE, CLASS)
 TENSORFLOW_NODE(Add, TFAdd)
 TENSORFLOW_NODE(FusedBatchNorm, TFFusedBatchNorm)
+TENSORFLOW_NODE(Mul, TFMul)
diff --git a/contrib/moco-tf/src/IR/TFMul.h b/contrib/moco-tf/src/IR/TFMul.h
new file mode 100644 (file)
index 0000000..5390612
--- /dev/null
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFMUL_H__
+#define __MOCO_TF_IR_TFMUL_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFMul corresponds to the following GraphDef
+/*
+node {
+  name: "mul"
+  op: "Mul"
+  input: "x"
+  input: "y"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+*/
+
+class TFMul final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::Mul>>
+{
+public:
+  TFMul() = default;
+
+public:
+  Node *x(void) const { return at(0)->node(); }
+  void x(Node *node) { at(0)->node(node); }
+
+  Node *y(void) const { return at(1)->node(); }
+  void y(Node *node) { at(1)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFMUL_H__
diff --git a/contrib/moco-tf/src/IR/TFMul.test.cpp b/contrib/moco-tf/src/IR/TFMul.test.cpp
new file mode 100644 (file)
index 0000000..cc7c588
--- /dev/null
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFMul.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFMulTest, constructor)
+{
+  moco::tf::TFMul mul_node;
+
+  ASSERT_EQ(mul_node.dialect(), moco::tf::TFDialect::get());
+  ASSERT_EQ(mul_node.opcode(), moco::tf::TFOpcode::Mul);
+
+  ASSERT_EQ(mul_node.x(), nullptr);
+  ASSERT_EQ(mul_node.y(), nullptr);
+}
index 0a6949d..8c9c358 100644 (file)
@@ -324,6 +324,12 @@ bool fix_padding(moco::tf::TFFusedBatchNorm *node)
   return false;
 }
 
+bool fix_padding(moco::tf::TFMul *node)
+{
+  // Nothing to do with padding
+  return false;
+}
+
 } // namespace
 
 namespace moco
index 79a50c5..34cf0eb 100644 (file)
@@ -563,6 +563,22 @@ bool fix_shape(moco::tf::TFFusedBatchNorm *node)
   return copy_shapedata(input, node);
 }
 
+bool fix_shape(moco::tf::TFMul *node)
+{
+  auto x = node->x();
+  auto y = node->y();
+  auto x_shapedata = x->annot<ShapeInferenceData>();
+  auto y_shapedata = y->annot<ShapeInferenceData>();
+  if (x_shapedata == nullptr || y_shapedata == nullptr)
+  {
+    return false;
+  }
+  // TODO check shape difference
+
+  // Output shape is same as the input
+  return copy_shapedata(x, node);
+}
+
 } // namespace
 
 namespace moco