[mlir][linalg] All StructuredOp parameters are inputs or outputs.
authorTobias Gysi <gysit@google.com>
Tue, 29 Jun 2021 06:54:39 +0000 (06:54 +0000)
committerTobias Gysi <gysit@google.com>
Tue, 29 Jun 2021 07:45:50 +0000 (07:45 +0000)
Adapt the StructuredOp verifier to ensure all operands are either in the input or the output group. The change is possible after adding support for scalar input operands (https://reviews.llvm.org/D104220).

Differential Revision: https://reviews.llvm.org/D104783

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

index ad91e23..e1f096d 100644 (file)
@@ -253,7 +253,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return getNumInputs() + getNumOutputs();
+        return this->getOperation()->getNumOperands();
       }]
     >,
     //===------------------------------------------------------------------===//
@@ -346,8 +346,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         result.reserve(numOutputs);
         llvm::transform(
           this->getOperation()->getOpOperands()
-            .drop_front(getNumInputs())
-            .take_front(numOutputs),
+            .take_back(numOutputs),
           std::back_inserter(result),
           [](OpOperand &opOperand) { return &opOperand; });
         return result;
@@ -458,8 +457,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         OpOperandVector result;
         result.reserve(numInputsAndOutputs);
         llvm::transform(
-          this->getOperation()->getOpOperands()
-            .take_front(numInputsAndOutputs),
+          this->getOperation()->getOpOperands(),
           std::back_inserter(result),
           [](OpOperand &opOperand) { return &opOperand; });
         return result;
@@ -928,22 +926,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     /// `createFlatListOfOperandStaticDims`.
     SmallVector<int64_t, 4> computeStaticLoopSizes();
 
-    /// Returns all the operands past the inputs, output_buffers and
-    /// init_tensors operands. Asserts that these operands are value types to
-    /// allow transformations like tiling to just use the values when cloning
-    /// `linalgOp`.
-    Operation::operand_range getAssumedNonShapedOperands() {
-      Operation::operand_range res{
-        getOperation()->getOperands().begin() + getNumInputsAndOutputs(),
-        getOperation()->getOperands().end()};
-      for (Type t : TypeRange{res}) {
-        (void)t;
-        assert((t.isSignlessIntOrIndexOrFloat() || t.template isa<VectorType>())
-               &&"expected scalar or vector type");
-      }
-      return res;
-    }
-
     /// Returns the value that expresses the shape of the output in terms of
     /// shape of the input operands where possible
     LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
index 45a9f8e..e83c624 100644 (file)
@@ -318,14 +318,15 @@ LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim(
 
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
-  // Expect at least one input/output operand.
+  // Expect at least one output operand.
   // This means an op that constructs a tensor out of indices cannot be a
   // LinalgOp at the moment. For now this will have to be a special op until we
   // have output shape operands that are not tensors.
-  int64_t numInputsAndOutputs = linalgOp.getNumInputsAndOutputs();
-  if (numInputsAndOutputs == 0)
-    return op->emitOpError("expected at least one input/output operand");
-  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, numInputsAndOutputs)))
+  int64_t numInputs = linalgOp.getNumInputs();
+  int64_t numOutputs = linalgOp.getNumOutputs();
+  if (numOutputs == 0)
+    return op->emitOpError("expected at least one output operand");
+  if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
     return failure();
   // Should have at least one output tensor per result tensor.
   // Can also have outbut buffers that do not correspond to results.
index 11cb3e1..f4524f1 100644 (file)
@@ -3038,8 +3038,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
                                  : opOperand->get());
       newResultTypes.push_back(newOperands.back().getType());
     }
-    auto extraOperands = op.getAssumedNonShapedOperands();
-    newOperands.append(extraOperands.begin(), extraOperands.end());
     // Clone op.
     Operation *newOp =
         op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
@@ -3109,7 +3107,6 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
         newOperands.push_back(opOperand->get());
     SmallVector<Value> outputOperands = op.getOutputOperands();
     llvm::append_range(newOperands, outputOperands);
-    llvm::append_range(newOperands, op.getAssumedNonShapedOperands());
 
     // Repair the indexing maps by filtering out the ones that have been
     // eliminated.
index 414aa63..fba709a 100644 (file)
@@ -119,8 +119,6 @@ static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
   SmallVector<Value, 8> newOperands = inputs;
   newOperands.append(outputs.begin(), outputs.end());
-  auto otherOperands = linalgOp.getAssumedNonShapedOperands();
-  newOperands.append(otherOperands.begin(), otherOperands.end());
   linalgOp.clone(rewriter, linalgOp.getLoc(),
                  /*resultTypes=*/ArrayRef<Type>{}, newOperands);
   // Replace the results of the old op with the new output buffers.
index c951e70..287d2d4 100644 (file)
@@ -1241,8 +1241,6 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
   // Clone the newly bufferized op.
   SmallVector<Value> newOperands = newInputBuffers;
   newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
-  auto otherOperands = op.getAssumedNonShapedOperands();
-  newOperands.append(otherOperands.begin(), otherOperands.end());
   op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands);
 
   // Replace the results of the old op with the new output buffers.
index 0ff0594..d596495 100644 (file)
@@ -205,10 +205,6 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
                                       getTiledOperands(b, producer), ivs,
                                       tileSizes, sizeBounds));
 
-  // Append the other operands.
-  auto operands = producer.getAssumedNonShapedOperands();
-  clonedShapes.append(operands.begin(), operands.end());
-
   // Iterate over the results in order.
   // Extract the subtensor type from the linearized range.
   // Since we do not enforce any canonicalizations on the fly, this is always
index a9366d1..b6420f7 100644 (file)
@@ -242,8 +242,6 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
         applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
         b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds);
-    auto nonShapedOperands = op.getAssumedNonShapedOperands();
-    tiledOperands.append(nonShapedOperands.begin(), nonShapedOperands.end());
 
     // TODO: use an interface/adaptor to avoid leaking position in
     // `tiledOperands`.
index 79335c3..f1c8a6f 100644 (file)
@@ -190,8 +190,6 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
   // Clone `opToPad` to operate on the statically padded shapes.
   auto resultTensorTypes =
       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
-  ValueRange otherOperands = opToPad.getAssumedNonShapedOperands();
-  newOperands.append(otherOperands.begin(), otherOperands.end());
   linalg::LinalgOp paddedOp =
       opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);