[spirv] AccessChainOp canonicalization.
authorDenis Khalikov <khalikov.denis@huawei.com>
Fri, 25 Oct 2019 01:40:38 +0000 (18:40 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 25 Oct 2019 01:41:34 +0000 (18:41 -0700)
Combine chained `spirv::AccessChainOp` operations into one
`spirv::AccessChainOp` operation.

Closes tensorflow/mlir#198

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/198 from denis0x0D:sandbox/canon_access_chain 0cb87955a85511071143d62637ff939d0dabc2bd
PiperOrigin-RevId: 276609345

mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/canonicalize.mlir

index 7cfb7e3..1ffb352 100644 (file)
@@ -127,6 +127,8 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
 
   let builders = [OpBuilder<[{Builder *builder, OperationState &state,
                               Value *basePtr, ArrayRef<Value *> indices}]>];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
index 44fecf3..e56d9e6 100644 (file)
@@ -544,6 +544,41 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
   return success();
 }
 
+namespace {
+
+// Combine chained `spirv::AccessChainOp` operations into one
+// `spirv::AccessChainOp` operation.
+struct CombineChainedAccessChain
+    : public OpRewritePattern<spirv::AccessChainOp> {
+  using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
+                                     PatternRewriter &rewriter) const override {
+    auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
+        accessChainOp.base_ptr()->getDefiningOp());
+
+    if (!parentAccessChainOp) {
+      return matchFailure();
+    }
+
+    // Combine indices.
+    SmallVector<Value *, 4> indices(parentAccessChainOp.indices());
+    indices.append(accessChainOp.indices().begin(),
+                   accessChainOp.indices().end());
+
+    rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
+        accessChainOp, parentAccessChainOp.base_ptr(), indices);
+
+    return matchSuccess();
+  }
+};
+} // namespace
+
+void spirv::AccessChainOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<CombineChainedAccessChain>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // spv._address_of
 //===----------------------------------------------------------------------===//
@@ -1976,7 +2011,8 @@ namespace {
 //                       | merge block |
 //                       +-------------+
 //
-struct SelectionOpCanonicalizer : public OpRewritePattern<spirv::SelectionOp> {
+struct ConvertSelectionOpToSelect
+    : public OpRewritePattern<spirv::SelectionOp> {
   using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
 
   PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
@@ -2071,7 +2107,7 @@ private:
   }
 };
 
-PatternMatchResult SelectionOpCanonicalizer::canCanonicalizeSelection(
+PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
     Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
   // Each block must consists of 2 operations.
   if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
@@ -2110,7 +2146,7 @@ PatternMatchResult SelectionOpCanonicalizer::canCanonicalizeSelection(
 
 void spirv::SelectionOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<SelectionOpCanonicalizer>(context);
+  results.insert<ConvertSelectionOpToSelect>(context);
 }
 
 //===----------------------------------------------------------------------===//
index a9a6d0f..02d8645 100644 (file)
@@ -1,6 +1,64 @@
 // RUN: mlir-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
 
 //===----------------------------------------------------------------------===//
+// spv.AccsessChain
+//===----------------------------------------------------------------------===//
+
+func @combine_full_access_chain() -> f32 {
+  // CHECK: %[[INDEX:.*]] = spv.constant 0
+  // CHECK-NEXT: %[[VAR:.*]] = spv.Variable
+  // CHECK-NEXT: %[[PTR:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]]
+  // CHECK-NEXT: spv.Load "Function" %[[PTR]]
+  %c0 = spv.constant 0: i32
+  %0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
+  %1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
+  %2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %3 = spv.Load "Function" %2 : f32
+  spv.ReturnValue %3 : f32
+}
+
+// -----
+
+func @combine_access_chain_multi_use() -> !spv.array<4xf32> {
+  // CHECK: %[[INDEX:.*]] = spv.constant 0
+  // CHECK-NEXT: %[[VAR:.*]] = spv.Variable
+  // CHECK-NEXT: %[[PTR_0:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]]]
+  // CHECK-NEXT: %[[PTR_1:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]]
+  // CHECK-NEXT: spv.Load "Function" %[[PTR_0]]
+  // CHECK-NEXT: spv.Load "Function" %[[PTR_1]]
+  %c0 = spv.constant 0: i32
+  %0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
+  %1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
+  %2 = spv.AccessChain %1[%c0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %3 = spv.AccessChain %2[%c0] : !spv.ptr<!spv.array<4xf32>, Function>
+  %4 = spv.Load "Function" %2 : !spv.array<4xf32>
+  %5 = spv.Load "Function" %3 : f32
+  spv.ReturnValue %4: !spv.array<4xf32>
+}
+
+// -----
+
+func @dont_combine_access_chain_without_common_base() -> !spv.array<4xi32> {
+  // CHECK: %[[INDEX:.*]] = spv.constant 1
+  // CHECK-NEXT: %[[VAR_0:.*]] = spv.Variable
+  // CHECK-NEXT: %[[VAR_1:.*]] = spv.Variable
+  // CHECK-NEXT: %[[VAR_0_PTR:.*]] = spv.AccessChain %[[VAR_0]][%[[INDEX]]]
+  // CHECK-NEXT: %[[VAR_1_PTR:.*]] = spv.AccessChain %[[VAR_1]][%[[INDEX]]]
+  // CHECK-NEXT: spv.Load "Function" %[[VAR_0_PTR]]
+  // CHECK-NEXT: spv.Load "Function" %[[VAR_1_PTR]]
+  %c1 = spv.constant 1: i32
+  %0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
+  %1 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
+  %2 = spv.AccessChain %0[%c1] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
+  %3 = spv.AccessChain %1[%c1] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
+  %4 = spv.Load "Function" %2 : !spv.array<4xi32>
+  %5 = spv.Load "Function" %3 : !spv.array<4xi32>
+  spv.ReturnValue %4 : !spv.array<4xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // spv.CompositeExtract
 //===----------------------------------------------------------------------===//