[spirv] Define common types using op definition spec
authorLei Zhang <antiagainst@google.com>
Mon, 17 Jun 2019 20:29:06 +0000 (13:29 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 20 Jun 2019 06:04:39 +0000 (23:04 -0700)
This CL also tightens spv.FMul to only accept 16/32/64-bit floats.

PiperOrigin-RevId: 253649352

mlir/include/mlir/SPIRV/SPIRVBase.td
mlir/include/mlir/SPIRV/SPIRVOps.td
mlir/include/mlir/SPIRV/SPIRVStructureOps.td
mlir/lib/SPIRV/SPIRVOps.cpp
mlir/test/SPIRV/ops.mlir

index aff5e1a34962bb24f903fb28026cbd48ef078a4d..64c692f2ce907a8b379855c2f78330ecc14f2595 100644 (file)
@@ -62,10 +62,44 @@ def SPV_Dialect : Dialect {
 // SPIR-V type definitions
 //===----------------------------------------------------------------------===//
 
+def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
+def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
+def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
+
+// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
+// for the definition of the following types and type categories.
+
+def SPV_Void : TypeAlias<NoneType, "void type">;
+def SPV_Bool : IntOfWidths<[1]>;
+def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>;
+def SPV_Float : FloatOfWidths<[16, 32, 64]>;
+def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>;
+// Component type check is done in the type parser for the following SPIR-V
+// dialect-specific types so we use "Any" here.
+def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
+def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">;
+def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">;
+
+def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
+def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
+def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray]>;
+def SPV_Composite: AnyTypeOf<[SPV_Vector, SPV_AnyArray]>;
+def SPV_Type : AnyTypeOf<[
+    SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
+    SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray
+  ]>;
+
 class SPV_ScalarOrVectorOf<Type type> :
     Type<Or<[type.predicate, VectorOf<[type]>.predicate]>,
          "scalar/vector of " # type.description>;
 
+// TODO(antiagainst): Use a more appropriate way to model optional operands
+class SPV_Optional<Type type> : Variadic<type>;
+
+//===----------------------------------------------------------------------===//
+// SPIR-V enum definitions
+//===----------------------------------------------------------------------===//
+
 // Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY!
 
 def SPV_AM_Logical                    : EnumAttrCase<"Logical", 0>;
index cab3820270bbbd8c3a2ee03bd085c070668f05f2..a58a17968ab30eef42e5a7d4383e07466b43ced9 100644 (file)
@@ -48,8 +48,8 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> {
   }];
 
   let arguments = (ins
-    SPV_ScalarOrVectorOf<AnyFloat>:$operand1,
-    SPV_ScalarOrVectorOf<AnyFloat>:$operand2
+    SPV_ScalarOrVectorOf<SPV_Float>:$operand1,
+    SPV_ScalarOrVectorOf<SPV_Float>:$operand2
   );
 
   let results = (outs
index 6e084b91d69a7405a5f79c3f03df08497b98aac0..ad86fb200e3460e4b136c859505d4845f1fc3630 100644 (file)
@@ -79,7 +79,7 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
 
   // Custom parser and printer implemented by static functions in SPVOps.cpp
   let parser = [{ return parseModule(parser, result); }];
-  let printer = [{ printModule(getOperation(), p); }];
+  let printer = [{ printModule(*this, p); }];
 
   let verifier = [{ return verifyModule(*this); }];
 }
index aa77fdd6ea593714d4d8fd3cc1b780107a52e886..2579f173226ea0eed4049ecc1848c4cccb7da68a 100644 (file)
@@ -85,8 +85,10 @@ static ParseResult parseModule(OpAsmParser *parser, OperationState *state) {
   return success();
 }
 
-static ParseResult printModule(Operation *op, OpAsmPrinter *printer) {
-  *printer << op->getName();
+static ParseResult printModule(spirv::ModuleOp moduleOp,
+                               OpAsmPrinter *printer) {
+  auto *op = moduleOp.getOperation();
+  *printer << spirv::ModuleOp::getOperationName();
   printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
                        /*printBlockTerminators=*/false);
   *printer << " attributes";
@@ -113,8 +115,6 @@ static LogicalResult verifyModule(spirv::ModuleOp moduleOp) {
 
     for (auto &block : funcOp)
       for (auto &op : block) {
-        // TODO(antiagainst): verify that return ops have the same type as the
-        // enclosing function
         if (op.getDialect() == dialect)
           continue;
 
index 225a76e349f048df6a05fe7c7ec620bfef9b7264..bfc5f586f515f1888285e881443d0b271c4d2ab1 100644 (file)
@@ -19,15 +19,23 @@ func @fmul_vector(%arg: vector<4xf32>) -> vector<4xf32> {
 // -----
 
 func @fmul_i32(%arg: i32) -> i32 {
-  // expected-error @+1 {{must be scalar/vector of floating-point}}
+  // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
   %0 = spv.FMul %arg, %arg : i32
   return %0 : i32
 }
 
 // -----
 
+func @fmul_bf16(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
+  %0 = spv.FMul %arg, %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
-  // expected-error @+1 {{must be scalar/vector of floating-point}}
+  // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
   %0 = spv.FMul %arg, %arg : tensor<4xf32>
   return %0 : tensor<4xf32>
 }