From 799ff356b9308da9c06b61c214508851a35c3b93 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sun, 5 Apr 2020 20:53:59 -0700 Subject: [PATCH] [Runtime][Contrib] Support cudnn softmax (#5214) --- python/tvm/contrib/cudnn.py | 24 ++++++++ python/tvm/relay/op/nn/_nn.py | 4 +- python/tvm/relay/op/op_attrs.py | 5 ++ python/tvm/relay/op/strategy/cuda.py | 22 +++++++- python/tvm/relay/op/strategy/generic.py | 22 +++++++- python/tvm/relay/op/strategy/hls.py | 16 +++++- python/tvm/relay/op/strategy/opengl.py | 16 +++++- python/tvm/relay/op/strategy/x86.py | 16 +++++- src/relay/op/nn/nn.cc | 9 +-- src/runtime/contrib/cudnn/cudnn_utils.cc | 10 ++++ src/runtime/contrib/cudnn/cudnn_utils.h | 8 +++ src/runtime/contrib/cudnn/softmax.cc | 94 ++++++++++++++++++++++++++++++++ tests/python/contrib/test_cudnn.py | 46 ++++++++++++++++ topi/python/topi/cuda/__init__.py | 2 +- topi/python/topi/cuda/softmax.py | 12 ++++ topi/python/topi/nn/softmax.py | 1 - 16 files changed, 281 insertions(+), 26 deletions(-) create mode 100644 src/runtime/contrib/cudnn/softmax.cc diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index e627245..5043520 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -402,3 +402,27 @@ def conv_forward(x, ins[1], outs[0], conv_dtype), name="y") + +def softmax(x, axis=-1): + """Compute softmax using CuDNN + + Parameters + ---------- + x : tvm.te.Tensor + The input tensor + + axis : int + The axis to compute the softmax + + Returns + ------- + ret : tvm.te.Tensor + The result tensor + """ + return te.extern( + x.shape, [x], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.softmax.forward", + ins[0], + outs[0], + axis), name="y") diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 39d98c0..51e7128 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -34,12 +34,12 @@ reg.register_pattern("nn.relu", OpPattern.ELEMWISE) # softmax -reg.register_schedule("nn.softmax", strategy.schedule_softmax) +reg.register_strategy("nn.softmax", strategy.softmax_strategy) reg.register_pattern("nn.softmax", OpPattern.OPAQUE) # log_softmax -reg.register_schedule("nn.log_softmax", strategy.schedule_softmax) +reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax) reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 1a07486..a47be76 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -69,6 +69,11 @@ class DenseAttrs(Attrs): """Attributes for nn.dense""" +@tvm._ffi.register_object("relay.attrs.SoftmaxAttrs") +class SoftmaxAttrs(Attrs): + """Attributes for nn.softmax""" + + @tvm._ffi.register_object("relay.attrs.FIFOBufferAttrs") class FIFOBufferAttrs(Attrs): """Attributes for nn.fifo_buffer""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 45ee701..845be66 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -60,9 +60,25 @@ def schedule_adaptive_pool_cuda(attrs, outs, target): with target: return topi.cuda.schedule_adaptive_pool(outs) -@schedule_softmax.register(["cuda", "gpu"]) -def schedule_softmax_cuda(attrs, outs, target): - """schedule softmax for cuda""" +@softmax_strategy.register(["cuda", "gpu"]) +def softmax_strategy_cuda(attrs, inputs, out_type, target): + """softmax cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.cuda.schedule_softmax), + name="softmax.cuda") + if target.target_name == "cuda" and "cudnn" in target.libs: + strategy.add_implementation( + wrap_compute_softmax(topi.cuda.softmax_cudnn), + wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn), + name="softmax.cudnn", + plevel=15) + return strategy + +@schedule_log_softmax.register(["cuda", "gpu"]) +def schedule_log_softmax_cuda(attrs, outs, target): + """scheudle log_softmax for cuda""" with target: return topi.cuda.schedule_softmax(outs) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 388e104..0a26080 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -107,9 +107,27 @@ def schedule_adaptive_pool(attrs, outs, target): return topi.generic.schedule_adaptive_pool(outs) # softmax +def wrap_compute_softmax(topi_compute): + """Wrap softmax topi compute""" + def _compute_softmax(attrs, inputs, out_type): + axis = attrs.get_int("axis") + return [topi_compute(inputs[0], axis)] + return _compute_softmax + +@override_native_generic_func("softmax_strategy") +def softmax_strategy(attrs, inputs, out_type, target): + """softmax generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implemenation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.generic.schedule_softmax), + name="softmax.generic") + return strategy + +# log_softmax @generic_func -def schedule_softmax(attrs, outs, target): - """Schedule softmax""" +def schedule_log_softmax(attrs, outs, target): + """Schedule log_softmax op""" with target: return topi.generic.schedule_softmax(outs) diff --git a/python/tvm/relay/op/strategy/hls.py b/python/tvm/relay/op/strategy/hls.py index 514902b..d41e85f 100644 --- a/python/tvm/relay/op/strategy/hls.py +++ b/python/tvm/relay/op/strategy/hls.py @@ -50,9 +50,19 @@ def schedule_adaptive_pool_hls(attrs, outs, target): with target: return topi.hls.schedule_adaptive_pool(outs) -@schedule_softmax.register("hls") -def schedule_softmax_hls(attrs, outs, target): - """schedule softmax for hls""" +@softmax_strategy.register("hls") +def softmax_strategy_hls(attrs, inputs, out_type, target): + """softmax hls strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.hls.schedule_softmax), + name="softmax.hls") + return strategy + +@schedule_log_softmax.register("hls") +def schedule_log_softmax_hls(attrs, inputs, out_type, target): + """schedule log_softmax for hls""" with target: return topi.hls.schedule_softmax(outs) diff --git a/python/tvm/relay/op/strategy/opengl.py b/python/tvm/relay/op/strategy/opengl.py index 45e290c..12c288c 100644 --- a/python/tvm/relay/op/strategy/opengl.py +++ b/python/tvm/relay/op/strategy/opengl.py @@ -44,9 +44,19 @@ def schedule_adaptive_pool_opengl(attrs, outs, target): with target: return topi.opengl.schedule_adaptive_pool(outs) -@schedule_softmax.register("opengl") -def schedule_softmax_opengl(attrs, outs, target): - """schedule softmax for opengl""" +@softmax_strategy.register("opengl") +def softmax_strategy_opengl(attrs, inputs, out_type, target): + """softmax opengl strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.opengl.schedule_softmax), + name="softmax.opengl") + return strategy + +@schedule_log_softmax.register("opengl") +def schedule_log_softmax_opengl(attrs, outs, target): + """schedule log_softmax for opengl""" with target: return topi.opengl.schedule_softmax(outs) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 6606b5c..ba0b3d2 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -55,9 +55,19 @@ def schedule_adaptive_pool_cpu(attrs, outs, target): with target: return topi.x86.schedule_adaptive_pool(outs) -@schedule_softmax.register("cpu") -def schedule_softmax_cpu(attrs, outs, target): - """schedule softmax for x86""" +@softmax_strategy.register("cpu") +def softmax_strategy_cpu(attrs, inputs, out_type, target): + """softmax x86 strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.x86.schedule_softmax), + name="softmax.x86") + return strategy + +@schedule_log_softmax.register("cpu") +def schedule_log_softmax_cpu(attrs, outs, target): + """schedule log_softmax op for x86""" with target: return topi.x86.schedule_softmax(outs) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 4934e06..b9ba74f 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -347,14 +347,7 @@ RELAY_REGISTER_OP("nn.softmax") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(1) -.add_type_rel("Identity", IdentityRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - CHECK(param != nullptr); - return Array{ topi::nn::softmax(inputs[0], param->axis) }; -}); +.add_type_rel("Identity", IdentityRel); // relay.nn.log_softmax diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index fa185e9..9c895c5 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -140,5 +140,15 @@ void ConvEntry::CleanWorkspace() { workspace_size = 0; } +// SoftmaxEntry + +SoftmaxEntry::SoftmaxEntry() { + CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); +} + +SoftmaxEntry::~SoftmaxEntry() { + CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); +} + } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 0042245..ee6bb50 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -85,12 +85,20 @@ struct ConvEntry { void CleanWorkspace(); }; // ConvThreadEntry +struct SoftmaxEntry { + cudnnSoftmaxMode_t mode; + cudnnDataType_t data_type; + cudnnTensorDescriptor_t shape_desc; + SoftmaxEntry(); + ~SoftmaxEntry(); +}; // SoftmaxEntry struct CuDNNThreadEntry { CuDNNThreadEntry(); ~CuDNNThreadEntry(); cudnnHandle_t handle{nullptr}; ConvEntry conv_entry; + SoftmaxEntry softmax_entry; runtime::DeviceAPI *cuda_api{nullptr}; static CuDNNThreadEntry* ThreadLocal(); }; // CuDNNThreadEntry diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc new file mode 100644 index 0000000..fb6d8a6 --- /dev/null +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/cudnn/softmax.cc + * \brief Use external cudnn softmax function + */ +#include +#include +#include "cudnn_utils.h" + +namespace tvm { +namespace contrib { + +using namespace runtime; + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") +.set_body([](TVMArgs args, TVMRetValue *ret) { + DLTensor* x = args[0]; + DLTensor* y = args[1]; + int axis = args[2]; + int ndim = x->ndim; + int64_t* shape = x->shape; + if (axis < 0) axis += ndim; + CHECK(axis >= 0 && axis < ndim); + + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); + + // Set mode and shape descriptor + if (axis == ndim - 1) { + int64_t N = 1; + for (int i = 0; i < ndim - 1; ++i) { + N *= shape[i]; + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, + CUDNN_TENSOR_NCHW, + entry_ptr->softmax_entry.data_type, + static_cast(N), + static_cast(shape[ndim - 1]), + 1, + 1)); + } else { + int64_t pre_axis_dim = 1; + int64_t post_axis_dim = 1; + for (int i = 0; i < ndim; ++i) { + if (i < axis) { + pre_axis_dim *= shape[i]; + } else if (i > axis) { + post_axis_dim *= shape[i]; + } + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL; + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, + CUDNN_TENSOR_NCHW, + entry_ptr->softmax_entry.data_type, + static_cast(pre_axis_dim), + static_cast(shape[axis]), + static_cast(post_axis_dim), + 1)); + } + + auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type); + auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type); + CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, + CUDNN_SOFTMAX_ACCURATE, + entry_ptr->softmax_entry.mode, + alpha, + entry_ptr->softmax_entry.shape_desc, + x->data, + beta, + entry_ptr->softmax_entry.shape_desc, + y->data)); +}); + +} // namespace contrib +} // namespace tvm diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 58e7b49..5d1f100 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -158,6 +158,52 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): def test_conv3d(): verify_conv3d("float32", "float32", tensor_format=0) + +def verify_softmax(shape, axis, dtype="float32"): + A = te.placeholder(shape, dtype=dtype, name='A') + B = cudnn.softmax(A, axis) + s = te.create_schedule([B.op]) + + ctx = tvm.gpu(0) + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = topi.testing.softmax_python(a_np) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + f = tvm.build(s, [A, B], "cuda", target_host="llvm", name="softmax") + f(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3) + +def verify_softmax_4d(shape, dtype="float32"): + A = te.placeholder(shape, dtype=dtype, name='A') + B = cudnn.softmax(A, axis=1) + s = te.create_schedule([B.op]) + + ctx = tvm.gpu(0) + n, c, h, w = shape + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c)) + b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + f = tvm.build(s, [A, B], "cuda", target_host="llvm", name="softmax") + f(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3) + +def test_softmax(): + if not tvm.runtime.enabled("cuda"): + print("skip because cuda is not enabled...") + return + if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): + print("skip because cudnn is not enabled...") + return + + verify_softmax((32, 10), -1) + verify_softmax((3, 4), -1) + verify_softmax((1, 5), -1, "float64") + verify_softmax_4d((1, 16, 256, 256)) + verify_softmax_4d((1, 16, 256, 256), "float64") + if __name__ == "__main__": test_conv2d() test_conv3d() + test_softmax() diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 83ddedc..c20e257 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -34,7 +34,7 @@ from .conv3d import * from .conv3d_winograd import * from . import conv3d_alter_op from .reduction import schedule_reduce -from .softmax import schedule_softmax +from .softmax import * from .injective import schedule_injective, schedule_elemwise, schedule_broadcast from .dense import * from .pooling import * diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py index 54d5bfb..62c437a 100644 --- a/topi/python/topi/cuda/softmax.py +++ b/topi/python/topi/cuda/softmax.py @@ -17,6 +17,8 @@ # pylint: disable=invalid-name, unused-variable, trailing-whitespace """Schedule for softmax operator""" from tvm import te +from tvm.contrib import cudnn +from .. import generic from .injective import schedule_injective_from_existing @@ -79,3 +81,13 @@ def schedule_softmax(outs): s[softmax].bind(tx, thread_x) return s + + +def softmax_cudnn(x, axis=-1): + """Perform softmax on the data using cudnn""" + return cudnn.softmax(x, axis) + + +def schedule_softmax_cudnn(outs): + """Schedule for softmax cudnn op""" + return generic.schedule_extern(outs) diff --git a/topi/python/topi/nn/softmax.py b/topi/python/topi/nn/softmax.py index c414372..fb51384 100644 --- a/topi/python/topi/nn/softmax.py +++ b/topi/python/topi/nn/softmax.py @@ -77,7 +77,6 @@ def softmax(x, axis=-1): return te.compute(shape, lambda *indices: _normalize(exp, expsum, *indices), name='T_softmax_norm', attrs={"axis" : axis}) - @tvm.te.tag_scope(tag='log_softmax_output') def log_softmax(x): """Perform log softmax activation on the data -- 2.7.4