From 74127bc062e246fa3bc62cdd6c162ec5d9eb0c10 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Mon, 24 Jun 2019 10:59:05 -0700 Subject: [PATCH] Add SPIR-V Load/Store operations. Currently this only support memory operands being None, Volatile, Aligned and Nontemporal PiperOrigin-RevId: 254792353 --- mlir/include/mlir/IR/OpImplementation.h | 6 + mlir/include/mlir/SPIRV/SPIRVBase.td | 18 +++ mlir/include/mlir/SPIRV/SPIRVOps.td | 103 ++++++++++++++ mlir/lib/Parser/Parser.cpp | 10 ++ mlir/lib/SPIRV/SPIRVOps.cpp | 242 ++++++++++++++++++++++++++++++++ mlir/test/SPIRV/ops.mlir | 221 +++++++++++++++++++++++++++++ 6 files changed, 600 insertions(+) diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 162ed11..c344a22 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -244,9 +244,15 @@ public: /// Parse a `[` token. virtual ParseResult parseLSquare() = 0; + /// Parse a `[` token if present. + virtual ParseResult parseOptionalLSquare() = 0; + /// Parse a `]` token. virtual ParseResult parseRSquare() = 0; + /// Parse a `]` token if present. + virtual ParseResult parseOptionalRSquare() = 0; + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/SPIRV/SPIRVBase.td b/mlir/include/mlir/SPIRV/SPIRVBase.td index c27a4cd..52f8eef 100644 --- a/mlir/include/mlir/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/SPIRV/SPIRVBase.td @@ -190,6 +190,24 @@ def SPV_ImageFormatAttr : let underlyingType = "uint32_t"; } +def SPV_MA_None : EnumAttrCase<"None", 0x0000>; +def SPV_MA_Volatile : EnumAttrCase<"Volatile", 0x0001>; +def SPV_MA_Aligned : EnumAttrCase<"Aligned", 0x0002>; +def SPV_MA_Nontemporal : EnumAttrCase<"Nontemporal", 0x0004>; +def SPV_MA_MakePointerAvailableKHR : EnumAttrCase<"MakePointerAvailableKHR", 0x0008>; +def SPV_MA_MakePointerVisibleKHR : EnumAttrCase<"MakePointerVisibleKHR", 0x0010>; +def SPV_MA_NonPrivatePointerKHR : EnumAttrCase<"NonPrivatePointerKHR", 0x0020>; + +def SPV_MemoryAccessAttr : + EnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [ + SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal, + SPV_MA_MakePointerAvailableKHR, SPV_MA_MakePointerVisibleKHR, + SPV_MA_NonPrivatePointerKHR + ]> { + let cppNamespace = "::mlir::spirv"; + let underlyingType = "uint32_t"; +} + def SPV_MM_Simple : EnumAttrCase<"Simple", 0>; def SPV_MM_GLSL450 : EnumAttrCase<"GLSL450", 1>; def SPV_MM_OpenCL : EnumAttrCase<"OpenCL", 2>; diff --git a/mlir/include/mlir/SPIRV/SPIRVOps.td b/mlir/include/mlir/SPIRV/SPIRVOps.td index 76853cb..3ce4f64 100644 --- a/mlir/include/mlir/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/SPIRV/SPIRVOps.td @@ -65,6 +65,61 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> { let opcode = 133; } +def SPV_LoadOp : SPV_Op<"Load"> { + let summary = "Load value through a pointer."; + + let description = [{ + Result Type is the type of the loaded object. It must be a type + with fixed size; i.e., it cannot be, nor include, any + OpTypeRuntimeArray types. + + Pointer is the pointer to load through. Its type must be an + OpTypePointer whose Type operand is the same as Result Type. + + If present, any Memory Operands must begin with a memory operand + literal. If not present, it is the same as specifying the memory + operand None. + + ### Custom assembly form + + ``` {.ebnf} + memory-access ::= `"None"` | `"Volatile"` | `"Aligned"` integer-literal + | `"NonTemporal"` + + load-op ::= ssa-id ` = spv.Load ` storage-class ssa-use + (`[` memory-access `]`)? ` : ` spirv-element-type + ``` + + For example: + + ``` + %0 = spv.Variable : !spv.ptr + %1 = spv.Load "Function" %0 : f32 + %2 = spv.Load "Function" %0 ["Volatile"] : f32 + %3 = spv.Load "Function" %0 ["Aligned", 4] : f32 + ``` + }]; + + let arguments = (ins + SPV_AnyPtr:$ptr, + OptionalAttr:$memory_access, + OptionalAttr:$alignment + ); + + let results = (outs + SPV_Type:$value + ); + + let extraClassDeclaration = [{ + static StringRef getMemoryAccessAttrName() { + return "memory_access"; + } + static StringRef getAlignmentAttrName() { + return "alignment"; + } + }]; +} + def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> { let summary = "Return with no value from a function with void return type"; @@ -84,6 +139,54 @@ def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> { let opcode = 253; } +def SPV_StoreOp : SPV_Op<"Store"> { + let summary = "Store through a pointer."; + + let description = [{ + Pointer is the pointer to store through. Its type must be an + OpTypePointer whose Type operand is the same as the type of + Object. + + Object is the object to store. + + If present, any Memory Operands must begin with a memory operand + literal. If not present, it is the same as specifying the memory + operand None. + + ### Custom assembly form + + ``` {.ebnf} + store-op ::= `spv.Store ` storage-class ssa-use `, ` ssa-use `, ` + (memory-access)? : spirv-element-type + ``` + + For example: + + ``` + %0 = spv.Variable : !spv.ptr + %1 = spv.FMul ... : f32 + spv.Store "Function" %0, %1 : f32 + spv.Store "Function" %0, %1 ["Volatile"] : f32 + spv.Store "Function" %0, %1 ["Aligned", 4] : f32 + }]; + + let arguments = (ins + SPV_AnyPtr:$ptr, + SPV_Type:$value, + OptionalAttr:$memory_access, + OptionalAttr:$alignment + ); + + let extraClassDeclaration = [{ + static StringRef getMemoryAccessAttrName() { + return "memory_access"; + } + static StringRef getAlignmentAttrName() { + return "alignment"; + } + }]; +} + def SPV_VariableOp : SPV_Op<"Variable"> { let summary = [{ Allocate an object in memory, resulting in a pointer to it, which can be diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index d4428d4..e274012 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3104,11 +3104,21 @@ public: return parser.parseToken(Token::l_square, "expected '['"); } + /// Parses a '[' if present. + ParseResult parseOptionalLSquare() override { + return success(parser.consumeIf(Token::l_square)); + } + /// Parse a `]` token. ParseResult parseRSquare() override { return parser.parseToken(Token::r_square, "expected ']'"); } + /// Parses a ']' if present. + ParseResult parseOptionalRSquare() override { + return success(parser.consumeIf(Token::r_square)); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/SPIRV/SPIRVOps.cpp b/mlir/lib/SPIRV/SPIRVOps.cpp index 4a2eace..c404c90 100644 --- a/mlir/lib/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/SPIRV/SPIRVOps.cpp @@ -37,6 +37,70 @@ static constexpr const char kValueAttrName[] = "value"; // Common utility functions //===----------------------------------------------------------------------===// +static ParseResult parseStorageClassAttribute(spirv::StorageClass &storageClass, + OpAsmParser *parser, + OperationState *state) { + Attribute storageClassAttr; + SmallVector storageAttr; + auto loc = parser->getCurrentLocation(); + if (parser->parseAttribute(storageClassAttr, "storage_class", storageAttr)) { + return failure(); + } + if (!storageClassAttr.isa()) { + return parser->emitError(loc, "expected a string storage class specifier"); + } + auto storageClassOptional = spirv::symbolizeStorageClass( + storageClassAttr.cast().getValue()); + if (!storageClassOptional) { + return parser->emitError(loc, "invalid storage class specifier :") + << storageClassAttr; + } + storageClass = storageClassOptional.getValue(); + return success(); +} + +template +static ParseResult parseMemoryAccessAttributes(OpAsmParser *parser, + OperationState *state) { + // Parse an optional list of attributes staring with '[' + if (parser->parseOptionalLSquare()) { + // Nothing to do + return success(); + } + + Attribute memAccessAttr; + auto loc = parser->getCurrentLocation(); + if (parser->parseAttribute(memAccessAttr, + LoadStoreOpTy::getMemoryAccessAttrName(), + state->attributes)) { + return failure(); + } + // Check that this is a memory attribute + if (!memAccessAttr.isa()) { + return parser->emitError(loc, "expected a string memory access specifier"); + } + auto memAccessOptional = + spirv::symbolizeMemoryAccess(memAccessAttr.cast().getValue()); + if (!memAccessOptional) { + return parser->emitError(loc, "invalid memory access specifier :") + << memAccessAttr; + } + + if (auto memAccess = + memAccessOptional.getValue() == spirv::MemoryAccess::Aligned) { + // Parse integer attribute for alignment. + Attribute alignmentAttr; + Type i32Type = parser->getBuilder().getIntegerType(32); + if (parser->parseComma() || + parser->parseAttribute(alignmentAttr, i32Type, + LoadStoreOpTy::getAlignmentAttrName(), + state->attributes)) { + return failure(); + } + } + return parser->parseRSquare(); +} + // Parses an op that has no inputs and no outputs. static ParseResult parseNoIOOp(OpAsmParser *parser, OperationState *state) { if (parser->parseOptionalAttributeDict(state->attributes)) @@ -44,6 +108,74 @@ static ParseResult parseNoIOOp(OpAsmParser *parser, OperationState *state) { return success(); } +template +static void +printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter *printer, + SmallVectorImpl &elidedAttrs) { + // Print optional memory access attribute. + if (auto memaccess = loadStoreOp.memory_access()) { + elidedAttrs.push_back(LoadStoreOpTy::getMemoryAccessAttrName()); + *printer << " [\"" << memaccess << "\""; + + // Print integer alignment attribute. + if (auto alignment = loadStoreOp.alignment()) { + elidedAttrs.push_back(LoadStoreOpTy::getAlignmentAttrName()); + *printer << ", " << alignment; + } + *printer << "]"; + } +} + +template +static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) { + // ODS checks for attributes values. Just need to verify that if the + // memory-access attribute is Aligned, then the alignment attribute must be + // present. + auto *op = loadStoreOp.getOperation(); + auto memaccessAttr = op->getAttr(LoadStoreOpTy::getMemoryAccessAttrName()); + if (!memaccessAttr) { + // Alignment attribute shouldnt be present if memory access attribute is not + // present. + if (op->getAttr(LoadStoreOpTy::getAlignmentAttrName())) { + return loadStoreOp.emitOpError( + "invalid alignment specification without aligned memory access " + "specification"); + } + return success(); + } + + if (auto memaccess = + spirv::symbolizeMemoryAccess( + memaccessAttr.template cast().getValue()) == + spirv::MemoryAccess::Aligned) { + if (!op->getAttr(LoadStoreOpTy::getAlignmentAttrName())) { + return loadStoreOp.emitOpError("missing alignment value"); + } + } else { + if (op->getAttr(LoadStoreOpTy::getAlignmentAttrName())) { + return loadStoreOp.emitOpError( + "invalid alignment specification with non-aligned memory access " + "specification"); + } + } + return success(); +} + +template +static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value *ptr, + Value *val) { + // ODS already checks ptr is spirv::PointerType. Just check that the pointee + // type of the pointer and the type of the value are the same + // + // TODO(ravishankarm): Check that the value type satisfies restrictions of + // SPIR-V OpLoad/OpStore operations + if (val->getType() != + ptr->getType().cast().getPointeeType()) { + return op.emitOpError("mismatch in result type and pointer type"); + } + return success(); +} + // Prints an op that has no inputs and no outputs. static void printNoIOOp(Operation *op, OpAsmPrinter *printer) { *printer << op->getName(); @@ -122,6 +254,59 @@ static LogicalResult verify(spirv::ConstantOp constOp) { } //===----------------------------------------------------------------------===// +// spv.LoadOp +//===----------------------------------------------------------------------===// + +static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *state) { + // Parse the storage class specification + spirv::StorageClass storageClass; + OpAsmParser::OperandType ptrInfo; + Type elementType; + if (parseStorageClassAttribute(storageClass, parser, state) || + parser->parseOperand(ptrInfo) || + parseMemoryAccessAttributes(parser, state) || + parser->parseOptionalAttributeDict(state->attributes) || + parser->parseColon() || parser->parseType(elementType)) { + return failure(); + } + + auto ptrType = spirv::PointerType::get(elementType, storageClass); + if (parser->resolveOperand(ptrInfo, ptrType, state->operands)) { + return failure(); + } + + state->addTypes(elementType); + return success(); +} + +static void print(spirv::LoadOp loadOp, OpAsmPrinter *printer) { + auto *op = loadOp.getOperation(); + SmallVector elidedAttrs; + *printer + << spirv::LoadOp::getOperationName() << " \"" + << loadOp.ptr()->getType().cast().getStorageClassStr() + << "\" "; + // Print the pointer operand. + printer->printOperand(loadOp.ptr()); + + printMemoryAccessAttribute(loadOp, printer, elidedAttrs); + + printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); + *printer << " : " << loadOp.getType(); +} + +static LogicalResult verify(spirv::LoadOp loadOp) { + // SPIR-V spec : "Result Type is the type of the loaded object. It must be a + // type with fixed size; i.e., it cannot be, nor include, any + // OpTypeRuntimeArray types." + if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(), + loadOp.value()))) { + return failure(); + } + return verifyMemoryAccessAttribute(loadOp); +} + +//===----------------------------------------------------------------------===// // spv.module //===----------------------------------------------------------------------===// @@ -205,6 +390,63 @@ static LogicalResult verifyReturn(spirv::ReturnOp returnOp) { } //===----------------------------------------------------------------------===// +// spv.StoreOp +//===----------------------------------------------------------------------===// + +static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *state) { + // Parse the storage class specification + spirv::StorageClass storageClass; + SmallVector operandInfo; + auto loc = parser->getCurrentLocation(); + Type elementType; + if (parseStorageClassAttribute(storageClass, parser, state) || + parser->parseOperandList(operandInfo, 2) || + parseMemoryAccessAttributes(parser, state) || + parser->parseColon() || parser->parseType(elementType)) { + return failure(); + } + + auto ptrType = spirv::PointerType::get(elementType, storageClass); + if (parser->resolveOperands(operandInfo, {ptrType, elementType}, loc, + state->operands)) { + return failure(); + } + return success(); +} + +static void print(spirv::StoreOp storeOp, OpAsmPrinter *printer) { + auto *op = storeOp.getOperation(); + SmallVector elidedAttrs; + *printer << spirv::StoreOp::getOperationName() << " \"" + << storeOp.ptr() + ->getType() + .cast() + .getStorageClassStr() + << "\" "; + // Print the pointer operand + printer->printOperand(storeOp.ptr()); + *printer << ", "; + // Print the value operand + printer->printOperand(storeOp.value()); + + printMemoryAccessAttribute(storeOp, printer, elidedAttrs); + + *printer << " : " << storeOp.value()->getType(); + + printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); +} + +static LogicalResult verify(spirv::StoreOp storeOp) { + // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an + // OpTypePointer whose Type operand is the same as the type of Object." + if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(), + storeOp.value()))) { + return failure(); + } + return verifyMemoryAccessAttribute(storeOp); +} + +//===----------------------------------------------------------------------===// // spv.Variable //===----------------------------------------------------------------------===// diff --git a/mlir/test/SPIRV/ops.mlir b/mlir/test/SPIRV/ops.mlir index 3b1c2db..7c1b812 100644 --- a/mlir/test/SPIRV/ops.mlir +++ b/mlir/test/SPIRV/ops.mlir @@ -43,6 +43,117 @@ func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> { // ----- //===----------------------------------------------------------------------===// +// spv.LoadOp +//===----------------------------------------------------------------------===// + +// CHECK_LABEL: @simple_load +func @simple_load() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load "Function" %0 : f32 + %1 = spv.Load "Function" %0 : f32 + return +} + +// CHECK_LABEL: @volatile_load +func @volatile_load() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load "Function" %0 ["Volatile"] : f32 + %1 = spv.Load "Function" %0 ["Volatile"] : f32 + return +} + +// CHECK_LABEL: @aligned_load +func @aligned_load() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load "Function" %0 ["Aligned", 4] : f32 + %1 = spv.Load "Function" %0 ["Aligned", 4] : f32 + return +} + +// ----- + +func @simple_load_missing_storageclass() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected non-function type}} + %1 = spv.Load %0 : f32 + return +} + +// ----- + +func @simple_load_missing_operand() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected SSA operand}} + %1 = spv.Load "Function" : f32 + return +} + +// ----- + +func @simple_load_missing_rettype() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+2 {{expected ':'}} + %1 = spv.Load "Function" %0 + return +} + +// ----- + +func @volatile_load_missing_lbrace() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ':'}} + %1 = spv.Load "Function" %0 "Volatile"] : f32 + return +} + +// ----- + +func @volatile_load_missing_rbrace() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ']'}} + %1 = spv.Load "Function" %0 ["Volatile" : f32 + return +} + +// ----- + +func @aligned_load_missing_alignment() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ','}} + %1 = spv.Load "Function" %0 ["Aligned"] : f32 + return +} + +// ----- + +func @aligned_load_missing_comma() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ','}} + %1 = spv.Load "Function" %0 ["Aligned" 4] : f32 + return +} + +// ----- + +func @load_incorrect_attributes() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ']'}} + %1 = spv.Load "Function" %0 ["Volatile", 4] : f32 + return +} + +// ----- + +func @aligned_load_incorrect_attributes() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ']'}} + %1 = spv.Load "Function" %0 ["Aligned", 4, 23] : f32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// // spv.Return //===----------------------------------------------------------------------===// @@ -69,6 +180,116 @@ func @return_mismatch_func_signature() -> () { // ----- //===----------------------------------------------------------------------===// +// spv.StoreOp +//===----------------------------------------------------------------------===// + +func @simple_store(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Store "Function" %0, %arg0 : f32 + spv.Store "Function" %0, %arg0 : f32 + return +} + +// CHECK_LABEL: @volatile_store +func @volatile_store(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Store "Function" %0, %arg0 ["Volatile"] : f32 + spv.Store "Function" %0, %arg0 ["Volatile"] : f32 + return +} + +// CHECK_LABEL: @aligned_store +func @aligned_store(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Store "Function" %0, %arg0 ["Aligned", 4] : f32 + spv.Store "Function" %0, %arg0 ["Aligned", 4] : f32 + return +} + +// ----- + +func @simple_store_missing_ptr_type(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected non-function type}} + spv.Store %0, %arg0 : f32 + return +} + +// ----- + +func @simple_store_missing_operand(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{custom op 'spv.Store' invalid operand}} : f32 + spv.Store "Function" , %arg0 : f32 + return +} + +// ----- + +func @simple_store_missing_operand(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{custom op 'spv.Store' expected 2 operands}} : f32 + spv.Store "Function" %0 : f32 + return +} + +// ----- + +func @volatile_store_missing_lbrace(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ':'}} + spv.Store "Function" %0, %arg0 "Volatile"] : f32 + return +} + +// ----- + +func @volatile_store_missing_rbrace(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ']'}} + spv.Store "Function" %0, %arg0 ["Volatile" : f32 + return +} + +// ----- + +func @aligned_store_missing_alignment(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ','}} + spv.Store "Function" %0, %arg0 ["Aligned"] : f32 + return +} + +// ----- + +func @aligned_store_missing_comma(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ','}} + spv.Store "Function" %0, %arg0 ["Aligned" 4] : f32 + return +} + +// ----- + +func @load_incorrect_attributes(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ']'}} + spv.Store "Function" %0, %arg0 ["Volatile", 4] : f32 + return +} + +// ----- + +func @aligned_store_incorrect_attributes(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{expected ']'}} + spv.Store "Function" %0, %arg0 ["Aligned", 4, 23] : f32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// // spv.Variable //===----------------------------------------------------------------------===// -- 2.7.4