[spirv] Add AccessChainOp operation.
authorDenis Khalikov <dennis.khalikov@gmail.com>
Thu, 25 Jul 2019 22:42:41 +0000 (15:42 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 25 Jul 2019 22:43:12 +0000 (15:43 -0700)
AccessChainOp creates a pointer into a composite object that can be used with
OpLoad and OpStore.

Closes tensorflow/mlir#52

PiperOrigin-RevId: 260035676

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/ops.mlir

index 0702be4cc81290ee4a59601f897c9133ffa6182b..448355036bafe48c24697eaab878a49dcd31c769 100644 (file)
@@ -96,6 +96,7 @@ def SPV_OC_OpFunctionEnd       : I32EnumAttrCase<"OpFunctionEnd", 56>;
 def SPV_OC_OpVariable          : I32EnumAttrCase<"OpVariable", 59>;
 def SPV_OC_OpLoad              : I32EnumAttrCase<"OpLoad", 61>;
 def SPV_OC_OpStore             : I32EnumAttrCase<"OpStore", 62>;
+def SPV_OC_OpAccessChain       : I32EnumAttrCase<"OpAccessChain", 65>;
 def SPV_OC_OpDecorate          : I32EnumAttrCase<"OpDecorate", 71>;
 def SPV_OC_OpCompositeExtract  : I32EnumAttrCase<"OpCompositeExtract", 81>;
 def SPV_OC_OpFMul              : I32EnumAttrCase<"OpFMul", 133>;
@@ -110,7 +111,8 @@ def SPV_OpcodeAttr :
       SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
       SPV_OC_OpConstantNull, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
       SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
-      SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpFMul, SPV_OC_OpReturn
+      SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract,
+      SPV_OC_OpFMul, SPV_OC_OpReturn
       ]> {
     let returnType = "::mlir::spirv::Opcode";
     let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
index 7fd4d64c6c1d39e83e7306de5e29c90f16f5c6b3..e2facd3df1669576c7b331b0000adc795de6afa1 100644 (file)
@@ -43,6 +43,64 @@ include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
 
 // -----
 
+def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
+  let summary = [{
+    Create a pointer into a composite object that can be used with OpLoad
+    and OpStore.
+  }];
+
+  let description = [{
+    Result Type must be an OpTypePointer. Its Type operand must be the type
+    reached by walking the Base’s type hierarchy down to the last provided
+    index in Indexes, and its Storage Class operand must be the same as the
+    Storage Class of Base.
+
+    Base must be a pointer, pointing to the base of a composite object.
+
+    Indexes walk the type hierarchy to the desired depth, potentially down
+    to scalar granularity. The first index in Indexes will select the top-
+    level member/element/component/element of the base composite. All
+    composite constituents use zero-based numbering, as described by their
+    OpType… instruction. The second index will apply similarly to that
+    result, and so on. Once any non-composite type is reached, there must be
+    no remaining (unused) indexes.
+
+     Each index in Indexes
+    - must be a scalar integer type,
+    - is treated as a signed count, and
+    - must be an OpConstant when indexing into a structure.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    access-chain-op ::= ssa-id `=` `spv.AccessChain` ssa-use
+                        `[` ssa-use (',' ssa-use)* `]`
+                        `:` pointer-type
+    ```
+
+    For example:
+
+    ```
+    %0 = "spv.constant"() { value = 1: i32} : () -> i32
+    %1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+    %2 = spv.AccessChain %1[%0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+    %3 = spv.Load "Function" %2 ["Volatile"] : !spv.array<4xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_AnyPtr:$base_ptr,
+    Variadic<SPV_Integer>:$indices
+  );
+
+  let results = (outs
+    SPV_AnyPtr:$component_ptr
+  );
+
+  let autogenSerialization = 0;
+}
+
+// -----
+
 def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
   let summary = "Extract a part of a composite object.";
 
index d84cbf111dca972b77dff615536db126c0333894..76b26e0dbbb92b4b586bdab5909b9e78e1d95fe6 100644 (file)
@@ -57,6 +57,21 @@ inline Dst bitwiseCast(Src source) noexcept {
   return dest;
 }
 
+static LogicalResult extractValueFromConstOp(Operation *op,
+                                             int32_t &indexValue) {
+  auto constOp = llvm::dyn_cast<spirv::ConstantOp>(op);
+  if (!constOp) {
+    return failure();
+  }
+  auto valueAttr = constOp.value();
+  auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
+  if (!integerValueAttr) {
+    return failure();
+  }
+  indexValue = integerValueAttr.getInt();
+  return success();
+}
+
 template <typename EnumClass>
 static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
                                       OperationState *state) {
@@ -199,6 +214,129 @@ static void printNoIOOp(Operation *op, OpAsmPrinter *printer) {
   printer->printOptionalAttrDict(op->getAttrs());
 }
 
+//===----------------------------------------------------------------------===//
+// spv.AccessChainOp
+//===----------------------------------------------------------------------===//
+
+static Type getElementPtrType(Type type, ArrayRef<Value *> indices,
+                              Location baseLoc) {
+  if (!indices.size()) {
+    emitError(baseLoc, "'spv.AccessChain' op expected at least "
+                       "one index ");
+    return nullptr;
+  }
+
+  auto ptrType = type.dyn_cast<spirv::PointerType>();
+  if (!ptrType) {
+    emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
+                       "to composite type, but provided ")
+        << type;
+    return nullptr;
+  }
+
+  auto resultType = ptrType.getPointeeType();
+  auto resultStorageClass = ptrType.getStorageClass();
+  int32_t index = 0;
+
+  for (auto indexSSA : indices) {
+    auto cType = resultType.dyn_cast<spirv::CompositeType>();
+    if (!cType) {
+      emitError(baseLoc,
+                "'spv.AccessChain' op cannot extract from non-composite type ")
+          << resultType << " with index " << index;
+      return nullptr;
+    }
+    index = 0;
+    if (resultType.isa<spirv::StructType>()) {
+      Operation *op = indexSSA->getDefiningOp();
+      if (!op) {
+        emitError(baseLoc, "'spv.AccessChain' op index must be an "
+                           "integer spv.constant to access "
+                           "element of spv.struct");
+        return nullptr;
+      }
+
+      // TODO(denis0x0D): this should be relaxed to allow
+      // integer literals of other bitwidths.
+      if (failed(extractValueFromConstOp(op, index))) {
+        emitError(baseLoc,
+                  "'spv.AccessChain' index must be an integer spv.constant to "
+                  "access element of spv.struct, but provided ")
+            << op->getName();
+        return nullptr;
+      }
+      if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
+        emitError(baseLoc, "'spv.AccessChain' op index ")
+            << index << " out of bounds for " << resultType;
+        return nullptr;
+      }
+    }
+    resultType = cType.getElementType(index);
+  }
+  return spirv::PointerType::get(resultType, resultStorageClass);
+}
+
+static ParseResult parseAccessChainOp(OpAsmParser *parser,
+                                      OperationState *state) {
+  OpAsmParser::OperandType ptrInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
+  Type type;
+  // TODO(denis0x0D): regarding to the spec an index must be any integer type,
+  // figure out how to use resolveOperand with a range of types and do not
+  // fail on first attempt.
+  Type indicesType = parser->getBuilder().getIntegerType(32);
+
+  if (parser->parseOperand(ptrInfo) ||
+      parser->parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperand(ptrInfo, type, state->operands) ||
+      parser->resolveOperands(indicesInfo, indicesType, state->operands)) {
+    return failure();
+  }
+
+  Location baseLoc = state->operands.front()->getLoc();
+  auto resultType = getElementPtrType(
+      type, llvm::makeArrayRef(state->operands).drop_front(), baseLoc);
+  if (!resultType) {
+    return failure();
+  }
+
+  state->addTypes(resultType);
+  return success();
+}
+
+static void print(spirv::AccessChainOp op, OpAsmPrinter *printer) {
+  *printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
+           << '[';
+  printer->printOperands(op.indices());
+  *printer << "] : " << op.base_ptr()->getType();
+}
+
+static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
+  SmallVector<Value *, 4> indices(accessChainOp.indices().begin(),
+                                  accessChainOp.indices().end());
+  auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(),
+                                      indices, accessChainOp.getLoc());
+  if (!resultType) {
+    return failure();
+  }
+
+  auto providedResultType =
+      accessChainOp.getType().dyn_cast<spirv::PointerType>();
+  if (!providedResultType) {
+    return accessChainOp.emitOpError(
+               "result type must be a pointer, but provided")
+           << providedResultType;
+  }
+
+  if (resultType != providedResultType) {
+    return accessChainOp.emitOpError("invalid result type: expected ")
+           << resultType << ", but provided " << providedResultType;
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.CompositeExtractOp
 //===----------------------------------------------------------------------===//
index e3056c9f2c63c009a79ae766a6cbe5a5221c1839..7da21b91849e027b8ccf1071cbf56ac299d3e0a7 100644 (file)
@@ -1,5 +1,124 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// spv.AccessChain
+//===----------------------------------------------------------------------===//
+
+func @access_chain_struct() -> () {
+  %0 = spv.constant 1: i32
+  %1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.struct<f32, !spv.array<4 x f32>>, Function>
+  %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  return
+}
+
+// -----
+
+func @access_chain_1D_array(%arg0 : i32) -> () {
+  %0 = spv.Variable : !spv.ptr<!spv.array<4xf32>, Function>
+  // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x f32>, Function>
+  %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4xf32>, Function>
+  return
+}
+
+// -----
+
+func @access_chain_2D_array_1(%arg0 : i32) -> () {
+  %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
+  %1 = spv.AccessChain %0[%arg0, %arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %2 = spv.Load "Function" %1 ["Volatile"] : f32
+  return
+}
+
+// -----
+
+func @access_chain_2D_array_2(%arg0 : i32) -> () {
+  %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
+  %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32>
+  return
+}
+
+// -----
+
+func @access_chain_non_composite() -> () {
+  %0 = spv.constant 1: i32
+  // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
+  %1 = spv.Variable : !spv.ptr<f32, Function>
+  %2 = spv.AccessChain %1[%0] : !spv.ptr<f32, Function>
+  return
+}
+
+// -----
+
+func @access_chain_no_indices(%index0 : i32) -> () {
+  // expected-error @+1 {{expected at least one index}}
+  %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  return
+}
+
+// -----
+
+func @access_chain_invalid_type(%index0 : i32) -> () {
+  %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  // expected-error @+1 {{expected a pointer to composite type, but provided '!spv.array<4 x !spv.array<4 x f32>>'}}
+  %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>>
+  %2 = spv.AccessChain %1[%index0] : !spv.array<4x!spv.array<4xf32>>
+  return
+}
+
+// -----
+
+func @access_chain_invalid_index_1(%index0 : i32) -> () {
+   %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  // expected-error @+1 {{expected SSA operand}}
+  %1 = spv.AccessChain %0[%index, 4] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  return
+}
+
+// -----
+
+func @access_chain_invalid_index_2(%index0 : i32) -> () {
+  // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct}}
+  %0 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  return
+}
+
+// -----
+
+func @access_chain_invalid_constant_type_1() -> () {
+  %0 = std.constant 1: i32
+  // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct, but provided std.constant}}
+  %1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  return
+}
+
+// -----
+
+func @access_chain_out_of_bounds() -> () {
+  %index0 = "spv.constant"() { value = 12: i32} : () -> i32
+  // expected-error @+1 {{'spv.AccessChain' op index 12 out of bounds for '!spv.struct<f32, !spv.array<4 x f32>>'}}
+  %0 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  return
+}
+
+// -----
+
+func @access_chain_invalid_accessing_type(%index0 : i32) -> () {
+  // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
+  %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[%index, %index0, %index0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.CompositeExtractOp
 //===----------------------------------------------------------------------===//