[Runtime][Contrib] Support cudnn softmax (#5214)
authorHaichen Shen <shenhaichen@gmail.com>
Mon, 6 Apr 2020 03:53:59 +0000 (20:53 -0700)
committerGitHub <noreply@github.com>
Mon, 6 Apr 2020 03:53:59 +0000 (20:53 -0700)
16 files changed:
python/tvm/contrib/cudnn.py
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/op/op_attrs.py
python/tvm/relay/op/strategy/cuda.py
python/tvm/relay/op/strategy/generic.py
python/tvm/relay/op/strategy/hls.py
python/tvm/relay/op/strategy/opengl.py
python/tvm/relay/op/strategy/x86.py
src/relay/op/nn/nn.cc
src/runtime/contrib/cudnn/cudnn_utils.cc
src/runtime/contrib/cudnn/cudnn_utils.h
src/runtime/contrib/cudnn/softmax.cc [new file with mode: 0644]
tests/python/contrib/test_cudnn.py
topi/python/topi/cuda/__init__.py
topi/python/topi/cuda/softmax.py
topi/python/topi/nn/softmax.py

index e627245..5043520 100644 (file)
@@ -402,3 +402,27 @@ def conv_forward(x,
             ins[1],
             outs[0],
             conv_dtype), name="y")
+
+def softmax(x, axis=-1):
+    """Compute softmax using CuDNN
+
+    Parameters
+    ----------
+    x : tvm.te.Tensor
+        The input tensor
+
+    axis : int
+        The axis to compute the softmax
+
+    Returns
+    -------
+    ret : tvm.te.Tensor
+        The result tensor
+    """
+    return te.extern(
+        x.shape, [x],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.cudnn.softmax.forward",
+            ins[0],
+            outs[0],
+            axis), name="y")
index 39d98c0..51e7128 100644 (file)
@@ -34,12 +34,12 @@ reg.register_pattern("nn.relu", OpPattern.ELEMWISE)
 
 
 # softmax
-reg.register_schedule("nn.softmax", strategy.schedule_softmax)
+reg.register_strategy("nn.softmax", strategy.softmax_strategy)
 reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
 
 
 # log_softmax
-reg.register_schedule("nn.log_softmax", strategy.schedule_softmax)
+reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax)
 reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
 
 
index 1a07486..a47be76 100644 (file)
@@ -69,6 +69,11 @@ class DenseAttrs(Attrs):
     """Attributes for nn.dense"""
 
 
+@tvm._ffi.register_object("relay.attrs.SoftmaxAttrs")
+class SoftmaxAttrs(Attrs):
+    """Attributes for nn.softmax"""
+
+
 @tvm._ffi.register_object("relay.attrs.FIFOBufferAttrs")
 class FIFOBufferAttrs(Attrs):
     """Attributes for nn.fifo_buffer"""
index 45ee701..845be66 100644 (file)
@@ -60,9 +60,25 @@ def schedule_adaptive_pool_cuda(attrs, outs, target):
     with target:
         return topi.cuda.schedule_adaptive_pool(outs)
 
