From 9696fee635d85a93722ee66f36f17a78b2b7625b Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Tue, 12 Feb 2019 16:47:53 -0800 Subject: [PATCH] Register CUDA kernels for caffe2 operators (#16691) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16691 Previous diffs already introduced a macro that registers caffe2 CPU kernels with c10. This now also registers the CUDA kernels with it. Reviewed By: bwasti Differential Revision: D13901619 fbshipit-source-id: c15e5b7081ff10e5219af460779b88d6e091a6a6 --- caffe2/core/c10_operator.h | 110 +++++++++++++++++++--- caffe2/operators/layer_norm_op.cu | 4 + caffe2/python/operator_test/layer_norm_op_test.py | 17 +++- 3 files changed, 119 insertions(+), 12 deletions(-) diff --git a/caffe2/core/c10_operator.h b/caffe2/core/c10_operator.h index 98b7562..d154189 100644 --- a/caffe2/core/c10_operator.h +++ b/caffe2/core/c10_operator.h @@ -23,7 +23,11 @@ inline void _call_caffe2_op(const c10::FunctionSchema& schema, std::vector const c10::OperatorHandle& c10_op_handle_for_c2_op(); -template -void call_caffe2_op_from_c10(c10::Stack* stack, c10::KernelCache* cache) { // TODO Pass in correct cache type - _call_caffe2_op_from_c10(stack, c10_op_handle_for_c2_op().schema(), &_call_caffe2_op); +template +void call_caffe2_op_from_c10( + c10::Stack* stack, + c10::KernelCache* cache) { // TODO Pass in correct cache type + _call_caffe2_op_from_c10( + stack, + c10_op_handle_for_c2_op().schema(), + deviceType, + &_call_caffe2_op); } inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName, std::vector inputs, std::vector outputs) { @@ -85,14 +95,51 @@ inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName } } -#define C10_DECLARE_CAFFE2_OPERATOR(OperatorName) \ - namespace caffe2 { namespace _c10_ops { \ - C10_DECLARE_OP_SCHEMA(OperatorName); \ - }} /** - * Call this macro to register a caffe2 operator with the c10 dispatcher. + * To register a caffe2 operator caffe2::MyOperator with the c10 dispatcher, + * call: + * + * In caffe2/operators/MyOperator.h: + * + * > C10_DECLARE_CAFFE2_OPERATOR(C10MyOperator) // C10MyOperator is the name + * used by c10 for this operator + * + * In caffe2/operators/MyOperator.cc + * + * > C10_REGISTER_CAFFE2_OPERATOR_CPU( + * > C10MyOperator, + * > (std::vector{ + * > c10::Argument("input1"), + * > c10::Argument("input2", c10::IntType::get()), + * > c10::Argument("input3", c10::FloatType::get()) + * > }), (std::vector{ + * > c10::Argument("output1"), + * > c10::Argument("output2") + * > }), + * > caffe2::MyOperator // This is the caffe2 operator + * class template > ) + * + * In caffe2/operators/MyOperator.cu + * + * > C10_REGISTER_CAFFE2_OPERATOR_CUDA(C10MyOperator, + * caffe2::MyOperator) + * + * Notes: + * - all macros must be defined in the top level namespace, not in namespace + * caffe2. + * - all operators must call C10_DECLARE_CAFFE2_OPERATOR and + * C10_REGISTER_CAFFE2_OPERATOR_CPU. + * - calling C10_REGISTER_CAFFE2_OPERATOR_CUDA is optional and can be omitted if + * you don't want to expose the operator for CUDA operations. */ +#define C10_DECLARE_CAFFE2_OPERATOR(OperatorName) \ + namespace caffe2 { \ + namespace _c10_ops { \ + C10_DECLARE_OP_SCHEMA(OperatorName); \ + } \ + } + // TODO This macro should take a JIT schema string instead of a vector of inputs and outputs. #define C10_REGISTER_CAFFE2_OPERATOR_CPU( \ OperatorName, Inputs, Outputs, OperatorClass) \ @@ -117,6 +164,47 @@ inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \ namespace c10 { \ C10_REGISTER_KERNEL(caffe2::_c10_ops::OperatorName) /*.withCache()*/ \ - .kernel<&caffe2::detail::call_caffe2_op_from_c10>() \ + .kernel<&caffe2::detail::call_caffe2_op_from_c10< \ + OperatorClass, \ + at::DeviceType::CPU>>() \ .dispatchKey(CPUTensorId()); \ } + +#define C10_REGISTER_CAFFE2_OPERATOR_CUDA(OperatorName, OperatorClass) \ + /* Store the c10 operator handle so call_caffe2_op_from_c10 can access it */ \ + namespace caffe2 { \ + namespace detail { \ + template <> \ + const c10::OperatorHandle& c10_op_handle_for_c2_op() { \ + return caffe2::_c10_ops::OperatorName(); \ + } \ + } \ + } \ + namespace c10 { \ + C10_REGISTER_KERNEL(caffe2::_c10_ops::OperatorName) /*.withCache()*/ \ + .kernel<&caffe2::detail::call_caffe2_op_from_c10< \ + OperatorClass, \ + at::DeviceType::CUDA>>() \ + .dispatchKey(CUDATensorId()); \ + } + +// You should never manually call the C10_REGISTER_CAFFE2_OPERATOR_HIP macro. +// The C10_REGISTER_CAFFE2_OPERATOR_CUDA macro from above will be automatically +// rewritten to C10_REGISTER_CAFFE2_OPERATOR_HIP by hipify. +#define C10_REGISTER_CAFFE2_OPERATOR_HIP(OperatorName, OperatorClass) \ + /* Store the c10 operator handle so call_caffe2_op_from_c10 can access it */ \ + namespace caffe2 { \ + namespace detail { \ + template <> \ + const c10::OperatorHandle& c10_op_handle_for_c2_op() { \ + return caffe2::_c10_ops::OperatorName(); \ + } \ + } \ + } \ + namespace c10 { \ + C10_REGISTER_KERNEL(caffe2::_c10_ops::OperatorName) /*.withCache()*/ \ + .kernel<&caffe2::detail::call_caffe2_op_from_c10< \ + OperatorClass, \ + at::DeviceType::HIP>>() \ + .dispatchKey(CUDATensorId()); \ + } diff --git a/caffe2/operators/layer_norm_op.cu b/caffe2/operators/layer_norm_op.cu index 440783c..c87e267 100644 --- a/caffe2/operators/layer_norm_op.cu +++ b/caffe2/operators/layer_norm_op.cu @@ -267,3 +267,7 @@ void LayerNormGradientOp::LayerNormBackward( REGISTER_CUDA_OPERATOR(LayerNormGradient, LayerNormGradientOp); } // namespace caffe2 + +C10_REGISTER_CAFFE2_OPERATOR_CUDA( + LayerNorm, + caffe2::LayerNormOp) diff --git a/caffe2/python/operator_test/layer_norm_op_test.py b/caffe2/python/operator_test/layer_norm_op_test.py index b971437..df77de4 100644 --- a/caffe2/python/operator_test/layer_norm_op_test.py +++ b/caffe2/python/operator_test/layer_norm_op_test.py @@ -3,7 +3,7 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals -from caffe2.python import brew, core +from caffe2.python import brew, core, workspace from caffe2.python.model_helper import ModelHelper from hypothesis import given import caffe2.python.hypothesis_test_util as hu @@ -170,6 +170,21 @@ class TestLayerNormOp(serial.SerializedTestCase): torch.testing.assert_allclose(expected_mean, actual_mean) torch.testing.assert_allclose(expected_stdev, actual_stdev) + # Test case is using workspace.has_cuda_support and not workspace.has_gpu_support + # to exclude it from HIP because tensor interop doesn't work for HIP tensors yet + @unittest.skipIf(not workspace.has_cuda_support, "No cuda support") + @given(X=hu.tensor(min_dim=2)) + def test_layer_norm_op_pytorch_cuda(self, X): + axis = np.random.randint(0, len(X.shape)) + epsilon = 1e-4 + + expected_norm, expected_mean, expected_stdev = _layer_norm_ref(axis, epsilon, X) + actual_norm, actual_mean, actual_stdev = torch.ops._caffe2.LayerNorm(torch.tensor(X).cuda(), axis, epsilon) + + torch.testing.assert_allclose(expected_norm, actual_norm.cpu()) + torch.testing.assert_allclose(expected_mean, actual_mean.cpu()) + torch.testing.assert_allclose(expected_stdev, actual_stdev.cpu()) + @given(X=hu.tensor(min_dim=2), **hu.gcs) def test_layer_norm_op_jit(self, X, gc, dc): @torch.jit.script -- 2.7.4