From dd2e444325ddf047f224524dfc7d1881aa1051e8 Mon Sep 17 00:00:00 2001 From: Denis Khalikov Date: Thu, 24 Oct 2019 18:40:38 -0700 Subject: [PATCH] [spirv] AccessChainOp canonicalization. Combine chained `spirv::AccessChainOp` operations into one `spirv::AccessChainOp` operation. Closes tensorflow/mlir#198 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/198 from denis0x0D:sandbox/canon_access_chain 0cb87955a85511071143d62637ff939d0dabc2bd PiperOrigin-RevId: 276609345 --- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 2 + mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 42 +++++++++++++++++++-- mlir/test/Dialect/SPIRV/canonicalize.mlir | 58 +++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 7cfb7e3..1ffb352 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -127,6 +127,8 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> { let builders = [OpBuilder<[{Builder *builder, OperationState &state, Value *basePtr, ArrayRef indices}]>]; + + let hasCanonicalizer = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 44fecf3..e56d9e6 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -544,6 +544,41 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) { return success(); } +namespace { + +// Combine chained `spirv::AccessChainOp` operations into one +// `spirv::AccessChainOp` operation. +struct CombineChainedAccessChain + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp, + PatternRewriter &rewriter) const override { + auto parentAccessChainOp = dyn_cast_or_null( + accessChainOp.base_ptr()->getDefiningOp()); + + if (!parentAccessChainOp) { + return matchFailure(); + } + + // Combine indices. + SmallVector indices(parentAccessChainOp.indices()); + indices.append(accessChainOp.indices().begin(), + accessChainOp.indices().end()); + + rewriter.replaceOpWithNewOp( + accessChainOp, parentAccessChainOp.base_ptr(), indices); + + return matchSuccess(); + } +}; +} // namespace + +void spirv::AccessChainOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // spv._address_of //===----------------------------------------------------------------------===// @@ -1976,7 +2011,8 @@ namespace { // | merge block | // +-------------+ // -struct SelectionOpCanonicalizer : public OpRewritePattern { +struct ConvertSelectionOpToSelect + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp, @@ -2071,7 +2107,7 @@ private: } }; -PatternMatchResult SelectionOpCanonicalizer::canCanonicalizeSelection( +PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection( Block *trueBlock, Block *falseBlock, Block *mergeBlock) const { // Each block must consists of 2 operations. if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) || @@ -2110,7 +2146,7 @@ PatternMatchResult SelectionOpCanonicalizer::canCanonicalizeSelection( void spirv::SelectionOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir index a9a6d0f..02d8645 100644 --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -1,6 +1,64 @@ // RUN: mlir-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s //===----------------------------------------------------------------------===// +// spv.AccsessChain +//===----------------------------------------------------------------------===// + +func @combine_full_access_chain() -> f32 { + // CHECK: %[[INDEX:.*]] = spv.constant 0 + // CHECK-NEXT: %[[VAR:.*]] = spv.Variable + // CHECK-NEXT: %[[PTR:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] + // CHECK-NEXT: spv.Load "Function" %[[PTR]] + %c0 = spv.constant 0: i32 + %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> + %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function> + %2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr>, Function> + %3 = spv.Load "Function" %2 : f32 + spv.ReturnValue %3 : f32 +} + +// ----- + +func @combine_access_chain_multi_use() -> !spv.array<4xf32> { + // CHECK: %[[INDEX:.*]] = spv.constant 0 + // CHECK-NEXT: %[[VAR:.*]] = spv.Variable + // CHECK-NEXT: %[[PTR_0:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]]] + // CHECK-NEXT: %[[PTR_1:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] + // CHECK-NEXT: spv.Load "Function" %[[PTR_0]] + // CHECK-NEXT: spv.Load "Function" %[[PTR_1]] + %c0 = spv.constant 0: i32 + %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> + %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function> + %2 = spv.AccessChain %1[%c0] : !spv.ptr>, Function> + %3 = spv.AccessChain %2[%c0] : !spv.ptr, Function> + %4 = spv.Load "Function" %2 : !spv.array<4xf32> + %5 = spv.Load "Function" %3 : f32 + spv.ReturnValue %4: !spv.array<4xf32> +} + +// ----- + +func @dont_combine_access_chain_without_common_base() -> !spv.array<4xi32> { + // CHECK: %[[INDEX:.*]] = spv.constant 1 + // CHECK-NEXT: %[[VAR_0:.*]] = spv.Variable + // CHECK-NEXT: %[[VAR_1:.*]] = spv.Variable + // CHECK-NEXT: %[[VAR_0_PTR:.*]] = spv.AccessChain %[[VAR_0]][%[[INDEX]]] + // CHECK-NEXT: %[[VAR_1_PTR:.*]] = spv.AccessChain %[[VAR_1]][%[[INDEX]]] + // CHECK-NEXT: spv.Load "Function" %[[VAR_0_PTR]] + // CHECK-NEXT: spv.Load "Function" %[[VAR_1_PTR]] + %c1 = spv.constant 1: i32 + %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> + %1 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> + %2 = spv.AccessChain %0[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function> + %3 = spv.AccessChain %1[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function> + %4 = spv.Load "Function" %2 : !spv.array<4xi32> + %5 = spv.Load "Function" %3 : !spv.array<4xi32> + spv.ReturnValue %4 : !spv.array<4xi32> +} + +// ----- + +//===----------------------------------------------------------------------===// // spv.CompositeExtract //===----------------------------------------------------------------------===// -- 2.7.4