Support missing Caffe Ops in Visitors (#1396)
authorАндрей Шедько/AI Tools Lab /SRR/Assistant Engineer/삼성전자 <a.shedko@partner.samsung.com>
Mon, 10 Sep 2018 14:30:28 +0000 (17:30 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Mon, 10 Sep 2018 14:30:28 +0000 (17:30 +0300)
Added stubs to Visitors for Missing Caffe Ops.
ShapeInference, DotDumper, Interpreter now have skeleton implementations

Signed-off-by: Andrei Shedko <a.shedko@partner.samsung.com>
13 files changed:
contrib/nnc/core/modelIR/ShapeInference.cpp
contrib/nnc/core/modelIR/ir_dot_dumper.cpp
contrib/nnc/core/modelIR/visitor.cpp
contrib/nnc/include/core/modelIR/ShapeInference.h
contrib/nnc/include/core/modelIR/ir_dot_dumper.h
contrib/nnc/include/core/modelIR/ir_dot_node_info.h
contrib/nnc/include/core/modelIR/visitor.h
contrib/nnc/include/plugin/interpreter/Interpreter.h
contrib/nnc/plugin/interpreter/Interpreter.cpp
contrib/nnc/plugin/soft_backend/model_analyzer.cpp
contrib/nnc/plugin/soft_backend/model_analyzer.h
contrib/nnc/plugin/soft_backend/serializer.cpp
contrib/nnc/plugin/soft_backend/serializer.h

index c9fc137..844531e 100644 (file)
@@ -13,6 +13,9 @@
 #include "core/modelIR/operations/concat_op.h"
 #include "core/modelIR/operations/bias_add_op.h"
 #include "core/modelIR/operations/reshape_op.h"
+#include "core/modelIR/operations/batch_norm.h"
+#include "core/modelIR/operations/scale_op.h"
+#include "core/modelIR/operations/dropout_op.h"
 
 namespace nncc
 {
@@ -249,6 +252,24 @@ void ShapeInference::visit(ADT::INode::Ref node, ops::ReshapeOp &op)
   op.setOutputShape(0, outShape);
 }
 
+void ShapeInference::visit(ADT::INode::Ref node, ops::ScaleOp &op)
+{
+  fillInputShapes(node, op);
+  op.setOutputShape(0, op.getInputShape(0));
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::DropoutOp &op)
+{
+  fillInputShapes(node, op);
+  op.setOutputShape(0, op.getInputShape(0));
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::BatchNormOp &op)
+{
+  fillInputShapes(node, op);
+  op.setOutputShape(0, op.getInputShape(0));
+}
+
 } // namespace model
 } // namespace IR
 } // namespace core
index 72fbc68..45336ee 100644 (file)
@@ -152,6 +152,35 @@ void IrDotDumper::visit(INode *node, ops::VariableOp &op)
   dotBuilder.updateWithNode(node, nodeInfo);
 }
 
+void IrDotDumper::visit(INode *node, ops::BatchNormOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("BatchNorm", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op))
+                                 .withMisc("Moving Average Fraction", op.getMovingAvgFraction())
+                                 .withMisc("Eps", op.getEps())
+                                 .withMisc("Spatial", op.getSpatial());
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+void IrDotDumper::visit(INode *node, ops::ScaleOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("ScaleOp", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op))
+                                 .withShape("Scale Tensor", op.getWeights().getShape());
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+void IrDotDumper::visit(INode *node, ops::DropoutOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("DropoutOp", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op))
+                                 .withMisc("DropRate",op.getRate());
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
 } // namespace dumper
 } // namespace core
 } // namespace contrib
index 090c96a..1c13332 100644 (file)
@@ -1,3 +1,6 @@
+
+#include <core/modelIR/visitor.h>
+
 #include "core/modelIR/visitor.h"
 
 namespace nncc {
@@ -17,6 +20,9 @@ void Visitor::visit(ADT::INode *node, ops::BiasAddOp &op) {(void)node; (void)op;
 void Visitor::visit(ADT::INode *node, ops::VariableOp &op) {(void)node; (void)op;};
 void Visitor::visit(ADT::INode *node, ops::ReluOp &op) {(void)node; (void)op;};
 void Visitor::visit(ADT::INode *node, ops::ReshapeOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::ScaleOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::BatchNormOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::DropoutOp &op) {(void)node; (void)op;};
 
 } // namespace model
 } // namespace IR
index 18ade01..3872922 100644 (file)
@@ -33,8 +33,11 @@ class ShapeInference : public IVisitor {
   void visit(ADT::INode::Ref node, ops::BiasAddOp &op) override;
   void visit(ADT::INode::Ref node, ops::ReshapeOp &op) override;
   void visit(ADT::INode::Ref node, ops::VariableOp &op) override;
+  void visit(ADT::INode *node, ops::ScaleOp &op) override;
+  void visit(ADT::INode *node, ops::BatchNormOp &op) override;
+  void visit(ADT::INode *node, ops::DropoutOp &op) override;
 
- protected:
+protected:
   void fillInputShapes(ADT::INode::Ref node, OpDescription &op);
 };
 
