out.write(outtex, gid_, outz);
}
-kernel void clamp_half4(texture2d_array<half, access::read> in[[texture(0)]],
- texture2d_array<half, access::write> out[[texture(1)]],
- constant half* clamp_buf[[buffer(0)]],
+constant bool clamp_is_arr = (ushort_arg_1 > 1 || ushort_arg_0 > 4);
+constant bool clamp_is_tex = !clamp_is_arr;
+kernel void clamp(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(clamp_is_arr)]],
+ texture2d<half, access::read> in_tex[[texture(0), function_constant(clamp_is_tex)]],
+ texture2d_array<half, access::write> out_arr[[texture(1), function_constant(clamp_is_arr)]],
+ texture2d<half, access::write> out_tex[[texture(1), function_constant(clamp_is_tex)]],
ushort3 gid[[thread_position_in_grid]]) {
- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
+ const ushort w = clamp_is_arr? out_arr.get_width() : out_tex.get_width();
+ const ushort h = clamp_is_arr? out_arr.get_height() : out_tex.get_height();
+ if (gid.x >= w || gid.y >= h) {
return;
}
- const half4 min_(clamp_buf[0], clamp_buf[0], clamp_buf[0], clamp_buf[0]);
- const half4 max_(clamp_buf[1], clamp_buf[1], clamp_buf[1], clamp_buf[1]);
+ const float4 min_(float_arg_0, float_arg_0, float_arg_0, float_arg_0);
+ const float4 max_(float_arg_1, float_arg_1, float_arg_1, float_arg_1);
ushort2 gid_ = gid.xy;
- half4 value = in.read(gid_, gid.z);
- half4 clamped = clamp(value, min_, max_);
- out.write(clamped, gid_, gid.z);
-}
-
-kernel void clamp_half4_nonarray(texture2d<half, access::read> in[[texture(0)]],
- texture2d<half, access::write> out[[texture(1)]],
- constant half* clamp_buf[[buffer(0)]],
- ushort2 gid[[thread_position_in_grid]]) {
- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
- return;
+ if(clamp_is_arr){
+ float4 value = (float4)in_arr.read(gid_, gid.z);
+ half4 clamped = (half4)clamp(value, min_, max_);
+ out_arr.write(clamped, gid_, gid.z);
+ } else {
+ float4 value = (float4)in_tex.read(gid_);
+ half4 clamped = (half4)clamp(value, min_, max_);
+ out_tex.write(clamped, gid_);
}
- const half4 min_(clamp_buf[0], clamp_buf[0], clamp_buf[0], clamp_buf[0]);
- const half4 max_(clamp_buf[1], clamp_buf[1], clamp_buf[1], clamp_buf[1]);
- half4 value = in.read(gid);
- half4 clamped = clamp(value, min_, max_);
- out.write(clamped, gid);
}
kernel void hardswish(texture2d_array<half, access::read> in[[texture(0)]],
+#import <ATen/native/metal/MetalContext.h>
#import <ATen/native/metal/MetalTensorUtils.h>
-#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
#import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
-#import <ATen/native/metal/MetalContext.h>
+#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@implementation MPSCNNClampOp {
}
- (void)encode:(id<MTLCommandBuffer>)cb {
- /*
- `clamp(vector<half4>, float, float)` is not available on iOS 10.0,
- have to use `clamp(vector<half4>, half4, half4)` instead.
- */
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
- id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
- pipelineState:at::native::metal::mpscnn::kernelFor(
- _X, "clamp_half4", "clamp_half4_nonarray")];
-
+ id<MTLComputePipelineState> state =
+ [[MetalContext sharedInstance] specializedPipelineState:"clamp"
+ Constants:@[
+ @(_min.floatValue),
+ @(_max.floatValue),
+ @(_X.featureChannels),
+ @(_X.numberOfImages)
+ ]];
[encoder setComputePipelineState:state];
[encoder setTexture:[_X texture] atIndex:0];
[encoder setTexture:[_Y texture] atIndex:1];
- id<MTLBuffer> clampBuffer = [[MetalContext sharedInstance].device
- newBufferWithLength:2 * sizeof(fp16_t)
- options:MTLResourceOptionCPUCacheModeWriteCombined];
- fp16_t* clampBufferPtr = (fp16_t*)[clampBuffer contents];
- clampBufferPtr[0] = _min.floatValue;
- clampBufferPtr[1] = _max.floatValue;
- [encoder setBuffer:clampBuffer offset:0 atIndex:0];
const auto& launchParams =
at::native::metal::mpscnn::spatialPointwiseKernelLaunchParams(state, _Y);
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
}
bool test_hardtanh_() {
-#if TARGET_OS_IPHONE
__block std::vector<int64_t> size{1, 32, 112, 112};
return TEST(size, __PRETTY_FUNCTION__, ^bool {
auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
auto Y2 = at::hardtanh_(X2, 0, 6.0).cpu();
return almostEqual(Y1, Y2);
});
-#else
- // Skip this test on MacOS as the shader function doesn't work well
- // Will get back and fix it - T82700462
- return true;
-#endif
+}
+
+bool test_hardtanh() {
+ __block std::vector<int64_t> size{1, 3, 4, 4};
+ return TEST(size, __PRETTY_FUNCTION__, ^bool {
+ auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
+ auto Y1 = at::hardtanh(X1, 0, 6.0);
+ auto X2 = X1.metal();
+ auto Y2 = at::hardtanh(X2, 0, 6.0).cpu();
+ return almostEqual(Y1, Y2);
+ });
}
bool test_mean_dim() {
});
}
-
bool test_chunk() {
__block std::vector<int64_t> size{1, 4, 2, 2};
return TEST(size, __PRETTY_FUNCTION__, ^bool {