From 1488c5dd03a7a619b2f955ddd6997751b1149784 Mon Sep 17 00:00:00 2001 From: "Cheng,Penghui" Date: Fri, 4 Jan 2019 22:30:48 -0800 Subject: [PATCH] support 0 size in any of the tensor dimensions in mkldnn (#15295) Summary: support 0 size in any of the tensor dimensions in mkldnn Pull Request resolved: https://github.com/pytorch/pytorch/pull/15295 Differential Revision: D13573747 Pulled By: yinghai fbshipit-source-id: 5bf7a0b9e2567e80f44981a7823be5407fc94e53 --- caffe2/ideep/operators/operator_fallback_ideep.h | 3 +-- caffe2/python/ideep/copy_op_test.py | 23 +++++++++++------------ caffe2/python/pybind_state_ideep.cc | 8 ++++---- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/caffe2/ideep/operators/operator_fallback_ideep.h b/caffe2/ideep/operators/operator_fallback_ideep.h index 87b0b80..2d31489 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.h +++ b/caffe2/ideep/operators/operator_fallback_ideep.h @@ -126,8 +126,7 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator { "output type who needs copying."); const auto& src = local_output_blobs_[i]->template Get(); auto src_dims = src.sizes().vec(); - if (src.template IsType() && src.sizes().size() != 0 && - src.numel() != 0 && base_op_->type() != "Python") { + if (src.template IsType() && src.dim() != 0 && base_op_->type() != "Python") { Blob* dst = OperatorBase::OutputBlob(i); // The output tensor must be ideep tensor with public format. // If reusing ideep tensor with non-public format, the tensor buffer diff --git a/caffe2/python/ideep/copy_op_test.py b/caffe2/python/ideep/copy_op_test.py index 12a6d46..4b0a15b 100644 --- a/caffe2/python/ideep/copy_op_test.py +++ b/caffe2/python/ideep/copy_op_test.py @@ -9,6 +9,7 @@ from random import randint from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace + @unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.") class CopyTest(unittest.TestCase): def _get_deep_device(self): @@ -19,7 +20,7 @@ class CopyTest(unittest.TestCase): "CopyCPUToIDEEP", ["X"], ["X_ideep"], - ) + ) op.device_option.CopyFrom(self._get_deep_device()) n = randint(1, 128) c = randint(1, 64) @@ -31,13 +32,12 @@ class CopyTest(unittest.TestCase): X_ideep = workspace.FetchBlob("X_ideep") np.testing.assert_allclose(X, X_ideep) - @unittest.skipIf(True, "zero dim is NOT supported for now.") def test_copy_to_ideep_zero_dim(self): op = core.CreateOperator( - "CopyCPUToIDEEP", - ["X"], - ["X_ideep"], - ) + "CopyCPUToIDEEP", + ["X"], + ["X_ideep"], + ) op.device_option.CopyFrom(self._get_deep_device()) n = 0 c = randint(1, 128) @@ -52,7 +52,7 @@ class CopyTest(unittest.TestCase): "CopyIDEEPToCPU", ["X_ideep"], ["X"], - ) + ) op.device_option.CopyFrom(self._get_deep_device()) n = randint(1, 128) c = randint(1, 64) @@ -64,13 +64,12 @@ class CopyTest(unittest.TestCase): X_ideep = workspace.FetchBlob("X") np.testing.assert_allclose(X, X_ideep) - @unittest.skipIf(True, "zero dim is NOT supported for now.") def test_copy_from_ideep_zero_dim(self): op = core.CreateOperator( - "CopyIDEEPToCPU", - ["X_ideep"], - ["X"], - ) + "CopyIDEEPToCPU", + ["X_ideep"], + ["X"], + ) op.device_option.CopyFrom(self._get_deep_device()) n = 0 c = randint(1, 64) diff --git a/caffe2/python/pybind_state_ideep.cc b/caffe2/python/pybind_state_ideep.cc index 17d5d8f..38df5e1 100644 --- a/caffe2/python/pybind_state_ideep.cc +++ b/caffe2/python/pybind_state_ideep.cc @@ -55,7 +55,9 @@ public: FetchedBlob FetchTensor(const itensor &atensor, bool force_copy) { #ifdef USE_NUMPY FetchedBlob result; - CAFFE_ENFORCE(atensor.materialized(), + CAFFE_ENFORCE((atensor.ndims() != 0) && + (atensor.get_nelems() == 0 || + atensor.get_data_handle() != nullptr), "Trying to fetch uninitialized tensor"); const int numpy_type = CaffeToNumpyType(type_transform(atensor)); CAFFE_ENFORCE( @@ -152,9 +154,7 @@ public: bool ZeroDim(PyArrayObject *array) { #ifdef USE_NUMPY int ndim = PyArray_NDIM(array); - npy_intp *npy_dims = PyArray_DIMS(array); - return ndim == 0 || - std::find(npy_dims, npy_dims + ndim, 0) != npy_dims + ndim; + return ndim == 0; #else CAFFE_THROW("Caffe2 was compiled without NumPy support."); #endif -- 2.7.4