index 2de6102..15cdbf4 100644 (file)
@@ -14,6 +14,9 @@
 #include "core/modelIR/operations/concat_op.h"
 #include "core/modelIR/operations/bias_add_op.h"
 #include "core/modelIR/operations/reshape_op.h"
+#include "core/modelIR/operations/batch_norm.h"
+#include "core/modelIR/operations/scale_op.h"
+#include "core/modelIR/operations/dropout_op.h"
 
 #include "core/modelIR/ir_dot_builder.h"
 
@@ -47,6 +50,9 @@ public:
   void visit(INode *node, ops::BiasAddOp &op) override;
   void visit(INode *node, ops::VariableOp &op) override;
   void visit(INode *node, ops::ReshapeOp &op) override;
+  void visit(INode *node, ops::ScaleOp &op) override;
+  void visit(INode *node, ops::BatchNormOp &op) override;
+  void visit(INode *node, ops::DropoutOp &op) override;
 
   void writeDot(std::ostream &os) { dotBuilder.writeDot(os); };
 
index 4c7fe72..77bf297 100644 (file)
@@ -82,8 +82,6 @@ private:
 
   bool hasPool = false;
   PoolType poolType = PoolType::MAX;
-
-  float axis = -1;
 };
 
 } // namespace dumper
index cf6f9ae..cc89cf5 100644 (file)
@@ -26,6 +26,9 @@ namespace ops
   class VariableOp;
   class ReluOp;
   class ReshapeOp;
