[mlir][linalg] Add support for scalar input operands.
authorTobias Gysi <gysit@google.com>
Mon, 14 Jun 2021 05:59:33 +0000 (05:59 +0000)
committerTobias Gysi <gysit@google.com>
Mon, 14 Jun 2021 06:27:16 +0000 (06:27 +0000)
Up to now all structured op operands are assumed to be shaped. The patch relaxes this assumption and allows scalar input operands. In contrast to shaped operands scalar operands are not indexed and directly forwarded to the body of the operation. As all other operands, scalar operands are associated to an indexing map that in case of a scalar or a 0D-operand has an empty range.

We will use scalar operands as a replacement for the capture mechanism. In contrast to captures, the approach ensures we can generate the function signature from the operand list and it prevents outdated capture values in case a transformation updates only the capture operand but not the hidden body of a named operation.

Removing captures and updating existing operations such as linalg.fill is left for a later patch.

The patch depends on https://reviews.llvm.org/D103891 and https://reviews.llvm.org/D103890.

Differential Revision: https://reviews.llvm.org/D104109

24 files changed:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
mlir/test/Dialect/Linalg/fusion-tensor.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/Dialect/Linalg/reshape_fusion.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

index 6dabb6d..bc9acee 100644 (file)
@@ -15,8 +15,6 @@
 
 include "mlir/IR/OpBase.td"
 
-def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
-
 def Linalg_Dialect : Dialect {
   let name = "linalg";
   let description = [{
index 73e5570..e53f2f3 100644 (file)
@@ -586,6 +586,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
+        Return true if the `opOperand` is a scalar value.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isScalar",
+      /*args=*/(ins "OpOperand*":$opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(opOperand->getOwner() == this->getOperation());
+        return !opOperand->get().getType().template isa<ShapedType>();
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
         Return the input or output indexing map for `opOperand`.
       }],
       /*retTy=*/"AffineMap",
@@ -694,10 +707,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         return this->getOperation()->getNumResults() == 0 &&
-          llvm::all_of(getInputAndOutputOperands(),
-            [](OpOperand *opOperand) {
-              return opOperand->get().getType().template isa<MemRefType>();
-            });
+          llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
+            return isScalar(opOperand) ||
+              opOperand->get().getType().template isa<MemRefType>();
+          }) &&
+          llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
+            return opOperand->get().getType().template isa<MemRefType>();
+          });
       }]
     >,
     InterfaceMethod<
@@ -709,8 +725,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return llvm::all_of(getInputAndOutputOperands(),
-          [](OpOperand *opOperand) {
+        return
+          llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
+            return isScalar(opOperand) ||
+              opOperand->get().getType().template isa<RankedTensorType>();
+          }) &&
+          llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
             return opOperand->get().getType().template isa<RankedTensorType>();
           });
       }]
index 5d7e2cc..9b6120e 100644 (file)
@@ -640,8 +640,8 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
   let arguments = (ins Variadic<Index>:$lowerBound,
                        Variadic<Index>:$upperBound,
                        Variadic<Index>:$step,
-                       Variadic<LinalgOperand>:$inputs,
-                       Variadic<LinalgOperand>:$outputs,
+                       Variadic<AnyType>:$inputs,
+                       Variadic<AnyShaped>:$outputs,
                        ArrayAttr:$iterator_types,
                        OptionalAttr<ArrayAttr>:$distribution_types);
   let results = (outs Variadic<AnyRankedTensor>:$results);
