def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
+def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>;
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
- SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpFMul, SPV_OC_OpReturn
+ SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract,
+ SPV_OC_OpFMul, SPV_OC_OpReturn
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
// -----
+def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
+ let summary = [{
+ Create a pointer into a composite object that can be used with OpLoad
+ and OpStore.
+ }];
+
+ let description = [{
+ Result Type must be an OpTypePointer. Its Type operand must be the type
+ reached by walking the Base’s type hierarchy down to the last provided
+ index in Indexes, and its Storage Class operand must be the same as the
+ Storage Class of Base.
+
+ Base must be a pointer, pointing to the base of a composite object.
+
+ Indexes walk the type hierarchy to the desired depth, potentially down
+ to scalar granularity. The first index in Indexes will select the top-
+ level member/element/component/element of the base composite. All
+ composite constituents use zero-based numbering, as described by their
+ OpType… instruction. The second index will apply similarly to that
+ result, and so on. Once any non-composite type is reached, there must be
+ no remaining (unused) indexes.
+
+ Each index in Indexes
+ - must be a scalar integer type,
+ - is treated as a signed count, and
+ - must be an OpConstant when indexing into a structure.
+
+ ### Custom assembly form
+ ``` {.ebnf}
+ access-chain-op ::= ssa-id `=` `spv.AccessChain` ssa-use
+ `[` ssa-use (',' ssa-use)* `]`
+ `:` pointer-type
+ ```
+
+ For example:
+
+ ```
+ %0 = "spv.constant"() { value = 1: i32} : () -> i32
+ %1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+ %2 = spv.AccessChain %1[%0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+ %3 = spv.Load "Function" %2 ["Volatile"] : !spv.array<4xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPV_AnyPtr:$base_ptr,
+ Variadic<SPV_Integer>:$indices
+ );
+
+ let results = (outs
+ SPV_AnyPtr:$component_ptr
+ );
+
+ let autogenSerialization = 0;
+}
+
+// -----
+
def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
let summary = "Extract a part of a composite object.";
return dest;
}
+static LogicalResult extractValueFromConstOp(Operation *op,
+ int32_t &indexValue) {
+ auto constOp = llvm::dyn_cast<spirv::ConstantOp>(op);
+ if (!constOp) {
+ return failure();
+ }
+ auto valueAttr = constOp.value();
+ auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
+ if (!integerValueAttr) {
+ return failure();
+ }
+ indexValue = integerValueAttr.getInt();
+ return success();
+}
+
template <typename EnumClass>
static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
OperationState *state) {
printer->printOptionalAttrDict(op->getAttrs());
}
+//===----------------------------------------------------------------------===//
+// spv.AccessChainOp
+//===----------------------------------------------------------------------===//
+
+static Type getElementPtrType(Type type, ArrayRef<Value *> indices,
+ Location baseLoc) {
+ if (!indices.size()) {
+ emitError(baseLoc, "'spv.AccessChain' op expected at least "
+ "one index ");
+ return nullptr;
+ }
+
+ auto ptrType = type.dyn_cast<spirv::PointerType>();
+ if (!ptrType) {
+ emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
+ "to composite type, but provided ")
+ << type;
+ return nullptr;
+ }
+
+ auto resultType = ptrType.getPointeeType();
+ auto resultStorageClass = ptrType.getStorageClass();
+ int32_t index = 0;
+
+ for (auto indexSSA : indices) {
+ auto cType = resultType.dyn_cast<spirv::CompositeType>();
+ if (!cType) {
+ emitError(baseLoc,
+ "'spv.AccessChain' op cannot extract from non-composite type ")
+ << resultType << " with index " << index;
+ return nullptr;
+ }
+ index = 0;
+ if (resultType.isa<spirv::StructType>()) {
+ Operation *op = indexSSA->getDefiningOp();
+ if (!op) {
+ emitError(baseLoc, "'spv.AccessChain' op index must be an "
+ "integer spv.constant to access "
+ "element of spv.struct");
+ return nullptr;
+ }
+
+ // TODO(denis0x0D): this should be relaxed to allow
+ // integer literals of other bitwidths.
+ if (failed(extractValueFromConstOp(op, index))) {
+ emitError(baseLoc,
+ "'spv.AccessChain' index must be an integer spv.constant to "
+ "access element of spv.struct, but provided ")
+ << op->getName();
+ return nullptr;
+ }
+ if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
+ emitError(baseLoc, "'spv.AccessChain' op index ")
+ << index << " out of bounds for " << resultType;
+ return nullptr;
+ }
+ }
+ resultType = cType.getElementType(index);
+ }
+ return spirv::PointerType::get(resultType, resultStorageClass);
+}
+
+static ParseResult parseAccessChainOp(OpAsmParser *parser,
+ OperationState *state) {
+ OpAsmParser::OperandType ptrInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
+ Type type;
+ // TODO(denis0x0D): regarding to the spec an index must be any integer type,
+ // figure out how to use resolveOperand with a range of types and do not
+ // fail on first attempt.
+ Type indicesType = parser->getBuilder().getIntegerType(32);
+
+ if (parser->parseOperand(ptrInfo) ||
+ parser->parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(ptrInfo, type, state->operands) ||
+ parser->resolveOperands(indicesInfo, indicesType, state->operands)) {
+ return failure();
+ }
+
+ Location baseLoc = state->operands.front()->getLoc();
+ auto resultType = getElementPtrType(
+ type, llvm::makeArrayRef(state->operands).drop_front(), baseLoc);
+ if (!resultType) {
+ return failure();
+ }
+
+ state->addTypes(resultType);
+ return success();
+}
+
+static void print(spirv::AccessChainOp op, OpAsmPrinter *printer) {
+ *printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
+ << '[';
+ printer->printOperands(op.indices());
+ *printer << "] : " << op.base_ptr()->getType();
+}
+
+static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
+ SmallVector<Value *, 4> indices(accessChainOp.indices().begin(),
+ accessChainOp.indices().end());
+ auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(),
+ indices, accessChainOp.getLoc());
+ if (!resultType) {
+ return failure();
+ }
+
+ auto providedResultType =
+ accessChainOp.getType().dyn_cast<spirv::PointerType>();
+ if (!providedResultType) {
+ return accessChainOp.emitOpError(
+ "result type must be a pointer, but provided")
+ << providedResultType;
+ }
+
+ if (resultType != providedResultType) {
+ return accessChainOp.emitOpError("invalid result type: expected ")
+ << resultType << ", but provided " << providedResultType;
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+//===----------------------------------------------------------------------===//
+// spv.AccessChain
+//===----------------------------------------------------------------------===//
+
+func @access_chain_struct() -> () {
+ %0 = spv.constant 1: i32
+ %1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+ // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.struct<f32, !spv.array<4 x f32>>, Function>
+ %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+ return
+}
+
+// -----
+
+func @access_chain_1D_array(%arg0 : i32) -> () {
+ %0 = spv.Variable : !spv.ptr<!spv.array<4xf32>, Function>
+ // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x f32>, Function>
+ %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4xf32>, Function>
+ return
+}
+
+// -----
+
+func @access_chain_2D_array_1(%arg0 : i32) -> () {
+ %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
+ %1 = spv.AccessChain %0[%arg0, %arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ %2 = spv.Load "Function" %1 ["Volatile"] : f32
+ return
+}
+
+// -----
+
+func @access_chain_2D_array_2(%arg0 : i32) -> () {
+ %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
+ %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32>
+ return
+}
+
+// -----
+
+func @access_chain_non_composite() -> () {
+ %0 = spv.constant 1: i32
+ // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
+ %1 = spv.Variable : !spv.ptr<f32, Function>
+ %2 = spv.AccessChain %1[%0] : !spv.ptr<f32, Function>
+ return
+}
+
+// -----
+
+func @access_chain_no_indices(%index0 : i32) -> () {
+ // expected-error @+1 {{expected at least one index}}
+ %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ %1 = spv.AccessChain %0[] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ return
+}
+
+// -----
+
+func @access_chain_invalid_type(%index0 : i32) -> () {
+ %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ // expected-error @+1 {{expected a pointer to composite type, but provided '!spv.array<4 x !spv.array<4 x f32>>'}}
+ %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>>
+ %2 = spv.AccessChain %1[%index0] : !spv.array<4x!spv.array<4xf32>>
+ return
+}
+
+// -----
+
+func @access_chain_invalid_index_1(%index0 : i32) -> () {
+ %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ // expected-error @+1 {{expected SSA operand}}
+ %1 = spv.AccessChain %0[%index, 4] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ return
+}
+
+// -----
+
+func @access_chain_invalid_index_2(%index0 : i32) -> () {
+ // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct}}
+ %0 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+ %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+ return
+}
+
+// -----
+
+func @access_chain_invalid_constant_type_1() -> () {
+ %0 = std.constant 1: i32
+ // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct, but provided std.constant}}
+ %1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+ %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+ return
+}
+
+// -----
+
+func @access_chain_out_of_bounds() -> () {
+ %index0 = "spv.constant"() { value = 12: i32} : () -> i32
+ // expected-error @+1 {{'spv.AccessChain' op index 12 out of bounds for '!spv.struct<f32, !spv.array<4 x f32>>'}}
+ %0 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+ %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+ return
+}
+
+// -----
+
+func @access_chain_invalid_accessing_type(%index0 : i32) -> () {
+ // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
+ %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ %1 = spv.AccessChain %0[%index, %index0, %index0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//