[mlir][Linalg] NFC - getAssumedNonShapedOperands now returns OperandRange
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 20 Jan 2021 19:02:08 +0000 (19:02 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 20 Jan 2021 19:23:26 +0000 (19:23 +0000)
Also adds a isInput interface method.

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td

index 8513360..b8009a8 100644 (file)
@@ -611,6 +611,22 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
+        Return true if `opOperand` is an input tensor.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isInputTensor",
+      /*args=*/(ins "OpOperand *":$opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        if (!opOperand->get().getType().template isa<RankedTensorType>())
+          return false;
+        if (opOperand->getOperandNumber() < $_op.getNumInputs())
+          return true;
+        return false;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
         Return true if `opOperand` is an init tensor. This is true when it is
         an output tensor operand whose value is used in the payload region.
       }],
@@ -1063,18 +1079,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     /// init_tensors operands. Asserts that these operands are value types to
     /// allow transformations like tiling to just use the values when cloning
     /// `linalgOp`.
-    SmallVector<Value, 4> getAssumedNonShapedOperands() {
-      unsigned numShapedOperands = getNumShapedOperands();
-      unsigned nExtraOperands =
-        getOperation()->getNumOperands() - numShapedOperands;
-      SmallVector<Value, 4> res;
-      res.reserve(nExtraOperands);
-      for (unsigned i = 0; i < nExtraOperands; ++i) {
-        res.push_back(getOperation()->getOperand(numShapedOperands + i));
-        assert((res.back().getType().isSignlessIntOrIndexOrFloat()
-                || res.back().getType().template isa<VectorType>()) &&
-               "expected scalar or vector type");
-      }
+    Operation::operand_range getAssumedNonShapedOperands() {
+      Operation::operand_range res{
+        getOperation()->getOperands().begin() + getNumShapedOperands(),
+        getOperation()->getOperands().end()};
+      for (Type t : TypeRange{res})
+        assert((t.isSignlessIntOrIndexOrFloat() || t.template isa<VectorType>())
+               &&"expected scalar or vector type");
       return res;
     }