--- /dev/null
+#version 450 core
+#define PRECISION $precision
+
+layout(std430) buffer;
+
+/* Qualifiers: layout - storage - precision - memory */
+
+layout(set = 0, binding = 0) uniform PRECISION restrict writeonly image3D uOutput;
+layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
+layout(set = 0, binding = 2) uniform PRECISION restrict Block {
+ ivec4 size;
+ float negative_slope;
+} uBlock;
+
+layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
+
+void main() {
+ const ivec3 pos = ivec3(gl_GlobalInvocationID);
+
+ if (all(lessThan(pos, uBlock.size.xyz))) {
+ const vec4 inval = texelFetch(uInput, pos, 0);
+ const vec4 negative_values = vec4(lessThan(inval, vec4(0.0f)));
+ const vec4 positive_values = vec4(1.0) - negative_values;
+ const vec4 mask = negative_values * vec4(uBlock.negative_slope) + positive_values;
+ const vec4 outval = inval * mask;
+ imageStore(uOutput, pos, outval);
+ }
+}
--- /dev/null
+#version 450 core
+#define PRECISION $precision
+
+layout(std430) buffer;
+
+/* Qualifiers: layout - storage - precision - memory */
+
+layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput;
+layout(set = 0, binding = 1) uniform PRECISION restrict Block {
+ ivec4 size;
+ float negative_slope;
+} uBlock;
+
+layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
+
+void main() {
+ const ivec3 pos = ivec3(gl_GlobalInvocationID);
+
+ if (all(lessThan(pos, uBlock.size.xyz))) {
+ const vec4 inval = imageLoad(uOutput, pos);
+ const vec4 negative_values = vec4(lessThan(inval, vec4(0.0f)));
+ const vec4 positive_values = vec4(1.0) - negative_values;
+ const vec4 mask = negative_values * vec4(uBlock.negative_slope) + positive_values;
+ const vec4 outval = inval * mask;
+ imageStore(uOutput, pos, outval);
+ }
+}
return self;
}
+Tensor leaky_relu(
+ const Tensor& self_arg,
+ const Scalar& negative_slope) {
+ api::Context* const context = api::context();
+
+ const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan();
+ const vTensor& v_self = convert(self);
+
+ vTensor v_output{
+ context,
+ v_self.sizes(),
+ v_self.options(),
+ };
+
+ api::Command::Pool& command_pool = context->command().pool;
+ api::Command::Buffer& command_buffer = command_pool.stream();
+ {
+ if C10_LIKELY(v_output.has_image() && v_self.has_image()) {
+ const struct Block final {
+ uvec3 extents;
+ uint32_t _;
+ float negative_slope;
+ } block {
+ v_output.extents(),
+ 0u,
+ negative_slope.to<float>(),
+ };
+
+ context->dispatch(
+ command_buffer,
+ {
+ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
+ VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
+ VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
+ },
+ VK_KERNEL(leaky_relu),
+ v_output.extents(),
+ context->gpu().adapter->local_work_group_size(),
+ // Write-only access bypasses synchronization but inserts appropriate
+ // barriers if necessary.
+ v_output.image(
+ command_buffer,
+ vTensor::Stage::Compute,
+ vTensor::Access::Write),
+ // Read-only access is implied on const tensors and triggers an async
+ // synchronization if necessary.
+ v_self.image(
+ command_buffer,
+ vTensor::Stage::Compute),
+ // Object lifetime is managed by the resource pool.
+ // It is OK not to keep track of the handle.
+ context->resource().pool.uniform(block).object);
+ }
+ else {
+ TORCH_CHECK(false, "Not implemented!");
+ }
+ }
+ command_pool.submit(context->gpu().queue, command_buffer);
+
+ return convert(v_output);
+}
+
+Tensor& leaky_relu_(
+ Tensor& self,
+ const Scalar& negative_slope) {
+ api::Context* const context = api::context();
+
+ TORCH_CHECK(
+ self.is_vulkan(),
+ "Vulkan: In-place leaky relu is only supported on Vulkan tensors.");
+
+ vTensor& v_self = convert(self);
+
+ api::Command::Pool& command_pool = context->command().pool;
+ api::Command::Buffer& command_buffer = command_pool.stream();
+ {
+ if C10_LIKELY(v_self.has_image()) {
+ const struct Block final {
+ uvec3 extents;
+ uint32_t _;
+ float negative_slope;
+ } block {
+ v_self.extents(),
+ 0u,
+ negative_slope.to<float>(),
+ };
+
+ context->dispatch(
+ command_buffer,
+ {
+ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
+ VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
+ },
+ VK_KERNEL(leaky_relu_),
+ v_self.extents(),
+ context->gpu().adapter->local_work_group_size(),
+ // Read-Write access triggers an async synchronization if necessory
+ // and inserts appropriate barriers if hazards are detected.
+ v_self.image(
+ command_buffer,
+ vTensor::Stage::Compute,
+ vTensor::Access::Read | vTensor::Access::Write),
+ // Object lifetime is managed by the resource pool.
+ // It is OK not to keep track of the handle.
+ context->resource().pool.uniform(block).object);
+ }
+ else {
+ TORCH_CHECK(false, "Not implemented!");
+ }
+ }
+ command_pool.submit(context->gpu().queue, command_buffer);
+
+ return self;
+}
+
Tensor sigmoid(const Tensor& self) {
return ops::activation(self, VK_KERNEL(sigmoid));
}
m.impl(TORCH_SELECTIVE_NAME("aten::hardswish_"), hardswish_);
m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh"), hardtanh);
m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh_"), hardtanh_);
+ m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu"), leaky_relu);
+ m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu_"), leaky_relu_);
m.impl(TORCH_SELECTIVE_NAME("aten::sigmoid"), sigmoid);
m.impl(TORCH_SELECTIVE_NAME("aten::sigmoid_"), sigmoid_);
m.impl(TORCH_SELECTIVE_NAME("aten::tanh"), tanh);
}
}
+TEST(VulkanAPITest, leaky_relu) {
+ if (!at::is_vulkan_available()) {
+ return;
+ }
+
+ for (const auto negative_slope : {0.01, 0.001, 1.0, -0.001}) {
+ const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat));
+ const auto in_vulkan = in_cpu.vulkan();
+
+ const auto out_cpu = at::leaky_relu(in_cpu, negative_slope);
+ const auto out_vulkan = at::leaky_relu(in_vulkan, negative_slope);
+
+ const auto check = almostEqual(out_cpu, out_vulkan.cpu());
+
+ if (!check) {
+ showRtol(out_cpu, out_vulkan.cpu());
+ }
+
+ ASSERT_TRUE(check);
+ }
+}
+
+TEST(VulkanAPITest, leaky_relu_) {
+ if (!at::is_vulkan_available()) {
+ return;
+ }
+
+ for (const auto negative_slope : {0.01, 0.001, 1.0, -0.001}) {
+ auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat));
+ auto vulkan = cpu.vulkan();
+
+ at::leaky_relu_(cpu, negative_slope);
+ at::leaky_relu_(vulkan, negative_slope);
+
+ const auto check = almostEqual(cpu, vulkan.cpu());
+ if (!check) {
+ showRtol(cpu, vulkan.cpu());
+ }
+
+ ASSERT_TRUE(check);
+ }
+}
+
TEST(VulkanAPITest, hardswish) {
if (!at::is_vulkan_available()) {
return;