From 926e718d5fe129f67a10eef7ef8ce754b25c1e1e Mon Sep 17 00:00:00 2001 From: "Cheng,Penghui" Date: Fri, 11 Jan 2019 12:48:57 -0800 Subject: [PATCH] Add/fallback some operators for mkl-dnn (#11696) Summary: Implementation LeakyRelu operator for mkl-dnn,the speed-up of a single operation is up to 10X on BDW. Implementation rashape operator for mkl-dnn,it will resolve occasionally crash issue which use fallback reshape operator. Implementation CreateBlobQueue and SafeEnqueueBlobs operators,it will resolve crash issue which use fallback operators. Fallback CreateBlobsQueueDBOp,TensorProtosDBInput,CloseBlobsQueue operators. Implement adam operator for mkl-dnn,the speed-up of a single operator is up to 6X on BDW. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11696 Reviewed By: yinghai Differential Revision: D10100438 Pulled By: wesolwsk fbshipit-source-id: 0b6e06897cc11e0a8e349d80a870b1e72e47f10d --- caffe2/ideep/operators/adam_op.cc | 179 ++++++++++++++++++++++ caffe2/ideep/operators/operator_fallback_ideep.cc | 19 +-- caffe2/ideep/operators/operator_fallback_ideep.h | 2 +- caffe2/ideep/operators/queue_ops.cc | 71 +++++++++ caffe2/ideep/operators/relu_op.cc | 38 ++++- caffe2/ideep/operators/reshape_op.cc | 121 +++++++++++++++ caffe2/python/ideep/adam_op_test.py | 82 ++++++++++ caffe2/python/ideep/blobs_queue_db_test.py | 109 +++++++++++++ caffe2/python/ideep/leaky_relu_op_test.py | 91 +++++++++++ caffe2/python/ideep/reshape_ops_test.py | 141 +++++++++++++++++ caffe2/queue/blobs_queue_db.cc | 11 ++ 11 files changed, 845 insertions(+), 19 deletions(-) create mode 100644 caffe2/ideep/operators/adam_op.cc create mode 100644 caffe2/ideep/operators/queue_ops.cc create mode 100644 caffe2/ideep/operators/reshape_op.cc create mode 100644 caffe2/python/ideep/adam_op_test.py create mode 100644 caffe2/python/ideep/blobs_queue_db_test.py create mode 100644 caffe2/python/ideep/leaky_relu_op_test.py create mode 100644 caffe2/python/ideep/reshape_ops_test.py diff --git a/caffe2/ideep/operators/adam_op.cc b/caffe2/ideep/operators/adam_op.cc new file mode 100644 index 0000000..732cf74 --- /dev/null +++ b/caffe2/ideep/operators/adam_op.cc @@ -0,0 +1,179 @@ +#include + +namespace caffe2 { + +void adam_ideep_update( + int N, + const float* g, + const float* m, + const float* v, + float* ng, + float* nm, + float* nv, + float beta1, + float beta2, + float eps_hat, + float correction, + const float* lr) { +#ifdef _OPENMP + #pragma omp parallel for schedule(static) +#endif + for (auto i = 0; i < N; ++i) { + float gi = g[i]; + float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); + float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); + ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat); + } +} + +void adam_ideep_compute( + int N, + const float* w, + const float* g, + const float* m, + const float* v, + float* nw, + float* nm, + float* nv, + float beta1, + float beta2, + float eps_hat, + float correction, + const float* lr) { +#ifdef _OPENMP + #pragma omp parallel for schedule(static) +#endif + for (auto i = 0; i < N; ++i) { + float gi = g[i]; + float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); + float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); + nw[i] = w[i] + lr[0] * correction * mi / (std::sqrt(vi) + eps_hat); + } +} + +void adam_ideep_compute_output_grad( + int N, + const float* w, + const float* g, + const float* m, + const float* v, + float* nw, + float* nm, + float* nv, + float* ng, + float beta1, + float beta2, + float eps_hat, + float correction, + const float* lr) { + +#ifdef _OPENMP + #pragma omp parallel for schedule(static) +#endif + for (auto i = 0; i < N; ++i) { + float gi = g[i]; + float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); + float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); + float ngi = ng[i] = correction * mi / (std::sqrt(vi) + eps_hat); + nw[i] = w[i] + lr[0] * ngi; + } +} + +template +class IDEEPAdamOp final : public IDEEPOperator { + public: + USE_IDEEP_DEF_ALIASES(); + USE_IDEEP_OPERATOR_FUNCTIONS(); + + IDEEPAdamOp(const OperatorDef& operator_def, Workspace* ws) + : IDEEPOperator(operator_def, ws), + beta1_(OperatorBase::GetSingleArgument("beta1", 0.9f)), + beta2_(OperatorBase::GetSingleArgument("beta2", 0.999f)), + epsilon_(OperatorBase::GetSingleArgument("epsilon", 1e-5f)) {} + bool RunOnDevice() override { + // Iter live on the CPU + CAFFE_ENFORCE(OperatorBase::InputIsTensorType(ITER, CPU)); + const auto& params = Input(PARAM); + const auto& moment_1 = Input(MOMENT_1); + const auto& moment_2 = Input(MOMENT_2); + const auto& grad = Input(GRAD); + // TODO: Use itensor after 0-dim is supported. Now use CPU tensor. + const auto& lr = OperatorBase::Input(LR, CPU); + auto* out_params = Output(OUTPUT_PARAM); + auto* out_moment1 = Output(OUTPUT_MOMENT_1); + auto* out_moment2 = Output(OUTPUT_MOMENT_2); + + CAFFE_ENFORCE(lr.size() == 1); + CAFFE_ENFORCE(grad.get_nelems() == params.get_nelems()); + CAFFE_ENFORCE(grad.get_nelems() == moment_1.get_nelems()); + CAFFE_ENFORCE(grad.get_nelems() == moment_2.get_nelems()); + if (params != *out_params) + out_params->reinit(params.get_descriptor()); + if (moment_1 != *out_moment1) + out_moment1->reinit(moment_1.get_descriptor()); + if (moment_2 != *out_moment2) + out_moment2->reinit(moment_2.get_descriptor()); + const auto w = static_cast(params.get_data_handle()); + const auto g = static_cast(grad.get_data_handle()); + const auto m = static_cast(moment_1.get_data_handle()); + const auto v = static_cast(moment_2.get_data_handle()); + auto nw = static_cast(out_params->get_data_handle()); + auto nm = static_cast(out_moment1->get_data_handle()); + auto nv = static_cast(out_moment2->get_data_handle()); + const auto nlr = lr.template data(); + const auto iter = + OperatorBase::Input(ITER, CPU).template data()[0]; + const auto t = iter + 1; + const auto correction = + std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t)); + if (OutputSize() == 3) { + adam_ideep_compute( + grad.get_nelems(), + w, + g, + m, + v, + nw, + nm, + nv, + beta1_, + beta2_, + epsilon_, + correction, + nlr); + } else { + auto* out_grad = Output(OUTPUT_GRAD); + if (grad != *out_grad) + out_grad->reinit(grad.get_descriptor()); + auto ng = static_cast(out_grad->get_data_handle()); + adam_ideep_compute_output_grad( + grad.get_nelems(), + w, + g, + m, + v, + nw, + nm, + nv, + ng, + beta1_, + beta2_, + epsilon_, + correction, + nlr); + } + + return true; + } + + protected: + T beta1_{0.9}; + T beta2_{0.999}; + T epsilon_{1e-8}; + INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER); + OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD); +}; + +REGISTER_IDEEP_OPERATOR(Adam, IDEEPAdamOp); + +} // namespace caffe2 diff --git a/caffe2/ideep/operators/operator_fallback_ideep.cc b/caffe2/ideep/operators/operator_fallback_ideep.cc index d078a56..6016923 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.cc +++ b/caffe2/ideep/operators/operator_fallback_ideep.cc @@ -18,12 +18,10 @@ #include #include #include -#include #include #include #include #include -#include #include #include #include @@ -32,9 +30,10 @@ #include #include #include -#include #include #include +#include +#include // can add more non-IDEEP operators if needed namespace caffe2 { @@ -52,9 +51,6 @@ REGISTER_IDEEP_OPERATOR( REGISTER_IDEEP_OPERATOR(Flatten, IDEEPFallbackOp>); REGISTER_IDEEP_OPERATOR(ResizeLike, IDEEPFallbackOp>); REGISTER_IDEEP_OPERATOR(Transpose, IDEEPFallbackOp>); -REGISTER_IDEEP_OPERATOR( - Reshape, - IDEEPFallbackOp, SkipIndices<1>>); // filter operators REGISTER_IDEEP_OPERATOR( @@ -109,7 +105,7 @@ REGISTER_IDEEP_OPERATOR( REGISTER_IDEEP_OPERATOR( PRelu, IDEEPFallbackOp>); - + // ctc decoder operators REGISTER_IDEEP_OPERATOR( CTCGreedyDecoder, @@ -134,9 +130,6 @@ REGISTER_IDEEP_OPERATOR( LearningRate, IDEEPFallbackOp>); REGISTER_IDEEP_OPERATOR( - LeakyRelu, - IDEEPFallbackOp>); -REGISTER_IDEEP_OPERATOR( Mul, IDEEPFallbackOp< BinaryElementwiseOp>>); @@ -170,14 +163,12 @@ REGISTER_IDEEP_OPERATOR( ConvTransposeGradient, IDEEPFallbackOp>); REGISTER_IDEEP_OPERATOR( - LeakyReluGradient, - IDEEPFallbackOp>); -REGISTER_IDEEP_OPERATOR( MulGradient, IDEEPFallbackOp>>); -REGISTER_IDEEP_OPERATOR(Adam, IDEEPFallbackOp>); +REGISTER_IDEEP_OPERATOR(TensorProtosDBInput, IDEEPFallbackOp>); +REGISTER_IDEEP_OPERATOR(CloseBlobsQueue, IDEEPFallbackOp>); } // namespace caffe2 diff --git a/caffe2/ideep/operators/operator_fallback_ideep.h b/caffe2/ideep/operators/operator_fallback_ideep.h index 77001b9..4372807 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.h +++ b/caffe2/ideep/operators/operator_fallback_ideep.h @@ -111,7 +111,7 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator { } } - if (!base_op_->Run()) { + if (!base_op_->Run(0)) { LOG(ERROR) << "Base op run failed in IDEEPFallbackOp. Def: " << ProtoDebugString(this->debug_def()); return false; diff --git a/caffe2/ideep/operators/queue_ops.cc b/caffe2/ideep/operators/queue_ops.cc new file mode 100644 index 0000000..fb7887c --- /dev/null +++ b/caffe2/ideep/operators/queue_ops.cc @@ -0,0 +1,71 @@ +#include +#include + +namespace caffe2 { + +class IDEEPCreateBlobsQueueOp final : public IDEEPOperator { + public: + USE_IDEEP_DEF_ALIASES(); + USE_IDEEP_OPERATOR_FUNCTIONS(); + + IDEEPCreateBlobsQueueOp(const OperatorDef& operator_def, Workspace* ws) + : IDEEPOperator(operator_def, ws), + ws_(ws), + name(operator_def.output().Get(0)) {} + + bool RunOnDevice() override { + const auto capacity = GetSingleArgument("capacity", 1); + const auto numBlobs = GetSingleArgument("num_blobs", 1); + const auto enforceUniqueName = + GetSingleArgument("enforce_unique_name", false); + const auto fieldNames = + OperatorBase::template GetRepeatedArgument("field_names"); + CAFFE_ENFORCE_EQ(this->OutputSize(), 1); + auto queuePtr = OperatorBase::Outputs()[0] + ->template GetMutable>(); + + CAFFE_ENFORCE(queuePtr); + *queuePtr = std::make_shared( + ws_, name, capacity, numBlobs, enforceUniqueName, fieldNames); + return true; + } + + private: + Workspace* ws_{nullptr}; + const std::string name; +}; + +class IDEEPSafeEnqueueBlobsOp final : public IDEEPOperator { + public: + USE_IDEEP_DEF_ALIASES(); + USE_IDEEP_OPERATOR_FUNCTIONS(); + + IDEEPSafeEnqueueBlobsOp(const OperatorDef& operator_def, Workspace* ws) + : IDEEPOperator(operator_def, ws) {} + + bool RunOnDevice() override { + auto queue = + OperatorBase::Inputs()[0]->template Get>(); + CAFFE_ENFORCE(queue); + auto size = queue->getNumBlobs(); + CAFFE_ENFORCE( + OutputSize() == size + 1, + "Expected " + caffe2::to_string(size + 1) + ", " + + " got: " + caffe2::to_string(size)); + bool status = queue->blockingWrite(OperatorBase::Outputs()); + + auto st = OperatorBase::Output(1, CPU); + st->Resize(); + auto stat = st->template mutable_data(); + stat[0] = !status; + return true; + } +}; + +REGISTER_IDEEP_OPERATOR(CreateBlobsQueue, IDEEPCreateBlobsQueueOp); +SHOULD_NOT_DO_GRADIENT(IDEEPCreateBlobsQueueOp); + +REGISTER_IDEEP_OPERATOR(SafeEnqueueBlobs, IDEEPSafeEnqueueBlobsOp); +SHOULD_NOT_DO_GRADIENT(IDEEPSafeEnqueueBlobsOp); + +} // namespace caffe2 diff --git a/caffe2/ideep/operators/relu_op.cc b/caffe2/ideep/operators/relu_op.cc index 7f81d0e..7e591ff 100644 --- a/caffe2/ideep/operators/relu_op.cc +++ b/caffe2/ideep/operators/relu_op.cc @@ -8,19 +8,33 @@ class IDEEPReluOp final : public IDEEPOperator { USE_IDEEP_OPERATOR_FUNCTIONS(); IDEEPReluOp(const OperatorDef& operator_def, Workspace* ws) - : IDEEPOperator(operator_def, ws) {} + : IDEEPOperator(operator_def, ws), alpha_(0.0) { + // Figure out the Relu descriptor. + if (operator_def.type().substr(0, 4) == "Relu") { + alpha_ = 0.0; + } else if (operator_def.type().substr(0, 9) == "LeakyRelu") { + if (HasArgument("alpha")) { + alpha_ = static_cast( + OperatorBase::GetSingleArgument("alpha", 0.01)); + } + } else { + LOG(FATAL) << "Unsupported Relu method: " << operator_def.type(); + } + } virtual ~IDEEPReluOp() {} bool RunOnDevice() override { const auto& X = Input(INPUT); auto* Y = Output(OUTPUT); - ideep::eltwise_forward::compute(X, *Y); + ideep::eltwise_forward::compute( + X, *Y, ialgo::eltwise_relu, iprop::forward_training, alpha_); return true; } private: + float alpha_; INPUT_TAGS(INPUT); OUTPUT_TAGS(OUTPUT); @@ -32,7 +46,19 @@ class IDEEPReluGradientOp final : public IDEEPOperator { USE_IDEEP_OPERATOR_FUNCTIONS(); IDEEPReluGradientOp(const OperatorDef& operator_def, Workspace* ws) - : IDEEPOperator(operator_def, ws) {} + : IDEEPOperator(operator_def, ws), alpha_(0.0) { + // Figure out the Relu descriptor. + if (operator_def.type().substr(0, 12) == "ReluGradient") { + alpha_ = 0.0; + } else if (operator_def.type().substr(0, 17) == "LeakyReluGradient") { + if (HasArgument("alpha")) { + alpha_ = static_cast( + OperatorBase::GetSingleArgument("alpha", 0.01)); + } + } else { + LOG(FATAL) << "Unsupported Relu method: " << operator_def.type(); + } + } virtual ~IDEEPReluGradientOp() {} bool RunOnDevice() override { @@ -40,12 +66,13 @@ class IDEEPReluGradientOp final : public IDEEPOperator { const auto& dY = Input(OUTPUT_GRAD); auto* dX = Output(INPUT_GRAD); - ideep::eltwise_backward::compute(Y, dY, *dX); + ideep::eltwise_backward::compute(Y, dY, *dX, ialgo::eltwise_relu, alpha_); return true; } private: + float alpha_; INPUT_TAGS(OUTPUT, OUTPUT_GRAD); OUTPUT_TAGS(INPUT_GRAD); @@ -54,4 +81,7 @@ class IDEEPReluGradientOp final : public IDEEPOperator { REGISTER_IDEEP_OPERATOR(Relu, IDEEPReluOp); REGISTER_IDEEP_OPERATOR(ReluGradient, IDEEPReluGradientOp); +REGISTER_IDEEP_OPERATOR(LeakyRelu, IDEEPReluOp); +REGISTER_IDEEP_OPERATOR(LeakyReluGradient, IDEEPReluGradientOp); + } // namespace caffe2 diff --git a/caffe2/ideep/operators/reshape_op.cc b/caffe2/ideep/operators/reshape_op.cc new file mode 100644 index 0000000..6a63359 --- /dev/null +++ b/caffe2/ideep/operators/reshape_op.cc @@ -0,0 +1,121 @@ +#include + +namespace caffe2 { + +// Takes a shape and data tensor and reshapes it +class IDEEPReshapeOp final : public IDEEPOperator { + public: + USE_IDEEP_DEF_ALIASES(); + USE_IDEEP_OPERATOR_FUNCTIONS(); + + IDEEPReshapeOp(const OperatorDef& operator_def, Workspace* ws) + : IDEEPOperator(operator_def, ws), + new_shape_(OperatorBase::GetRepeatedArgument("shape")) {} + + bool RunOnDevice() override { + ideep::tensor::dims actual_new_shape = new_shape_; + if (InputSize() == 2) { + CAFFE_ENFORCE( + !OperatorBase::HasArgument("shape"), + "New shape is specified by the input blob, do not pass in " + "the argument `shape`."); + + // shape info live on CPU + auto& shape = OperatorBase::Input(1, CPU); + CAFFE_ENFORCE(shape.ndim() == 1, "Shape should be 1-D"); + const int* shape_data = shape.template data(); + + actual_new_shape.reserve(shape.size()); + actual_new_shape.assign(shape_data, shape_data + shape.size()); + } else { + CAFFE_ENFORCE( + OperatorBase::HasArgument("shape"), "Argument `shape` is missing."); + } + + auto& input = Input(0); + // Copy over the dimensions for those that are specified zero. + for (int i = 0; i < actual_new_shape.size() && i < input.ndims(); ++i) { + if (actual_new_shape[i] == 0) { + actual_new_shape[i] = input.get_dim(i); + } + } + + // Checks if the new shape is valid and fills in the missing dimension + // specified by -1. + // NOTE: At most one dimension can be -1. + auto total_size = input.get_nelems(); + int size = 1; + int unknown_idx = -1; + for (int i = 0; i < actual_new_shape.size(); ++i) { + const auto dim = actual_new_shape[i]; + if (dim == -1) { + CAFFE_ENFORCE( + unknown_idx == -1, + "Argument `shape` has more than one missing dimension."); + unknown_idx = i; + } else { + size *= dim; + } + } + if (size == 0 && total_size != 0) { + CAFFE_THROW( + "Can not reshape a non-zero size (", + total_size, + ") tensor to zero size."); + } + + if (unknown_idx != -1) { + CAFFE_ENFORCE_NE( + size, + 0, + "New shape at dim ", + unknown_idx, + " can not be inferred since new size is zero."); + CAFFE_ENFORCE( + total_size % size == 0, + "Argument `shape` does not agree with the input data.", + " (", + total_size, + " vs ", + size, + ")"); + actual_new_shape[unknown_idx] = total_size / size; + } else { + CAFFE_ENFORCE_EQ( + total_size, + size, + "Argument `shape` does not agree with the input data.", + " (", + total_size, + " != ", + size, + ")"); + } + + // Write the original shape to the second output. + // shape info live on CPU + TensorCPU* old_shape = OperatorBase::Output(1, CPU); + old_shape->Resize(input.ndims()); + int* old_shape_data = old_shape->template mutable_data(); + for (int i = 0; i < input.ndims(); ++i) { + old_shape_data[i] = input.get_dim(i); + } + + auto* output = Output(0); + if (output != &input) { + // If we are not doing in-place computation, a copy is needed. + output->reinit_like(input); + ideep::direct_copy::compute(input, *output); + } + + output->reshape(actual_new_shape); + return true; + } + + private: + ideep::tensor::dims new_shape_; +}; + +REGISTER_IDEEP_OPERATOR(Reshape, IDEEPReshapeOp); + +} // namespace caffe2 diff --git a/caffe2/python/ideep/adam_op_test.py b/caffe2/python/ideep/adam_op_test.py new file mode 100644 index 0000000..a0d9b2c --- /dev/null +++ b/caffe2/python/ideep/adam_op_test.py @@ -0,0 +1,82 @@ +from __future__ import unicode_literals +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import hypothesis.strategies as st +import unittest +import caffe2.python.hypothesis_test_util as hu +from caffe2.python import core, workspace +from hypothesis import given +import caffe2.python.ideep_test_util as mu + + +@unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.") +class TestAdamOps(hu.HypothesisTestCase): + @given(inputs=hu.tensors(n=4), + ITER=st.integers(min_value=0, max_value=10000), + LR=st.floats(min_value=0.01, max_value=0.99, + allow_nan=False, allow_infinity=False), + beta1=st.floats(min_value=0.01, max_value=0.99, + allow_nan=False, allow_infinity=False), + beta2=st.floats(min_value=0.01, max_value=0.99, + allow_nan=False, allow_infinity=False), + epsilon=st.floats(min_value=0.01, max_value=0.99, + allow_nan=False, allow_infinity=False), + **mu.gcs) + def test_adam(self, inputs, ITER, LR, beta1, beta2, epsilon, gc, dc): + param, mom1, mom2, grad = inputs + ITER = np.array([ITER], dtype=np.int64) + LR = np.array([LR], dtype=np.float32) + mom2 = np.absolute(mom2) + op = core.CreateOperator( + "Adam", + ["param", "mom1", "mom2", "grad", "lr", "iter"], + ["output_param", "output_mom1", "output_mom2"], + beta1=beta1, beta2=beta2, epsilon=epsilon) + # Iter lives on the CPU + input_device_options = {'iter': hu.cpu_do, 'lr': hu.cpu_do} + + self.assertDeviceChecks( + dc, op, + [param, mom1, mom2, grad, LR, ITER], + [0], + input_device_options=input_device_options, + threshold=0.001) + + @given(inputs=hu.tensors(n=4), + ITER=st.integers(min_value=0, max_value=10000), + LR=st.floats(min_value=0.01, max_value=0.99, + allow_nan=False, allow_infinity=False), + beta1=st.floats(min_value=0.01, max_value=0.99, + allow_nan=False, allow_infinity=False), + beta2=st.floats(min_value=0.01, max_value=0.99, + allow_nan=False, allow_infinity=False), + epsilon=st.floats(min_value=0.01, max_value=0.99, + allow_nan=False, allow_infinity=False), + **mu.gcs) + def test_adam_output_grad(self, inputs, ITER, LR, beta1, beta2, epsilon, gc, dc): + param, mom1, mom2, grad = inputs + ITER = np.array([ITER], dtype=np.int64) + LR = np.array([LR], dtype=np.float32) + mom2 = np.absolute(mom2) + + op = core.CreateOperator( + "Adam", + ["param", "mom1", "mom2", "grad", "lr", "iter"], + ["output_param", "output_mom1", "output_mom2", "output_grad"], + beta1=beta1, beta2=beta2, epsilon=epsilon) + + # Iter lives on the CPU + input_device_options = {'iter': hu.cpu_do, 'lr': hu.cpu_do} + + self.assertDeviceChecks( + dc, op, + [param, mom1, mom2, grad, LR, ITER], + [0], + input_device_options=input_device_options, + threshold=0.001) + +if __name__ == "__main__": + unittest.main() diff --git a/caffe2/python/ideep/blobs_queue_db_test.py b/caffe2/python/ideep/blobs_queue_db_test.py new file mode 100644 index 0000000..ded18e8 --- /dev/null +++ b/caffe2/python/ideep/blobs_queue_db_test.py @@ -0,0 +1,109 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest +import numpy as np + +import caffe2.proto.caffe2_pb2 as caffe2_pb2 +from caffe2.python import core, workspace, timeout_guard + + +@unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.") +class BlobsQueueDBTest(unittest.TestCase): + def test_create_blobs_queue_db_string(self): + device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0) + with core.DeviceScope(device_opt): + def add_blobs(queue, num_samples): + blob = core.BlobReference("blob") + status = core.BlobReference("blob_status") + for i in range(num_samples): + self._add_blob_to_queue( + queue, self._create_test_tensor_protos(i), blob, status + ) + self._test_create_blobs_queue_db(add_blobs) + + def test_create_blobs_queue_db_tensor(self): + device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0) + with core.DeviceScope(device_opt): + def add_blobs(queue, num_samples): + blob = core.BlobReference("blob") + status = core.BlobReference("blob_status") + for i in range(num_samples): + data = self._create_test_tensor_protos(i) + data = np.array([data], dtype=str) + self._add_blob_to_queue( + queue, data, blob, status + ) + self._test_create_blobs_queue_db(add_blobs) + + def _test_create_blobs_queue_db(self, add_blobs_fun): + device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0) + with core.DeviceScope(device_opt): + num_samples = 10000 + batch_size = 10 + init_net = core.Net('init_net') + net = core.Net('test_create_blobs_queue_db') + queue = init_net.CreateBlobsQueue([], 'queue', capacity=num_samples) + reader = init_net.CreateBlobsQueueDB( + [queue], + 'blobs_queue_db_reader', + value_blob_index=0, + timeout_secs=0.1, + ) + workspace.RunNetOnce(init_net) + add_blobs_fun(queue, num_samples) + + net.TensorProtosDBInput( + [reader], + ['image', 'label'], + batch_size=batch_size + ) + workspace.CreateNet(net) + + close_net = core.Net('close_net') + close_net.CloseBlobsQueue([queue], []) + + for i in range(int(num_samples / batch_size)): + with timeout_guard.CompleteInTimeOrDie(2.0): + workspace.RunNet(net) + + images = workspace.FetchBlob('image') + labels = workspace.FetchBlob('label') + self.assertEqual(batch_size, len(images)) + self.assertEqual(batch_size, len(labels)) + for idx, item in enumerate(images): + self.assertEqual( + "foo{}".format(i * batch_size + idx).encode('utf-8'), item + ) + for item in labels: + self.assertEqual(1, item) + workspace.RunNetOnce(close_net) + + def _add_blob_to_queue(self, queue, data, blob, status): + device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0) + with core.DeviceScope(device_opt): + workspace.FeedBlob(blob, data, core.DeviceOption(caffe2_pb2.CPU, 0)) + op = core.CreateOperator( + "SafeEnqueueBlobs", + [queue, blob], + [blob, status], + ) + + workspace.RunOperatorOnce(op) + + def _create_test_tensor_protos(self, idx): + item = caffe2_pb2.TensorProtos() + data = item.protos.add() + data.data_type = core.DataType.STRING + data.string_data.append("foo{}".format(idx).encode('utf-8')) + label = item.protos.add() + label.data_type = core.DataType.INT32 + label.int32_data.append(1) + + return item.SerializeToString() + +if __name__ == "__main__": + import unittest + unittest.main() diff --git a/caffe2/python/ideep/leaky_relu_op_test.py b/caffe2/python/ideep/leaky_relu_op_test.py new file mode 100644 index 0000000..1af77b9 --- /dev/null +++ b/caffe2/python/ideep/leaky_relu_op_test.py @@ -0,0 +1,91 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest +import hypothesis.strategies as st +from hypothesis import given +import numpy as np +from caffe2.python import core, workspace, model_helper +import caffe2.python.hypothesis_test_util as hu +import caffe2.python.ideep_test_util as mu + + +@unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.") +class LeakyReluTest(hu.HypothesisTestCase): + def _get_inputs(self, N, C, H, W, order): + input_data = np.random.rand(N, C, H, W).astype(np.float32) - 0.5 + + # default step size is 0.05 + input_data[np.logical_and( + input_data >= 0, input_data <= 0.051)] = 0.051 + input_data[np.logical_and( + input_data <= 0, input_data >= -0.051)] = -0.051 + + return input_data, + + def _get_op(self, device_option, alpha, order, inplace=False): + outputs = ['output' if not inplace else "input"] + op = core.CreateOperator( + 'LeakyRelu', + ['input'], + outputs, + alpha=alpha, + device_option=device_option) + return op + + def _feed_inputs(self, input_blobs, device_option): + names = ['input', 'scale', 'bias'] + for name, blob in zip(names, input_blobs): + self.ws.create_blob(name).feed(blob, device_option=device_option) + + @given(N=st.integers(2, 3), + C=st.integers(2, 3), + H=st.integers(2, 3), + W=st.integers(2, 3), + alpha=st.floats(0, 1), + seed=st.integers(0, 1000), + **mu.gcs) + def test_leaky_relu_gradients(self, gc, dc, N, C, H, W, alpha, seed): + np.random.seed(seed) + + op = self._get_op( + device_option=gc, + alpha=alpha, + order='NCHW') + input_blobs = self._get_inputs(N, C, H, W, "NCHW") + + self.assertDeviceChecks(dc, op, input_blobs, [0]) + self.assertGradientChecks(gc, op, input_blobs, 0, [0]) + + @given(N=st.integers(2, 10), + C=st.integers(3, 10), + H=st.integers(5, 10), + W=st.integers(7, 10), + alpha=st.floats(0, 1), + seed=st.integers(0, 1000)) + def test_leaky_relu_model_helper_helper(self, N, C, H, W, alpha, seed): + np.random.seed(seed) + order = 'NCHW' + arg_scope = {'order': order} + model = model_helper.ModelHelper(name="test_model", arg_scope=arg_scope) + model.LeakyRelu( + 'input', + 'output', + alpha=alpha) + + input_blob = np.random.rand(N, C, H, W).astype(np.float32) + + self.ws.create_blob('input').feed(input_blob) + + self.ws.create_net(model.param_init_net).run() + self.ws.create_net(model.net).run() + + output_blob = self.ws.blobs['output'].fetch() + + assert output_blob.shape == (N, C, H, W) + + +if __name__ == "__main__": + unittest.main() diff --git a/caffe2/python/ideep/reshape_ops_test.py b/caffe2/python/ideep/reshape_ops_test.py new file mode 100644 index 0000000..8136172 --- /dev/null +++ b/caffe2/python/ideep/reshape_ops_test.py @@ -0,0 +1,141 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python.test_util import TestCase +from caffe2.proto import caffe2_pb2 +import unittest +import numpy as np +from caffe2.python import core, workspace + + +@unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.") +class TestReShapeOps(TestCase): + def test_reshape_ops(self): + device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0) + with core.DeviceScope(device_opt): + workspace.FeedBlob('res', np.array([[0, 0, 0, 0]], dtype=np.float32)) + workspace.FeedBlob('shape', np.array([1, 4], dtype=np.int32), core.DeviceOption(caffe2_pb2.CPU, 0)) + workspace.FeedBlob('input', np.zeros((2, 2), dtype=np.float32)) + workspace.RunOperatorOnce(core.CreateOperator( + 'Reshape', ['input', 'shape'], ['output', 'old_shape'])) + assert ((workspace.FetchBlob('output') == + workspace.FetchBlob('res')).all()) + + def test_basic_reshape(self): + _test_reshape(old_shape=(4, 2, 1), new_shape=(2, 4)) + _test_reshape(old_shape=(4, 2, 1), new_shape=(2, 4), arg_shape=False) + + def test_missing_dim(self): + _test_reshape(old_shape=(4, 2, 1), new_shape=(-1, 8)) + _test_reshape(old_shape=(4, 2, 1), new_shape=(-1, 8), arg_shape=False) + + def test_in_place(self): + _test_reshape(old_shape=(4, 2, 1), new_shape=(-1, 8), in_place=True) + _test_reshape(old_shape=(4, 2, 1), new_shape=(-1, 8), + in_place=True, arg_shape=False) + + def test_zero_dim(self): + _test_reshape(old_shape=(4, 2, 1), new_shape=(0, 0, 0), + expected_shape=(4, 2, 1)) + _test_reshape(old_shape=(4, 2, 1), new_shape=(0, 0, 0), + expected_shape=(4, 2, 1), arg_shape=False) + _test_reshape(old_shape=(4, 2, 1), new_shape=(0, 2, 1), + expected_shape=(4, 2, 1)) + _test_reshape(old_shape=(4, 2, 1), new_shape=(0, 2, 1), + expected_shape=(4, 2, 1), arg_shape=False) + + def test_zero_dim_and_missing_dim(self): + _test_reshape(old_shape=(4, 2, 1), new_shape=(0, -1, 0), + expected_shape=(4, 2, 1)) + _test_reshape(old_shape=(4, 2, 1), new_shape=(0, -1, 0), + expected_shape=(4, 2, 1), arg_shape=False) + _test_reshape(old_shape=(4, 3, 2), new_shape=(-1, 0), + expected_shape=(8, 3)) + _test_reshape(old_shape=(4, 3, 2), new_shape=(-1, 0), + expected_shape=(8, 3), arg_shape=False) + + def test_backprop(self): + device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0) + with core.DeviceScope(device_opt): + old_shape = (4, 2, 1) + new_shape = (1, 8) + X = np.random.rand(*old_shape).astype(np.float32) + Y = np.random.rand(*new_shape).astype(np.float32) + + net = core.Net('net') + + net.GivenTensorFill([], 'X', shape=old_shape, values=X.flatten()) + net.GivenTensorFill([], 'Y', shape=new_shape, values=Y.flatten()) + + net.Reshape(['X'], ['X_out', 'old_shape'], shape=new_shape) + net.Mul(['X_out', 'Y'], 'Z') + net.AddGradientOperators(['Z']) + + workspace.RunNetOnce(net) + + Z = workspace.FetchBlob('Z') + X_grad = workspace.FetchBlob('X_grad') + + # Check forward computation + np.testing.assert_allclose( + Z.squeeze(), (X.reshape(new_shape) * Y).squeeze(), rtol=1e-5) + + # Check the shape of the gradient + np.testing.assert_array_equal(X_grad.shape, X.shape) + + # Check the gradient + np.testing.assert_allclose(X_grad, Y.reshape(old_shape), rtol=1e-5) + + def test_input_shape_changes(self): + device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0) + with core.DeviceScope(device_opt): + workspace.FeedBlob( + 'input_blob', + np.array(np.random.rand(10, 20, 10), dtype=np.float32)) + net = core.Net('mynet') + z, _ = net.Reshape('input_blob', + ['z_reshape', 'dummy_size'], + shape=(-1, 10)) + workspace.CreateNet(net) + workspace.RunNet(net) + workspace.FeedBlob( + 'input_blob', + np.array(np.random.rand(10, 40, 10), dtype=np.float32)) + workspace.RunNet(net) + + +def _test_reshape(old_shape, new_shape, expected_shape=None, arg_shape=True, + in_place=False): + devices = [core.DeviceOption(caffe2_pb2.IDEEP, 0)] + + for device_opt in devices: + with core.DeviceScope(device_opt): + if expected_shape is None: + expected_shape = new_shape + X = np.random.rand(*old_shape).astype(np.float32) + + blob_in = 'X' + blob_out = blob_in if in_place else blob_in + '_out' + + if arg_shape: + op = core.CreateOperator('Reshape', + [blob_in], + [blob_out, 'old_shape'], + shape=new_shape) + else: + op = core.CreateOperator('Reshape', + [blob_in, 'new_shape'], + [blob_out, 'old_shape']) + workspace.FeedBlob('new_shape', np.asarray(new_shape, dtype=np.int32), + core.DeviceOption(caffe2_pb2.CPU, 0)) + + workspace.FeedBlob(blob_in, X) + workspace.RunOperatorOnce(op) + + Y = workspace.FetchBlob(blob_out) + np.testing.assert_allclose(Y, X.reshape(expected_shape)) + +if __name__ == "__main__": + unittest.main() diff --git a/caffe2/queue/blobs_queue_db.cc b/caffe2/queue/blobs_queue_db.cc index bd7795c..7d3806a 100644 --- a/caffe2/queue/blobs_queue_db.cc +++ b/caffe2/queue/blobs_queue_db.cc @@ -10,6 +10,11 @@ #include "caffe2/core/operator.h" #include "caffe2/queue/blobs_queue.h" +#ifdef CAFFE2_USE_MKLDNN +#include +#include +#endif + namespace caffe2 { namespace db { @@ -37,6 +42,12 @@ class CreateBlobsQueueDBOp : public Operator { REGISTER_CPU_OPERATOR(CreateBlobsQueueDB, CreateBlobsQueueDBOp); +#ifdef CAFFE2_USE_MKLDNN +REGISTER_IDEEP_OPERATOR( + CreateBlobsQueueDB, + IDEEPFallbackOp, SkipIndices<0>>); +#endif + OPERATOR_SCHEMA(CreateBlobsQueueDB) .NumInputs(1) .NumOutputs(1) -- 2.7.4