[spirv] Add binary arithmetic operations.
authorDenis Khalikov <dennis.khalikov@gmail.com>
Tue, 30 Jul 2019 16:42:33 +0000 (09:42 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 30 Jul 2019 18:55:12 +0000 (11:55 -0700)
Add binary operations such as: OpIAdd, OpFAdd, OpISub, OpFSub, OpIMul,
OpFDiv, OpFRem, OpFMod.

Closes tensorflow/mlir#54

PiperOrigin-RevId: 260734166

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/test/Dialect/SPIRV/Serialization/bin_ops.mlir [new file with mode: 0644]
mlir/test/Dialect/SPIRV/ops.mlir

index 4483550..a12f339 100644 (file)
@@ -99,7 +99,15 @@ def SPV_OC_OpStore             : I32EnumAttrCase<"OpStore", 62>;
 def SPV_OC_OpAccessChain       : I32EnumAttrCase<"OpAccessChain", 65>;
 def SPV_OC_OpDecorate          : I32EnumAttrCase<"OpDecorate", 71>;
 def SPV_OC_OpCompositeExtract  : I32EnumAttrCase<"OpCompositeExtract", 81>;
+def SPV_OC_OpIAdd              : I32EnumAttrCase<"OpIAdd", 128>;
+def SPV_OC_OpFAdd              : I32EnumAttrCase<"OpFAdd", 129>;
+def SPV_OC_OpISub              : I32EnumAttrCase<"OpISub", 130>;
+def SPV_OC_OpFSub              : I32EnumAttrCase<"OpFSub", 131>;
+def SPV_OC_OpIMul              : I32EnumAttrCase<"OpIMul", 132>;
 def SPV_OC_OpFMul              : I32EnumAttrCase<"OpFMul", 133>;
+def SPV_OC_OpFDiv              : I32EnumAttrCase<"OpFDiv", 136>;
+def SPV_OC_OpFRem              : I32EnumAttrCase<"OpFRem", 140>;
+def SPV_OC_OpFMod              : I32EnumAttrCase<"OpFMod", 141>;
 def SPV_OC_OpReturn            : I32EnumAttrCase<"OpReturn", 253>;
 
 def SPV_OpcodeAttr :
@@ -112,7 +120,8 @@ def SPV_OpcodeAttr :
       SPV_OC_OpConstantNull, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
       SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
       SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract,
-      SPV_OC_OpFMul, SPV_OC_OpReturn
+      SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
+      SPV_OC_OpFMul, SPV_OC_OpFDiv, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpReturn
       ]> {
     let returnType = "::mlir::spirv::Opcode";
     let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
@@ -577,4 +586,22 @@ class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
   bit autogenSerialization = 1;
 }
 
+class SPV_ArithmeticOp<string mnemonic, Type type,
+                       list<OpTrait> traits = []> :
+      SPV_Op<mnemonic,
+             !listconcat(traits, [NoSideEffect, SameOperandsAndResultType])> {
+  let arguments = (ins
+    SPV_ScalarOrVectorOf<type>:$operand1,
+    SPV_ScalarOrVectorOf<type>:$operand2
+  );
+  let results = (outs
+    SPV_ScalarOrVectorOf<type>:$result
+  );
+  let parser = [{ return impl::parseBinaryOp(parser, result); }];
+  let printer = [{ return impl::printBinaryOp(getOperation(), p); }];
+  // No additional verification needed in addition to the ODS-generated ones.
+  let verifier = [{ return success(); }];
+}
+
+
 #endif // SPIRV_BASE
index 022df0f..8f97f78 100644 (file)
@@ -66,8 +66,11 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
     no remaining (unused) indexes.
 
      Each index in Indexes
+
     - must be a scalar integer type,
+
     - is treated as a signed count, and
+
     - must be an OpConstant when indexing into a structure.
 
     ### Custom assembly form
@@ -250,7 +253,100 @@ def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
 
 // -----
 