-@schedule_softmax.register(["cuda", "gpu"])
-def schedule_softmax_cuda(attrs, outs, target):
-    """schedule softmax for cuda"""
+@softmax_strategy.register(["cuda", "gpu"])
+def softmax_strategy_cuda(attrs, inputs, out_type, target):
+    """softmax cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_softmax(topi.nn.softmax),
+        wrap_topi_schedule(topi.cuda.schedule_softmax),
+        name="softmax.cuda")
+    if target.target_name == "cuda" and "cudnn" in target.libs:
+        strategy.add_implementation(
+            wrap_compute_softmax(topi.cuda.softmax_cudnn),
+            wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn),
+            name="softmax.cudnn",
+            plevel=15)
+    return strategy
+
+@schedule_log_softmax.register(["cuda", "gpu"])
+def schedule_log_softmax_cuda(attrs, outs, target):
+    """scheudle log_softmax for cuda"""
     with target:
         return topi.cuda.schedule_softmax(outs)
 
index 388e104..0a26080 100644 (file)
@@ -107,9 +107,27 @@ def schedule_adaptive_pool(attrs, outs, target):
         return topi.generic.schedule_adaptive_pool(outs)
 
 # softmax
+def wrap_compute_softmax(topi_compute):
+    """Wrap softmax topi compute"""
+    def _compute_softmax(attrs, inputs, out_type):
+        axis = attrs.get_int("axis")
+        return [topi_compute(inputs[0], axis)]
+    return _compute_softmax
+
+@override_native_generic_func("softmax_strategy")
+def softmax_strategy(attrs, inputs, out_type, target):
+    """softmax generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implemenation(
+        wrap_compute_softmax(topi.nn.softmax),
+        wrap_topi_schedule(topi.generic.schedule_softmax),
+        name="softmax.generic")
+    return strategy
+
+# log_softmax
 @generic_func
-def schedule_softmax(attrs, outs, target):
-    """Schedule softmax"""
+def schedule_log_softmax(attrs, outs, target):
+    """Schedule log_softmax op"""
     with target:
         return topi.generic.schedule_softmax(outs)
 
index 514902b..d41e85f 100644 (file)
@@ -50,9 +50,19 @@ def schedule_adaptive_pool_hls(attrs, outs, target):
     with target:
         return topi.hls.schedule_adaptive_pool(outs)
 
-@schedule_softmax.register("hls")
-def schedule_softmax_hls(attrs, outs, target):
-    """schedule softmax for hls"""
+@softmax_strategy.register("hls")
+def softmax_strategy_hls(attrs, inputs, out_type, target):
+    """softmax hls strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_softmax(topi.nn.softmax),
+        wrap_topi_schedule(topi.hls.schedule_softmax),
+        name="softmax.hls")
+    return strategy
+
+@schedule_log_softmax.register("hls")
+def schedule_log_softmax_hls(attrs, inputs, out_type, target):
+    """schedule log_softmax for hls"""
     with target:
         return topi.hls.schedule_softmax(outs)
 
index 45e290c..12c288c 100644 (file)
@@ -44,9 +44,19 @@ def schedule_adaptive_pool_opengl(attrs, outs, target):
     with target:
         return topi.opengl.schedule_adaptive_pool(outs)
 
-@schedule_softmax.register("opengl")
-def schedule_softmax_opengl(attrs, outs, target):
-    """schedule softmax for opengl"""
+@softmax_strategy.register("opengl")
+def softmax_strategy_opengl(attrs, inputs, out_type, target):
+    """softmax opengl strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_softmax(topi.nn.softmax),
+        wrap_topi_schedule(topi.opengl.schedule_softmax),
+        name="softmax.opengl")
+    return strategy
+
+@schedule_log_softmax.register("opengl")
+def schedule_log_softmax_opengl(attrs, outs, target):
+    """schedule log_softmax for opengl"""
     with target:
         return topi.opengl.schedule_softmax(outs)
 
index 6606b5c..ba0b3d2 100644 (file)
@@ -55,9 +55,19 @@ def schedule_adaptive_pool_cpu(attrs, outs, target):
     with target:
         return topi.x86.schedule_adaptive_pool(outs)
 