+  class ScaleOp;
+  class BatchNormOp;
+  class DropoutOp;
 }
 
 /**
@@ -44,6 +47,9 @@ class IVisitor {
   virtual void visit(ADT::INode *node, ops::VariableOp &op) = 0;
   virtual void visit(ADT::INode *node, ops::ReluOp &op) = 0;
   virtual void visit(ADT::INode *node, ops::ReshapeOp &op) = 0;
+  virtual void visit(ADT::INode *node, ops::ScaleOp &op) = 0;
+  virtual void visit(ADT::INode *node, ops::BatchNormOp &op) = 0;
+  virtual void visit(ADT::INode *node, ops::DropoutOp &op) = 0;
 
   virtual ~IVisitor() = default;
 };
@@ -68,6 +74,12 @@ public:
     void visit(ADT::INode *node, ops::VariableOp &op) override;
     void visit(ADT::INode *node, ops::ReluOp &op) override;
     void visit(ADT::INode *node, ops::ReshapeOp &op) override;
+    void visit(ADT::INode *node, ops::ScaleOp &op) override;
+    void visit(ADT::INode *node, ops::BatchNormOp &op) override;
+    void visit(ADT::INode *node, ops::DropoutOp &op) override;
+
+    ~Visitor() override = default;
+
 };
 
 } // namespace model
index cd2a255..4df43af 100644 (file)
@@ -42,6 +42,9 @@ public:
   void visit(ADT::INode::Ref node, ops::BiasAddOp &op) override;
   void visit(ADT::INode::Ref node, ops::VariableOp &op) override;
   void visit(ADT::INode::Ref node, ops::ReshapeOp &op) override;
+  void visit(ADT::INode::Ref node, ops::ScaleOp &op) override;
+  void visit(ADT::INode::Ref node, ops::BatchNormOp &op) override;
+  void visit(ADT::INode::Ref node, ops::DropoutOp &op) override;
 
   void setInput(const std::string &name, const TensorVariant& data);
   std::vector<TensorVariant> &getResult(ADT::INode::Ref node);
index 502a18c..4e42038 100644 (file)
@@ -1,4 +1,5 @@
 #include <cmath>
+#include <cassert>
 
 #include "plugin/interpreter/Interpreter.h"
 
@@ -12,6 +13,9 @@
 #include "core/modelIR/operations/relu_op.h"
 #include "core/modelIR/operations/concat_op.h"
 #include "core/modelIR/operations/bias_add_op.h"
+#include "core/modelIR/operations/batch_norm.h"
+#include "core/modelIR/operations/scale_op.h"
+#include "core/modelIR/operations/dropout_op.h"
 
 #include "ops/Bias.h"
 #include "ops/Concat.h"
@@ -161,6 +165,40 @@ void NNInterpreter::visit(ADT::INode *node, ops::BiasAddOp &op)
   var(node->getId()) = impl::BiasAdd(input, op.getWeights(), op.getOutputShape(0))();
 }
 
+void NNInterpreter::visit(ADT::INode *node, ops::BatchNormOp &op)
+{
+  mapByName(node);
+  auto operand = node->getPrevNodes()[0];
+  TensorVariant input(var(operand.node->getId())[operand.index]);
+  (void)input; (void)op;
+  // TODO implement this
+  //  var(node->getId()) = impl::BatchNormOp(input, op)();
+  assert("BatchNormOp Not implemented yet" == 0);
+
+}
+
+void NNInterpreter::visit(ADT::INode *node, ops::ScaleOp &op)
+{
+  mapByName(node);
+  auto operand = node->getPrevNodes()[0];
+  TensorVariant input(var(operand.node->getId())[operand.index]);
+  (void)input; (void)op;
+  // TODO implement this
+  // var(node->getId()) = impl::ScaleOp(input, op)();
+  assert("ScaleOp Not implemented yet" == 0);
+}
+
+void NNInterpreter::visit(ADT::INode *node, ops::DropoutOp &op)
+{
+  mapByName(node);
+  auto operand = node->getPrevNodes()[0];
+  TensorVariant input(var(operand.node->getId())[operand.index]);
+  (void)input; (void)op;
+  // TODO implement this
+  // var(node->getId()) = impl::DropoutOp(input, op)();
+  assert("DropoutOp Not implemented yet" == 0);
+}
+
 void NNInterpreter::mapByName(ADT::INode::Ref n) {
   auto &nodeName = n->getName();
   if (nodeByName.find(nodeName) != nodeByName.end())
index 4971ab6..a49d67f 100644 (file)
@@ -16,6 +16,9 @@
 #include "core/modelIR/operations/bias_add_op.h"
 #include "core/modelIR/operations/relu_op.h"
 #include "core/modelIR/operations/reshape_op.h"
+#include "core/modelIR/operations/batch_norm.h"
+#include "core/modelIR/operations/scale_op.h"
+#include "core/modelIR/operations/dropout_op.h"
 
 using namespace std;
 
@@ -162,6 +165,21 @@ void ModelAnalyzer::visit(ADT::INode *node, ops::ReshapeOp &op)
   addOpDescr(node, "reshape");
 }
 
+void ModelAnalyzer::visit(ADT::INode *node, ops::DropoutOp &op)
+{
+  addOpDescr(node, "dropout");
+}
+
+void ModelAnalyzer::visit(ADT::INode *node, ops::ScaleOp &op)
+{
+  addOpDescr(node, "scale");
+}
+
+void ModelAnalyzer::visit(ADT::INode *node, ops::BatchNormOp &op)
+{
+  addOpDescr(node, "batchNorm");
+}
+
 } // namespace soft
 } // namespace backend
 } // namespace contrib
index 458ae73..baa5deb 100644 (file)
@@ -41,6 +41,9 @@ public:
   void visit(ADT::INode *node, ops::VariableOp &op) override;
   void visit(ADT::INode *node, ops::ReluOp &op) override;
   void visit(ADT::INode *node, ops::ReshapeOp &op) override;
+  void visit(ADT::INode *node, ops::ScaleOp &op) override;
+  void visit(ADT::INode *node, ops::BatchNormOp &op) override;
+  void visit(ADT::INode *node, ops::DropoutOp &op) override;
 
   struct TensorDescription
   {
index 9580646..f0ad082 100644 (file)
@@ -12,6 +12,9 @@
 #include "core/modelIR/operations/bias_add_op.h"
 #include "core/modelIR/operations/relu_op.h"
 #include "core/modelIR/operations/reshape_op.h"
+#include "core/modelIR/operations/batch_norm.h"
+#include "core/modelIR/operations/scale_op.h"
+#include "core/modelIR/operations/dropout_op.h"
 #include "core/modelIR/ir_node.h"
 
 #include <algorithm>
@@ -223,6 +226,26 @@ void Serializer::visit(ADT::INode *node, ops::ReshapeOp &op)
   serializeShape(op.getOutputShape(0));
 }
 
+void Serializer::visit(ADT::INode *node, ops::BatchNormOp &op)
+{
+  _curOp->_paramStartOffset = _buffer.size();
+  serializeT<float>(op.getEps());
+  serializeT<float>(op.getMovingAvgFraction());
+  serializeT<bool>(op.getSpatial());
+}
+
+void Serializer::visit(ADT::INode *node, ops::ScaleOp &op)
+{
+  _curOp->_paramStartOffset = _buffer.size();
+  serializeTensor(op.getWeights());
+}
+
+void Serializer::visit(ADT::INode *node, ops::DropoutOp &op)
+{
+  _curOp->_paramStartOffset = _buffer.size();
+  serializeT<float>(op.getRate());
+}
+
 void Serializer::serialize(list<ModelAnalyzer::OpDescr> &inferenceSequence)
 {
   for (ModelAnalyzer::OpDescr &descr: inferenceSequence)
index 0c15098..ba8252a 100644 (file)
@@ -37,6 +37,9 @@ public:
   void visit(ADT::INode *node, ops::VariableOp &op) override;
   void visit(ADT::INode *node, ops::ReluOp &op) override;
   void visit(ADT::INode *node, ops::ReshapeOp &op) override;
+  void visit(ADT::INode *node, ops::ScaleOp &op) override;
+  void visit(ADT::INode *node, ops::BatchNormOp &op) override;
+  void visit(ADT::INode *node, ops::DropoutOp &op) override;
 
   void serialize(std::list<ModelAnalyzer::OpDescr> &inferenceSequence);