From de32c03ebeefcbebbb914ee13308acf1bea15428 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 12 Jun 2019 11:50:19 -0700 Subject: [PATCH] Add Linalg FillOp This CL adds a generic FillOp to Linalg and its lowering to loops. This is achieved by avoiding to specify the static NLoopTypes and ViewRanks type traits but instead defines the relevant methods as `extraClassDeclaration`. The relevant AffineMap and scalar emission code are added, with relevant tests. This gives us a first rank-agnostic Linalg op with its generic lowering to loops that should compose with view-based tiling and fusion. PiperOrigin-RevId: 252869205 --- mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td | 29 +++++++++++++--- mlir/include/mlir/Linalg/IR/LinalgOps.h | 3 +- mlir/lib/Linalg/IR/LinalgOps.cpp | 45 ++++++++++++++++++------- mlir/test/Linalg/loops.mlir | 18 ++++++++++ 4 files changed, 78 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td index 2e1a3a5..a8567be 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td @@ -68,7 +68,6 @@ LinalgParametricIntNativeOpTrait<"ViewRanks", ranks> // Base Tablegen class for Linalg ops. class LinalgLibrary_Op props> : Op { - let arguments = (ins Variadic); // default variadic builder let parser = [{ return parseLinalgLibraryOp(parser, result); }]; let printer = [{ printLinalgLibraryOp(p, *this); }]; @@ -82,17 +81,39 @@ class LinalgLibrary_Op props> //////////////////////////////////////////////////////////////////////////////// // Concrete Linalg ops. //////////////////////////////////////////////////////////////////////////////// +def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> { + let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>); + let extraClassDeclaration = [{ + unsigned getNumParallelLoops() { + auto *view = *(getOperands().begin()); + return view->getType().cast().getRank(); + } + unsigned getNumReductionLoops() { return 0; } + unsigned getNumWindowLoops() { return 0; } + unsigned getNumLoops() { return getNumParallelLoops(); } + Value *getValue() { + return *(getOperands().begin() + getNumInputsAndOutputs()); + } + }]; + let verifier = [{ return ::verify(*this); }]; +} def DotOp : LinalgLibrary_Op<"dot", [NInputsAndOutputs<2, 1>, NLoopTypes<0, 1, 0>, - ViewRanks<[1, 1, 0]>]> {} + ViewRanks<[1, 1, 0]>]> { + let arguments = (ins View, View, View); +} def MatvecOp : LinalgLibrary_Op<"matvec", [NInputsAndOutputs<2, 1>, NLoopTypes<1, 1, 0>, - ViewRanks<[2, 1, 1]>]> {} + ViewRanks<[2, 1, 1]>]> { + let arguments = (ins View, View, View); +} def MatmulOp : LinalgLibrary_Op<"matmul", [NInputsAndOutputs<2, 1>, NLoopTypes<2, 1, 0>, - ViewRanks<[2, 2, 2]>]> {} + ViewRanks<[2, 2, 2]>]> { + let arguments = (ins View, View, View); +} #endif // LINALG_LIBRARY_OPS diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Linalg/IR/LinalgOps.h index bad8c47..25bd9e5 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -527,7 +527,8 @@ private: } Operation *create(OpBuilder &builder, Location loc, ArrayRef operands) override { - return builder.create(loc, operands); + return builder.create(loc, ArrayRef{}, operands, + ArrayRef{}); } }; Concept *impl; diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 7d41c86..0e6fa9e 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -713,6 +713,14 @@ static ParseResult parseLinalgLibraryOp(OpAsmParser *parser, result->operands)); } +static LogicalResult verify(FillOp op) { + auto viewType = op.getOutputViewType(0); + auto fillType = op.getValue()->getType(); + if (viewType.getElementType() != fillType) + return op.emitOpError("expects fill type to match view elemental type"); + return success(); +} + namespace mlir { namespace linalg { @@ -732,6 +740,12 @@ SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { auto i = getAffineDimExpr(0, context); auto j = getAffineDimExpr(1, context); auto k = getAffineDimExpr(2, context); + if (auto fillOp = dyn_cast(op)) { + // filling_value -> O(ivs) + unsigned rank = fillOp.getNumLoops(); + return SmallVector{ + AffineMap::getMultiDimIdentityMap(rank, op->getContext())}; + } if (isa(op)) // A(r_i) * B(r_i) -> C() return SmallVector{AffineMap::get(1, 0, {i}), @@ -757,8 +771,9 @@ void mlir::linalg::emitScalarImplementation( using linalg_load = ValueBuilder; using linalg_store = OperationBuilder; using IndexedValue = TemplatedIndexedValue; - assert(reductionIvs.size() == 1); - auto innermostLoop = linalg::getForInductionVarOwner(reductionIvs.back()); + auto *innermostIv = + reductionIvs.empty() ? parallelIvs.back() : reductionIvs.back(); + auto innermostLoop = linalg::getForInductionVarOwner(innermostIv); auto *body = innermostLoop.getBody(); using edsc::op::operator+; using edsc::op::operator*; @@ -769,26 +784,32 @@ void mlir::linalg::emitScalarImplementation( OpBuilder b(body, std::prev(body->end(), 1)); ScopedContext scope(b, innermostLoop.getLoc()); auto *op = linalgOp.getOperation(); - if (isa(op)) { + if (auto fillOp = dyn_cast(op)) { + IndexedValue O(fillOp.getOutput(0)); + SmallVector ivs(parallelIvs.begin(), parallelIvs.end()); + O(ivs) = ValueHandle(fillOp.getValue()); + return; + } + if (auto dotOp = dyn_cast(op)) { IndexHandle r_i(reductionIvs[0]); - IndexedValue A(op->getOperand(0)), B(op->getOperand(1)), - C(op->getOperand(2)); + IndexedValue A(dotOp.getInput(0)), B(dotOp.getInput(1)), + C(dotOp.getOutput(0)); C() = C() + A(r_i) * B(r_i); return; } - if (isa(op)) { + if (auto matvecOp = dyn_cast(op)) { IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]); - IndexedValue A(op->getOperand(0)), B(op->getOperand(1)), - C(op->getOperand(2)); + IndexedValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), + C(matvecOp.getOutput(0)); C(i) = C(i) + A(i, r_j) * B(r_j); return; } - if (isa(op)) { + if (auto matmulOp = dyn_cast(op)) { IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]); - IndexedValue A(op->getOperand(0)), B(op->getOperand(1)), - C(op->getOperand(2)); + IndexedValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), + C(matmulOp.getOutput(0)); C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); return; } - llvm_unreachable("Missing loopToOperandRangesMaps for op"); + llvm_unreachable("Missing emitScalarImplementation for op"); } diff --git a/mlir/test/Linalg/loops.mlir b/mlir/test/Linalg/loops.mlir index fbed1c7..7c8c816 100644 --- a/mlir/test/Linalg/loops.mlir +++ b/mlir/test/Linalg/loops.mlir @@ -91,3 +91,21 @@ func @dot_view(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !l // CHECK-DAG: %[[c:.*]] = linalg.load %arg2[] : !linalg.view // CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 // CHECK: linalg.store %[[res]], %arg2[] : !linalg.view + +func @fill_view(%arg0: !linalg.view, %arg1: f32) { + linalg.fill(%arg0, %arg1) : !linalg.view, f32 + return +} +// CHECK-LABEL: func @fill_view(%arg0: !linalg.view, %arg1: f32) { +// CHECK: linalg.for %i0 = %c0 to %0 step %c1 { +// CHECK: linalg.store %arg1, %arg0[%i0] : !linalg.view + +func @fill_view3(%arg0: !linalg.view, %arg1: f32) { + linalg.fill(%arg0, %arg1) : !linalg.view, f32 + return +} +// CHECK-LABEL: func @fill_view3(%arg0: !linalg.view, %arg1: f32) { +// CHECK: linalg.for %i0 = %c0 to %{{.*}} step %c1 { +// CHECK: linalg.for %i1 = %c0 to %{{.*}} step %c1 { +// CHECK: linalg.for %i2 = %c0 to %{{.*}} step %c1 { +// CHECK: linalg.store %arg1, %arg0[%i0, %i1, %i2] : !linalg.view -- 2.7.4