-@schedule_softmax.register("cpu")
-def schedule_softmax_cpu(attrs, outs, target):
-    """schedule softmax for x86"""
+@softmax_strategy.register("cpu")
+def softmax_strategy_cpu(attrs, inputs, out_type, target):
+    """softmax x86 strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_softmax(topi.nn.softmax),
+        wrap_topi_schedule(topi.x86.schedule_softmax),
+        name="softmax.x86")
+    return strategy
+
+@schedule_log_softmax.register("cpu")
+def schedule_log_softmax_cpu(attrs, outs, target):
+    """schedule log_softmax op for x86"""
     with target:
         return topi.x86.schedule_softmax(outs)
 
index 4934e06..b9ba74f 100644 (file)
@@ -347,14 +347,7 @@ RELAY_REGISTER_OP("nn.softmax")
 .set_num_inputs(1)
 .add_argument("data", "Tensor", "The input tensor.")
 .set_support_level(1)
-.add_type_rel("Identity", IdentityRel)
-.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
-                                         const Array<te::Tensor>& inputs,
-                                         const Type& out_type) {
-  const auto* param = attrs.as<SoftmaxAttrs>();
-  CHECK(param != nullptr);
-  return Array<te::Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
-});
+.add_type_rel("Identity", IdentityRel);
 
 
 // relay.nn.log_softmax
index fa185e9..9c895c5 100644 (file)
@@ -140,5 +140,15 @@ void ConvEntry::CleanWorkspace() {
   workspace_size = 0;
 }
 
+// SoftmaxEntry
+
+SoftmaxEntry::SoftmaxEntry() {
+  CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc));
+}
+
+SoftmaxEntry::~SoftmaxEntry() {
+  CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc));
+}
+
 }  // namespace contrib
 }  // namespace tvm
index 0042245..ee6bb50 100644 (file)
@@ -85,12 +85,20 @@ struct ConvEntry {
   void CleanWorkspace();
 };  // ConvThreadEntry
 
+struct SoftmaxEntry {
+  cudnnSoftmaxMode_t mode;
+  cudnnDataType_t data_type;
+  cudnnTensorDescriptor_t shape_desc;
+  SoftmaxEntry();
+  ~SoftmaxEntry();
+};  // SoftmaxEntry
 
 struct CuDNNThreadEntry {
   CuDNNThreadEntry();
   ~CuDNNThreadEntry();
   cudnnHandle_t handle{nullptr};
   ConvEntry conv_entry;
+  SoftmaxEntry softmax_entry;
   runtime::DeviceAPI *cuda_api{nullptr};
   static CuDNNThreadEntry* ThreadLocal();
 };  // CuDNNThreadEntry
diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc
new file mode 100644 (file)
index 0000000..fb6d8a6
--- /dev/null
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/cudnn/softmax.cc
+ * \brief Use external cudnn softmax function
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/device_api.h>
+#include "cudnn_utils.h"
+
+namespace tvm {
+namespace contrib {
+
+using namespace runtime;
+
+TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+  DLTensor* x = args[0];
+  DLTensor* y = args[1];
+  int axis = args[2];
+  int ndim = x->ndim;
+  int64_t* shape = x->shape;
+  if (axis < 0) axis += ndim;
+  CHECK(axis >= 0 && axis < ndim);
+
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
+
+  // Set mode and shape descriptor
+  if (axis == ndim - 1) {
+    int64_t N = 1;
+    for (int i = 0; i < ndim - 1; ++i) {
+      N *= shape[i];
+    }
+    entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE;
+    CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc,
+                                          CUDNN_TENSOR_NCHW,
+                                          entry_ptr->softmax_entry.data_type,
+                                          static_cast<int>(N),
+                                          static_cast<int>(shape[ndim - 1]),
+                                          1,
+                                          1));
+  } else {
+    int64_t pre_axis_dim = 1;
+    int64_t post_axis_dim = 1;
+    for (int i = 0; i < ndim; ++i) {
+      if (i < axis) {
+        pre_axis_dim *= shape[i];
+      } else if (i > axis) {
+        post_axis_dim *= shape[i];
+      }
+    }
+    entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL;
+    CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc,
+                                          CUDNN_TENSOR_NCHW,
+                                          entry_ptr->softmax_entry.data_type,
+                                          static_cast<int>(pre_axis_dim),
+                                          static_cast<int>(shape[axis]),
+                                          static_cast<int>(post_axis_dim),
+                                          1));
+  }
+
+  auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type);
+  auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type);
+  CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle,
+                                 CUDNN_SOFTMAX_ACCURATE,
+                                 entry_ptr->softmax_entry.mode,
+                                 alpha,
+                                 entry_ptr->softmax_entry.shape_desc,
+                                 x->data,
+                                 beta,
+                                 entry_ptr->softmax_entry.shape_desc,
+                                 y->data));
+});
+
+}  // namespace contrib
+}  // namespace tvm
index 58e7b49..5d1f100 100644 (file)
@@ -158,6 +158,52 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0):
 def test_conv3d():
     verify_conv3d("float32", "float32", tensor_format=0)
 
+
+def verify_softmax(shape, axis, dtype="float32"):
+    A = te.placeholder(shape, dtype=dtype, name='A')
+    B = cudnn.softmax(A, axis)
+    s = te.create_schedule([B.op])
+
+    ctx = tvm.gpu(0)
+    a_np = np.random.uniform(size=shape).astype(dtype)
+    b_np = topi.testing.softmax_python(a_np)
+    a = tvm.nd.array(a_np, ctx)
+    b = tvm.nd.array(b_np, ctx)
+    f = tvm.build(s, [A, B], "cuda", target_host="llvm", name="softmax")
+    f(a, b)
+    tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3)
+
+def verify_softmax_4d(shape, dtype="float32"):
+    A = te.placeholder(shape, dtype=dtype, name='A')
+    B = cudnn.softmax(A, axis=1)
+    s = te.create_schedule([B.op])
+
+    ctx = tvm.gpu(0)
+    n, c, h, w = shape
+    a_np = np.random.uniform(size=shape).astype(dtype)
+    b_np = topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c))
+    b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2)
+    a = tvm.nd.array(a_np, ctx)
+    b = tvm.nd.array(b_np, ctx)
+    f = tvm.build(s, [A, B], "cuda", target_host="llvm", name="softmax")
+    f(a, b)
+    tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3)
+
+def test_softmax():
+    if not tvm.runtime.enabled("cuda"):
+        print("skip because cuda is not enabled...")
+        return
+    if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
+        print("skip because cudnn is not enabled...")
+        return
+
+    verify_softmax((32, 10), -1)
+    verify_softmax((3, 4), -1)
+    verify_softmax((1, 5), -1, "float64")
+    verify_softmax_4d((1, 16, 256, 256))
+    verify_softmax_4d((1, 16, 256, 256), "float64")
+
 if __name__ == "__main__":
     test_conv2d()
     test_conv3d()
+    test_softmax()
index 83ddedc..c20e257 100644 (file)
@@ -34,7 +34,7 @@ from .conv3d import *
 from .conv3d_winograd import *
 from . import conv3d_alter_op
 from .reduction import schedule_reduce
-from .softmax import schedule_softmax
+from .softmax import *
 from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
 from .dense import *
 from .pooling import *
index 54d5bfb..62c437a 100644 (file)
@@ -17,6 +17,8 @@
 # pylint: disable=invalid-name, unused-variable, trailing-whitespace
 """Schedule for softmax operator"""
 from tvm import te
+from tvm.contrib import cudnn
+from .. import generic
 from .injective import schedule_injective_from_existing
 
 
@@ -79,3 +81,13 @@ def schedule_softmax(outs):
         s[softmax].bind(tx, thread_x)
 
     return s
+
+
+def softmax_cudnn(x, axis=-1):
+    """Perform softmax on the data using cudnn"""
+    return cudnn.softmax(x, axis)
+
+
+def schedule_softmax_cudnn(outs):
+    """Schedule for softmax cudnn op"""
+    return generic.schedule_extern(outs)
index c414372..fb51384 100644 (file)
@@ -77,7 +77,6 @@ def softmax(x, axis=-1):
     return te.compute(shape, lambda *indices: _normalize(exp, expsum, *indices),
                       name='T_softmax_norm', attrs={"axis" : axis})
 
-
 @tvm.te.tag_scope(tag='log_softmax_output')
 def log_softmax(x):
     """Perform log softmax activation on the data