[neurun] Introduce InstanceNorm op (#8526)
author여지환/On-Device Lab(SR)/Staff Engineer/삼성전자 <jihwan.yeo@samsung.com>
Fri, 1 Nov 2019 07:23:55 +0000 (16:23 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Fri, 1 Nov 2019 07:23:55 +0000 (16:23 +0900)
This commit introduces InstanceNorm op into runtime frontend.
  - Introduce InstanceNormNode into neurun
  - Introduce InstanceNorm op into frontend
  - Use default value if value of epsilon is zero

Signed-off-by: JiHwan Yeo <jihwan.yeo@samsung.com>
runtime/neurun/core/include/model/Operations.Include.h
runtime/neurun/core/include/model/Operations.lst
runtime/neurun/core/include/model/operation/InstanceNorm.h [new file with mode: 0644]
runtime/neurun/core/src/compiler/OperationValidator.cc
runtime/neurun/core/src/compiler/OperationValidator.h
runtime/neurun/core/src/graph/dumper/Dumper.cc
runtime/neurun/core/src/graph/dumper/Dumper.h
runtime/neurun/core/src/model/operation/InstanceNorm.cc [new file with mode: 0644]
runtime/neurun/frontend/base_loader/base_loader.h
runtime/neurun/frontend/circle/circle_loader.cc

index c778324..1bdcbf2 100644 (file)
@@ -57,6 +57,7 @@
 #include "operation/EmbeddingLookup.h"
 #include "operation/L2Normalization.h"
 #include "operation/HashtableLookup.h"
+#include "operation/InstanceNorm.h"
 #include "operation/PReLU.h"
 #include "operation/TransposeConv.h"
 #include "operation/SQRT.h"
index e4ba336..c44c232 100644 (file)
@@ -61,6 +61,7 @@ OP(L2Pool2D                   , true)
 OP(EmbeddingLookup            , true)
 OP(L2Normalization            , true)
 OP(HashtableLookup            , true)
+OP(InstanceNorm               , true)
 OP(PReLU                      , true)
 OP(TransposeConv              , true)
 OP(SQRT                       , true)
diff --git a/runtime/neurun/core/include/model/operation/InstanceNorm.h b/runtime/neurun/core/include/model/operation/InstanceNorm.h
new file mode 100644 (file)
index 0000000..882848d
--- /dev/null
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NEURUN_MODEL_OPERATION_INSTANCE_NORM_H__
+#define __NEURUN_MODEL_OPERATION_INSTANCE_NORM_H__
+
+#include "model/Operation.h"
+#include "model/InternalType.h"
+
+namespace neurun
+{
+namespace model
+{
+namespace operation
+{
+
+class InstanceNorm : public model::Operation
+{
+public:
+  enum Input
+  {
+    INPUT = 0,
+    GAMMA,
+    BETA
+  };
+
+  struct Param
+  {
+    Activation activation;
+    float epsilon;
+  };
+
+public:
+  InstanceNorm(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
+               const Param &param);
+
+public:
+  void accept(OperationVisitor &v) const override;
+  std::string getName() const override { return "InstanceNorm"; }
+
+public:
+  const Param &param() const { return _param; }
+
+private:
+  Param _param;
+};
+
+} // namespace operation
+} // namespace model
+} // namespace neurun
+
+#endif // __NEURUN_MODEL_OPERATION_INSTANCE_NORM_H__
index 656fc3a..4ad8569 100644 (file)
@@ -103,6 +103,24 @@ void OperationValidator::visit(const model::operation::Softmax &node)
   assert(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
 }
 
+void OperationValidator::visit(const model::operation::InstanceNorm &node)
+{
+  const auto ofm_index{node.getOutputs().at(0)};
+  const auto ifm_index{node.getInputs().at(model::operation::InstanceNorm::Input::INPUT)};
+  const auto gamma_index{node.getInputs().at(model::operation::InstanceNorm::Input::GAMMA)};
+  const auto beta_index{node.getInputs().at(model::operation::InstanceNorm::Input::BETA)};
+
+  UNUSED_RELEASE(ofm_index);
+  UNUSED_RELEASE(ifm_index);
+  UNUSED_RELEASE(gamma_index);
+  UNUSED_RELEASE(beta_index);
+
+  assert(_ctx.at(ifm_index).shape().rank() == 4);
+  assert(_ctx.at(ifm_index).shape() == _ctx.at(ofm_index).shape());
+  assert(_ctx.at(gamma_index).shape().rank() == 1);
+  assert(_ctx.at(beta_index).shape().rank() == 1);
+}
+
 void OperationValidator::visit(const model::operation::Permute &node)
 {
   VERBOSE(Permute) << "Configure Permute operation" << std::endl;
index d007116..fd8b406 100644 (file)
@@ -25,7 +25,7 @@ namespace neurun
 namespace model
 {
 class Operands;
-} // namespace graph
+} // namespace model
 } // namespace neurun
 
 namespace neurun
@@ -47,6 +47,7 @@ public:
   void visit(const model::operation::Cast &node) override;
   void visit(const model::operation::Comparison &node) override;
   void visit(const model::operation::Softmax &node) override;
+  void visit(const model::operation::InstanceNorm &node) override;
   void visit(const model::operation::Permute &node) override;
   void visit(const model::operation::ReduceSum &node) override;
   void visit(const model::operation::Transpose &node) override;
index 04bc34b..d725ac1 100644 (file)
@@ -193,6 +193,16 @@ void Dumper::visit(const HashtableLookup &node)
                << node.getInputs().at(HashtableLookup::Output::HITS).value() << ")" << std::endl;
 }
 
