Implement leaky relu op
authorVincent Phan <vincentphan@fb.com>
Fri, 27 Aug 2021 20:51:38 +0000 (13:51 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 20:52:49 +0000 (13:52 -0700)
Summary: Implemented leaky relu op as per: https://www.internalfb.com/tasks/?t=97492679

Test Plan:
buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_api_test_binAndroid\#android-arm64 --show-output
adb push buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_api_test
adb shell "/data/local/tmp/vulkan_api_test"

all tests pass, including new ones

Reviewed By: SS-JIA

Differential Revision: D30186225

fbshipit-source-id: fdb1f8f7b3a28b5504581822185c0475dcd53a3e

aten/src/ATen/native/vulkan/glsl/leaky_relu.glsl [new file with mode: 0644]
aten/src/ATen/native/vulkan/glsl/leaky_relu_.glsl [new file with mode: 0644]
aten/src/ATen/native/vulkan/ops/Clamp.cpp
aten/src/ATen/test/vulkan_api_test.cpp

diff --git a/aten/src/ATen/native/vulkan/glsl/leaky_relu.glsl b/aten/src/ATen/native/vulkan/glsl/leaky_relu.glsl
new file mode 100644 (file)
index 0000000..f947e78
--- /dev/null
@@ -0,0 +1,28 @@
+#version 450 core
+#define PRECISION $precision
+
+layout(std430) buffer;
+
+/* Qualifiers: layout - storage - precision - memory */
+
+layout(set = 0, binding = 0) uniform PRECISION restrict writeonly image3D   uOutput;
+layout(set = 0, binding = 1) uniform PRECISION                    sampler3D uInput;
+layout(set = 0, binding = 2) uniform PRECISION restrict           Block {
+  ivec4 size;
+  float negative_slope;
+} uBlock;
+
+layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
+
+void main() {
+  const ivec3 pos = ivec3(gl_GlobalInvocationID);
+
+  if (all(lessThan(pos, uBlock.size.xyz))) {
+    const vec4 inval = texelFetch(uInput, pos, 0);
+    const vec4 negative_values = vec4(lessThan(inval, vec4(0.0f)));
+    const vec4 positive_values = vec4(1.0) - negative_values;
+    const vec4 mask = negative_values * vec4(uBlock.negative_slope) + positive_values;
+    const vec4 outval = inval * mask;
+    imageStore(uOutput, pos, outval);
+  }
+}
diff --git a/aten/src/ATen/native/vulkan/glsl/leaky_relu_.glsl b/aten/src/ATen/native/vulkan/glsl/leaky_relu_.glsl
new file mode 100644 (file)
index 0000000..345e669
--- /dev/null
@@ -0,0 +1,27 @@
+#version 450 core
+#define PRECISION $precision
+
+layout(std430) buffer;
+
+/* Qualifiers: layout - storage - precision - memory */
+
+layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput;
+layout(set = 0, binding = 1)          uniform PRECISION restrict Block {
+  ivec4 size;
+  float negative_slope;
+} uBlock;
+
+layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
+
+void main() {
+  const ivec3 pos = ivec3(gl_GlobalInvocationID);
+
+  if (all(lessThan(pos, uBlock.size.xyz))) {
+    const vec4 inval = imageLoad(uOutput, pos);
+    const vec4 negative_values = vec4(lessThan(inval, vec4(0.0f)));
+    const vec4 positive_values = vec4(1.0) - negative_values;
+    const vec4 mask = negative_values * vec4(uBlock.negative_slope) + positive_values;
+    const vec4 outval = inval * mask;
+    imageStore(uOutput, pos, outval);
+  }
+}
index c6f046e..7982b0e 100644 (file)
@@ -404,6 +404,121 @@ Tensor& hardshrink_(
   return self;
 }
 
+Tensor leaky_relu(
+    const Tensor& self_arg,
+    const Scalar& negative_slope) {
+  api::Context* const context = api::context();
+
+  const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan();
+  const vTensor& v_self = convert(self);
+
+  vTensor v_output{
+    context,
+    v_self.sizes(),
+    v_self.options(),
+  };
+
+  api::Command::Pool& command_pool = context->command().pool;
+  api::Command::Buffer& command_buffer = command_pool.stream();
+  {
+    if C10_LIKELY(v_output.has_image() && v_self.has_image()) {
+      const struct Block final {
+        uvec3 extents;
+        uint32_t _;
+        float negative_slope;
+      } block {
+        v_output.extents(),
+        0u,
+        negative_slope.to<float>(),
+      };
+
+      context->dispatch(
+          command_buffer,
+          {
+            VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
+            VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
+            VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
+          },
+          VK_KERNEL(leaky_relu),
+          v_output.extents(),
+          context->gpu().adapter->local_work_group_size(),
+          // Write-only access bypasses synchronization but inserts appropriate
+          // barriers if necessary.
+          v_output.image(
+              command_buffer,
+              vTensor::Stage::Compute,
+              vTensor::Access::Write),
+          // Read-only access is implied on const tensors and triggers an async
+          // synchronization if necessary.
+          v_self.image(
+              command_buffer,
+              vTensor::Stage::Compute),
+          // Object lifetime is managed by the resource pool.
+          // It is OK not to keep track of the handle.
+          context->resource().pool.uniform(block).object);
+    }
+    else {
+      TORCH_CHECK(false, "Not implemented!");
+    }
+  }
+  command_pool.submit(context->gpu().queue, command_buffer);
+
+  return convert(v_output);
+}
+
+Tensor& leaky_relu_(
+    Tensor& self,
+    const Scalar& negative_slope) {
+  api::Context* const context = api::context();
+
+  TORCH_CHECK(
+      self.is_vulkan(),
+      "Vulkan: In-place leaky relu is only supported on Vulkan tensors.");
+
+  vTensor& v_self = convert(self);
+
+  api::Command::Pool& command_pool = context->command().pool;
+  api::Command::Buffer& command_buffer = command_pool.stream();
+  {
+    if C10_LIKELY(v_self.has_image()) {
+      const struct Block final {
+        uvec3 extents;
+        uint32_t _;
+        float negative_slope;
+      } block {
+        v_self.extents(),
+        0u,
+        negative_slope.to<float>(),
+      };
+
+      context->dispatch(
+          command_buffer,
+          {
+            VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
+            VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
+          },
+          VK_KERNEL(leaky_relu_),
+          v_self.extents(),
+          context->gpu().adapter->local_work_group_size(),
+          // Read-Write access triggers an async synchronization if necessory
+          // and inserts appropriate barriers if hazards are detected.
+          v_self.image(
+              command_buffer,
+              vTensor::Stage::Compute,
+              vTensor::Access::Read | vTensor::Access::Write),
+          // Object lifetime is managed by the resource pool.
+          // It is OK not to keep track of the handle.
+          context->resource().pool.uniform(block).object);
+    }
+    else {
+      TORCH_CHECK(false, "Not implemented!");
+    }
+  }
+  command_pool.submit(context->gpu().queue, command_buffer);
+
+  return self;
+}
+
 Tensor sigmoid(const Tensor& self) {
   return ops::activation(self, VK_KERNEL(sigmoid));
 }
@@ -433,6 +548,8 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
   m.impl(TORCH_SELECTIVE_NAME("aten::hardswish_"), hardswish_);
   m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh"), hardtanh);
   m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh_"), hardtanh_);
+  m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu"), leaky_relu);
+  m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu_"), leaky_relu_);
   m.impl(TORCH_SELECTIVE_NAME("aten::sigmoid"), sigmoid);
   m.impl(TORCH_SELECTIVE_NAME("aten::sigmoid_"), sigmoid_);
   m.impl(TORCH_SELECTIVE_NAME("aten::tanh"), tanh);
index 2873d3c..d4b466a 100644 (file)
@@ -979,6 +979,49 @@ TEST(VulkanAPITest, hardshrink_) {
   }
 }
 
