[spirv] Add lowering for std.fdiv, std.frem, std.fsub
authorDenis Khalikov <khalikov.denis@huawei.com>
Wed, 11 Dec 2019 19:17:03 +0000 (11:17 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 11 Dec 2019 19:17:35 +0000 (11:17 -0800)
Closes tensorflow/mlir#313

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/313 from denis0x0D:sandbox/lowering_std_farith 41715070a74d13bfa9401957478978c1bb8006c0
PiperOrigin-RevId: 285023586

mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir

index 9ebbea6..6f3a6a8 100644 (file)
@@ -21,7 +21,10 @@ class BinaryOpPattern<Op src, Op tgt> :
           (tgt $l, $r)>;
 
 def : BinaryOpPattern<AddFOp, SPV_FAddOp>;
+def : BinaryOpPattern<DivFOp, SPV_FDivOp>;
 def : BinaryOpPattern<MulFOp, SPV_FMulOp>;
+def : BinaryOpPattern<RemFOp, SPV_FRemOp>;
+def : BinaryOpPattern<SubFOp, SPV_FSubOp>;
 
 // Constant Op
 // TODO(ravishankarm): Handle lowering other constant types.
index 3e1b442..c113493 100644 (file)
@@ -13,6 +13,20 @@ func @add_sub(%arg0 : i32, %arg1 : i32) {
   return
 }
 
+// CHECK-LABEL: @fadd_scalar
+func @fadd_scalar(%arg: f32) -> f32 {
+  // CHECK: spv.FAdd
+  %0 = addf %arg, %arg : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @fdiv_scalar
+func @fdiv_scalar(%arg: f32) -> f32 {
+  // CHECK: spv.FDiv
+  %0 = divf %arg, %arg : f32
+  return %0 : f32
+}
+
 // CHECK-LABEL: @fmul_scalar
 func @fmul_scalar(%arg: f32) -> f32 {
   // CHECK: spv.FMul
@@ -57,6 +71,20 @@ func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
   return %0 : tensor<4xf32>
 }
 
+// CHECK-LABEL: @frem_scalar
+func @frem_scalar(%arg: f32) -> f32 {
+  // CHECK: spv.FRem
+  %0 = remf %arg, %arg : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @fsub_scalar
+func @fsub_scalar(%arg: f32) -> f32 {
+  // CHECK: spv.FSub
+  %0 = subf %arg, %arg : f32
+  return %0 : f32
+}
+
 // CHECK-LABEL: @div_rem
 func @div_rem(%arg0 : i32, %arg1 : i32) {
   // CHECK: spv.SDiv