Resubmit PR-18512: Improved onnx export for 3 onnx ops (#18571)
authorbddppq <bai@in.tum.de>
Fri, 29 Mar 2019 01:07:10 +0000 (18:07 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 01:12:49 +0000 (18:12 -0700)
Summary:
Fix ROCm CI failure
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18571

Differential Revision: D14669323

Pulled By: bddppq

fbshipit-source-id: 022afe5c20e680295c9cfdfe1ec14650305955a8

caffe2/contrib/CMakeLists.txt
caffe2/contrib/aten/CMakeLists.txt
caffe2/contrib/aten/aten_op_gpu.cc [moved from caffe2/contrib/aten/aten_op_cuda.cc with 100% similarity]
caffe2/contrib/aten/aten_op_template.h
caffe2/onnx/backend.cc
caffe2/onnx/backend.h
caffe2/python/onnx/tests/onnx_backend_test.py
tools/amd_build/build_amd.py
torch/onnx/symbolic.py

index 6034e4d..ba981c8 100644 (file)
@@ -23,3 +23,6 @@ set(Caffe2_GPU_INCLUDE ${Caffe2_GPU_INCLUDE} PARENT_SCOPE)
 set(Caffe2_CUDA_DEPENDENCY_LIBS ${Caffe2_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
 set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} PARENT_SCOPE)
 set(Caffe2_GPU_BINARY_SRCS ${Caffe2_GPU_BINARY_SRCS} PARENT_SCOPE)
+
+# HIP source
+set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} PARENT_SCOPE)
index add3918..95e3b83 100644 (file)
@@ -1,7 +1,12 @@
 if(NOT BUILD_ATEN_MOBILE AND BUILD_CAFFE2_OPS)
   # Add source generated by Codegen.cmake and pass to parent
   list(APPEND Caffe2_CPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/aten_op.cc)
-  list(APPEND Caffe2_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/aten_op_cuda.cc)
+  list(APPEND Caffe2_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/aten_op_gpu.cc)
   set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
   set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} PARENT_SCOPE)
+
+  if(USE_ROCM)
+    list(APPEND Caffe2_HIP_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/hip/aten_op_gpu.cc)
+    set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} PARENT_SCOPE)
+  endif()
 endif()
index e2c0043..9354cd7 100644 (file)
@@ -55,7 +55,13 @@ private:
   }
 
   at::Type& typeFor(const Tensor& ten) {
-    return at::getNonVariableType(backend(), typeMetaToScalarType(ten.meta()));
+    at::Backend b = backend();
+#ifdef __HIP_PLATFORM_HCC__
+    if (b == at::Backend::HIP) {
+      b = at::Backend::CUDA;
+    }
+#endif
+    return at::getNonVariableType(b, typeMetaToScalarType(ten.meta()));
   }
   at::Tensor tensorWrapping(const Tensor& ten_) {
     auto& ten = const_cast<Tensor&>(ten_);
@@ -80,6 +86,11 @@ private:
     auto at_sizes = src.sizes();
     caffe2::TypeMeta type_meta = typeMetaFor(src);
     at::Device device = src.device();
+#ifdef __HIP_PLATFORM_HCC__
+    if (device.type() == at::DeviceType::CUDA) {
+      device = at::Device(at::DeviceType::HIP, device.index());
+    }
+#endif
     at::TensorImpl* src_impl = src.unsafeReleaseTensorImpl();
     std::vector<int64_t> dims(at_sizes.begin(), at_sizes.end());
     dst->Resize(dims);
index e7c512a..3564ebe 100644 (file)
@@ -362,7 +362,8 @@ Caffe2Backend::get_special_operators() const {
               {"Dropout", &Caffe2Backend::CreateDropout},
               {"LRN", &Caffe2Backend::CreateLRN},
               {"DynamicSlice", &Caffe2Backend::CreateDynamicSlice},
-              {"RandomNormal", &Caffe2Backend::CreateRandomNormal}};
+              {"RandomNormal", &Caffe2Backend::CreateRandomNormal},
+              {"Where", &Caffe2Backend::CreateWhereOp}};
   return kSpecialOperators;
 }
 
@@ -580,6 +581,21 @@ Caffe2Ops Caffe2Backend::CreateRandomNormal(
   return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
 }
 
+Caffe2Ops Caffe2Backend::CreateWhereOp(
+    OnnxNode* onnx_node,
+    const ConversionContext& ctx) {
+  // The native Caffe2 op doesn't support broadcasting, so we defer the handling
+  // of this op to the ATen library that does.
+  onnx::NodeProto converted;
+  converted.CopyFrom(onnx_node->node);
+  converted.set_op_type("ATen");
+  onnx::AttributeProto* attr = converted.add_attribute();
+  attr->set_name("operator");
+  attr->set_s("where");
+  OnnxNode new_node(converted);
+  return CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
+}
+
 Caffe2Ops Caffe2Backend::CreateReciprocal(
     OnnxNode* onnx_node,
     const ConversionContext& ctx) {
index d61af29..8ee33ef 100644 (file)
@@ -236,6 +236,8 @@ class CAFFE2_API Caffe2Backend {
       OnnxNode* onnx_node,
       const ConversionContext& ctx);
 
+  Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx);
+
   Caffe2Ops CreateBatchNormalization(
       OnnxNode* onnx_node,
       const ConversionContext& ctx);
index 75d4b5a..f353e22 100644 (file)
@@ -52,7 +52,6 @@ backend_test.exclude(r'(test_hardsigmoid'  # Does not support Hardsigmoid.
                      '|test_isnan.*'  # Needs implementation
                      '|test_scatter.*'  # Should be similar to ScatterAssign
                      '|test_constantofshape_int.*'  # Needs implementation
-                     '|test_where.*'  # Needs implementation
                      '|test_shrink.*'  # Needs implementation
                      '|test_strnorm.*'  # Needs implementation
                      '|test_nonzero.*'  # Needs implementation
index 20cb46b..a33559a 100644 (file)
@@ -56,6 +56,7 @@ includes = [
     "caffe2/video/*",
     "caffe2/distributed/*",
     "caffe2/queue/*",
+    "caffe2/contrib/aten/*",
     "binaries/*",
     "caffe2/**/*_test*",
     "caffe2/core/*",
index fbb8d97..9a1911f 100644 (file)
@@ -548,6 +548,14 @@ def relu(g, input):
     return g.op("Relu", input)
 
 
+def ceil(g, input):
+    return g.op("Ceil", input)
+
+
+def floor(g, input):
+    return g.op("Floor", input)
+
+
 @parse_args('v', 't', 't')
 def threshold(g, self, threshold, value):
     # See Note [Export inplace]
@@ -922,7 +930,7 @@ def le(g, input, other):
 
 
 def where(g, condition, self, other):
-    return g.op("ATen", condition, self, other, operator_s="where")
+    return g.op("Where", condition, self, other)
 
 
 @parse_args('v', 'i', 'i')