#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/EDSC/Helpers.h"
+#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
}
-// TODO(ntv): This is in fact much more general than just vectorization for
-// matmul ops.
+// TODO(ntv, ataei): This is in fact much more general than just vectorization
+// for matmul and fill ops.
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
- if (isa<linalg::MatmulOp>(op))
+ if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op))
return success();
auto genericOp = dyn_cast<linalg::GenericOp>(op);
SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
Operation *op) {
- LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
- "]: Rewrite linalg op as vector.contract: "
- << *op << ":\n");
+ using edsc::intrinsics::std_load;
+ using edsc::intrinsics::std_store;
+ using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
+ using vector_broadcast = edsc::intrinsics::ValueBuilder<vector::BroadcastOp>;
+ using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
assert(succeeded(vectorizeLinalgOpPrecondition(op)) &&
"DRR failure case must be a precondition");
-
auto linalgOp = cast<linalg::LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
edsc::ScopedContext scope(rewriter, op->getLoc());
- using edsc::intrinsics::std_load;
- using edsc::intrinsics::std_store;
- using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
- using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
- auto vA = std_load(vector_type_cast(linalgOp.getInput(0)));
- auto vB = std_load(vector_type_cast(linalgOp.getInput(1)));
- auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0));
- auto vC = std_load(vectorMemRefC);
- auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(),
- linalgOp.iterator_types());
- std_store(vRes, vectorMemRefC);
+
+ if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
+ // Vectorize fill as a vector.broadcast.
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
+ "]: Rewrite linalg.fill as vector.broadcast: "
+ << *op << ":\n");
+ auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0));
+ auto dstVec = std_load(dstMemrefVec);
+ auto resVec = vector_broadcast(dstVec, fillOp.value());
+ std_store(resVec, dstMemrefVec);
+ } else {
+ // Vectorize other ops as vector contraction (currently only matmul).
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
+ "]: Rewrite linalg op as vector.contract: "
+ << *op << ":\n");
+ auto vA = std_load(vector_type_cast(linalgOp.getInput(0)));
+ auto vB = std_load(vector_type_cast(linalgOp.getInput(1)));
+ auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0));
+ auto vC = std_load(vectorMemRefC);
+ auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(),
+ linalgOp.iterator_types());
+ std_store(vRes, vectorMemRefC);
+ }
return {};
}