[iOS][GPU] Fix the clamp shader function for x86_64 (#63062)
authorTao Xu <taox@fb.com>
Thu, 12 Aug 2021 20:18:42 +0000 (13:18 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 12 Aug 2021 20:20:27 +0000 (13:20 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63062

Pervasively, due to the need of supporting 10.0, we used a fp16 version of the clamp kernel on Metal, which didn't work well on x86_64. Since we don't need to support 10.0 anymore, we can use the fp32 version, which works both on arm64 and x86_64.
ghstack-source-id: 135536785

Test Plan:
- `buck test pp-macos`
- Op tests in the playground app

{F641013793}

Reviewed By: husthyc

Differential Revision: D30239931

fbshipit-source-id: 6ad1bf71422b537e052fbd7b7465ba8deb7ca0cf

aten/src/ATen/native/metal/MetalShaders.h
aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm
aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm

index 90b6c68..5c25672 100644 (file)
@@ -367,33 +367,30 @@ kernel void append_features_off(texture2d<half, access::read> in_tex[[texture(0)
     out.write(outtex, gid_, outz);
 }
 
-kernel void clamp_half4(texture2d_array<half, access::read> in[[texture(0)]],
-                 texture2d_array<half, access::write> out[[texture(1)]],
-                 constant half* clamp_buf[[buffer(0)]],
+constant bool clamp_is_arr = (ushort_arg_1 > 1 || ushort_arg_0 > 4);
+constant bool clamp_is_tex = !clamp_is_arr;
+kernel void clamp(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(clamp_is_arr)]],
+                  texture2d<half, access::read> in_tex[[texture(0), function_constant(clamp_is_tex)]],
+                  texture2d_array<half, access::write> out_arr[[texture(1), function_constant(clamp_is_arr)]],
+                  texture2d<half, access::write> out_tex[[texture(1), function_constant(clamp_is_tex)]],
                  ushort3 gid[[thread_position_in_grid]]) {
-    if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
+    const ushort w = clamp_is_arr? out_arr.get_width() : out_tex.get_width();
+    const ushort h = clamp_is_arr? out_arr.get_height() : out_tex.get_height();
+    if (gid.x >= w || gid.y >= h) {
         return;
     }
-    const half4 min_(clamp_buf[0], clamp_buf[0], clamp_buf[0], clamp_buf[0]);
-    const half4 max_(clamp_buf[1], clamp_buf[1], clamp_buf[1], clamp_buf[1]);
+    const float4 min_(float_arg_0, float_arg_0, float_arg_0, float_arg_0);
+    const float4 max_(float_arg_1, float_arg_1, float_arg_1, float_arg_1);
     ushort2 gid_ = gid.xy;
-    half4 value = in.read(gid_, gid.z);
-    half4 clamped = clamp(value, min_, max_);
-    out.write(clamped, gid_, gid.z);
-}
-
-kernel void clamp_half4_nonarray(texture2d<half, access::read> in[[texture(0)]],
-                          texture2d<half, access::write> out[[texture(1)]],
-                          constant half* clamp_buf[[buffer(0)]],
-                          ushort2 gid[[thread_position_in_grid]]) {
-    if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
-        return;
+    if(clamp_is_arr){
+        float4 value = (float4)in_arr.read(gid_, gid.z);
+        half4 clamped = (half4)clamp(value, min_, max_);
+        out_arr.write(clamped, gid_, gid.z);
+    } else {
+        float4 value = (float4)in_tex.read(gid_);
+        half4 clamped = (half4)clamp(value, min_, max_);
+        out_tex.write(clamped, gid_);
     }
-    const half4 min_(clamp_buf[0], clamp_buf[0], clamp_buf[0], clamp_buf[0]);
-    const half4 max_(clamp_buf[1], clamp_buf[1], clamp_buf[1], clamp_buf[1]);
-    half4 value = in.read(gid);
-    half4 clamped = clamp(value, min_, max_);
-    out.write(clamped, gid);
 }
 
 kernel void hardswish(texture2d_array<half, access::read> in[[texture(0)]],
index 2f303a3..d1776b6 100644 (file)
@@ -1,7 +1,7 @@
+#import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/MetalTensorUtils.h>
-#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
-#import <ATen/native/metal/MetalContext.h>
+#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
 
 @implementation MPSCNNClampOp {
 }
 
 - (void)encode:(id<MTLCommandBuffer>)cb {
-  /*
-  `clamp(vector<half4>, float, float)` is not available on iOS 10.0,
-  have to use `clamp(vector<half4>, half4, half4)` instead.
-  */
   id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
-  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
-      pipelineState:at::native::metal::mpscnn::kernelFor(
-                        _X, "clamp_half4", "clamp_half4_nonarray")];
-
+  id<MTLComputePipelineState> state =
+      [[MetalContext sharedInstance] specializedPipelineState:"clamp"
+                                                    Constants:@[
+                                                      @(_min.floatValue),
+                                                      @(_max.floatValue),
+                                                      @(_X.featureChannels),
+                                                      @(_X.numberOfImages)
+                                                    ]];
   [encoder setComputePipelineState:state];
   [encoder setTexture:[_X texture] atIndex:0];
   [encoder setTexture:[_Y texture] atIndex:1];
-  id<MTLBuffer> clampBuffer = [[MetalContext sharedInstance].device
-      newBufferWithLength:2 * sizeof(fp16_t)
-                  options:MTLResourceOptionCPUCacheModeWriteCombined];
-  fp16_t* clampBufferPtr = (fp16_t*)[clampBuffer contents];
-  clampBufferPtr[0] = _min.floatValue;
-  clampBufferPtr[1] = _max.floatValue;
-  [encoder setBuffer:clampBuffer offset:0 atIndex:0];
   const auto& launchParams =
       at::native::metal::mpscnn::spatialPointwiseKernelLaunchParams(state, _Y);
   [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
index 3041e8d..ee992d9 100644 (file)
@@ -45,6 +45,7 @@ bool test_upsampling_nearest2d_vec();
 bool test_upsampling_nearest2d_vec2();
 bool test_adaptive_avg_pool2d();
 bool test_hardtanh_();
+bool test_hardtanh();
 bool test_reshape();
 bool test_mean_dim();
 bool test_mean_dim2();
index 0c47d9a..69497a9 100644 (file)
@@ -792,7 +792,6 @@ bool test_reflection_pad2d() {
 }
 
 bool test_hardtanh_() {
-#if TARGET_OS_IPHONE
   __block std::vector<int64_t> size{1, 32, 112, 112};
   return TEST(size, __PRETTY_FUNCTION__, ^bool {
     auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
@@ -801,11 +800,17 @@ bool test_hardtanh_() {
     auto Y2 = at::hardtanh_(X2, 0, 6.0).cpu();
     return almostEqual(Y1, Y2);
   });
-#else
-  // Skip this test on MacOS as the shader function doesn't work well
-  // Will get back and fix it - T82700462
-  return true;
-#endif
+}
+
+bool test_hardtanh() {
+  __block std::vector<int64_t> size{1, 3, 4, 4};
+  return TEST(size, __PRETTY_FUNCTION__, ^bool {
+    auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
+    auto Y1 = at::hardtanh(X1, 0, 6.0);
+    auto X2 = X1.metal();
+    auto Y2 = at::hardtanh(X2, 0, 6.0).cpu();
+    return almostEqual(Y1, Y2);
+  });
 }
 
 bool test_mean_dim() {
@@ -841,7 +846,6 @@ bool test_mean_dim3() {
     });
 }
 
-
 bool test_chunk() {
 __block std::vector<int64_t> size{1, 4, 2, 2};
 return TEST(size, __PRETTY_FUNCTION__, ^bool {
index a9d1ad1..d8b69ad 100644 (file)
@@ -73,6 +73,7 @@
   REG_TEST("test_upsampling_nearest2d_vec2", test_upsampling_nearest2d_vec2);
   REG_TEST("test_adaptive_avg_pool2d", test_adaptive_avg_pool2d);
   REG_TEST("test_hardtanh_", test_hardtanh_);
+  REG_TEST("test_hardtanh", test_hardtanh);
   REG_TEST("test_reshape", test_reshape);
   REG_TEST("test_mean_dim", test_mean_dim);
   REG_TEST("test_mean_dim2", test_mean_dim2);