// SPIR-V type definitions
//===----------------------------------------------------------------------===//
+def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
+def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
+def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
+
+// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
+// for the definition of the following types and type categories.
+
+def SPV_Void : TypeAlias<NoneType, "void type">;
+def SPV_Bool : IntOfWidths<[1]>;
+def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>;
+def SPV_Float : FloatOfWidths<[16, 32, 64]>;
+def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>;
+// Component type check is done in the type parser for the following SPIR-V
+// dialect-specific types so we use "Any" here.
+def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
+def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">;
+def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">;
+
+def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
+def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
+def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray]>;
+def SPV_Composite: AnyTypeOf<[SPV_Vector, SPV_AnyArray]>;
+def SPV_Type : AnyTypeOf<[
+ SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
+ SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray
+ ]>;
+
class SPV_ScalarOrVectorOf<Type type> :
Type<Or<[type.predicate, VectorOf<[type]>.predicate]>,
"scalar/vector of " # type.description>;
+// TODO(antiagainst): Use a more appropriate way to model optional operands
+class SPV_Optional<Type type> : Variadic<type>;
+
+//===----------------------------------------------------------------------===//
+// SPIR-V enum definitions
+//===----------------------------------------------------------------------===//
+
// Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY!
def SPV_AM_Logical : EnumAttrCase<"Logical", 0>;
return success();
}
-static ParseResult printModule(Operation *op, OpAsmPrinter *printer) {
- *printer << op->getName();
+static ParseResult printModule(spirv::ModuleOp moduleOp,
+ OpAsmPrinter *printer) {
+ auto *op = moduleOp.getOperation();
+ *printer << spirv::ModuleOp::getOperationName();
printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
*printer << " attributes";
for (auto &block : funcOp)
for (auto &op : block) {
- // TODO(antiagainst): verify that return ops have the same type as the
- // enclosing function
if (op.getDialect() == dialect)
continue;
// -----
func @fmul_i32(%arg: i32) -> i32 {
- // expected-error @+1 {{must be scalar/vector of floating-point}}
+ // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
%0 = spv.FMul %arg, %arg : i32
return %0 : i32
}
// -----
+func @fmul_bf16(%arg: bf16) -> bf16 {
+ // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
+ %0 = spv.FMul %arg, %arg : bf16
+ return %0 : bf16
+}
+
+// -----
+
func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
- // expected-error @+1 {{must be scalar/vector of floating-point}}
+ // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
%0 = spv.FMul %arg, %arg : tensor<4xf32>
return %0 : tensor<4xf32>
}