Add "axis" and "axis_w" arguments in FC to support customized axix to reduce dim...
authorGu, Jinghui <jinghui.gu@intel.com>
Wed, 21 Nov 2018 23:42:29 +0000 (15:42 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 21 Nov 2018 23:44:50 +0000 (15:44 -0800)
Summary:
Add "axis" and "axis_w" arguments in FC to support customized axix to reduce dim.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12971

Reviewed By: bddppq

Differential Revision: D12850675

Pulled By: yinghai

fbshipit-source-id: f1cde163201bd7add53b8475329db1f038a73019

caffe2/core/nomnigraph/include/nomnigraph/Generated/OpClasses.h
caffe2/core/nomnigraph/ops.def
caffe2/ideep/operators/fully_connected_op.cc
caffe2/opt/converter.cc
caffe2/python/ideep/fc_op_test.py

index 73c508b..f1e8d44 100644 (file)
@@ -565,13 +565,32 @@ class Clip : public NeuralNetOperator {
 
 class FC : public NeuralNetOperator {
  public:
-  FC() : NeuralNetOperator(NNKind::FC) {}
+  FC(int axis = 1, int axisW = 1)
+      : NeuralNetOperator(NNKind::FC), axis_(axis), axisW_(axisW) {}
 
   ~FC() {}
 
   NOMNIGRAPH_DEFINE_NN_RTTI(FC);
 
+  int getAxis() const {
+    return axis_;
+  }
+
+  int getAxisW() const {
+    return axisW_;
+  }
+
+  void setAxis(int axis) {
+    axis_ = axis;
+  }
+
+  void setAxisW(int axisW) {
+    axisW_ = axisW;
+  }
+
  private:
+  int axis_;
+  int axisW_;
 };
 
 class GivenTensorFill : public NeuralNetOperator {
index 0a3709e..2fa0729 100644 (file)
@@ -56,6 +56,9 @@ Clip
 - Max : float
 
 FC
+- Axis : int : 1
+- AxisW : int : 1
+
 GivenTensorFill
 Concat
 - Axis : int : -1
index 1bc75a9..80ed367 100644 (file)
@@ -2,6 +2,28 @@
 
 namespace caffe2 {
 
+USE_IDEEP_DEF_ALIASES();
+
+static inline itensor::dims CanonicalDims(itensor::dims adims, int32_t axis) {
+  CAFFE_ENFORCE(axis < (int32_t)adims.size(), "Invalid axis!");
+  CAFFE_ENFORCE(axis > (int32_t)-adims.size(), "Invalid axis!");
+  if (adims.size() == 2 || axis == 1) {
+    return adims;
+  }
+  if (axis < 0) {
+    axis += (int32_t)adims.size();
+  }
+
+  auto dim0 = std::accumulate(
+      adims.begin(),
+      adims.begin() + axis,
+      1,
+      std::multiplies<itensor::dim_t>());
+  auto dim1 = std::accumulate(
+      adims.begin() + axis, adims.end(), 1, std::multiplies<itensor::dim_t>());
+  return itensor::dims({dim0, dim1});
+}
+
 class IDEEPFullyConnectedOp final : public IDEEPOperator {
  public:
   USE_IDEEP_DEF_ALIASES();
@@ -9,8 +31,8 @@ class IDEEPFullyConnectedOp final : public IDEEPOperator {
 
   IDEEPFullyConnectedOp(const OperatorDef& operator_def, Workspace* ws)
       : IDEEPOperator(operator_def, ws),
-        float16_compute_(
-            OperatorBase::GetSingleArgument<bool>("float16_compute", false)) {}
+        axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
+        axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)) {}
   virtual ~IDEEPFullyConnectedOp() {}
 
   bool RunOnDevice() override {
@@ -18,17 +40,30 @@ class IDEEPFullyConnectedOp final : public IDEEPOperator {
     const auto& filter = Input(FILTER);
     auto* Y = Output(OUTPUT);
 
+    itensor X_in = X;
+    auto X_dims = CanonicalDims(X_in.get_dims(), axis_);
+    if (X_in.get_dims() != X_dims) {
+      X_in.reshape(X_dims);
+    }
+
+    itensor filter_in = filter;
+    auto filter_dims = CanonicalDims(filter_in.get_dims(), axis_w_);
+    if (filter_in.get_dims() != filter_dims) {
+      filter_in.reshape(filter_dims);
+    }
+
     if (InputSize() > BIAS) {
-      ideep::inner_product_forward::compute(X, filter, Input(BIAS), *Y);
+      ideep::inner_product_forward::compute(X_in, filter_in, Input(BIAS), *Y);
     } else {
-      ideep::inner_product_forward::compute(X, filter, *Y);
+      ideep::inner_product_forward::compute(X_in, filter_in, *Y);
     }
 
     return true;
   }
 
  private:
-  bool float16_compute_;
+  size_t axis_{1};
+  size_t axis_w_{1};
 
   INPUT_TAGS(INPUT, FILTER, BIAS);
   OUTPUT_TAGS(OUTPUT);
@@ -41,8 +76,8 @@ class IDEEPFullyConnectedGradientOp final : public IDEEPOperator {
 
   IDEEPFullyConnectedGradientOp(const OperatorDef& operator_def, Workspace* ws)
       : IDEEPOperator(operator_def, ws),
-        float16_compute_(
-            OperatorBase::GetSingleArgument<bool>("float16_compute", false)) {}
+        axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
+        axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)) {}
   virtual ~IDEEPFullyConnectedGradientOp() {}
 
   bool RunOnDevice() override {
@@ -52,18 +87,31 @@ class IDEEPFullyConnectedGradientOp final : public IDEEPOperator {
     auto* dfilter = Output(FILTER_GRAD);
     auto* dbias = Output(BIAS_GRAD);
 
-    ideep::inner_product_backward_weights::compute(X, dY, *dfilter, *dbias);
+    itensor X_in = X;
+    auto X_dims = CanonicalDims(X_in.get_dims(), axis_);
+    if (X_in.get_dims() != X_dims) {
+      X_in.reshape(X_dims);
+    }
+
+    itensor filter_in = filter;
+    auto filter_dims = CanonicalDims(filter_in.get_dims(), axis_w_);
+    if (filter_in.get_dims() != filter_dims) {
+      filter_in.reshape(filter_dims);
+    }
+
+    ideep::inner_product_backward_weights::compute(X_in, dY, *dfilter, *dbias);
 
     if (OutputSize() > INPUT_GRAD) {
       ideep::inner_product_backward_data::compute(
-          dY, filter, X.get_dims(), *Output(INPUT_GRAD));
+          dY, filter_in, X_in.get_dims(), *Output(INPUT_GRAD));
     }
 
     return true;
   }
 
  private:
-  bool float16_compute_;
+  size_t axis_{1};
+  size_t axis_w_{1};
 
   INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
   OUTPUT_TAGS(FILTER_GRAD, BIAS_GRAD, INPUT_GRAD);
index 6f4d9da..4d701ed 100644 (file)
@@ -1,7 +1,7 @@
 #include <limits>
 
-#include "caffe2/opt/converter.h"
 #include "caffe2/core/logging.h"
+#include "caffe2/opt/converter.h"
 
 #include "nomnigraph/Graph/Algorithms.h"
 
@@ -67,7 +67,8 @@ std::map<std::string, caffe2::Argument> Converter::getArgumentsFromOperator(
   return argMap;
 }
 
-repr::NeuralNetOperator::NNLayout getLayout(std::map<std::string, caffe2::Argument> argMap) {
+repr::NeuralNetOperator::NNLayout getLayout(
+    std::map<std::string, caffe2::Argument> argMap) {
   auto arg = argMap.find("order");
   if (arg != argMap.end()) {
     auto order = argMap["order"].s();
@@ -94,7 +95,8 @@ OperatorDef Converter::convertToOperatorDef(
   return op;
 }
 
-std::vector<int> getKernelShape(std::map<std::string, caffe2::Argument> argMap) {
+std::vector<int> getKernelShape(
+    std::map<std::string, caffe2::Argument> argMap) {
   // There are literally three ways to define shapes in Conv in Caffe2
   std::vector<int> kernelShape;
   if (argMap.count("kernel")) {
@@ -233,6 +235,33 @@ class ConcatConverter : public Converter {
 };
 REGISTER_CONVERTER(Concat, ConcatConverter);
 
+class FCConverter : public Converter {
+  std::unique_ptr<nom::repr::NeuralNetOperator> convertToNeuralNetOperator(
+      const OperatorDef& op) override {
+    std::unique_ptr<repr::NeuralNetOperator> nnOp =
+        util::make_unique<repr::FC>();
+    auto argMap = getArgumentsFromOperator(op);
+
+    auto c = dyn_cast<repr::FC>(nnOp.get());
+    if (argMap.count("axis")) {
+      CAFFE_ENFORCE(argMap["axis"].has_i(), "Invalid axis argument");
+      int axis = static_cast<int>(argMap["axis"].i());
+      c->setAxis(axis);
+    }
+    if (argMap.count("axis_w")) {
+      CAFFE_ENFORCE(argMap["axis_w"].has_i(), "Invalid axis_w argument");
+      int axis_w = static_cast<int>(argMap["axis_w"].i());
+      c->setAxisW(axis_w);
+    }
+
+    return nnOp;
+  }
+  // Does not override default converter to OperatorDef
+
+  virtual ~FCConverter() {}
+};
+REGISTER_CONVERTER(FC, FCConverter);
+
 } // namespace
 
 std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
@@ -267,7 +296,6 @@ std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
   return nnOp;
 }
 
-
 /// \brief Ingest a caffe2 protobuf model and output an NNModule.
 /// \param net The caffe2 protobuf NetDef
 repr::NNModule convertToNNModule(
@@ -298,7 +326,7 @@ repr::NNModule convertToNNModule(
   for (const auto& op : net.op()) {
     auto opNode = dfg.createNode(); // Create an empty node for the operator.
     // First calculate in-edges (data dependencies).
-    for (const auto &input : op.input()) {
+    for (const autoinput : op.input()) {
       // If we've never seen this tensor, make one.
       if (!blobMap.count(input)) {
         auto tensor = util::make_unique<repr::Tensor>(input);
@@ -315,7 +343,7 @@ repr::NNModule convertToNNModule(
     }
 
     // Then save outputs into the blobMap for later consumption.
-    for (const auto &output : op.output()) {
+    for (const autooutput : op.output()) {
       auto tensor = util::make_unique<repr::Tensor>(output);
       auto tensorNode =
           dfg.createNode(unique_dyn_cast<repr::NeuralNetData>(tensor));
@@ -346,7 +374,7 @@ repr::NNModule convertToNNModule(
           externalInputNames.size(),
           " unused blobs: ",
           os.str());
-    // Otherwise, we add the blobs to the graph as no-ops
+      // Otherwise, we add the blobs to the graph as no-ops
     } else {
       for (const auto& input : externalInputNames) {
         blobMap[input] = dfg.createNode(util::make_unique<repr::Tensor>(input));
@@ -368,9 +396,9 @@ repr::NNModule convertToNNModule(
 
 caffe2::OperatorDef convertToOperatorDef(
     const repr::NNGraph::NodeRef& instrNode) {
-  auto *nnOp = repr::nn::get<repr::NeuralNetOperator>(instrNode);
+  autonnOp = repr::nn::get<repr::NeuralNetOperator>(instrNode);
   auto op_type = nnOp->getName();
-  auto *annotation = nnOp->getAnnotation();
+  autoannotation = nnOp->getAnnotation();
   caffe2::OperatorDef op;
 
   if (ConverterRegistry()->Has(op_type)) {
@@ -410,7 +438,7 @@ Caffe2Annotation* getOrAddCaffe2Annotation(
   return c2_annotation;
 }
 
-caffe2::NetDef convertToCaffe2Proto(repr::NNModule &m) {
+caffe2::NetDef convertToCaffe2Proto(repr::NNModulem) {
   auto predictNet = caffe2::NetDef();
   return convertToCaffe2Proto(m, predictNet);
 }
@@ -443,7 +471,9 @@ std::vector<std::string> mergeExternalTensors(
   return out;
 }
 
-caffe2::NetDef convertToCaffe2Proto(repr::NNModule &m, const caffe2::NetDef& oldNet) {
+caffe2::NetDef convertToCaffe2Proto(
+    repr::NNModule& m,
+    const caffe2::NetDef& oldNet) {
   auto predictNet = caffe2::NetDef();
   // We copy the old net rather than mutate it.
   predictNet.CopyFrom(oldNet);
@@ -453,7 +483,7 @@ caffe2::NetDef convertToCaffe2Proto(repr::NNModule &m, const caffe2::NetDef& old
 
   // Simply iterate through the CFG and populate data dependencies
   // with the DFG
-  for (const auto &bbNode : m.controlFlow.getMutableNodes()) {
+  for (const autobbNode : m.controlFlow.getMutableNodes()) {
     if (bbNode->getOutEdges().size() > 1) {
       CAFFE_THROW("Control flow not yet supported in Caffe2 converter.");
     }
@@ -461,20 +491,19 @@ caffe2::NetDef convertToCaffe2Proto(repr::NNModule &m, const caffe2::NetDef& old
     for (const auto& instrNode : bb.getInstructions()) {
       caffe2::OperatorDef op = convertToOperatorDef(instrNode);
 
-      for (const auto &inEdge : instrNode->getInEdges()) {
-        auto *tensorNode =
+      for (const autoinEdge : instrNode->getInEdges()) {
+        autotensorNode =
             dyn_cast<repr::NeuralNetData>(inEdge->tail()->data().get());
         *op.add_input() = tensorNode->getName();
       }
-      for (const auto &outEdge : instrNode->getOutEdges()) {
-        auto *tensorNode =
+      for (const autooutEdge : instrNode->getOutEdges()) {
+        autotensorNode =
             dyn_cast<repr::NeuralNetData>(outEdge->head()->data().get());
         *op.add_output() = tensorNode->getName();
       }
 
-      auto *nnOp = repr::nn::get<repr::NeuralNetOperator>(instrNode);
+      autonnOp = repr::nn::get<repr::NeuralNetOperator>(instrNode);
       if (nnOp->getLayout() != repr::NeuralNetOperator::NNLayout::Undefined) {
-
         caffe2::Argument* arg = nullptr;
         for (int i = 0; i < op.arg_size(); ++i) {
           auto arg_ = op.mutable_arg(i);
index b656c90..03deedb 100644 (file)
@@ -4,6 +4,7 @@ from __future__ import print_function
 from __future__ import unicode_literals
 
 import unittest
+from functools import reduce
 import hypothesis.strategies as st
 from hypothesis import given, settings
 import numpy as np
@@ -31,6 +32,196 @@ class FcTest(hu.HypothesisTestCase):
         for i in range(3):
             self.assertGradientChecks(gc, op, [X, W, b], i, [0])
 
+    @given(n=st.integers(1, 5),
+           m=st.integers(1, 5),
+           c=st.integers(1, 5),
+           h=st.integers(1, 5),
+           w=st.integers(1, 5),
+           axis=st.integers(1, 3),
+           **mu.gcs)
+    def test_fc_with_axis(self, n, m, c, h, w, axis, gc, dc):
+        X = np.random.rand(n, c, h, w).astype(np.float32) - 0.5
+        k = reduce((lambda x, y: x * y), [n, c, h, w][axis - 4:])
+        nn = reduce((lambda x, y: x * y), [n, c, h, w][:axis])
+        W = np.random.rand(m, k).astype(np.float32) - 0.5
+        b = np.random.rand(m).astype(np.float32) - 0.5
+        dY = np.random.rand(nn, m).astype(np.float32) - 0.5
+
+        op0 = core.CreateOperator(
+            'FC',
+            ['X', 'W', 'b'],
+            ["Y"],
+            axis=axis,
+            device_option=dc[0]
+        )
+
+        op0_bw = core.CreateOperator(
+            'FCGradient',
+            ['X', 'W', 'dY'],
+            ["dW", "db"],
+            axis=axis,
+            device_option=dc[0]
+        )
+
+        workspace.ResetWorkspace()
+        workspace.FeedBlob('X', X, dc[0])
+        workspace.FeedBlob('W', W, dc[0])
+        workspace.FeedBlob('b', b, dc[0])
+        workspace.RunOperatorOnce(op0)
+        Y0 = workspace.FetchBlob('Y')
+
+        workspace.FeedBlob('dY', dY, dc[0])
+        workspace.RunOperatorOnce(op0_bw)
+        dW0 = workspace.FetchBlob('dW')
+        db0 = workspace.FetchBlob('db')
+
+        op1 = core.CreateOperator(
+            'FC',
+            ['X', 'W', 'b'],
+            ["Y"],
+            axis=axis,
+            device_option=dc[1]
+        )
+
+        op1_bw = core.CreateOperator(
+            'FCGradient',
+            ['X', 'W', 'dY'],
+            ["dW", "db"],
+            axis=axis,
+            device_option=dc[1]
+        )
+
+        workspace.SwitchWorkspace("_device_check_", True)
+        workspace.FeedBlob('X', X, dc[1])
+        workspace.FeedBlob('W', W, dc[1])
+        workspace.FeedBlob('b', b, dc[1])
+        workspace.RunOperatorOnce(op1)
+        Y1 = workspace.FetchBlob('Y')
+
+        workspace.FeedBlob('dY', dY, dc[1])
+        workspace.RunOperatorOnce(op1_bw)
+        dW1 = workspace.FetchBlob('dW')
+        db1 = workspace.FetchBlob('db')
+
+        Y0 = Y0.flatten()
+        Y1 = Y1.flatten()
+        if not np.allclose(Y0, Y1, atol=0.01, rtol=0.01):
+            print(Y1)
+            print(Y0)
+            print(np.max(np.abs(Y1 - Y0)))
+            self.assertTrue(False)
+
+        dW0 = dW0.flatten()
+        dW1 = dW1.flatten()
+        if not np.allclose(dW0, dW1, atol=0.01, rtol=0.01):
+            print(dW1)
+            print(dW0)
+            print(np.max(np.abs(dW1 - dW0)))
+            self.assertTrue(False)
+
+        db0 = db0.flatten()
+        db1 = db1.flatten()
+        if not np.allclose(db0, db1, atol=0.01, rtol=0.01):
+            print(db1)
+            print(db0)
+            print(np.max(np.abs(db1 - db0)))
+            self.assertTrue(False)
+
+    @given(n=st.integers(1, 5),
+           o=st.integers(1, 5),
+           i=st.integers(1, 5),
+           h=st.integers(1, 5),
+           w=st.integers(1, 5),
+           axis_w=st.integers(1, 3),
+           **mu.gcs)
+    def test_fc_with_axis_w(self, n, o, i, h, w, axis_w, gc, dc):
+        W = np.random.rand(o, i, h, w).astype(np.float32) - 0.5
+        k = reduce((lambda x, y: x * y), [o, i, h, w][axis_w - 4:])
+        m = reduce((lambda x, y: x * y), [o, i, h, w][:axis_w])
+        X = np.random.rand(n, k).astype(np.float32) - 0.5
+        b = np.random.rand(m).astype(np.float32) - 0.5
+        dY = np.random.rand(n, m).astype(np.float32) - 0.5
+
+        op0 = core.CreateOperator(
+            'FC',
+            ['X', 'W', 'b'],
+            ["Y"],
+            axis_w=axis_w,
+            device_option=dc[0]
+        )
+
+        op0_bw = core.CreateOperator(
+            'FCGradient',
+            ['X', 'W', 'dY'],
+            ["dW", "db"],
+            axis_w=axis_w,
+            device_option=dc[0]
+        )
+
+        workspace.ResetWorkspace()
+        workspace.FeedBlob('X', X, dc[0])
+        workspace.FeedBlob('W', W, dc[0])
+        workspace.FeedBlob('b', b, dc[0])
+        workspace.RunOperatorOnce(op0)
+        Y0 = workspace.FetchBlob('Y')
+
+        workspace.FeedBlob('dY', dY, dc[0])
+        workspace.RunOperatorOnce(op0_bw)
+        dW0 = workspace.FetchBlob('dW')
+        db0 = workspace.FetchBlob('db')
+
+        op1 = core.CreateOperator(
+            'FC',
+            ['X', 'W', 'b'],
+            ["Y"],
+            axis_w=axis_w,
+            device_option=dc[1]
+        )
+
+        op1_bw = core.CreateOperator(
+            'FCGradient',
+            ['X', 'W', 'dY'],
+            ["dW", "db"],
+            axis_w=axis_w,
+            device_option=dc[1]
+        )
+
+        workspace.SwitchWorkspace("_device_check_", True)
+        workspace.FeedBlob('X', X, dc[1])
+        workspace.FeedBlob('W', W, dc[1])
+        workspace.FeedBlob('b', b, dc[1])
+        workspace.RunOperatorOnce(op1)
+        Y1 = workspace.FetchBlob('Y')
+
+        workspace.FeedBlob('dY', dY, dc[1])
+        workspace.RunOperatorOnce(op1_bw)
+        dW1 = workspace.FetchBlob('dW')
+        db1 = workspace.FetchBlob('db')
+
+        Y0 = Y0.flatten()
+        Y1 = Y1.flatten()
+        if not np.allclose(Y0, Y1, atol=0.01, rtol=0.01):
+            print(Y1)
+            print(Y0)
+            print(np.max(np.abs(Y1 - Y0)))
+            self.assertTrue(False)
+
+        dW0 = dW0.flatten()
+        dW1 = dW1.flatten()
+        if not np.allclose(dW0, dW1, atol=0.01, rtol=0.01):
+            print(dW1)
+            print(dW0)
+            print(np.max(np.abs(dW1 - dW0)))
+            self.assertTrue(False)
+
+        db0 = db0.flatten()
+        db1 = db1.flatten()
+        if not np.allclose(db0, db1, atol=0.01, rtol=0.01):
+            print(db1)
+            print(db0)
+            print(np.max(np.abs(db1 - db0)))
+            self.assertTrue(False)
+
 
 if __name__ == "__main__":
     unittest.main()