#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
-static FailureOr<Operation *>
-vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
+/// Try to vectorize `convOp` as a convolution.
+static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
+ LinalgOp convOp);
/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
SmallVector<Value> results;
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. Will require stride/dilation attributes inference.
- if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
- LDBG("Vectorize as a conv: " << linalgOp);
- FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
- if (failed(convOr))
- return failure();
+ FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
+ if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
} else {
+ if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
+ return failure();
LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
return failure();
};
} // namespace
-/// Helper function to vectorize a `linalgOp` with convolution semantics.
+/// Helper function to vectorize a LinalgOp with convolution semantics.
// TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *>
-vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
- // TODO: these are legitimately part of ConvolutionOpInterface.
- auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
- auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
+ // The ConvolutionOpInterface gives us guarantees of existence for
+ // strides/dilations. However, we do not need to rely on those, we can simply
+ // use them if present, otherwise use the default and let the generic conv.
+ // matcher in the ConvGenerator succeed or fail.
+ auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
+ auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
- LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
- Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
+ Conv1DNwcGenerator e(b, op, stride, dilation);
auto res = e.generateConv();
if (succeeded(res))
return res;
return e.generateDilatedConv();
}
-struct VectorizeConvolution
- : public OpInterfaceRewritePattern<ConvolutionOpInterface> {
+struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
+ LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
- FailureOr<Operation *> resultOrFail =
- vectorizeConvolution(rewriter, convOp);
+ FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
if (failed(resultOrFail))
return failure();
Operation *newOp = *resultOrFail;
if (newOp->getNumResults() == 0) {
- rewriter.eraseOp(convOp.getOperation());
+ rewriter.eraseOp(op.getOperation());
return success();
}
assert(newOp->getNumResults() == 1 && "expected single result");
- rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
+ rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
return success();
}
};