+void Dumper::visit(const InstanceNorm &node)
+{
+  VERBOSE(LIR) << "* InstanceNorm" << std::endl;
+  VERBOSE(LIR) << "  - Inputs : IFM(" << node.getInputs().at(InstanceNorm::Input::INPUT).value()
+               << ") Gamma(" << node.getInputs().at(InstanceNorm::Input::GAMMA).value() << ") Beta("
+               << node.getInputs().at(InstanceNorm::Input::BETA).value() << ") Epsilon("
+               << node.param().epsilon << ")" << std::endl;
+  VERBOSE(LIR) << "  - Output : OFM(" << node.getOutputs().at(0).value() << ")" << std::endl;
+}
+
 void Dumper::visit(const L2Normalization &node)
 {
   VERBOSE(LIR) << "* L2Normalization" << std::endl;
index c6bb7ec..db8c57e 100644 (file)
@@ -50,6 +50,7 @@ public:
   void visit(const model::operation::FullyConnected &node) override;
   void visit(const model::operation::Gather &) override;
   void visit(const model::operation::HashtableLookup &) override;
+  void visit(const model::operation::InstanceNorm &) override;
   void visit(const model::operation::L2Normalization &) override;
   void visit(const model::operation::L2Pool2D &) override;
   void visit(const model::operation::LocalResponseNormalization &) override;
diff --git a/runtime/neurun/core/src/model/operation/InstanceNorm.cc b/runtime/neurun/core/src/model/operation/InstanceNorm.cc
new file mode 100644 (file)
index 0000000..66ffe77
--- /dev/null
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "model/operation/InstanceNorm.h"
+
+#include <cassert>
+
+#include "model/OperationVisitor.h"
+
+namespace neurun
+{
+namespace model
+{
+namespace operation
+{
+
+void InstanceNorm::accept(OperationVisitor &v) const { v.visit(*this); }
+
+InstanceNorm::InstanceNorm(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
+                           const Param &param)
+    : model::Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
+{
+}
+
+} // namespace operation
+} // namespace model
+} // namespace neurun
index 57d0b98..eb9923b 100644 (file)
@@ -90,6 +90,7 @@ protected:
   void loadSoftmax(const Operator *op);
   void loadMaxPool2D(const Operator *op);
   void loadConcatenation(const Operator *op);
+  void loadInstanceNorm(const Operator *op);
   void loadFC(const Operator *op);
   void loadAdd(const Operator *op);
   void loadSub(const Operator *op);
@@ -418,6 +419,29 @@ void BaseLoader<LoaderDomain, SpecificLoader>::loadConcatenation(const Operator
 }
 
 template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadInstanceNorm(const Operator *op)
+{
+  // This runtime_error will be removed if the one of backend supports this operation
+  throw std::runtime_error("NYI");
+
+  model::OperandIndexSequence inputs;
+  model::OperandIndexSequence outputs;
+
+  loadOperationIO(op, inputs, outputs);
+
+  model::operation::InstanceNorm::Param param;
+  const auto *options = op->builtin_options_as_InstanceNormOptions();
+
+  param.activation = convertActivation(options->fused_activation_function());
+  // Use default value 1e-5 if value of epsilon is zero
+  param.epsilon = options->epsilon() == 0.f ? 1e-5 : options->epsilon();
+
+  std::unique_ptr<model::Operation> new_op(
+      new model::operation::InstanceNorm(inputs, outputs, param));
+  _graph.addOperation(std::move(new_op));
+}
+
+template <typename LoaderDomain, typename SpecificLoader>
 void BaseLoader<LoaderDomain, SpecificLoader>::loadFC(const Operator *op)
 {
   model::OperandIndexSequence inputs;
index 18669cf..2383a17 100644 (file)
@@ -77,12 +77,27 @@ public:
     // Create operations
     for (const auto *op : *subgraph->operators())
     {
-      loadOperation(op);
+      CircleLoader::loadOperation(op);
     }
 
     (void)subgraph->data_format();
   }
 
+  void loadOperation(const circle::Operator *op)
+  {
+    const auto builtin_op = _op_code_to_builtin_op[op->opcode_index()];
+
+    switch (builtin_op)
+    {
+      case circle::BuiltinOperator::BuiltinOperator_INSTANCE_NORM:
+        loadInstanceNorm(op);
+        return;
+      default:
+        loadOperation(op);
+        return;
+    }
+  }
+
   void verify()
   {
     flatbuffers::Verifier verifier(reinterpret_cast<const std::uint8_t *>(_buffer.data()),