//===----------------------------------------------------------------------===//
using LinalgLoops = SmallVector<Operation *, 4>;
-/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops.
+/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops.
void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes);
-/// Populates patterns for vectorizing low-D convolution ops. This is a step in
+/// Populate patterns for vectorizing low-D convolution ops. This is a step in
/// progressive lowering for convolution ops, it assume high-D convolution ops
/// were decomposed previously.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
/// canonicalizations of named ops into another named op.
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
-/// Populates the given list with patterns to bufferize linalg ops.
+/// Populate the given list with patterns to bufferize linalg ops.
void populateLinalgBufferizePatterns(
bufferization::BufferizeTypeConverter &converter,
RewritePatternSet &patterns);
return *this;
}
- /// Function that allows the caller to control when to stop fusion. Once a
+ /// Function to allow the caller to control when to stop fusion. Once a
/// producer is deemed fusable with the consumer (structurally), this callback
/// can be used to abort the fusion based on non-structural constraints. This
/// is the hook for cost models to control the amount of fusion done.
/// more fusion opportunities.
void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
-/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
+/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify
/// the new ordering of the loop nest. The length of `interchangeVector`
/// An empty vector is interpreted as the identity permutation and the
/// transformation returns early.
///
-/// Returns a struct containing the tiled loops in the specified order
+/// Return a struct containing the tiled loops in the specified order
/// and the cloned op if successful, llvm::None otherwise.
///
/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
const LinalgDependenceGraph &dependenceGraph,
const LinalgTilingOptions &tilingOptions);
-/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts
+/// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts
/// the index accesses of `op`. This is an in-place transformation controlled by
/// `interchangeVector`. An empty vector is interpreted as the identity
/// permutation and the transformation returns early.
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`op.rank` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
-void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
- ArrayRef<unsigned> interchangeVector);
+FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp,
+ ArrayRef<unsigned> interchangeVector);
-/// Creates a GenericOp from the given named operation `namedOp`. Assumes
-/// `namedOp` is not a GenericOp and has a region builder.
-GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
+/// Create a GenericOp from the given named operation `namedOp` and replace
+/// namedOp.
+/// Return failure if `namedOp` is a GenericOp or misses a region builder.
+FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
+ LinalgOp namedOp);
/// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
}
};
-/// Creates a new buffer using the `allocationFn` provided. The size of this
+/// Create a new buffer using the `allocationFn` provided. The size of this
/// buffer is the smallest constant bounding size along each dimension that can
/// be computed for the size of the result of `subView`. Returns the allocated
/// buffer as `fullLocalView` and the view that matches the size of the result
const AllocBufferCallbackFn &allocationFn,
DataLayout &layout);
-/// Promotes the `subViews` into a new buffer allocated at the insertion point
+/// Promote the `subViews` into a new buffer allocated at the insertion point
/// `b`. Promotion occurs in 3 steps:
/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
/// 2. Take a full view on the buffer.
/// Infers statically sized buffers from subViews unless `dynamicBuffers` is
/// true.
///
-/// Returns the modified linalg op (the modification happens in place) as well
+/// Return the modified linalg op (the modification happens in place) as well
/// as all the copy ops created.
FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
const LinalgPromotionOptions &options);
/// Emit a suitable vector form for a Linalg op with fully static shape.
-LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
- SmallVectorImpl<Value> &newResults);
+LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp);
-/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
+/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
-/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
+/// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
-/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
+/// Emit a loop nest of `affine.for` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
// Preconditions that ensure the corresponding transformation succeeds and can
// be applied as a rewrite pattern.
//===----------------------------------------------------------------------===//
-/// Emits a `generic` operation with the `indexing_maps` and `iterator_types`
-/// permutated according to `permutation`.
-LogicalResult
-interchangeGenericOpPrecondition(GenericOp genericOp,
- ArrayRef<unsigned> interchangeVector);
-
-/// Generalize named operations to generic operations.
-LogicalResult generalizeNamedOpPrecondition(Operation *op);
-
-/// Promote std.subviews feeding linalg operations.
+/// Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult promoteSubviewsPrecondition(Operation *op,
LinalgPromotionOptions options);
-/// Return success if the operation can be vectorized.
-LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
-
-/// Return success if `op` can be vectorized assuming it is static. This allows
-/// checking if an op will be vectorizable once all the dimensions are folded to
-/// static values.
-/// It is the same as `vectorizeLinalgOpPrecondition` for static shapes.
-LogicalResult vectorizeStaticLinalgOpPrecondition(LinalgOp op);
-
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
-/// Base pattern that applied the tiling transformation specified by `options`.
+/// Base pattern that applies the tiling transformation specified by `options`.
/// Abort and return failure in 2 cases:
/// 1. if the tiling specification is invalid and tiling fails to occur.
/// 2. if tiling occurs but `options.paddingValueComputationFunction` is set
};
///
-/// Linalg generic interchage pattern.
+/// Linalg generic interchange pattern.
///
-/// Apply the `interchange` transformation as a pattern.
+/// Apply the `interchange` transformation on a RewriterBase.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `interchange` for more details.
struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
///
/// Linalg vectorization patterns.
///
-/// Apply the `vectorizeLinalgOp` transformation as a pattern.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `vectorizeLinalgOp` for more details.
-
/// Empty for now, used for SFINAE purposes only.
struct LinalgVectorizationOptions {};
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `vectorizeLinalgOp` for more details.
struct LinalgBaseVectorizationPattern : public RewritePattern {
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
LinalgBaseVectorizationPattern(MLIRContext *context,
using namespace mlir;
using namespace mlir::linalg;
-LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
+static LogicalResult generalizeNamedOpPrecondition(Operation *op) {
LinalgOp namedOp = dyn_cast<LinalgOp>(op);
// Check if the operation is a LinalgOp but not a GenericOp.
if (!namedOp || isa<GenericOp>(op))
return success();
}
-GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
- LinalgOp namedOp) {
+FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
+ LinalgOp namedOp) {
+ if (failed(generalizeNamedOpPrecondition(namedOp)))
+ return rewriter.notifyMatchFailure(namedOp, "preconditions not met");
+
SmallVector<Value> inputOperands = namedOp.getInputOperands();
SmallVector<Value> outputOperands = namedOp.getOutputOperands();
SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
outputOperands, indexingMaps, iterators);
rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
genericOp.region().begin());
+ rewriter.replaceOp(namedOp, genericOp->getResults());
return genericOp;
}
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
using namespace mlir;
using namespace mlir::linalg;
-LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
- GenericOp genericOp, ArrayRef<unsigned> interchangeVector) {
+static LogicalResult
+interchangeGenericOpPrecondition(GenericOp genericOp,
+ ArrayRef<unsigned> interchangeVector) {
// Interchange vector must be non-empty and match the number of loops.
if (interchangeVector.empty() ||
genericOp.getNumLoops() != interchangeVector.size())
return success();
}
-void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
- GenericOp genericOp,
- ArrayRef<unsigned> interchangeVector) {
- // 1. Compute the inverse permutation map.
+FailureOr<GenericOp>
+mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+ ArrayRef<unsigned> interchangeVector) {
+ if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
+ return rewriter.notifyMatchFailure(genericOp, "preconditions not met");
+
+ // 1. Compute the inverse permutation map, it must be non-null since the
+ // preconditions are satisfied.
MLIRContext *context = genericOp.getContext();
AffineMap permutationMap = inversePermutation(
AffineMap::getPermutationMap(interchangeVector, context));
- assert(permutationMap && "expected permutation to be invertible");
- assert(interchangeVector.size() == genericOp.getNumLoops() &&
- "expected interchange vector to have entry for every loop");
+ assert(permutationMap && "unexpected null map");
+
+ // Start a guarded inplace update.
+ rewriter.startRootUpdate(genericOp);
+ auto guard =
+ llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); });
// 2. Compute the interchanged indexing maps.
- SmallVector<Attribute, 4> newIndexingMaps;
+ SmallVector<AffineMap> newIndexingMaps;
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
AffineMap m = genericOp.getTiedIndexingMap(opOperand);
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
- newIndexingMaps.push_back(AffineMapAttr::get(m));
+ newIndexingMaps.push_back(m);
}
genericOp->setAttr(getIndexingMapsAttrName(),
- ArrayAttr::get(context, newIndexingMaps));
+ rewriter.getAffineMapArrayAttr(newIndexingMaps));
// 3. Compute the interchanged iterator types.
ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
- SmallVector<Attribute, 4> itTypesVector;
+ SmallVector<Attribute> itTypesVector;
llvm::append_range(itTypesVector, itTypes);
SmallVector<int64_t> permutation(interchangeVector.begin(),
interchangeVector.end());
indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices);
}
}
+
+ return genericOp;
}
struct LinalgNamedOpConversionPass
: public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
LinalgNamedOpConversionPass() = default;
- LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {}
+ LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default;
void runOnOperation() override {
Operation *op = getOperation();
GenericOp genericOp, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, genericOp)))
return failure();
- if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
+
+ FailureOr<GenericOp> transformedOp =
+ interchangeGenericOp(rewriter, genericOp, interchangeVector);
+ if (failed(transformedOp))
return failure();
- // TODO: figure out how this interplays with named ops. In particular this
- // should break the named op property.
- rewriter.updateRootInPlace(genericOp, [&]() {
- interchangeGenericOp(rewriter, genericOp, interchangeVector);
- // New filter if specified.
- filter.replaceLinalgTransformationFilter(rewriter, genericOp);
- });
+ // New filter if specified.
+ filter.replaceLinalgTransformationFilter(rewriter, genericOp);
return success();
}
Operation *op, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
- if (failed(generalizeNamedOpPrecondition(op)))
+ FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, op);
+ if (failed(genericOp))
return failure();
-
- GenericOp genericOp = generalizeNamedOp(rewriter, op);
- rewriter.replaceOp(op, genericOp.getResults());
- filter.replaceLinalgTransformationFilter(rewriter, genericOp);
+ filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
return success();
}
LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
+ // TODO: Interface-based rewrite.
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
- if (failed(filter.checkAndNotify(rewriter, linalgOp)))
- return failure();
- SmallVector<Value> newResults;
- if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
+ if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
- if (!newResults.empty())
- rewriter.replaceOp(op, newResults);
- else
- rewriter.eraseOp(op);
- return success();
+ return vectorize(rewriter, linalgOp);
}
LogicalResult mlir::linalg::applyStagedPatterns(
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
}
-/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
-/// with pad_val) and GenericOp (to copy contents).
+/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to
+/// initialize with pad_val) and GenericOp (to copy contents).
LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
return success();
}
-LogicalResult
-mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
+static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
if (isElementwise(op))
return success();
// TODO: isaConvolutionOpInterface that can also infer from generic features.
return success();
}
-LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
- auto linalgOp = cast<linalg::LinalgOp>(op);
+static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
// All types must be static shape to go to vector.
if (linalgOp.hasDynamicShape()) {
LDBG("precondition failed: dynamic shape");
return vectorizeStaticLinalgOpPrecondition(linalgOp);
}
-LogicalResult
-mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
- SmallVectorImpl<Value> &newResults) {
- if (failed(vectorizeLinalgOpPrecondition(op)))
+LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
+ LinalgOp linalgOp) {
+ if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
return failure();
- auto linalgOp = cast<LinalgOp>(op);
-
- // TODO: isaConvolutionOpInterface that can also infer from generic features.
- // But we will still need stride/dilation attributes that will be annoying to
- // reverse-engineer...
- if (auto convOp = dyn_cast<ConvolutionOpInterface>(op)) {
- FailureOr<Operation *> resultOrFail = vectorizeConvolution(b, convOp);
- if (failed(resultOrFail))
+ SmallVector<Value> results;
+ // TODO: isaConvolutionOpInterface that can also infer from generic
+ // features. Will require stride/dilation attributes inference.
+ if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
+ LDBG("Vectorize as a conv: " << linalgOp);
+ FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
+ if (failed(convOr))
+ return failure();
+ llvm::append_range(results, (*convOr)->getResults());
+ } else {
+ LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
+ if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
return failure();
- Operation *newOp = *resultOrFail;
- llvm::append_range(newResults, newOp->getResults());
- return success();
}
- LDBG(""
- << "Vectorize linalg op as a generic by broadcasting to "
- "maximal common shape: "
- << *op);
- return vectorizeAsLinalgGeneric(b, linalgOp, newResults);
+ if (!results.empty())
+ rewriter.replaceOp(linalgOp, results);
+ else
+ rewriter.eraseOp(linalgOp);
+
+ return success();
}
//----------------------------------------------------------------------------//
return attr.cast<IntegerAttr>().getInt();
}
-/// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs
-/// are converted to ConstantIndexOps. Other attribute types are not supported.
+/// Given an ArrayRef of OpFoldResults, return a vector of Values.
+/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
+/// not supported.
static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> ofrs) {
SmallVector<Value> result;
GenericPadTensorOpVectorizationPattern(MLIRContext *context,
PatternBenefit benefit = 1)
: GeneralizePadTensorOpPattern(context, tryVectorizeCopy, benefit) {}
- /// Vectorize the copying of a PadTensorOp's source. This is possible if each
- /// dimension size is statically know in the source type or the result type
- /// (or both).
+ /// Vectorize the copying of a PadTensorOp's source. This is possible if
+ /// each dimension size is statically know in the source type or the result
+ /// type (or both).
static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter,
PadTensorOp padOp, Value dest) {
auto sourceType = padOp.getSourceType();
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
if (!sourceType.isDynamicDim(i)) {
vecShape.push_back(sourceType.getDimSize(i));
- // Source shape is statically known: Neither read nor write are out-of-
- // bounds.
+ // Source shape is statically known: Neither read nor write are
+ // out-of- bounds.
readInBounds.push_back(true);
writeInBounds.push_back(true);
} else if (!resultType.isDynamicDim(i)) {
- // Source shape is not statically known, but result shape is. Vectorize
- // with size of result shape. This may be larger than the source size.
+ // Source shape is not statically known, but result shape is.
+ // Vectorize with size of result shape. This may be larger than the
+ // source size.
vecShape.push_back(resultType.getDimSize(i));
// Read may be out-of-bounds because the result size could be larger
// than the source size.
padOp.getLoc(), vecType, padOp.source(), readIndices, padValue,
ArrayRef<bool>{readInBounds});
- // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
- // tensor, write directly to the FillOp's operand.
+ // If `dest` is a FillOp and the TransferWriteOp would overwrite the
+ // entire tensor, write directly to the FillOp's operand.
if (llvm::equal(vecShape, resultType.getShape()) &&
llvm::all_of(writeInBounds, [](bool b) { return b; }))
if (auto fill = dest.getDefiningOp<FillOp>())
}
};
-/// Base pattern for rewriting PadTensorOps whose result is consumed by a given
-/// operation type OpTy.
+/// Base pattern for rewriting PadTensorOps whose result is consumed by a
+/// given operation type OpTy.
template <typename OpTy>
struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
};
/// Rewrite use of PadTensorOp result in TransferWriteOp.
-/// This pattern rewrites TransferWriteOps that write to a padded tensor value,
-/// where the same amount of padding is immediately removed again after the
-/// write. In such cases, the TransferWriteOp can write to the non-padded tensor
-/// value and apply out-of-bounds masking. E.g.:
+/// This pattern rewrites TransferWriteOps that write to a padded tensor
+/// value, where the same amount of padding is immediately removed again after
+/// the write. In such cases, the TransferWriteOp can write to the non-padded
+/// tensor value and apply out-of-bounds masking. E.g.:
/// ```
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
/// : tensor<...> to tensor<?x?xf32>
/// ```
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
/// : tensor<...> to tensor<?x?xf32>
-/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor<?x?xf32>
+/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
+/// tensor<?x?xf32>
/// ```
/// Note: It is important that the ExtractSliceOp %r resizes the result of the
-/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even
-/// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from
-/// %r's old dimensions.
+/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
+/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
+/// from %r's old dimensions.
///
/// This rewrite is possible if:
/// - Low padding is static 0.
/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
-/// ExtractSliceOp trims the same amount of padding that was added beforehand.
+/// ExtractSliceOp trims the same amount of padding that was added
+/// beforehand.
/// - Single, scalar padding value.
struct PadTensorOpVectorizationWithTransferWritePattern
: public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
/// sizes may turn out to be equal at runtime.
bool hasSameTensorSize(Value beforePadding,
tensor::ExtractSliceOp afterTrimming) const {
- // If the input to PadTensorOp is a CastOp, try with with both CastOp result
- // and CastOp operand.
+ // If the input to PadTensorOp is a CastOp, try with with both CastOp
+ // result and CastOp operand.
if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
if (hasSameTensorSize(castOp.source(), afterTrimming))
return true;
if (t1.getNumDynamicDims() == 0)
return true;
- // All dynamic sizes must be the same. The only supported case at the moment
- // is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
+ // All dynamic sizes must be the same. The only supported case at the
+ // moment is when `beforePadding` is an ExtractSliceOp (or a cast
+ // thereof).
// Apart from CastOp, only ExtractSliceOp is supported.
auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
// InsertSliceOp.
rewriter.setInsertionPoint(insertOp);
- // Generate TransferReadOp: Read entire source tensor and add high padding.
+ // Generate TransferReadOp: Read entire source tensor and add high
+ // padding.
SmallVector<Value> readIndices(
vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
auto read = rewriter.create<vector::TransferReadOp>(
// Forwarding patterns
//----------------------------------------------------------------------------//
-/// Check whether there is any interleaved use of any `values` between `firstOp`
-/// and `secondOp`. Conservatively return `true` if any op or value is in a
-/// different block.
+/// Check whether there is any interleaved use of any `values` between
+/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
+/// is in a different block.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
ValueRange values) {
if (firstOp->getBlock() != secondOp->getBlock() ||
return false;
}
-/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
+/// Return the unique subview use of `v` if it is indeed unique, null
+/// otherwise.
static memref::SubViewOp getSubViewUseIfUnique(Value v) {
memref::SubViewOp subViewOp;
for (auto &u : v.getUses()) {
return failure();
LDBG("with copy " << *copyOp);
- // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
+ // Find the fill into `viewOrAlloc` without interleaved uses before the
+ // copy.
FillOp maybeFillOp;
for (auto &u : viewOrAlloc.getUses()) {
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
/// ```
/// kw is always unrolled.
- /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
+ /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+ /// > 1.
FailureOr<Operation *> conv() {
if (!valid)
return failure();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
- // When strideW == 1, we can batch the contiguous loads and avoid unrolling
+ // When strideW == 1, we can batch the contiguous loads and avoid
+ // unrolling
int64_t wSizeStep = strideW == 1 ? wSize : 1;
Type lhsEltType = lhsShapedType.getElementType();
VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType);
VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType);
- // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, 0].
+ // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
+ // 0].
Value lhs = builder.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c, f} @ [0, 0, 0].
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
/// ```
/// kw is always unrolled.
- /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
+ /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+ /// > 1.
FailureOr<Operation *> dilatedConv() {
if (!valid)
return failure();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
- // When strideW == 1, we can batch the contiguous loads and avoid unrolling
+ // When strideW == 1, we can batch the contiguous loads and avoid
+ // unrolling
int64_t wSizeStep = strideW == 1 ? wSize : 1;
Type lhsEltType = lhsShapedType.getElementType();
VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
- // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, 0].
+ // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
+ // 0].
Value lhs = builder.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c} @ [0, 0].