From 52ff9bc63916b6cb438d816d2b938a4d93d4fa96 Mon Sep 17 00:00:00 2001 From: Tao Xu Date: Tue, 7 Sep 2021 15:36:11 -0700 Subject: [PATCH] [iOS][Metal] Add aten:hardswish (#64588) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64588 Add `aten::hardswish` to run the mobilenetv3 model from torchvision. ghstack-source-id: 137479323 Test Plan: - buck test pp-macos - circleCI Reviewed By: beback4u Differential Revision: D30781008 fbshipit-source-id: 83454869195ef4ab50570ea9b3bf2a55f32a3e86 --- .../ATen/native/metal/mpscnn/tests/MPSCNNTests.h | 2 +- .../ATen/native/metal/mpscnn/tests/MPSCNNTests.mm | 8 ++-- .../native/metal/mpscnn/tests/MetalOpTestRunner.mm | 2 +- aten/src/ATen/native/metal/ops/MetalHardswish.mm | 51 ++++++++++++++++++---- 4 files changed, 48 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h index 599f2ce..12e0296 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h @@ -40,8 +40,8 @@ bool test_log_softmax(); bool test_softmax(); bool test_sigmoid(); bool test_hardsigmoid(); +bool test_hardswish_(); bool test_hardswish(); -bool test_hardswish2(); bool test_upsampling_nearest2d_vec(); bool test_upsampling_nearest2d_vec2(); bool test_adaptive_avg_pool2d(); diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm index 5a8f6de..ca776fb 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm @@ -250,7 +250,7 @@ bool test_hardsigmoid() { }); } -bool test_hardswish() { +bool test_hardswish_() { __block std::vector size{3, 3, 44, 44}; return TEST(size, __PRETTY_FUNCTION__, ^bool { auto X = @@ -262,14 +262,14 @@ bool test_hardswish() { }); } -bool test_hardswish2() { +bool test_hardswish() { __block std::vector size{1, 3, 44, 44}; return TEST(size, __PRETTY_FUNCTION__, ^bool { auto X = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) * 12 - 6; auto X2 = X.metal(); - auto Y1 = at::hardswish_(X); - auto Y2 = at::hardswish_(X2).cpu(); + auto Y1 = at::hardswish(X); + auto Y2 = at::hardswish(X2).cpu(); return almostEqual(Y1, Y2); }); } diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm b/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm index 5e74998..5dff3be 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm +++ b/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm @@ -68,8 +68,8 @@ REG_TEST("test_softmax", test_softmax); REG_TEST("test_sigmoid", test_sigmoid); REG_TEST("test_hardsigmoid", test_hardsigmoid); + REG_TEST("test_hardswish_", test_hardswish_); REG_TEST("test_hardswish", test_hardswish); - REG_TEST("test_hardswish2", test_hardswish2); REG_TEST("test_upsampling_nearest2d_vec", test_upsampling_nearest2d_vec); REG_TEST("test_upsampling_nearest2d_vec2", test_upsampling_nearest2d_vec2); REG_TEST("test_adaptive_avg_pool2d", test_adaptive_avg_pool2d); diff --git a/aten/src/ATen/native/metal/ops/MetalHardswish.mm b/aten/src/ATen/native/metal/ops/MetalHardswish.mm index d571e48..ba5e09c 100644 --- a/aten/src/ATen/native/metal/ops/MetalHardswish.mm +++ b/aten/src/ATen/native/metal/ops/MetalHardswish.mm @@ -1,9 +1,9 @@ #include #import +#import #import #import #import -#import #import #import #import @@ -23,14 +23,14 @@ Tensor& hardswish_(Tensor& input) { MPSImage* Y = createTemporaryImage(commandBuffer, imageSize); id encoder = [commandBuffer.buffer computeCommandEncoder]; - id state = [[MetalContext sharedInstance] - specializedPipelineState:"hardswish" - Constants:@[ - @(X.numberOfImages), - @(X.featureChannels), - @(X.height), - @(X.width) - ]]; + id state = + [[MetalContext sharedInstance] specializedPipelineState:"hardswish" + Constants:@[ + @(X.numberOfImages), + @(X.featureChannels), + @(X.height), + @(X.width) + ]]; [encoder setComputePipelineState:state]; [encoder setTexture:[X texture] atIndex:0]; @@ -47,8 +47,41 @@ Tensor& hardswish_(Tensor& input) { return input; } +Tensor hardswish(const at::Tensor& input) { + MPSImage* X = imageFromTensor(input); + IntArrayRef outputSize = input.sizes(); + MetalTensorImplStorage mt{outputSize.vec()}; + MetalCommandBuffer* commandBuffer = getCommandBuffer(input); + mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer); + MPSImage* Y = mt.texture()->image(); + id encoder = + [commandBuffer.buffer computeCommandEncoder]; + id state = + [[MetalContext sharedInstance] specializedPipelineState:"hardswish" + Constants:@[ + @(X.numberOfImages), + @(X.featureChannels), + @(X.height), + @(X.width) + ]]; + + [encoder setComputePipelineState:state]; + [encoder setTexture:[X texture] atIndex:0]; + [encoder setTexture:[Y texture] atIndex:1]; + + const auto& launchParams = + metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + + auto output = makeTensor(std::move(mt), input.options()); + return output; +} + TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl("hardswish_", TORCH_FN(hardswish_)); + m.impl("hardswish", TORCH_FN(hardswish)); }; } -- 2.7.4