From 2ebeb33697d7fb9acaac026cfa140af56e617d11 Mon Sep 17 00:00:00 2001 From: "Gu, Jinghui" Date: Mon, 7 Jan 2019 11:07:51 -0800 Subject: [PATCH] Fallback to CPU concat op to handle TensorCPU inputs (#15263) Summary: Fallback to CPU concat op to handle TensorCPU inputs Pull Request resolved: https://github.com/pytorch/pytorch/pull/15263 Differential Revision: D13587030 Pulled By: yinghai fbshipit-source-id: 010a8579d61c3beb8556eb92493a552b2ab0030c --- caffe2/ideep/operators/concat_split_op.cc | 44 ++++++++++++++++++----------- caffe2/python/ideep/concat_split_op_test.py | 43 ++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 16 deletions(-) diff --git a/caffe2/ideep/operators/concat_split_op.cc b/caffe2/ideep/operators/concat_split_op.cc index 1a0de0c..b105aad 100644 --- a/caffe2/ideep/operators/concat_split_op.cc +++ b/caffe2/ideep/operators/concat_split_op.cc @@ -1,4 +1,6 @@ #include +#include +#include namespace caffe2 { @@ -6,9 +8,11 @@ class IDEEPConcatOp final : public IDEEPOperator { public: USE_IDEEP_DEF_ALIASES(); USE_IDEEP_OPERATOR_FUNCTIONS(); + using FALLBACK_OP = IDEEPFallbackOp, SkipIndices<0>>; IDEEPConcatOp(const OperatorDef& operator_def, Workspace* ws) - : IDEEPOperator(operator_def, ws) { + : IDEEPOperator(operator_def, ws), + fallback_(operator_def, ws) { CAFFE_ENFORCE( !(OperatorBase::HasArgument("axis") && OperatorBase::HasArgument("order")), "You shouldn't specify both the dim to concat, and the order " @@ -25,39 +29,47 @@ class IDEEPConcatOp final : public IDEEPOperator { virtual ~IDEEPConcatOp() {} bool RunOnDevice() override { - auto* output = Output(OUTPUT); + bool fallback_to_cpu = false; + vector inputs_itensor; - vector inputs; for (int i = 0; i < InputSize(); ++i) { if (OperatorBase::InputBlob(i).template IsType()) { - inputs.emplace_back(Input(i)); + auto& tensor_ideep = Input(i); + if (tensor_ideep.ndims() == 0 || tensor_ideep.get_nelems() == 0) + continue; + inputs_itensor.emplace_back(tensor_ideep); } else { CAFFE_ENFORCE( BlobIsTensorType(OperatorBase::InputBlob(i), CPU), "Expect cpu tensor if not itensor"); auto& tensor_cpu = OperatorBase::Input(i, CPU); - CAFFE_ENFORCE( - tensor_cpu.sizes().size() == 0 || tensor_cpu.numel() == 0, - "Expect zero dim tensor"); + if (tensor_cpu.sizes().size() == 0 || tensor_cpu.numel() == 0) + continue; + fallback_to_cpu = true; + break; } } - auto axis_vdata = ideep::concat::compute(inputs, axis_, add_axis_, *output); - Tensor* axis_info = OutputTensor( - AXIS_INFO, - vector(1, InputSize()), - at::dtype().device(CPU)); - int* axis_data = axis_info->template mutable_data(); - for (int i = 0; i < axis_vdata.size(); i++) { - axis_data[i] = axis_vdata[i]; + if (!fallback_to_cpu) { + auto* output = Output(OUTPUT); + Tensor* axis_info = OutputTensor(AXIS_INFO, + vector(1, InputSize()), at::dtype().device(CPU)); + auto* axis_data = axis_info->template mutable_data(); + auto axis_vdata = + ideep::concat::compute(inputs_itensor, axis_, add_axis_, *output); + for (int i = 0; i < axis_vdata.size(); i++) { + axis_data[i] = axis_vdata[i]; + } + return true; } - return true; + return fallback_.Run(0); } private: int axis_; int add_axis_; + FALLBACK_OP fallback_; INPUT_TAGS(INPUT0); OUTPUT_TAGS(OUTPUT, AXIS_INFO); diff --git a/caffe2/python/ideep/concat_split_op_test.py b/caffe2/python/ideep/concat_split_op_test.py index 3a5763f..29e7750 100644 --- a/caffe2/python/ideep/concat_split_op_test.py +++ b/caffe2/python/ideep/concat_split_op_test.py @@ -113,6 +113,49 @@ class TestConcatSplitOps(hu.HypothesisTestCase): self.assertGradientChecks(gc, op, splits, i, [0]) + @given(tensor_splits=_tensor_splits(add_axis=True), **mu.gcs) + def test_concat_with_TensorCPU(self, tensor_splits, gc, dc): + axis, _, splits = tensor_splits + op0 = core.CreateOperator( + "Concat", + ['X_{}'.format(i) for i in range(len(splits))], + ['concat_result0', 'split_info0'], + axis=axis, + add_axis=1, + device_option=dc[0] + ) + op1 = core.CreateOperator( + "Concat", + ['X_{}'.format(i) for i in range(len(splits))], + ['concat_result1', 'split_info1'], + axis=axis, + add_axis=1, + device_option=dc[1] + ) + + for i, X in enumerate(splits): + workspace.FeedBlob('X_{}'.format(i), X, dc[0]) + + workspace.RunOperatorOnce(op0) + res0 = workspace.FetchBlob('concat_result0') + inf0 = workspace.FetchBlob('split_info0') + + workspace.RunOperatorOnce(op1) + res1 = workspace.FetchBlob('concat_result1') + inf1 = workspace.FetchBlob('split_info1') + + if not np.allclose(res0, res1, atol=0.0, rtol=0.0): + print(res1.flatten()) + print(res0.flatten()) + print(np.max(np.abs(res1 - res0))) + self.assertTrue(False) + + if not np.allclose(inf0, inf1, atol=0.0, rtol=0.0): + print(inf1.flatten()) + print(inf0.flatten()) + print(np.max(np.abs(inf1 - inf0))) + self.assertTrue(False) + if __name__ == "__main__": unittest.main() -- 2.7.4