// Base Tablegen class for Linalg ops.
class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
: Op<Linalg_Dialect, mnemonic, props> {
- let arguments = (ins Variadic<View>); // default variadic builder
let parser = [{ return parseLinalgLibraryOp(parser, result); }];
let printer = [{ printLinalgLibraryOp(p, *this); }];
////////////////////////////////////////////////////////////////////////////////
// Concrete Linalg ops.
////////////////////////////////////////////////////////////////////////////////
+def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> {
+ let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>);
+ let extraClassDeclaration = [{
+ unsigned getNumParallelLoops() {
+ auto *view = *(getOperands().begin());
+ return view->getType().cast<ViewType>().getRank();
+ }
+ unsigned getNumReductionLoops() { return 0; }
+ unsigned getNumWindowLoops() { return 0; }
+ unsigned getNumLoops() { return getNumParallelLoops(); }
+ Value *getValue() {
+ return *(getOperands().begin() + getNumInputsAndOutputs());
+ }
+ }];
+ let verifier = [{ return ::verify(*this); }];
+}
def DotOp : LinalgLibrary_Op<"dot",
[NInputsAndOutputs<2, 1>,
NLoopTypes<0, 1, 0>,
- ViewRanks<[1, 1, 0]>]> {}
+ ViewRanks<[1, 1, 0]>]> {
+ let arguments = (ins View, View, View);
+}
def MatvecOp : LinalgLibrary_Op<"matvec",
[NInputsAndOutputs<2, 1>,
NLoopTypes<1, 1, 0>,
- ViewRanks<[2, 1, 1]>]> {}
+ ViewRanks<[2, 1, 1]>]> {
+ let arguments = (ins View, View, View);
+}
def MatmulOp : LinalgLibrary_Op<"matmul",
[NInputsAndOutputs<2, 1>,
NLoopTypes<2, 1, 0>,
- ViewRanks<[2, 2, 2]>]> {}
+ ViewRanks<[2, 2, 2]>]> {
+ let arguments = (ins View, View, View);
+}
#endif // LINALG_LIBRARY_OPS
}
Operation *create(OpBuilder &builder, Location loc,
ArrayRef<Value *> operands) override {
- return builder.create<ConcreteOp>(loc, operands);
+ return builder.create<ConcreteOp>(loc, ArrayRef<Type>{}, operands,
+ ArrayRef<NamedAttribute>{});
}
};
Concept *impl;
result->operands));
}
+static LogicalResult verify(FillOp op) {
+ auto viewType = op.getOutputViewType(0);
+ auto fillType = op.getValue()->getType();
+ if (viewType.getElementType() != fillType)
+ return op.emitOpError("expects fill type to match view elemental type");
+ return success();
+}
+
namespace mlir {
namespace linalg {
auto i = getAffineDimExpr(0, context);
auto j = getAffineDimExpr(1, context);
auto k = getAffineDimExpr(2, context);
+ if (auto fillOp = dyn_cast<FillOp>(op)) {
+ // filling_value -> O(ivs)
+ unsigned rank = fillOp.getNumLoops();
+ return SmallVector<AffineMap, 4>{
+ AffineMap::getMultiDimIdentityMap(rank, op->getContext())};
+ }
if (isa<DotOp>(op))
// A(r_i) * B(r_i) -> C()
return SmallVector<AffineMap, 4>{AffineMap::get(1, 0, {i}),
using linalg_load = ValueBuilder<linalg::LoadOp>;
using linalg_store = OperationBuilder<linalg::StoreOp>;
using IndexedValue = TemplatedIndexedValue<linalg_load, linalg_store>;
- assert(reductionIvs.size() == 1);
- auto innermostLoop = linalg::getForInductionVarOwner(reductionIvs.back());
+ auto *innermostIv =
+ reductionIvs.empty() ? parallelIvs.back() : reductionIvs.back();
+ auto innermostLoop = linalg::getForInductionVarOwner(innermostIv);
auto *body = innermostLoop.getBody();
using edsc::op::operator+;
using edsc::op::operator*;
OpBuilder b(body, std::prev(body->end(), 1));
ScopedContext scope(b, innermostLoop.getLoc());
auto *op = linalgOp.getOperation();
- if (isa<DotOp>(op)) {
+ if (auto fillOp = dyn_cast<FillOp>(op)) {
+ IndexedValue O(fillOp.getOutput(0));
+ SmallVector<IndexHandle, 8> ivs(parallelIvs.begin(), parallelIvs.end());
+ O(ivs) = ValueHandle(fillOp.getValue());
+ return;
+ }
+ if (auto dotOp = dyn_cast<DotOp>(op)) {
IndexHandle r_i(reductionIvs[0]);
- IndexedValue A(op->getOperand(0)), B(op->getOperand(1)),
- C(op->getOperand(2));
+ IndexedValue A(dotOp.getInput(0)), B(dotOp.getInput(1)),
+ C(dotOp.getOutput(0));
C() = C() + A(r_i) * B(r_i);
return;
}
- if (isa<MatvecOp>(op)) {
+ if (auto matvecOp = dyn_cast<MatvecOp>(op)) {
IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
- IndexedValue A(op->getOperand(0)), B(op->getOperand(1)),
- C(op->getOperand(2));
+ IndexedValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
+ C(matvecOp.getOutput(0));
C(i) = C(i) + A(i, r_j) * B(r_j);
return;
}
- if (isa<MatmulOp>(op)) {
+ if (auto matmulOp = dyn_cast<MatmulOp>(op)) {
IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
- IndexedValue A(op->getOperand(0)), B(op->getOperand(1)),
- C(op->getOperand(2));
+ IndexedValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
+ C(matmulOp.getOutput(0));
C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
return;
}
- llvm_unreachable("Missing loopToOperandRangesMaps for op");
+ llvm_unreachable("Missing emitScalarImplementation for op");
}
// CHECK-DAG: %[[c:.*]] = linalg.load %arg2[] : !linalg.view<f32>
// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
// CHECK: linalg.store %[[res]], %arg2[] : !linalg.view<f32>
+
+func @fill_view(%arg0: !linalg.view<?xf32>, %arg1: f32) {
+ linalg.fill(%arg0, %arg1) : !linalg.view<?xf32>, f32
+ return
+}
+// CHECK-LABEL: func @fill_view(%arg0: !linalg.view<?xf32>, %arg1: f32) {
+// CHECK: linalg.for %i0 = %c0 to %0 step %c1 {
+// CHECK: linalg.store %arg1, %arg0[%i0] : !linalg.view<?xf32>
+
+func @fill_view3(%arg0: !linalg.view<?x?x?xf32>, %arg1: f32) {
+ linalg.fill(%arg0, %arg1) : !linalg.view<?x?x?xf32>, f32
+ return
+}
+// CHECK-LABEL: func @fill_view3(%arg0: !linalg.view<?x?x?xf32>, %arg1: f32) {
+// CHECK: linalg.for %i0 = %c0 to %{{.*}} step %c1 {
+// CHECK: linalg.for %i1 = %c0 to %{{.*}} step %c1 {
+// CHECK: linalg.for %i2 = %c0 to %{{.*}} step %c1 {
+// CHECK: linalg.store %arg1, %arg0[%i0, %i1, %i2] : !linalg.view<?x?x?xf32>