From 1989716ae56a76168291f9f462d4af7852164b4f Mon Sep 17 00:00:00 2001 From: bddppq Date: Thu, 28 Mar 2019 18:07:10 -0700 Subject: [PATCH] Resubmit PR-18512: Improved onnx export for 3 onnx ops (#18571) Summary: Fix ROCm CI failure Pull Request resolved: https://github.com/pytorch/pytorch/pull/18571 Differential Revision: D14669323 Pulled By: bddppq fbshipit-source-id: 022afe5c20e680295c9cfdfe1ec14650305955a8 --- caffe2/contrib/CMakeLists.txt | 3 +++ caffe2/contrib/aten/CMakeLists.txt | 7 ++++++- .../contrib/aten/{aten_op_cuda.cc => aten_op_gpu.cc} | 0 caffe2/contrib/aten/aten_op_template.h | 13 ++++++++++++- caffe2/onnx/backend.cc | 18 +++++++++++++++++- caffe2/onnx/backend.h | 2 ++ caffe2/python/onnx/tests/onnx_backend_test.py | 1 - tools/amd_build/build_amd.py | 1 + torch/onnx/symbolic.py | 10 +++++++++- 9 files changed, 50 insertions(+), 5 deletions(-) rename caffe2/contrib/aten/{aten_op_cuda.cc => aten_op_gpu.cc} (100%) diff --git a/caffe2/contrib/CMakeLists.txt b/caffe2/contrib/CMakeLists.txt index 6034e4d..ba981c8 100644 --- a/caffe2/contrib/CMakeLists.txt +++ b/caffe2/contrib/CMakeLists.txt @@ -23,3 +23,6 @@ set(Caffe2_GPU_INCLUDE ${Caffe2_GPU_INCLUDE} PARENT_SCOPE) set(Caffe2_CUDA_DEPENDENCY_LIBS ${Caffe2_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE) set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} PARENT_SCOPE) set(Caffe2_GPU_BINARY_SRCS ${Caffe2_GPU_BINARY_SRCS} PARENT_SCOPE) + +# HIP source +set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} PARENT_SCOPE) diff --git a/caffe2/contrib/aten/CMakeLists.txt b/caffe2/contrib/aten/CMakeLists.txt index add3918..95e3b83 100644 --- a/caffe2/contrib/aten/CMakeLists.txt +++ b/caffe2/contrib/aten/CMakeLists.txt @@ -1,7 +1,12 @@ if(NOT BUILD_ATEN_MOBILE AND BUILD_CAFFE2_OPS) # Add source generated by Codegen.cmake and pass to parent list(APPEND Caffe2_CPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/aten_op.cc) - list(APPEND Caffe2_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/aten_op_cuda.cc) + list(APPEND Caffe2_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/aten_op_gpu.cc) set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} PARENT_SCOPE) + + if(USE_ROCM) + list(APPEND Caffe2_HIP_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/hip/aten_op_gpu.cc) + set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} PARENT_SCOPE) + endif() endif() diff --git a/caffe2/contrib/aten/aten_op_cuda.cc b/caffe2/contrib/aten/aten_op_gpu.cc similarity index 100% rename from caffe2/contrib/aten/aten_op_cuda.cc rename to caffe2/contrib/aten/aten_op_gpu.cc diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index e2c0043..9354cd7 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -55,7 +55,13 @@ private: } at::Type& typeFor(const Tensor& ten) { - return at::getNonVariableType(backend(), typeMetaToScalarType(ten.meta())); + at::Backend b = backend(); +#ifdef __HIP_PLATFORM_HCC__ + if (b == at::Backend::HIP) { + b = at::Backend::CUDA; + } +#endif + return at::getNonVariableType(b, typeMetaToScalarType(ten.meta())); } at::Tensor tensorWrapping(const Tensor& ten_) { auto& ten = const_cast(ten_); @@ -80,6 +86,11 @@ private: auto at_sizes = src.sizes(); caffe2::TypeMeta type_meta = typeMetaFor(src); at::Device device = src.device(); +#ifdef __HIP_PLATFORM_HCC__ + if (device.type() == at::DeviceType::CUDA) { + device = at::Device(at::DeviceType::HIP, device.index()); + } +#endif at::TensorImpl* src_impl = src.unsafeReleaseTensorImpl(); std::vector dims(at_sizes.begin(), at_sizes.end()); dst->Resize(dims); diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc index e7c512a..3564ebe 100644 --- a/caffe2/onnx/backend.cc +++ b/caffe2/onnx/backend.cc @@ -362,7 +362,8 @@ Caffe2Backend::get_special_operators() const { {"Dropout", &Caffe2Backend::CreateDropout}, {"LRN", &Caffe2Backend::CreateLRN}, {"DynamicSlice", &Caffe2Backend::CreateDynamicSlice}, - {"RandomNormal", &Caffe2Backend::CreateRandomNormal}}; + {"RandomNormal", &Caffe2Backend::CreateRandomNormal}, + {"Where", &Caffe2Backend::CreateWhereOp}}; return kSpecialOperators; } @@ -580,6 +581,21 @@ Caffe2Ops Caffe2Backend::CreateRandomNormal( return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); } +Caffe2Ops Caffe2Backend::CreateWhereOp( + OnnxNode* onnx_node, + const ConversionContext& ctx) { + // The native Caffe2 op doesn't support broadcasting, so we defer the handling + // of this op to the ATen library that does. + onnx::NodeProto converted; + converted.CopyFrom(onnx_node->node); + converted.set_op_type("ATen"); + onnx::AttributeProto* attr = converted.add_attribute(); + attr->set_name("operator"); + attr->set_s("where"); + OnnxNode new_node(converted); + return CommonOnnxNodeToCaffe2Ops(&new_node, ctx); +} + Caffe2Ops Caffe2Backend::CreateReciprocal( OnnxNode* onnx_node, const ConversionContext& ctx) { diff --git a/caffe2/onnx/backend.h b/caffe2/onnx/backend.h index d61af29..8ee33ef 100644 --- a/caffe2/onnx/backend.h +++ b/caffe2/onnx/backend.h @@ -236,6 +236,8 @@ class CAFFE2_API Caffe2Backend { OnnxNode* onnx_node, const ConversionContext& ctx); + Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx); + Caffe2Ops CreateBatchNormalization( OnnxNode* onnx_node, const ConversionContext& ctx); diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py index 75d4b5a..f353e22 100644 --- a/caffe2/python/onnx/tests/onnx_backend_test.py +++ b/caffe2/python/onnx/tests/onnx_backend_test.py @@ -52,7 +52,6 @@ backend_test.exclude(r'(test_hardsigmoid' # Does not support Hardsigmoid. '|test_isnan.*' # Needs implementation '|test_scatter.*' # Should be similar to ScatterAssign '|test_constantofshape_int.*' # Needs implementation - '|test_where.*' # Needs implementation '|test_shrink.*' # Needs implementation '|test_strnorm.*' # Needs implementation '|test_nonzero.*' # Needs implementation diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 20cb46b..a33559a 100644 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -56,6 +56,7 @@ includes = [ "caffe2/video/*", "caffe2/distributed/*", "caffe2/queue/*", + "caffe2/contrib/aten/*", "binaries/*", "caffe2/**/*_test*", "caffe2/core/*", diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index fbb8d97..9a1911f 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -548,6 +548,14 @@ def relu(g, input): return g.op("Relu", input) +def ceil(g, input): + return g.op("Ceil", input) + + +def floor(g, input): + return g.op("Floor", input) + + @parse_args('v', 't', 't') def threshold(g, self, threshold, value): # See Note [Export inplace] @@ -922,7 +930,7 @@ def le(g, input, other): def where(g, condition, self, other): - return g.op("ATen", condition, self, other, operator_s="where") + return g.op("Where", condition, self, other) @parse_args('v', 'i', 'i') -- 2.7.4