-def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> {
+def SPV_FAddOp : SPV_ArithmeticOp<"FAdd", SPV_Float, [Commutative]> {
+  let summary = "Floating-point addition of Operand 1 and Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    fadd-op ::= ssa-id `=` `spv.FAdd` ssa-use, ssa-use
+                          `:` float-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.FAdd %0, %1 : f32
+    %5 = spv.FAdd %2, %3 : vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_FDivOp : SPV_ArithmeticOp<"FDiv", SPV_Float> {
+  let summary = "Floating-point division of Operand 1 divided by Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.  The resulting value is undefined
+    if Operand 2 is 0.
+    ### Custom assembly form
+    ``` {.ebnf}
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    fdiv-op ::= ssa-id `=` `spv.FDiv` ssa-use, ssa-use
+                          `:` float-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.FDiv %0, %1 : f32
+    %5 = spv.FDiv %2, %3 : vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_FModOp : SPV_ArithmeticOp<"FMod", SPV_Float> {
+  let summary = [{
+    The floating-point remainder whose sign matches the sign of Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.  The resulting value is undefined
+    if Operand 2 is 0.  Otherwise, the result is the remainder r of Operand
+    1 divided by Operand 2 where if r ≠ 0, the sign of r is the same as the
+    sign of Operand 2.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    fmod-op ::= ssa-id `=` `spv.FMod` ssa-use, ssa-use
+                          `:` float-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.FMod %0, %1 : f32
+    %5 = spv.FMod %2, %3 : vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_FMulOp : SPV_ArithmeticOp<"FMul", SPV_Float, [Commutative]> {
   let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
 
   let description = [{
@@ -266,32 +362,190 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> {
     ``` {.ebnf}
     float-scalar-vector-type ::= float-type |
                                  `vector<` integer-literal `x` float-type `>`
-    execution-mode-op ::= `spv.FMul` ssa-use, ssa-use
+    fmul-op ::= `spv.FMul` ssa-use, ssa-use
                           `:` float-scalar-vector-type
     ```
 
     For example:
 
     ```
-    spv.FMul %0, %1 : f32
-    spv.FMul %2, %3 : vector<4xf32>
+    %4 = spv.FMul %0, %1 : f32
+    %5 = spv.FMul %2, %3 : vector<4xf32>
     ```
   }];
+}
 
-  let arguments = (ins
-    SPV_ScalarOrVectorOf<SPV_Float>:$operand1,
-    SPV_ScalarOrVectorOf<SPV_Float>:$operand2
-  );
+// -----
 
-  let results = (outs
-    SPV_ScalarOrVectorOf<AnyFloat>:$result
-  );
+def SPV_FRemOp : SPV_ArithmeticOp<"FRem", SPV_Float> {
+  let summary = [{
+    The floating-point remainder whose sign matches the sign of Operand 1.
+  }];
 
-  let parser = [{ return impl::parseBinaryOp(parser, result); }];
-  let printer = [{ return impl::printBinaryOp(getOperation(), p); }];
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type.
 
-  // No additional verification needed in addition to the ODS-generated ones.
-  let verifier = [{ return success(); }];
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.  The resulting value is undefined
+    if Operand 2 is 0.  Otherwise, the result is the remainder r of Operand
+    1 divided by Operand 2 where if r ≠ 0, the sign of r is the same as the
+    sign of Operand 1.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    frem-op ::= ssa-id `=` `spv.FRemOp` ssa-use, ssa-use
+                          `:` float-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.FRemOp %0, %1 : f32
+    %5 = spv.FRemOp %2, %3 : vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_FSubOp : SPV_ArithmeticOp<"FSub", SPV_Float> {
+  let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
+
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    fsub-op ::= ssa-id `=` `spv.FRemOp` ssa-use, ssa-use
+                          `:` float-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.FRemOp %0, %1 : f32
+    %5 = spv.FRemOp %2, %3 : vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_IAddOp : SPV_ArithmeticOp<"IAdd", SPV_Integer, [Commutative]> {
+  let summary = "Integer addition of Operand 1 and Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same number of components as Result
+    Type. They must have the same component width as Result Type.
+
+    The resulting value will equal the low-order N bits of the correct
+    result R, where N is the component width and R is computed with enough
+    precision to avoid overflow and underflow.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    iadd-op ::= ssa-id `=` `spv.IAdd` ssa-use, ssa-use
+                          `:` integer-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.IAdd %0, %1 : i32
+    %5 = spv.IAdd %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_IMulOp : SPV_ArithmeticOp<"IMul", SPV_Integer, [Commutative]> {
+  let summary = "Integer multiplication of Operand 1 and Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same number of components as Result
+    Type. They must have the same component width as Result Type.
+
+    The resulting value will equal the low-order N bits of the correct
+    result R, where N is the component width and R is computed with enough
+    precision to avoid overflow and underflow.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    imul-op ::= ssa-id `=` `spv.IMul` ssa-use, ssa-use
+                          `:` integer-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.IMul %0, %1 : i32
+    %5 = spv.IMul %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_ISubOp : SPV_ArithmeticOp<"ISub", SPV_Integer> {
+  let summary = "Integer subtraction of Operand 2 from Operand 1.";
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same number of components as Result
+    Type. They must have the same component width as Result Type.
+
+    The resulting value will equal the low-order N bits of the correct
+    result R, where N is the component width and R is computed with enough
+    precision to avoid overflow and underflow.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    isub-op ::= `spv.ISub` ssa-use, ssa-use
+                          `:` integer-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.ISub %0, %1 : i32
+    %5 = spv.ISub %2, %3 : vector<4xi32>
+
+    ```
+  }];
 }
 
 // -----