index 41fcc24..2b70572 100644 (file)
@@ -517,17 +517,12 @@ def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> {
 //===----------------------------------------------------------------------===//
 // Generic Linalg ops.
 //===----------------------------------------------------------------------===//
-class LinalgOperandOfRank<int rank>: Type<
-  And<[
-    LinalgOperand.predicate,
-    CPred<"$_self.cast<ShapedType>().getRank() == " # rank>]
-  >>;
 
 class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
     AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
     SingleBlockImplicitTerminator<"YieldOp">]> {
-  let arguments = (ins Variadic<AnyShaped>:$inputs,
+  let arguments = (ins Variadic<AnyType>:$inputs,
                        Variadic<AnyShaped>:$outputs,
                        AffineMapArrayAttr:$indexing_maps,
                        ArrayAttr:$iterator_types,
index 8a48e89..18717e9 100644 (file)
@@ -338,7 +338,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
     if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
       return failure();
 
-  // All shaped operands must be indexed.
+  // All input/output operands must be indexed.
   if (static_cast<int64_t>(linalgOp.indexing_maps().size()) !=
       linalgOp.getNumInputsAndOutputs())
     return op->emitOpError("expected the number of indexing_map (")
@@ -363,7 +363,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
 
     int64_t rank = linalgOp.getRank(opOperand);
     if (indexingMap.getNumResults() != rank)
-      return op->emitOpError("expected shaped value rank (")
+      return op->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand->getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
@@ -444,7 +444,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
 
   if (linalgOp.getNumInputsAndOutputs() + numBBIvs != block.getNumArguments())
     return op->emitOpError("expected as many non-induction variable region "
-                           "arguments as the number of shaped operands");
+                           "arguments as the number of input/output operands");
 
   // Note: the number and type of yield values are checked in the YieldOp.
   for (unsigned i = 0; i < numBBIvs; ++i)
@@ -452,14 +452,14 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
       return op->emitOpError("expected index block argument #") << i;
 
   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
-    Type elementType = getElementTypeOrSelf(opOperand->get().getType());
+    Type elementType = getElementTypeOrSelf(opOperand->get());
     Type argType =
         block.getArgument(numBBIvs + opOperand->getOperandNumber()).getType();
     if (elementType != argType)
       return op->emitOpError("expected type of bb argument #")
              << numBBIvs + opOperand->getOperandNumber() << " (" << argType
              << ")"
-             << " to match element type of corresponding shaped operand ("
+             << " to match element or self type of the corresponding operand ("
              << elementType << ")";
   }
 
@@ -489,10 +489,11 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
 
         // The first index or last index should be the maximum or the minimum in
         // the inferred index ranges since the range is increasing or
-        // decreasing. The size of dimensions of shaped operands and the maximum
-        // value + 1 in the inferred range should be the same. But, for now we
-        // check if the inferred ranges are in boundary of shaped operands' size
-        // or not in case that Affine Expressions are complicated such as d0 * 3
+        // decreasing. The size of dimensions of input/output operands and the
+        // maximum value + 1 in the inferred range should be the same. But, for
+        // now we check if the inferred ranges are in boundary of input/output
+        // operands' size or not in case that Affine Expressions are complicated
+        // such as d0 * 3
         // + d1 since it is not easy to handle the issues.
         // Found the case that this solution can't check, for example, (d0, d1)
         // -> (d1 - d0)
@@ -510,14 +511,14 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
         }
         if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
           if (inferredDimSize != shape[dim]) {
-            return op->emitOpError("inferred shaped operand #")
+            return op->emitOpError("inferred input/output operand #")
                    << opOperand->getOperandNumber()
                    << " has shape's dimension #" << dim << " to be "
                    << inferredDimSize << ", but found " << shape[dim];
           }
         } else {
           if (inferredDimSize > shape[dim]) {
-            return op->emitOpError("inferred shaped operand #")
+            return op->emitOpError("inferred input/output operand #")
                    << opOperand->getOperandNumber()
                    << " has shape's dimension #" << dim
                    << " to be greater than or equal to " << inferredDimSize
index 6aa0ed1..6eef6b0 100644 (file)
@@ -377,8 +377,7 @@ void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
 static LogicalResult verify(CopyOp op) {
   OpOperand *output = op.getOutputOperand(0);
   OpOperand *input = op.getInputOperand(0);
-  if (getElementTypeOrSelf(input->get().getType()) !=
-      getElementTypeOrSelf(output->get().getType()))
+  if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get()))
     return op.emitOpError("expects views of the same type");
   if (op.getRank(input) != op.getRank(output))
     return op.emitOpError("expects views of the same rank");
@@ -452,7 +451,7 @@ void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
 static LogicalResult verify(FillOp op) {
   OpOperand *output = op.getOutputOperand(0);
   Type fillType = op.value().getType();
-  if (getElementTypeOrSelf(output->get().getType()) != fillType)
+  if (getElementTypeOrSelf(output->get()) != fillType)
     return op.emitOpError("expects fill type to match view elemental type");
   if (!op.getNumResults() && !output->get().getType().isa<MemRefType>()) {
     return op.emitOpError(
@@ -489,7 +488,7 @@ void GenericOp::build(
   SmallVector<Type, 4> blockArgTypes;
   for (ValueRange container : {inputs, outputs})
     for (Value v : container)
-      blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
+      blockArgTypes.push_back(getElementTypeOrSelf(v));
 
   OpBuilder::InsertionGuard guard(builder);
   auto &region = *result.regions.front();
@@ -545,7 +544,7 @@ void IndexedGenericOp::build(
   SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
   for (ValueRange container : {inputs, outputs})
     for (Value v : container)
-      blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
+      blockArgTypes.push_back(getElementTypeOrSelf(v));
 
   OpBuilder::InsertionGuard guard(builder);
   auto &region = *result.regions.front();
@@ -2949,7 +2948,6 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
                        TypeRange inputTypes, TypeRange outputTypes,
                        ValueRange captures,
                        std::function<void(unsigned, unsigned)> errorHandler) {
-  assert(llvm::all_of(inputTypes, [](Type t) { return t.isa<ShapedType>(); }));
   assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
 
   // TODO: atm all operands go through getElementTypeOrSelf,
index 1c2d32b..3b63bf1 100644 (file)
@@ -484,18 +484,21 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
 
   b.setInsertionPoint(op);
   Location loc = op.getLoc();
-  SmallVector<Value, 2> newInputBuffers;
-  newInputBuffers.reserve(op.getNumInputs());
+  SmallVector<Value> newInputs;
+  newInputs.reserve(op.getNumInputs());
   for (OpOperand *opOperand : op.getInputOperands()) {
-    Value v = lookup(bvm, opOperand->get());
-    if (!v)
+    if (op.isScalar(opOperand)) {
+      newInputs.push_back(opOperand->get());
+      continue;
+    }
+    newInputs.push_back(lookup(bvm, opOperand->get()));
+    if (!newInputs.back())
       return failure();
-    newInputBuffers.push_back(v);
   }
-  SmallVector<Value, 2> newOutputBuffers;
+  SmallVector<Value> newOutputBuffers;
   if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm)))
     return failure();
-  finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm);
+  finalizeBufferAllocation(b, op, newInputs, newOutputBuffers, bvm);
   return success();
 }
 
index 102dbdb..1deea94 100644 (file)
@@ -301,7 +301,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
     ++dim;
   }
   // Compute the tensor or scalar replacement type.
