Add Linalg FillOp
authorNicolas Vasilache <ntv@google.com>
Wed, 12 Jun 2019 18:50:19 +0000 (11:50 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 20 Jun 2019 05:59:54 +0000 (22:59 -0700)
This CL adds a generic FillOp to Linalg and its lowering to loops.
This is achieved by avoiding to specify the static NLoopTypes and ViewRanks type traits but instead defines the relevant methods as `extraClassDeclaration`.
The relevant AffineMap and scalar emission code are added, with relevant tests.

This gives us a first rank-agnostic Linalg op with its generic lowering to loops that should compose with view-based tiling and fusion.

PiperOrigin-RevId: 252869205

mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td
mlir/include/mlir/Linalg/IR/LinalgOps.h
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/test/Linalg/loops.mlir

index 2e1a3a5..a8567be 100644 (file)
@@ -68,7 +68,6 @@ LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
 // Base Tablegen class for Linalg ops.
 class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
   : Op<Linalg_Dialect, mnemonic, props> {
-  let arguments = (ins Variadic<View>); // default variadic builder
   let parser = [{ return parseLinalgLibraryOp(parser, result); }];
   let printer = [{ printLinalgLibraryOp(p, *this); }];
 
@@ -82,17 +81,39 @@ class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
 ////////////////////////////////////////////////////////////////////////////////
 // Concrete Linalg ops.
 ////////////////////////////////////////////////////////////////////////////////
+def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> {
+  let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>);
+  let extraClassDeclaration = [{
+    unsigned getNumParallelLoops() {
+      auto *view = *(getOperands().begin());
+      return view->getType().cast<ViewType>().getRank();
+    }
+    unsigned getNumReductionLoops() { return 0; }
+    unsigned getNumWindowLoops() { return 0; }
+    unsigned getNumLoops() { return getNumParallelLoops(); }
+    Value *getValue() {
+      return *(getOperands().begin() + getNumInputsAndOutputs());
+    }
+  }];
+  let verifier = [{ return ::verify(*this); }];
+}
 def DotOp : LinalgLibrary_Op<"dot",
                             [NInputsAndOutputs<2, 1>,
                              NLoopTypes<0, 1, 0>,
-                             ViewRanks<[1, 1, 0]>]> {}
+                             ViewRanks<[1, 1, 0]>]> {
+  let arguments = (ins View, View, View);
+}
 def MatvecOp : LinalgLibrary_Op<"matvec",
                                   [NInputsAndOutputs<2, 1>,
                                    NLoopTypes<1, 1, 0>,
-                                   ViewRanks<[2, 1, 1]>]> {}
+                                   ViewRanks<[2, 1, 1]>]> {
+  let arguments = (ins View, View, View);
+}
 def MatmulOp : LinalgLibrary_Op<"matmul",
                                   [NInputsAndOutputs<2, 1>,
                                    NLoopTypes<2, 1, 0>,
-                                   ViewRanks<[2, 2, 2]>]> {}
+                                   ViewRanks<[2, 2, 2]>]> {
+  let arguments = (ins View, View, View);
+}
 
 #endif // LINALG_LIBRARY_OPS
index bad8c47..25bd9e5 100644 (file)
@@ -527,7 +527,8 @@ private:
     }
     Operation *create(OpBuilder &builder, Location loc,
                       ArrayRef<Value *> operands) override {
-      return builder.create<ConcreteOp>(loc, operands);
+      return builder.create<ConcreteOp>(loc, ArrayRef<Type>{}, operands,
+                                        ArrayRef<NamedAttribute>{});
     }
   };
   Concept *impl;
index 7d41c86..0e6fa9e 100644 (file)
@@ -713,6 +713,14 @@ static ParseResult parseLinalgLibraryOp(OpAsmParser *parser,
                                          result->operands));
 }
 