diff --git a/mlir/test/Dialect/SPIRV/Serialization/bin_ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/bin_ops.mlir
new file mode 100644 (file)
index 0000000..e7d5ac6
--- /dev/null
@@ -0,0 +1,52 @@
+// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
+
+func @spirv_bin_ops() -> () {
+  spv.module "Logical" "VulkanKHR" {
+   func @fmul(%arg0 : f32, %arg1 : f32) {
+      // CHECK: {{%.*}}= spv.FMul {{%.*}}, {{%.*}} : f32
+      %0 = spv.FMul %arg0, %arg1 : f32
+      spv.Return
+    }
+    func @fadd(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) {
+      // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : vector<4xf32>
+      %0 = spv.FAdd %arg0, %arg1 : vector<4xf32>
+      spv.Return
+    }
+    func @fdiv(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) {
+      // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : vector<4xf32>
+      %0 = spv.FDiv %arg0, %arg1 : vector<4xf32>
+      spv.Return
+    }
+    func @fmod(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) {
+      // CHECK: {{%.*}} = spv.FMod {{%.*}}, {{%.*}} : vector<4xf32>
+      %0 = spv.FMod %arg0, %arg1 : vector<4xf32>
+      spv.Return
+    }
+    func @fsub(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) {
+      // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : vector<4xf32>
+      %0 = spv.FSub %arg0, %arg1 : vector<4xf32>
+      spv.Return
+    }
+    func @frem(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) {
+      // CHECK: {{%.*}} = spv.FRem {{%.*}}, {{%.*}} : vector<4xf32>
+      %0 = spv.FRem %arg0, %arg1 : vector<4xf32>
+      spv.Return
+    }
+    func @iadd(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+      // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : vector<4xi32>
+      %0 = spv.IAdd %arg0, %arg1 : vector<4xi32>
+      spv.Return
+    }
+    func @isub(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+      // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : vector<4xi32>
+      %0 = spv.ISub %arg0, %arg1 : vector<4xi32>
+      spv.Return
+    }
+    func @imul(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+      // CHECK: {{%.*}} = spv.IMul {{%.*}}, {{%.*}} : vector<4xi32>
+      %0 = spv.IMul %arg0, %arg1 : vector<4xi32>
+      spv.Return
+    }
+  }
+  return
+}
index 7da21b9..c8771ec 100644 (file)
@@ -392,6 +392,42 @@ spv.module "Logical" "VulkanKHR" {
 // -----
 
 //===----------------------------------------------------------------------===//
+// spv.FAdd
+//===----------------------------------------------------------------------===//
+
+func @fadd_scalar(%arg: f32) -> f32 {
+  // CHECK: spv.FAdd
+  %0 = spv.FAdd %arg, %arg : f32
+  return %0 : f32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.FDiv
+//===----------------------------------------------------------------------===//
+
+func @fdiv_scalar(%arg: f32) -> f32 {
+  // CHECK: spv.FDiv
+  %0 = spv.FDiv %arg, %arg : f32
+  return %0 : f32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.FMod
+//===----------------------------------------------------------------------===//
+
+func @fmod_scalar(%arg: f32) -> f32 {
+  // CHECK: spv.FMod
+  %0 = spv.FMod %arg, %arg : f32
+  return %0 : f32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // spv.FMul
 //===----------------------------------------------------------------------===//
 
@@ -434,6 +470,66 @@ func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
 // -----
 
 //===----------------------------------------------------------------------===//
+// spv.FRem
+//===----------------------------------------------------------------------===//
+
+func @frem_scalar(%arg: f32) -> f32 {
+  // CHECK: spv.FRem
+  %0 = spv.FRem %arg, %arg : f32
+  return %0 : f32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.FSub
+//===----------------------------------------------------------------------===//
+
+func @fsub_scalar(%arg: f32) -> f32 {
+  // CHECK: spv.FSub
+  %0 = spv.FSub %arg, %arg : f32
+  return %0 : f32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.IAdd
+//===----------------------------------------------------------------------===//
+
+func @iadd_scalar(%arg: i32) -> i32 {
+  // CHECK: spv.IAdd
+  %0 = spv.IAdd %arg, %arg : i32
+  return %0 : i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.IMul
+//===----------------------------------------------------------------------===//
+
+func @imul_scalar(%arg: i32) -> i32 {
+  // CHECK: spv.IMul
+  %0 = spv.IMul %arg, %arg : i32
+  return %0 : i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.ISub
+//===----------------------------------------------------------------------===//
+
+func @isub_scalar(%arg: i32) -> i32 {
+  // CHECK: spv.ISub
+  %0 = spv.ISub %arg, %arg : i32
+  return %0 : i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // spv.LoadOp
 //===----------------------------------------------------------------------===//