-  Type elementType = getElementTypeOrSelf(opOperand->get().getType());
+  Type elementType = getElementTypeOrSelf(opOperand->get());
   Type replacementType = elementType == opOperand->get().getType()
                              ? elementType
                              : RankedTensorType::get(newShape, elementType);
index 0263bcb..098442c 100644 (file)
@@ -129,14 +129,14 @@ static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) {
   assert(producer.hasTensorSemantics() &&
          "only fusion on tensors is currently supported for TiledLinalgOp");
 
-  for (OpOperand *producerInput : producer.getInputTensorOperands()) {
+  for (OpOperand *producerInput : producer.getInputOperands()) {
     OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get());
     if (addedInput == nullptr)
       addedInput = &tiledLoop.appendInputOperand(b, producerInput->get());
     BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
     tiledOperands.push_back(addedBlockArg);
   }
-  for (OpOperand *producerOutput : producer.getOutputTensorOperands()) {
+  for (OpOperand *producerOutput : producer.getOutputOperands()) {
     OpResult result = producer.getTiedOpResult(producerOutput);
     OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
     OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
index 530c024..49870da 100644 (file)
@@ -126,8 +126,12 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
 
   // TODO: Avoid the loads if the corresponding argument of the
   // region has no uses.
-  // 1.a. Emit load from input views.
+  // 1.a. Emit load from input operand or for scalars access the operand itself.
   for (OpOperand *inputOperand : linalgOp.getInputOperands()) {
+    if (linalgOp.isScalar(inputOperand)) {
+      indexedValues.push_back(inputOperand->get());
+      continue;
+    }
     auto indexing = makeCanonicalAffineApplies(
         b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims);
     indexedValues.push_back(
index 44acac1..efd0c3b 100644 (file)
@@ -149,7 +149,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
   }
   Value pad = options.paddingValueComputationFunction(rewriter, *opOperand);
   auto staticTensorType = RankedTensorType::get(
-      staticSizes, getElementTypeOrSelf(opOperand->get().getType()));
+      staticSizes, getElementTypeOrSelf(opOperand->get()));
   result = linalg::PadTensorOp::createPadHighOp(
       staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter);
   return success();
index 689aae1..96846bf 100644 (file)
@@ -479,6 +479,10 @@ LogicalResult vectorizeAsLinalgGeneric(
   SmallVector<AffineMap> indexings;
   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
     BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber());
+    if (linalgOp.isScalar(opOperand)) {
+      bvm.map(bbarg, opOperand->get());
+      continue;
+    }
     // TODO: 0-d vectors.
     if (linalgOp.getShape(opOperand).empty()) {
       Value loaded =
@@ -494,14 +498,13 @@ LogicalResult vectorizeAsLinalgGeneric(
     if (broadcastToMaximalCommonShape) {
       map = inverseAndBroadcastProjectedPermuation(
           linalgOp.getTiedIndexingMap(opOperand));
-      vectorType = VectorType::get(
-          commonVectorShape, getElementTypeOrSelf(opOperand->get().getType()));
+      vectorType = VectorType::get(commonVectorShape,
+                                   getElementTypeOrSelf(opOperand->get()));
     } else {
       map = inversePermutation(
           reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
-      vectorType =
-          VectorType::get(map.compose(linalgOp.getShape(opOperand)),
-                          getElementTypeOrSelf(opOperand->get().getType()));
+      vectorType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
+                                   getElementTypeOrSelf(opOperand->get()));
     }
     Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
@@ -1157,7 +1160,7 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
 
   int64_t rank = op.getRank(input);
   int64_t numDims = mapping.size();
-  Type elemType = getElementTypeOrSelf(input->get().getType());
+  Type elemType = getElementTypeOrSelf(input->get());
 
   auto map = AffineMap::get(rank, 0, mapping, context);
   SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
index c99aafb..4460ae8 100644 (file)
@@ -1372,6 +1372,8 @@ public:
     // Detects sparse annotations and translate the per-dimension sparsity
     // information for all tensors to loop indices in the kernel.
     assert(op.getNumOutputs() == 1);
+    assert(llvm::none_of(op.getInputAndOutputOperands(),
+                         [&](OpOperand *t) { return op.isScalar(t); }));
     unsigned numTensors = op.getNumInputsAndOutputs();
     unsigned numLoops = op.iterator_types().getValue().size();
     Merger merger(numTensors, numLoops);
index bd4f66d..5a53c22 100644 (file)
@@ -2,6 +2,7 @@
 
 #accesses = [
   affine_map<(i, j, k, l, m) -> (i, k, m)>,
+  affine_map<(i, j, k, l, m) -> ()>,
   affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
 ]
 
   library_call = "some_external_func"
 }
 
-func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
+func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
   %0 = linalg.generic #trait
-     ins(%arg0 : tensor<?x1x?xf32>)
+     ins(%arg0, %arg1 : tensor<?x1x?xf32>, f32)
     outs(%shape : tensor<?x1x?x1x?xf32>) {
-       ^bb0(%arg2 : f32, %arg3 : f32) :
-         linalg.yield %arg2 : f32
+       ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
+         linalg.yield %arg3 : f32
        } -> tensor<?x1x?x1x?xf32>
   return %0 : tensor<?x1x?x1x?xf32>
 }
-//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
 //   CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-LABEL: func @drop_one_trip_loops
 //       CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
 //       CHECK: linalg.generic
-//  CHECK-SAME:   indexing_maps = [#[[$MAP2]], #[[$MAP3]]]
+//  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
 //       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
 
index 9f1566c..730482d 100644 (file)
@@ -292,7 +292,7 @@ module {
 // TLOOP:  %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
 // TLOOP:  %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
 
-// TLOOP:  %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = 
+// TLOOP:  %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) =
 // TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
 // TLOOP-SAME: step (%[[C32]], %[[C64]])
 // TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
@@ -305,7 +305,80 @@ module {
 // TLOOP:    %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]]
 // TLOOP:    %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]])
 
-// TLOOP:    %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]]) 
+// TLOOP:    %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
+// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
+// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]],
+// TLOOP-SAME:      %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]])
+// TLOOP-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: [[TY]])
+// TLOOP-SAME: iterators["reduction"] {
+
+// TLOOP:      %[[A_SUB_SUB:.*]] = subtensor %[[A_SUB_]][0, %[[K]]]
+// TLOOP:      %[[B_SUB_SUB:.*]] = subtensor %[[B_SUB_]][%[[K]], 0]
+
+// TLOOP:      %[[AB_SUB_SUB:.*]] = linalg.matmul
+// TLOOP-SAME:   ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
+// TLOOP-SAME:   outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]]
+// TLOOP:      linalg.yield %[[AB_SUB_SUB]] : [[TY]]
+// TLOOP:    }
+// TLOOP:    %[[SUB_RESULT:.*]] = subtensor_insert %[[AB_SUB]]
+// TLOOP-SAME:  into %[[OUT_]][%[[I]], %[[J]]]
+// TLOOP:    linalg.yield %[[SUB_RESULT]] : [[TY]]
+// TLOOP:  }
+// TLOOP:  return %[[AB]] : [[TY]]
+
+// -----
+
+module {
+  func @generic_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+                      %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    %c0 = constant 0.0 : f32
+    %0 = linalg.generic {
+      indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
+      iterator_types = ["parallel", "parallel"]}
+     ins(%c0 : f32)
+    outs(%arg0: tensor<?x?xf32>) {
+      ^bb(%0: f32, %1: f32) :
+        linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+    %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"}
+      ins(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+    return %1 : tensor<?x?xf32>
+  }
+}
+
+// TLOOP-LABEL: func @generic_plus_matmul(
+// TLOOP-SAME:    %[[OUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// TLOOP-SAME:    %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// TLOOP-SAME:    %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+
+// TLOOP-DAG:  %[[C0_F32:.*]] = constant 0.0
+// TLOOP-DAG:  %[[C32:.*]] = constant 32 : index
+// TLOOP-DAG:  %[[C64:.*]] = constant 64 : index
+// TLOOP-DAG:  %[[C16:.*]] = constant 16 : index
+// TLOOP-DAG:  %[[C0:.*]] = constant 0 : index
+// TLOOP-DAG:  %[[C1:.*]] = constant 1 : index
+
+// TLOOP:  %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
+// TLOOP:  %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
+
+// TLOOP:  %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) =
+// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
+// TLOOP-SAME: step (%[[C32]], %[[C64]])
+// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
+// TLOOP-SAME:      %[[B_:.*]] = %[[B]]: [[TY]],
+// TLOOP-SAME:      %[[C0_F32_:.*]] = %[[C0_F32]]
+// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
+
+// TLOOP:    %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP:    %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0]
+// TLOOP:    %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]]
+// TLOOP:    %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]]
+// TLOOP:    %[[INIT_SUB:.*]] = linalg.generic
+// TLOOP-SAME: ins(%[[C0_F32_]]
+// TLOOP-SAME: outs(%[[OUT_SUB]]
+
+// TLOOP:    %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
 // TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
 // TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]],
 // TLOOP-SAME:      %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]])
index 3146b41..068b875 100644 (file)
@@ -41,6 +41,48 @@ func @add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : te
 // -----
 
 // CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> ()>
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> ()>
+
+// CHECK-LABEL: @scalar_add_mul_fusion
+func @scalar_add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : f32, %arg2 : f32) -> tensor<?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = memref.dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = memref.dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, f32)
+      outs(%2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):       // no predecessors
+      %4 = addf %arg3, %arg4 : f32
+      linalg.yield %4 : f32
+  } -> tensor<?x?xf32>
+  // CHECK: linalg.generic {
+  // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP1]], [[$MAP1]], [[$MAP0]]{{\]}}
+  %4 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%3, %arg2 : tensor<?x?xf32>, f32)
+      outs(%2 : tensor<?x?xf32>) {
+    // CHECK: ^{{[a-zA-Z0-9_]*}}
+    // CHECK-SAME: [[ARG3:%[a-zA-Z0-9_]*]]
+    // CHECK-SAME: [[ARG4:%[a-zA-Z0-9_]*]]
+    // CHECK-SAME: [[ARG5:%[a-zA-Z0-9_]*]]
+    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):       // no predecessors
+      // CHECK: [[T1:%[a-zA-Z0-9_]*]] = addf [[ARG3]], [[ARG4]]
+      // CHECK-NOT: linalg.yield
+      // CHECK: mulf [[T1]], [[ARG5]]
+      // CHECK: linalg.yield
+      %5 = mulf %arg5, %arg6 : f32
+      linalg.yield %5 : f32
+    } -> tensor<?x?xf32>
+  return %4 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d1, d0)>
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d1, d0)>
index 9dc20a4..ac56add 100644 (file)
@@ -96,7 +96,7 @@ func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
 // -----
 
 func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+1 {{expected shaped value rank (1) to match the result rank of indexing_map #0 (2)}}
+  // expected-error @+1 {{expected operand rank (1) to match the result rank of indexing_map #0 (2)}}
   linalg.generic {
     indexing_maps =  [ affine_map<() -> (0, 0)> ],
     iterator_types = []}
@@ -108,6 +108,21 @@ func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>)
 
 // -----
 
+func @generic_scalar_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
+  %cst = constant 0.0 : f32
+  // expected-error @+1 {{expected operand rank (0) to match the result rank of indexing_map #0 (1)}}
+  linalg.generic {
+    indexing_maps =  [ affine_map<() -> (0)>, affine_map<() -> (0, 0)> ],
+    iterator_types = []}
+      ins(%cst : f32)
+      outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
+    ^bb(%0 : f32, %1 : f32):
+      linalg.yield %0: f32
+  }
+}
+
+// -----
+
 func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
   // expected-error @+7 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
   linalg.generic {
@@ -174,7 +189,7 @@ func @generic_empty_region(%arg0: memref<f32>) {
 // -----
 
 func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
-  // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}}
+  // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
   linalg.generic {
       indexing_maps =  [ affine_map<() -> ()>, affine_map<() -> ()> ],
       iterator_types = []}
@@ -186,8 +201,8 @@ func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
 
 // -----
 
-func @generic_block_arg_type(%arg0: memref<f32>) {
-  // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element type of corresponding shaped operand ('f32')}}
+func @generic_shaped_operand_block_arg_type(%arg0: memref<f32>) {
+  // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element or self type of the corresponding operand ('f32')}}
   linalg.generic {
     indexing_maps =  [ affine_map<() -> ()> ],
     iterator_types = []}
@@ -199,8 +214,21 @@ func @generic_block_arg_type(%arg0: memref<f32>) {
 
 // -----
 
+func @generic_scalar_operand_block_arg_type(%arg0: f32) {
+  // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element or self type of the corresponding operand ('f32')}}
+  linalg.generic {
+    indexing_maps =  [ affine_map<() -> ()> ],
+    iterator_types = []}
+      outs(%arg0 : f32) {
+    ^bb(%i: i1):
+    linalg.yield %i : i1
+  }
+}
+
+// -----
+
 func @indexed_generic_block_arg_count(%arg0: memref<?xf32>) {
-  // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}}
+  // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
   linalg.indexed_generic {
     indexing_maps =  [ affine_map<(i) -> (i)> ],
     iterator_types = ["parallel"]}
@@ -226,7 +254,7 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref<?xf32>) {
 // -----
 
 func @indexed_generic_block_arg_type(%arg0: memref<?xf32>) {
-  // expected-error @+1 {{expected type of bb argument #1 ('i1') to match element type of corresponding shaped operand ('f32')}}
+  // expected-error @+1 {{expected type of bb argument #1 ('i1') to match element or self type of the corresponding operand ('f32')}}
   linalg.indexed_generic {
     indexing_maps =  [ affine_map<(d0) -> (d0)> ],
     iterator_types = ["parallel"]}
@@ -239,7 +267,7 @@ func @indexed_generic_block_arg_type(%arg0: memref<?xf32>) {
 // -----
 
 func @indexed_generic_arg_count(%arg0: memref<f32>) {
-  // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}}
+  // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
   linalg.indexed_generic {
     indexing_maps =  [ affine_map<()[] -> ()> ],
     iterator_types = []}
@@ -401,7 +429,7 @@ func @reshape(%arg0: memref<?x?x?xf32>) {
 func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
                             %arg1: memref<2x3xf32>,
                             %arg2: memref<?x?x?xf32>) {
-  // expected-error @+1 {{expected shaped value rank (2) to match the result rank of indexing_map #1 (3)}}
+  // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}}
   linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
     memref<?x?x?xf32>, memref<2x3xf32>, memref<?x?x?xf32>
   return
@@ -410,7 +438,7 @@ func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
 // -----
 
 func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) {
-  // expected-error @+1 {{expected shaped value rank (2) to match the result rank of indexing_map #1 (3)}}
+  // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}}
   linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?xf32>)
                      outs(%c3 : memref<?x?x?xf32>)
   return
@@ -714,7 +742,7 @@ func @illegal_fill_tensor_with_memref_return
 // -----
 
 func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
-  // expected-error @+1 {{inferred shaped operand #1 has shape's dimension #0 to be 4, but found 3}}
+  // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 3}}
   linalg.matmul ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)
                       outs(%arg2 :memref<2x4xf32>)
   return
@@ -723,7 +751,7 @@ func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg
 // -----
 
 func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
-  // expected-error @+1 {{inferred shaped operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
+  // expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
   linalg.conv_2d_input_nhwc_filter_hwcf
     { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
     ins(%input, %filter : memref<1x3x4x2xf32>, memref<3x2x2x1xf32>)
index c469160..7bc8dc5 100644 (file)
@@ -975,6 +975,30 @@ func @generic_op_zero_rank(%arg0: memref<f32>, %arg1: memref<3x4xf32>)
 //       CHECKPARALLEL:   %[[a:.*]] = memref.load %[[ARG0]][]
 //       CHECKPARALLEL:   store %[[a]], %[[ARG1]][%[[i]], %[[j]]]
 
+func @generic_op_scalar(%arg0: f32, %arg1: memref<3x4xf32>)
+{
+  linalg.generic #trait_broadcast
+      ins(%arg0 : f32)
+     outs(%arg1 : memref<3x4xf32>) {
+    ^bb(%a: f32, %b: f32) :
+      linalg.yield %a : f32
+  }
+  return
+}
+
+// CHECK-LABEL: @generic_op_scalar
+//  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
+//  CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xf32>
+//       CHECK: scf.for %[[i:.*]] = {{.*}}
+//       CHECK:   scf.for %[[j:.*]] = {{.*}}
+//       CHECK:     store %[[ARG0]], %[[ARG1]][%[[i]], %[[j]]]
+
+// CHECKPARALLEL-LABEL: @generic_op_scalar
+//  CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
+//  CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xf32>
+//       CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]])
+//       CHECKPARALLEL:   store %[[ARG0]], %[[ARG1]][%[[i]], %[[j]]]
+
 func @generic_index_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
 {
   linalg.generic #trait_broadcast
index 4a3d54c..1cf4e6b 100644 (file)
@@ -2,37 +2,42 @@
 // RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file | FileCheck %s --check-prefix=FOLDUNITDIM
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+#map2 = affine_map<(d0, d1, d2) -> ()>
 func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
-                                         %arg1 : tensor<?x?x?xf32>) ->
+                                         %arg1 : tensor<?x?x?xf32>,
+                                         %arg2 : f32) ->
                                          tensor<?x?x?xf32>
 {
   %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]] :
     tensor<?x?x4x?xf32> into tensor<?x?x?xf32>
   %1 = linalg.generic {
-     indexing_maps = [#map0, #map1, #map1],
+     indexing_maps = [#map0, #map1, #map2, #map1],
      iterator_types = ["parallel", "parallel", "parallel"]}
-       ins(%0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+       ins(%0, %arg1, %arg2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32)
        outs(%0 : tensor<?x?x?xf32>) {
-    ^bb0(%arg3: f32, %arg4: f32, %s: f32):       // no predecessors
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32):       // no predecessors
       %1 = mulf %arg3, %arg4 : f32
-      linalg.yield %1 : f32
+      %2 = addf %1, %arg5 : f32
+      linalg.yield %2 : f32
   } -> tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
 }
 
 //  CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
 //  CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)>
+//  CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3) -> ()>
 //      CHECK: func @generic_op_reshape_producer_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: f32
 //      CHECK:   %[[T0:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
 // CHECK-SAME:     [0], [1, 2], [3]
 //      CHECK:   %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]]
 // CHECK-SAME:     [0], [1], [2, 3]
 //      CHECK:   %[[T3:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]]
+// CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>)
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]], %[[ARG2]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>, f32)
 // CHECK-SAME:     outs(%{{.+}} : tensor<?x?x?x4xf32>)
 //      CHECK:   %[[T4:.+]] = linalg.tensor_collapse_shape %[[T3]]
 // CHECK-SAME:     [0], [1], [2, 3]
@@ -42,18 +47,21 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
 // -----
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> ()>
 func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
-                                         %arg1 : tensor<?x?xf32>) ->
+                                         %arg1 : tensor<?x?xf32>,
+                                         %arg2 : f32) ->
                                          tensor<?x4x?x5xf32>
 {
   %0 = linalg.generic {
-     indexing_maps = [#map0, #map0, #map0],
+     indexing_maps = [#map0, #map0, #map1, #map0],
      iterator_types = ["parallel", "parallel"]}
-       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+       ins(%arg0, %arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>, f32)
        outs(%arg0 : tensor<?x?xf32>) {
-    ^bb0(%arg3: f32, %arg4: f32, %s: f32):       // no predecessors
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32):       // no predecessors
       %1 = mulf %arg3, %arg4 : f32
-      linalg.yield %1 : f32
+      %2 = addf %1, %arg5 : f32
+      linalg.yield %2 : f32
   } -> tensor<?x?xf32>
   %1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] :
     tensor<?x?xf32> into tensor<?x4x?x5xf32>
@@ -61,9 +69,12 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 }
 
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> ()>
+
 //      CHECK: func @generic_op_reshape_consumer_fusion
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>)
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: f32
 //      CHECK:   %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
 // CHECK-SAME:     [0], [1, 2, 3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x?x5xf32>
@@ -71,9 +82,9 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 // CHECK-SAME:     [0], [1, 2, 3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x?x5xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>)
+// CHECK-SAME:     ins(%[[T0]], %[[T1]], %[[ARG2]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>, f32)
 // CHECK-SAME:     outs(%{{.+}} : tensor<?x4x?x5xf32>)
 //      CHECK:   return %[[T3]] : tensor<?x4x?x5xf32>
 
index 9ca383a..4675f82 100644 (file)
@@ -316,6 +316,7 @@ func @pooling_sum(%arg0: memref<?x?x?xf32>,
 
 #accesses_0 = [
   affine_map<(i, j, k) -> (j, i)>,
+  affine_map<(i, j, k) -> ()>,
   affine_map<(i, j, k) -> (i, k, i + j)>
 ]
 
@@ -327,34 +328,34 @@ func @pooling_sum(%arg0: memref<?x?x?xf32>,
 
 func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
               %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+  %cst = constant 0.0 : f32
   linalg.generic #trait_0
-       ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
+       ins(%arg0, %cst : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, f32)
       outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
       attrs = {foo = 1} {
-    ^bb(%0: vector<3x4xi4>, %1: f32) :
-      %f0 = constant 0.0 : f32
-      linalg.yield %f0 : f32
+    ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) :
+      linalg.yield %1 : f32
   }
   return
 }
 // CHECK-LABEL: func @generic
 //       CHECK:   linalg.generic {
-//  CHECK-SAME:     indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
+//  CHECK-SAME:     indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
 //  CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel"],
 //  CHECK-SAME:     library_call = "some_external_function_name_1"}
-//  CHECK-SAME:      ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
+//  CHECK-SAME:      ins({{.*}}, {{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>, f32)
 //  CHECK-SAME:     outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
 //  CHECK-SAME:     {foo = 1 : i64}
 
 func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
                                 %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+  %cst = constant 0.0 : f32
   linalg.generic #trait_0
-       ins(%arg0 : tensor<?x?xvector<3x4xi4>>)
+       ins(%arg0, %cst : tensor<?x?xvector<3x4xi4>>, f32)
       outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
       attrs = {foo = 1} {
-    ^bb(%0: vector<3x4xi4>, %1: f32) :
-      %f0 = constant 0.0 : f32
-      linalg.yield %f0 : f32
+    ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) :
+      linalg.yield %1 : f32
   }
   return
 }
