From a670824fee57a2b261168a2bfcfc1f4e008d8d4a Mon Sep 17 00:00:00 2001 From: Tongliang Liao Date: Wed, 13 Feb 2019 17:08:40 -0800 Subject: [PATCH] Support FC (Caffe2) -> Gemm (ONNX) with variable input shape. (#16184) Summary: For >2D input, previously the code uses static shape captured during tracing and reshape before/after `Gemm`. Now we add `-1` to the first `Reshape`, and uses `Shape(X) => Slice(outer) => Concat(with -1 for inner) => Reshape` for the second. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16184 Differential Revision: D14070754 Pulled By: ezyang fbshipit-source-id: 86c69e9b254945b3406c07e122e57a00dfeba3df --- caffe2/onnx/onnx_exporter.cc | 79 ++++++++++++++++++++------------- caffe2/python/onnx/tests/c2_ref_test.py | 33 ++++++++++++++ 2 files changed, 81 insertions(+), 31 deletions(-) diff --git a/caffe2/onnx/onnx_exporter.cc b/caffe2/onnx/onnx_exporter.cc index b64a912..087fbf9 100644 --- a/caffe2/onnx/onnx_exporter.cc +++ b/caffe2/onnx/onnx_exporter.cc @@ -920,7 +920,6 @@ ConvertedResult OnnxExporter::CreateUpsampleNodes( MakeTensor("resolved scale tensor", tmp_vector, TensorProto::FLOAT); auto node = MakeNode("Constant", {}, {resolved_scale}); - MakeAttribute("value", resolved_scale_tensor); node.add_attribute()->CopyFrom( MakeAttribute("value", resolved_scale_tensor)); nodes.emplace_back(node); @@ -1060,16 +1059,18 @@ ConvertedResult OnnxExporter::CreateGemmNodes( if (has_axis) { axis = it->second->i(); } + + auto gemm_x_input = x; if (x_shape.dims().size() > 2) { // we need to reshape only when dimension is higher than 2 - auto outer = DimProd(x_shape, 0, axis); - auto inner = DimProd(x_shape, axis, x_shape.dims().size()); - std::vector dims = {outer, inner}; - auto reshaped_x = dummy_->NewDummyName(); - const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims)); - nodes.emplace_back( - MakeNode("Reshape", {x, const_tensors.back().name()}, {reshaped_x})); - x = reshaped_x; + const auto inner = DimProd(x_shape, axis, x_shape.dims().size()); + + gemm_x_input = dummy_->NewDummyName(); + const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, + std::vector{ -1, inner })); + nodes.emplace_back(MakeNode("Reshape", + { x, const_tensors.back().name() }, + { gemm_x_input })); } it = args.find("axis_w"); @@ -1081,32 +1082,48 @@ ConvertedResult OnnxExporter::CreateGemmNodes( // we need to reshape only when dimension is higher than 2 auto outer = DimProd(w_shape, 0, axis_w); auto inner = DimProd(w_shape, axis_w, w_shape.dims().size()); - std::vector dims = {outer, inner}; auto reshaped_w = dummy_->NewDummyName(); - const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims)); - nodes.emplace_back( - MakeNode("Reshape", {w, const_tensors.back().name()}, {reshaped_w})); + const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, + std::vector{ outer, inner })); + nodes.emplace_back(MakeNode("Reshape", + { w, const_tensors.back().name() }, + { reshaped_w })); w = reshaped_w; } - auto gemm_y_output = (has_axis) ? dummy_->NewDummyName() : y; - std::vector attrs = {MakeAttribute("transB", 1L)}; - nodes.emplace_back(MakeNode( - "Gemm", - {x, w, b}, - {gemm_y_output}, - attrs, - def.name())); - - if (has_axis) { - std::vector dims; - for (int i = 0; i < axis; ++i) { - dims.push_back(x_shape.dims(i)); - } - dims.push_back(-1); - const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims)); - nodes.emplace_back( - MakeNode("Reshape", {gemm_y_output, const_tensors.back().name()}, {y})); + auto gemm_y_output = axis > 1 ? dummy_->NewDummyName() : y; + nodes.emplace_back(MakeNode("Gemm", + { gemm_x_input, w, b }, + { gemm_y_output }, + { MakeAttribute("transB", 1L) }, + def.name())); + + // capture the outer shape if needed. + if (axis > 1) { + const auto x_shape = dummy_->NewDummyName(); + nodes.emplace_back(MakeNode("Shape", {x}, {x_shape})); + + const auto x_shape_outer = dummy_->NewDummyName(); + nodes.emplace_back(MakeNode("Slice", + { x_shape }, + { x_shape_outer }, + std::vector{ + MakeAttribute("starts", std::vector{ 0 }), + MakeAttribute("ends", std::vector{ axis }), + })); + + const auto y_shape = dummy_->NewDummyName(); + const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, { -1 })); + nodes.emplace_back(MakeNode("Concat", + { x_shape_outer, const_tensors.back().name() }, + { y_shape }, + std::vector{ + MakeAttribute("axis", static_cast(0)), + })); + + nodes.emplace_back(MakeNode("Reshape", + { gemm_y_output, y_shape }, + { y })); } return result; diff --git a/caffe2/python/onnx/tests/c2_ref_test.py b/caffe2/python/onnx/tests/c2_ref_test.py index df4df72..94a0256 100644 --- a/caffe2/python/onnx/tests/c2_ref_test.py +++ b/caffe2/python/onnx/tests/c2_ref_test.py @@ -186,6 +186,39 @@ class TestCaffe2Basic(DownloadingTestCase): onnx_outputs = c2.run_model(onnx_model, inputs=[X]) self.assertSameOutputs(c2_outputs, onnx_outputs) + def test_fc(self): + X_fake = np.zeros((3, 1, 3, 1, 7), dtype=np.float32) + X = np.random.randn(5, 2, 3, 1, 7).astype(np.float32) + W = np.random.randn(11, 21).astype(np.float32) + B = np.random.randn(11).astype(np.float32) + + predict_net = caffe2_pb2.NetDef() + predict_net.name = 'test-fc-net' + predict_net.external_input[:] = ['X', 'W', 'B'] + predict_net.external_output[:] = ['Y'] + predict_net.op.extend([ + core.CreateOperator( + 'FC', + inputs=['X', 'W', 'B'], + outputs=['Y'], + axis=2, + ), + ]) + ws, c2_outputs = c2_native_run_net( + init_net=None, + predict_net=predict_net, + inputs=[X, W, B]) + + onnx_model = c2_onnx.caffe2_net_to_onnx_model( + predict_net=predict_net, + value_info={ + 'X': (onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[X.dtype], X_fake.shape), + 'W': (onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[W.dtype], W.shape), + 'B': (onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[B.dtype], B.shape), + }) + onnx_outputs = c2.run_model(onnx_model, inputs=[X, W, B]) + self.assertSameOutputs(c2_outputs, onnx_outputs) + def test_gemm(self): # simple A = np.random.randn(3, 2).astype(np.float32) -- 2.7.4