From a3a7a67048c11ee74fbdd54037a6dbaf90367964 Mon Sep 17 00:00:00 2001 From: Yuchen Huang Date: Fri, 27 Aug 2021 18:57:22 -0700 Subject: [PATCH] [iOS][GPU] Consolidate array and non-array kernel for hardswish (#63369) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63369 ghstack-source-id: 136918152 (Note: this ignores all push blocking failures!) Test Plan: - `buck test pp-macos` - Op tests in PyTorchPlayground app - Run mobilenetv3 test https://pxl.cl/1Ncls Reviewed By: xta0 Differential Revision: D30354454 fbshipit-source-id: 88bf4f8b5871e63170161b3f3e44f99b8a3086c6 --- aten/src/ATen/native/metal/MetalShaders.h | 41 +++++++++++----------- .../ATen/native/metal/mpscnn/tests/MPSCNNTests.h | 1 + .../ATen/native/metal/mpscnn/tests/MPSCNNTests.mm | 12 +++++++ .../native/metal/mpscnn/tests/MetalOpTestRunner.mm | 1 + aten/src/ATen/native/metal/ops/MetalHardswish.mm | 4 +-- 5 files changed, 37 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/metal/MetalShaders.h b/aten/src/ATen/native/metal/MetalShaders.h index 5c25672..0ee703f 100644 --- a/aten/src/ATen/native/metal/MetalShaders.h +++ b/aten/src/ATen/native/metal/MetalShaders.h @@ -393,31 +393,32 @@ kernel void clamp(texture2d_array in_arr[[texture(0), functi } } -kernel void hardswish(texture2d_array in[[texture(0)]], - texture2d_array out[[texture(1)]], +constant bool hardswish_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4); +constant bool hardswish_is_tex = !hardswish_is_arr; +kernel void hardswish(texture2d_array in_arr[[texture(0), function_constant(hardswish_is_arr)]], + texture2d in_tex[[texture(0), function_constant(hardswish_is_tex)]], + texture2d_array out_arr[[texture(1), function_constant(hardswish_is_arr)]], + texture2d out_tex[[texture(1), function_constant(hardswish_is_tex)]], ushort3 gid[[thread_position_in_grid]]) { - if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + const ushort oH = ushort_arg_2; + const ushort oW = ushort_arg_3; + if (gid.x >= oW || gid.y >= oH) { return; } ushort2 gid_ = gid.xy; - half4 value = in.read(gid_, gid.z); - half4 mask1 = half4(value < 3.0); - half4 mask2 = half4(value > -3.0); - half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); - out.write(outval, gid_, gid.z); -} - -kernel void hardswish_nonarray(texture2d in[[texture(0)]], - texture2d out[[texture(1)]], - ushort2 gid[[thread_position_in_grid]]) { - if (gid.x >= out.get_width() || gid.y >= out.get_height()) { - return; + if (hardswish_is_arr) { + half4 value = in_arr.read(gid_, gid.z); + half4 mask1 = half4(value < 3.0); + half4 mask2 = half4(value > -3.0); + half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); + out_arr.write(outval, gid_, gid.z); + } else { + half4 value = in_tex.read(gid_); + half4 mask1 = half4(value < 3); + half4 mask2 = half4(value > -3.0); + half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); + out_tex.write(outval, gid_); } - half4 value = in.read(gid); - half4 mask1 = half4(value < 3); - half4 mask2 = half4(value > -3.0); - half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); - out.write(outval, gid); } constant bool out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h index ee992d9..599f2ce 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h @@ -41,6 +41,7 @@ bool test_softmax(); bool test_sigmoid(); bool test_hardsigmoid(); bool test_hardswish(); +bool test_hardswish2(); bool test_upsampling_nearest2d_vec(); bool test_upsampling_nearest2d_vec2(); bool test_adaptive_avg_pool2d(); diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm index 69497a9..5a8f6de 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm @@ -262,6 +262,18 @@ bool test_hardswish() { }); } +bool test_hardswish2() { + __block std::vector size{1, 3, 44, 44}; + return TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X = + at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) * 12 - 6; + auto X2 = X.metal(); + auto Y1 = at::hardswish_(X); + auto Y2 = at::hardswish_(X2).cpu(); + return almostEqual(Y1, Y2); + }); +} + bool test_addmm() { bool result = true; for (int i = 0; i < ITER_COUNT; ++i) { diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm b/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm index d8b69ad..f337e1d 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm +++ b/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm @@ -69,6 +69,7 @@ REG_TEST("test_sigmoid", test_sigmoid); REG_TEST("test_hardsigmoid", test_hardsigmoid); REG_TEST("test_hardswish", test_hardswish); + REG_TEST("test_hardswish2", test_hardswish2); REG_TEST("test_upsampling_nearest2d_vec", test_upsampling_nearest2d_vec); REG_TEST("test_upsampling_nearest2d_vec2", test_upsampling_nearest2d_vec2); REG_TEST("test_adaptive_avg_pool2d", test_adaptive_avg_pool2d); diff --git a/aten/src/ATen/native/metal/ops/MetalHardswish.mm b/aten/src/ATen/native/metal/ops/MetalHardswish.mm index 8d3526a..d571e48 100644 --- a/aten/src/ATen/native/metal/ops/MetalHardswish.mm +++ b/aten/src/ATen/native/metal/ops/MetalHardswish.mm @@ -24,9 +24,9 @@ Tensor& hardswish_(Tensor& input) { id encoder = [commandBuffer.buffer computeCommandEncoder]; id state = [[MetalContext sharedInstance] - specializedPipelineState:mpscnn::kernelFor( - X, "hardswish", "hardswish_nonarray") + specializedPipelineState:"hardswish" Constants:@[ + @(X.numberOfImages), @(X.featureChannels), @(X.height), @(X.width) -- 2.7.4