@@ -362,7 +363,7 @@ func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
 //       CHECK:   linalg.generic {
 //  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
 //  CHECK-SAME:     library_call = "some_external_function_name_1"}
-//  CHECK-SAME:     ins({{.*}} : tensor<?x?xvector<3x4xi4>>)
+//  CHECK-SAME:     ins({{.*}}, {{.*}} : tensor<?x?xvector<3x4xi4>>, f32)
 //  CHECK-SAME:     outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
 //  CHECK-SAME:     {foo = 1 : i64}
 
index 63dc9fb..a7322a7 100644 (file)
@@ -8,6 +8,7 @@
 func @matmul_tensors(
   %arg0: tensor<?x?xi8>, %arg1: tensor<?x?xi8>, %arg2: tensor<?x?xi32>)
     -> tensor<?x?xi32> {
+//      CHECK: %[[C0:.*]] = constant 0 : index
 //      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xi32>) {
 //      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xi32>) {
 //      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xi32>) {
@@ -19,11 +20,11 @@ func @matmul_tensors(
 //  CHECK-NOT:       linalg.matmul {{.*}} tensor<?x?xi8>
 
 // Padding injects static information.
-//      CHECK:       %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+//      CHECK:       %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK:         : tensor<?x?xi8> to tensor<2x4xi8>
-//      CHECK:       %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+//      CHECK:       %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK:         : tensor<?x?xi8> to tensor<4x3xi8>
-//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK:         : tensor<?x?xi32> to tensor<2x3xi32>
 //      CHECK:       %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>)
 // CHECK-SAME:                                           outs(%[[pC]] : tensor<2x3xi32>)  -> tensor<2x3xi32>
@@ -41,6 +42,41 @@ func @matmul_tensors(
   return %0 : tensor<?x?xi32>
 }
 
+// CHECK-LABEL: func @generic_scalar_and_tensor(
+// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<?x?x?xf32>
+// CHECK-SAME:    %[[VAL:[0-9a-z]+]]: f32) -> tensor<?x?x?xf32> {
+func @generic_scalar_and_tensor(
+  %arg0: tensor<?x?x?xf32>, %arg1: f32)
+    -> tensor<?x?x?xf32> {
+//      CHECK: %[[C0:.*]] = constant 0 : index
+//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?x?xf32>) {
+//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?x?xf32>) {
+//      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?x?xf32>) {
+//      CHECK:       %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+
+// Padding injects static information.
+//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}]
+//      CHECK:        : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+//      CHECK:       %[[pD:.*]] = linalg.generic
+// CHECK-SAME:         ins(%[[VAL]] : f32) outs(%[[pC]] : tensor<2x3x4xf32>)
+//      CHECK:       %[[sTD:.*]] = subtensor %[[pD]][0, 0, 0] [%{{.*}}, %{{.*}}, %{{.*}}] [1, 1, 1] : tensor<2x3x4xf32> to tensor<?x?x?xf32>
+//      CHECK:       %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+//      CHECK:       scf.yield %[[TD]] : tensor<?x?x?xf32>
+//      CHECK:     scf.yield %[[TD2]] : tensor<?x?x?xf32>
+//      CHECK:   scf.yield %[[TD1]] : tensor<?x?x?xf32>
+      %0 = linalg.generic {
+      indexing_maps =  [ affine_map<(d0, d1, d2) -> ()>,
+                        affine_map<(d0, d1, d2) -> (d0, d1, d2)> ],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      {__internal_linalg_transform__ = "tile-and-pad"}
+     ins(%arg1 : f32)
+    outs(%arg0: tensor<?x?x?xf32>) {
+      ^bb(%0: f32, %1: f32) :
+        linalg.yield %0 : f32
+    } -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
 // CHECK-1DIM-TILE: func @matmul_tensors(
 // CHECK-1DIM-TILE:    %[[TA:[0-9a-z]+]]: tensor<?x?xi8>
 // CHECK-1DIM-TILE:    %[[TB:[0-9a-z]+]]: tensor<?x?xi8>
@@ -65,6 +101,7 @@ func @matmul_partially_padded_tensors(
 // CHECK-1DIM-TILE-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x8xi8>
 // CHECK-1DIM-TILE-SAME:    %[[TB:[0-9a-z]+]]: tensor<8x?xi8>
 // CHECK-1DIM-TILE-SAME:    %[[TC:[0-9a-z]+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
+//      CHECK-1DIM-TILE:        %[[C0:.*]] = constant 0 : index
 //      CHECK-1DIM-TILE:        %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xi32>) {
 //      CHECK-1DIM-TILE:            %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xi32>) {
 //      CHECK-1DIM-TILE:                %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor<?x8xi8> to tensor<?x8xi8>
@@ -72,11 +109,11 @@ func @matmul_partially_padded_tensors(
 //      CHECK-1DIM-TILE:                %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8>
 //      CHECK-1DIM-TILE:                %[[sTBc:.*]] = tensor.cast %[[sTB]] : tensor<8x?xi8> to tensor<?x?xi8>
 //      CHECK-1DIM-TILE:                %[[sTC:.*]] = subtensor %[[TC1]][{{.*}}] : tensor<?x?xi32> to tensor<?x?xi32>
-//      CHECK-1DIM-TILE:                %[[pA:.*]] = linalg.pad_tensor %[[sTAc]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+//      CHECK-1DIM-TILE:                %[[pA:.*]] = linalg.pad_tensor %[[sTAc]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK-1DIM-TILE:                   : tensor<?x?xi8> to tensor<2x8xi8>
-//      CHECK-1DIM-TILE:                %[[pB:.*]] = linalg.pad_tensor %[[sTBc]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+//      CHECK-1DIM-TILE:                %[[pB:.*]] = linalg.pad_tensor %[[sTBc]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK-1DIM-TILE:                   : tensor<?x?xi8> to tensor<8x3xi8>
-//      CHECK-1DIM-TILE:                %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+//      CHECK-1DIM-TILE:                %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK-1DIM-TILE:                   : tensor<?x?xi32> to tensor<2x3xi32>
 //      CHECK-1DIM-TILE:               %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>)
 //      CHECK-1DIM-TILE:                                           outs(%[[pC]] : tensor<2x3xi32>)  -> tensor<2x3xi32>
index 5a7110f..ff594a0 100644 (file)
@@ -136,6 +136,23 @@ func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
 
 // -----
 
+// CHECK-LABEL: func @test_vectorize_scalar_input
+func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
+  //       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+  //       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+  linalg.generic {
+    indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
+    iterator_types = ["parallel", "parallel"]}
+   ins(%arg0 : f32)
+  outs(%A: memref<8x16xf32>) {
+    ^bb(%0: f32, %1: f32) :
+      linalg.yield %0 : f32
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @test_vectorize_fill
 func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
   //       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
index 26e4f6a..1e68c6c 100644 (file)
@@ -162,8 +162,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
           changed = true;
         }
-      } else {
-        assert(opOperand->get().getType().isa<RankedTensorType>());
+      } else if (opOperand->get().getType().isa<RankedTensorType>()) {
         // Tile and Fuse tensor input.
         if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
           continue;
index 16eb79f..9159dcf 100644 (file)
@@ -533,7 +533,7 @@ static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
 // For now, just assume it is the zero of type.
 // In the future, it should be the zero of type + op.
 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
-  auto t = getElementTypeOrSelf(op.get().getType());
+  auto t = getElementTypeOrSelf(op.get());
   return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
 }
 
@@ -544,7 +544,8 @@ static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) {
       linalg::LinalgTilingOptions()
           .setTileSizes(tileSizes)
           .setPaddingValueComputationFunction(getNeutralOfLinalgOp);
-  tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>>(
+  tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>,
+                    linalg::LinalgTilingPattern<linalg::GenericOp>>(
       context, linalgTilingOptions,
       linalg::LinalgTransformationFilter(
           Identifier::get("tile-and-pad", context)));