+TEST(VulkanAPITest, leaky_relu) {
+  if (!at::is_vulkan_available()) {
+    return;
+  }
+
+  for (const auto negative_slope : {0.01, 0.001, 1.0, -0.001}) {
+    const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat));
+    const auto in_vulkan = in_cpu.vulkan();
+
+    const auto out_cpu = at::leaky_relu(in_cpu, negative_slope);
+    const auto out_vulkan = at::leaky_relu(in_vulkan, negative_slope);
+
+    const auto check = almostEqual(out_cpu, out_vulkan.cpu());
+
+    if (!check) {
+      showRtol(out_cpu, out_vulkan.cpu());
+    }
+
+    ASSERT_TRUE(check);
+  }
+}
+
+TEST(VulkanAPITest, leaky_relu_) {
+  if (!at::is_vulkan_available()) {
+    return;
+  }
+
+  for (const auto negative_slope : {0.01, 0.001, 1.0, -0.001}) {
+    auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat));
+    auto vulkan = cpu.vulkan();
+
+    at::leaky_relu_(cpu, negative_slope);
+    at::leaky_relu_(vulkan, negative_slope);
+
+    const auto check = almostEqual(cpu, vulkan.cpu());
+    if (!check) {
+      showRtol(cpu, vulkan.cpu());
+    }
+
+    ASSERT_TRUE(check);
+  }
+}
+
 TEST(VulkanAPITest, hardswish) {
   if (!at::is_vulkan_available()) {
     return;