+static LogicalResult verify(FillOp op) {
+  auto viewType = op.getOutputViewType(0);
+  auto fillType = op.getValue()->getType();
+  if (viewType.getElementType() != fillType)
+    return op.emitOpError("expects fill type to match view elemental type");
+  return success();
+}
+
 namespace mlir {
 namespace linalg {
 
@@ -732,6 +740,12 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
   auto i = getAffineDimExpr(0, context);
   auto j = getAffineDimExpr(1, context);
   auto k = getAffineDimExpr(2, context);
+  if (auto fillOp = dyn_cast<FillOp>(op)) {
+    // filling_value -> O(ivs)
+    unsigned rank = fillOp.getNumLoops();
+    return SmallVector<AffineMap, 4>{
+        AffineMap::getMultiDimIdentityMap(rank, op->getContext())};
+  }
   if (isa<DotOp>(op))
     // A(r_i) * B(r_i) -> C()
     return SmallVector<AffineMap, 4>{AffineMap::get(1, 0, {i}),
@@ -757,8 +771,9 @@ void mlir::linalg::emitScalarImplementation(
   using linalg_load = ValueBuilder<linalg::LoadOp>;
   using linalg_store = OperationBuilder<linalg::StoreOp>;
   using IndexedValue = TemplatedIndexedValue<linalg_load, linalg_store>;
-  assert(reductionIvs.size() == 1);
-  auto innermostLoop = linalg::getForInductionVarOwner(reductionIvs.back());
+  auto *innermostIv =
+      reductionIvs.empty() ? parallelIvs.back() : reductionIvs.back();
+  auto innermostLoop = linalg::getForInductionVarOwner(innermostIv);
   auto *body = innermostLoop.getBody();
   using edsc::op::operator+;
   using edsc::op::operator*;
@@ -769,26 +784,32 @@ void mlir::linalg::emitScalarImplementation(
   OpBuilder b(body, std::prev(body->end(), 1));
   ScopedContext scope(b, innermostLoop.getLoc());
   auto *op = linalgOp.getOperation();
-  if (isa<DotOp>(op)) {
+  if (auto fillOp = dyn_cast<FillOp>(op)) {
+    IndexedValue O(fillOp.getOutput(0));
+    SmallVector<IndexHandle, 8> ivs(parallelIvs.begin(), parallelIvs.end());
+    O(ivs) = ValueHandle(fillOp.getValue());
+    return;
+  }
+  if (auto dotOp = dyn_cast<DotOp>(op)) {
     IndexHandle r_i(reductionIvs[0]);
-    IndexedValue A(op->getOperand(0)), B(op->getOperand(1)),
-        C(op->getOperand(2));
+    IndexedValue A(dotOp.getInput(0)), B(dotOp.getInput(1)),
+        C(dotOp.getOutput(0));
     C() = C() + A(r_i) * B(r_i);
     return;
   }
-  if (isa<MatvecOp>(op)) {
+  if (auto matvecOp = dyn_cast<MatvecOp>(op)) {
     IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
-    IndexedValue A(op->getOperand(0)), B(op->getOperand(1)),
-        C(op->getOperand(2));
+    IndexedValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
+        C(matvecOp.getOutput(0));
     C(i) = C(i) + A(i, r_j) * B(r_j);
     return;
   }
-  if (isa<MatmulOp>(op)) {
+  if (auto matmulOp = dyn_cast<MatmulOp>(op)) {
     IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
-    IndexedValue A(op->getOperand(0)), B(op->getOperand(1)),
-        C(op->getOperand(2));
+    IndexedValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
+        C(matmulOp.getOutput(0));
     C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
     return;
   }
-  llvm_unreachable("Missing loopToOperandRangesMaps for op");
+  llvm_unreachable("Missing emitScalarImplementation for op");
 }
index fbed1c7..7c8c816 100644 (file)
@@ -91,3 +91,21 @@ func @dot_view(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !l
 //   CHECK-DAG:   %[[c:.*]] = linalg.load %arg2[] : !linalg.view<f32>
 //   CHECK-DAG:   %[[res:.*]] = addf %[[c]], %[[inc]] : f32
 //       CHECK:   linalg.store %[[res]], %arg2[] : !linalg.view<f32>
+
+func @fill_view(%arg0: !linalg.view<?xf32>, %arg1: f32) {
+  linalg.fill(%arg0, %arg1) : !linalg.view<?xf32>, f32
+  return
+}
+// CHECK-LABEL: func @fill_view(%arg0: !linalg.view<?xf32>, %arg1: f32) {
+//       CHECK:   linalg.for %i0 = %c0 to %0 step %c1 {
+//       CHECK:     linalg.store %arg1, %arg0[%i0] : !linalg.view<?xf32>
+
+func @fill_view3(%arg0: !linalg.view<?x?x?xf32>, %arg1: f32) {
+  linalg.fill(%arg0, %arg1) : !linalg.view<?x?x?xf32>, f32
+  return
+}
+// CHECK-LABEL: func @fill_view3(%arg0: !linalg.view<?x?x?xf32>, %arg1: f32) {
+//       CHECK:   linalg.for %i0 = %c0 to %{{.*}} step %c1 {
+//       CHECK:     linalg.for %i1 = %c0 to %{{.*}} step %c1 {
+//       CHECK:       linalg.for %i2 = %c0 to %{{.*}} step %c1 {
+//       CHECK:         linalg.store %arg1, %arg0[%i0, %i1, %i2] : !linalg.view<?x?x?xf32>