Support conversion from Caffe2 MergeDim to ONNX Reshape + Squeeze. (#16189)
authorTongliang Liao <xkszltl@gmail.com>
Wed, 13 Feb 2019 22:57:27 +0000 (14:57 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 13 Feb 2019 23:53:38 +0000 (15:53 -0800)
Summary:
`MergeDim` can be done by `Reshape([1, -1, 0, 0, ...]) + Squeeze`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16189

Differential Revision: D14070676

Pulled By: ezyang

fbshipit-source-id: 28d7e9b35cc2c1dcbd4afb3fbdf7383e219b1777

caffe2/onnx/onnx_exporter.cc
caffe2/onnx/onnx_exporter.h
caffe2/operators/prepend_dim_op.cc
caffe2/python/onnx/tests/c2_ref_test.py

index 8857898..b64a912 100644 (file)
@@ -298,6 +298,7 @@ OnnxExporter::get_special_operators() const {
           {"AveragePool", &OnnxExporter::CreateConvPoolNodes},
           {"FC", &OnnxExporter::CreateGemmNodes},
           {"Concat", &OnnxExporter::CreateConcatNodes},
+          {"MergeDim", &OnnxExporter::CreateMergeDimNodes},
           {"LRN", &OnnxExporter::CreateLrnNodes},
           {"Reshape", &OnnxExporter::CreateReshapeNodes},
           {"Slice", &OnnxExporter::CreateSliceNodes},
@@ -746,6 +747,40 @@ ConvertedResult OnnxExporter::CreateConcatNodes(
   return result;
 }
 
+ConvertedResult OnnxExporter::CreateMergeDimNodes(
+    const caffe2::OperatorDef& def,
+    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
+  const auto& x = def.input(0);
+  const auto& y = def.output(0);
+
+  ConvertedResult result;
+  auto& nodes = result.first;
+  auto& const_tensors = result.second;
+
+  {
+    const auto ndim = shapes.at(x).dims().size();
+    CAFFE_ENFORCE_GE(ndim, 2, "No enough dims to merge.");
+    std::vector<int64_t> dims(ndim);
+    dims[0] = 1;
+    dims[1] = -1;
+    const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims));
+  }
+
+  const auto reshaped = dummy_->NewDummyName();
+  nodes.emplace_back(MakeNode("Reshape",
+              { x, const_tensors.back().name() },
+              { reshaped }));
+
+  nodes.emplace_back(MakeNode("Squeeze",
+              { reshaped },
+              { y },
+              std::vector<AttributeProto>{
+                  MakeAttribute("axes", std::vector<int64_t>{ 0 }),
+              }));
+
+  return result;
+}
+
 ConvertedResult OnnxExporter::CreateChannelShuffleNodes(
     const caffe2::OperatorDef& def,
     const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
index 30a5233..f7f1643 100644 (file)
@@ -97,6 +97,10 @@ class CAFFE2_API OnnxExporter {
       const caffe2::OperatorDef& def,
       const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
 
+  ConvertedResult CreateMergeDimNodes(
+      const caffe2::OperatorDef& def,
+      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
+
   ConvertedResult CreateLrnNodes(
       const caffe2::OperatorDef& def,
       const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
index 2796fec..9cdc228 100644 (file)
@@ -25,7 +25,8 @@ OPERATOR_SCHEMA(MergeDim)
 Merge first two dimensions in a single dimension with size dim(0) * dim(1).
 )DOC")
     .Input(0, "data", "An input tensor.")
-    .Output(0, "reshaped", "Reshaped tensor.");
+    .Output(0, "reshaped", "Reshaped tensor.")
+    .InheritOnnxSchema("Reshape");
 
 class GetPrependDimGradient : public GradientMakerBase {
   using GradientMakerBase::GradientMakerBase;
index 629cc76..df4df72 100644 (file)
@@ -474,6 +474,33 @@ class TestCaffe2Basic(DownloadingTestCase):
             op_names.append(op.type)
         self.assertEqual(op_names, ['Scale', 'Scale', 'MatMul', 'Add'])
 
+    def test_mergedim(self):
+        X = np.random.randn(2, 3, 1, 5).astype(np.float32)
+
+        predict_net = caffe2_pb2.NetDef()
+        predict_net.name = 'test-mergedim-net'
+        predict_net.external_input[:] = ['X']
+        predict_net.external_output[:] = ['Y']
+        predict_net.op.extend([
+            core.CreateOperator(
+                'MergeDim',
+                inputs=['X'],
+                outputs=['Y'],
+            ),
+        ])
+        ws, c2_outputs = c2_native_run_net(
+            init_net=None,
+            predict_net=predict_net,
+            inputs=[X])
+
+        onnx_model = c2_onnx.caffe2_net_to_onnx_model(
+            predict_net=predict_net,
+            value_info={
+                'X': (onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[X.dtype], X.shape),
+            })
+        onnx_outputs = c2.run_model(onnx_model, inputs=[X])
+        self.assertSameOutputs(c2_outputs, onnx_outputs)
+
     def test_tensor_filling_ops(self):
         for dtype in [
                 onnx.TensorProto.FLOAT,