namespace mlir {
class SPIRVTypeConverter;
+
/// Appends to a pattern list additional patterns for translating StandardOps to
-/// SPIR-V ops.
+/// SPIR-V ops. Also adds the patterns legalize ops not directly translated to
+/// SPIR-V dialect.
void populateStandardToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
+/// Appends to a pattern list patterns to legalize ops that are not directly
+/// lowered to SPIR-V.
+void populateStdLegalizationPatternsForSPIRVLowering(
+ MLIRContext *context, OwningRewritePatternList &patterns);
+
} // namespace mlir
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
#include "mlir/Pass/Pass.h"
namespace mlir {
+
/// Pass to convert StandardOps to SPIR-V ops.
std::unique_ptr<OpPassBase<ModuleOp>> createConvertStandardToSPIRVPass();
+
+/// Pass to legalize ops that are not directly lowered to SPIR-V.
+std::unique_ptr<Pass> createLegalizeStdOpsForSPIRVLoweringPass();
+
} // namespace mlir
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H
/// Returns the dynamic sizes for this subview operation if specified.
operand_range getDynamicSizes() { return sizes(); }
+ /// Returns in `staticStrides` the static value of the stride
+ /// operands. Returns failure() if the static value of the stride
+ /// operands could not be retrieved.
+ LogicalResult getStaticStrides(SmallVectorImpl<int64_t> &staticStrides);
+
// Auxiliary range data structure and helper function that unpacks the
// offset, size and stride operands of the SubViewOp into a list of triples.
// Such a list of triple is sometimes more convenient to manipulate.
add_llvm_library(MLIRStandardToSPIRVTransforms
ConvertStandardToSPIRV.cpp
ConvertStandardToSPIRVPass.cpp
+ LegalizeStandardForSPIRV.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
void populateStandardToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
+ // Add patterns that lower operations into SPIR-V dialect.
populateWithGenerated(context, &patterns);
- // Add the return op conversion.
patterns
.insert<ConstantIndexOpConversion, CmpIOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
--- /dev/null
+//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This transformation pass legalizes operations before the conversion to SPIR-V
+// dialect to handle ops that cannot be lowered directly.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// Merges subview operation with load operation.
+class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> {
+public:
+ using OpRewritePattern<LoadOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(LoadOp loadOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Merges subview operation with store operation.
+class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> {
+public:
+ using OpRewritePattern<StoreOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(StoreOp storeOp,
+ PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Utility functions for op legalization.
+//===----------------------------------------------------------------------===//
+
+/// Given the 'indices' of an load/store operation where the memref is a result
+/// of a subview op, returns the indices w.r.t to the source memref of the
+/// subview op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
+/// memref<4x4xf32, offset=?, strides=[?, ?]>
+/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
+///
+/// could be folded into
+///
+/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
+/// memref<12x42xf32>
+static LogicalResult
+resolveSourceIndices(Location loc, PatternRewriter &rewriter,
+ SubViewOp subViewOp, ArrayRef<Value *> indices,
+ SmallVectorImpl<Value *> &sourceIndices) {
+ // TODO: Aborting when the offsets are static. There might be a way to fold
+ // the subview op with load even if the offsets have been canonicalized
+ // away.
+ if (subViewOp.getNumOffsets() == 0)
+ return failure();
+
+ SmallVector<Value *, 2> opOffsets = llvm::to_vector<2>(subViewOp.offsets());
+ SmallVector<Value *, 2> opStrides;
+ if (subViewOp.getNumStrides()) {
+ // If the strides are dynamic, get the stride operands.
+ opStrides = llvm::to_vector<2>(subViewOp.strides());
+ } else {
+ // When static, the stride operands can be retrieved by taking the strides
+ // of the result of the subview op, and dividing the strides of the base
+ // memref.
+ SmallVector<int64_t, 2> staticStrides;
+ if (failed(subViewOp.getStaticStrides(staticStrides))) {
+ return failure();
+ }
+ opStrides.reserve(opOffsets.size());
+ for (auto stride : staticStrides) {
+ auto constValAttr = rewriter.getIntegerAttr(
+ IndexType::get(rewriter.getContext()), stride);
+ opStrides.emplace_back(rewriter.create<ConstantOp>(loc, constValAttr));
+ }
+ }
+ assert(opOffsets.size() == opStrides.size());
+
+ // New indices for the load are the current indices * subview_stride +
+ // subview_offset.
+ assert(indices.size() == opStrides.size());
+ sourceIndices.resize(indices.size());
+ for (auto index : enumerate(indices)) {
+ auto offset = opOffsets[index.index()];
+ auto stride = opStrides[index.index()];
+ auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
+ sourceIndices[index.index()] =
+ rewriter.create<AddIOp>(loc, offset, mul).getResult();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Folding SubViewOp and LoadOp.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
+ PatternRewriter &rewriter) const {
+ auto subViewOp =
+ dyn_cast_or_null<SubViewOp>(loadOp.memref()->getDefiningOp());
+ if (!subViewOp) {
+ return matchFailure();
+ }
+ SmallVector<Value *, 4> sourceIndices,
+ indices = llvm::to_vector<4>(loadOp.indices());
+ if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, indices,
+ sourceIndices)))
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
+ sourceIndices);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Folding SubViewOp and StoreOp.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
+ PatternRewriter &rewriter) const {
+ auto subViewOp =
+ dyn_cast_or_null<SubViewOp>(storeOp.memref()->getDefiningOp());
+ if (!subViewOp) {
+ return matchFailure();
+ }
+ SmallVector<Value *, 4> sourceIndices,
+ indices = llvm::to_vector<4>(storeOp.indices());
+ if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
+ indices, sourceIndices)))
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
+ subViewOp.source(), sourceIndices);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Hook for adding patterns.
+//===----------------------------------------------------------------------===//
+
+void mlir::populateStdLegalizationPatternsForSPIRVLowering(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<LoadOpOfSubViewFolder, StoreOpOfSubViewFolder>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// Pass for testing just the legalization patterns.
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct SPIRVLegalization final : public OperationPass<SPIRVLegalization> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void SPIRVLegalization::runOnOperation() {
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
+ applyPatternsGreedily(getOperation()->getRegions(), patterns);
+}
+
+std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
+ return std::make_unique<SPIRVLegalization>();
+}
+
+static PassRegistration<SPIRVLegalization>
+ pass("legalize-std-for-spirv", "Legalize standard ops for SPIR-V lowering");
return res;
}
+LogicalResult
+SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
+ // If the strides are dynamic return failure.
+ if (getNumStrides())
+ return failure();
+
+ // When static, the stride operands can be retrieved by taking the strides of
+ // the result of the subview op, and dividing the strides of the base memref.
+ int64_t resultOffset, baseOffset;
+ SmallVector<int64_t, 2> resultStrides, baseStrides;
+ if (failed(
+ getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) ||
+ llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) ||
+ failed(getStridesAndOffset(getType(), resultStrides, resultOffset)))
+ return failure();
+
+ assert(static_cast<int64_t>(resultStrides.size()) == getType().getRank() &&
+ baseStrides.size() == resultStrides.size() &&
+ "base and result memrefs must have the same rank");
+ assert(!llvm::is_contained(resultStrides,
+ MemRefType::getDynamicStrideOrOffset()) &&
+ "strides of subview op must be static, when there are no dynamic "
+ "strides specified");
+ staticStrides.resize(getType().getRank());
+ for (auto resultStride : enumerate(resultStrides)) {
+ auto baseStride = baseStrides[resultStride.index()];
+ // The result stride is expected to be a multiple of the base stride. Abort
+ // if that is not the case.
+ if (resultStride.value() < baseStride ||
+ resultStride.value() % baseStride != 0)
+ return failure();
+ staticStrides[resultStride.index()] = resultStride.value() / baseStride;
+ }
+ return success();
+}
+
static bool hasConstantOffsetSizesAndStrides(MemRefType memrefType) {
if (memrefType.getNumDynamicDims() > 0)
return false;
--- /dev/null
+// RUN: mlir-opt -legalize-std-for-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+// CHECK-LABEL: @fold_static_stride_subview_with_load
+// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index
+func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) {
+ // CHECK-NOT: subview
+ // CHECK: [[C2:%.*]] = constant 2 : index
+ // CHECK: [[C3:%.*]] = constant 3 : index
+ // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
+ // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
+ // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
+ // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
+ // CHECK: load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+ %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+ %1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
+ return
+}
+
+// CHECK-LABEL: @fold_dynamic_stride_subview_with_load
+// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index
+func @fold_dynamic_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) {
+ // CHECK-NOT: subview
+ // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[ARG5]] : index
+ // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
+ // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index
+ // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
+ // CHECK: load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+ %0 = subview %arg0[%arg1, %arg2][][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
+ %1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
+ return
+}
+
+// CHECK-LABEL: @fold_static_stride_subview_with_store
+// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: f32
+func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
+ // CHECK-NOT: subview
+ // CHECK: [[C2:%.*]] = constant 2 : index
+ // CHECK: [[C3:%.*]] = constant 3 : index
+ // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
+ // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
+ // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
+ // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
+ // CHECK: store [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+ %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+ store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
+ return
+}
+
+// CHECK-LABEL: @fold_dynamic_stride_subview_with_store
+// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index, [[ARG7:%.*]]: f32
+func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : f32) {
+ // CHECK-NOT: subview
+ // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[ARG5]] : index
+ // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
+ // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index
+ // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
+ // CHECK: store [[ARG7]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+ %0 = subview %arg0[%arg1, %arg2][][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
+ store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
+ return
+}
--- /dev/null
+// RUN: mlir-opt -legalize-std-for-spirv -convert-std-to-spirv %s -o - | FileCheck %s
+
+// TODO: For these examples running these passes separately produces
+// the desired output. Adding all of patterns within a single pass does
+// not seem to work.
+
+//===----------------------------------------------------------------------===//
+// std.subview
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @fold_static_stride_subview_with_load
+// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<384 x f32 [4]> [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32
+func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) {
+ // CHECK: [[C2:%.*]] = spv.constant 2
+ // CHECK: [[C3:%.*]] = spv.constant 3
+ // CHECK: [[T2:%.*]] = spv.IMul [[ARG3]], [[C2]]
+ // CHECK: [[T3:%.*]] = spv.IAdd [[ARG1]], [[T2]]
+ // CHECK: [[T4:%.*]] = spv.IMul [[ARG4]], [[C3]]
+ // CHECK: [[T5:%.*]] = spv.IAdd [[ARG2]], [[T4]]
+ // CHECK: [[C32:%.*]] = spv.constant 32
+ // CHECK: [[T7:%.*]] = spv.IMul [[C32]], [[T3]]
+ // CHECK: [[C1:%.*]] = spv.constant 1
+ // CHECK: [[T9:%.*]] = spv.IMul [[C1]], [[T5]]
+ // CHECK: [[T10:%.*]] = spv.IAdd [[T7]], [[T9]]
+ // CHECK: [[C0:%.*]] = spv.constant 0
+ // CHECK: [[T12:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[C0]], [[T10]]
+ // CHECK: spv.Load "StorageBuffer" [[T12]] : f32
+ %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+ %1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
+ return
+}
+
+// CHECK-LABEL: @fold_static_stride_subview_with_store
+// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<384 x f32 [4]> [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: f32
+func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
+ // CHECK: [[C2:%.*]] = spv.constant 2
+ // CHECK: [[C3:%.*]] = spv.constant 3
+ // CHECK: [[T2:%.*]] = spv.IMul [[ARG3]], [[C2]]
+ // CHECK: [[T3:%.*]] = spv.IAdd [[ARG1]], [[T2]]
+ // CHECK: [[T4:%.*]] = spv.IMul [[ARG4]], [[C3]]
+ // CHECK: [[T5:%.*]] = spv.IAdd [[ARG2]], [[T4]]
+ // CHECK: [[C32:%.*]] = spv.constant 32
+ // CHECK: [[T7:%.*]] = spv.IMul [[C32]], [[T3]]
+ // CHECK: [[C1:%.*]] = spv.constant 1
+ // CHECK: [[T9:%.*]] = spv.IMul [[C1]], [[T5]]
+ // CHECK: [[T10:%.*]] = spv.IAdd [[T7]], [[T9]]
+ // CHECK: [[C0:%.*]] = spv.constant 0
+ // CHECK: [[T12:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[C0]], [[T10]]
+ // CHECK: spv.Store "StorageBuffer" [[T12]], [[ARG5]] : f32
+ %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+ store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
+ return
+}