[mlir][Tensor] Avoid dropping attributes for `tensor.pad` operations during canonical...
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 20 Mar 2023 20:56:41 +0000 (20:56 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Mon, 20 Mar 2023 21:03:46 +0000 (21:03 +0000)
Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D146440

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp

index cc8bbd5..3c3fa70 100644 (file)
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "llvm/ADT/StringSet.h"
 #include <optional>
 
@@ -461,18 +462,10 @@ struct GenerateLoopNest {
 /// Returns an attribute list that excludes pre-defined attributes.
 template <typename OpTy>
 SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
-  llvm::StringSet<> elidedAttrs;
-  elidedAttrs.insert(op.getAttributeNames().begin(),
-                     op.getAttributeNames().end());
+  auto elidedAttrs = llvm::to_vector(op.getAttributeNames());
   if (isa<linalg::LinalgOp>(op.getOperation()))
-    elidedAttrs.insert(LinalgDialect::kMemoizedIndexingMapsAttrName);
-  SmallVector<NamedAttribute> attrs;
-  for (auto attr : op->getAttrs()) {
-    if (elidedAttrs.count(attr.getName()))
-      continue;
-    attrs.push_back(attr);
-  }
-  return attrs;
+    elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
+  return getPrunedAttributeList(op, elidedAttrs);
 }
 
 } // namespace linalg
index 09b7775..66d6dcc 100644 (file)
@@ -1295,13 +1295,13 @@ def Tensor_PadOp : Tensor_Op<"pad", [
 
   let builders = [
     // Build a PadOp with mixed static and dynamic entries.
-    OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$staticLow,
-      "ArrayRef<int64_t>":$staticHigh, "ValueRange":$low, "ValueRange":$high,
-      CArg<"bool", "false">:$nofold,
+    OpBuilder<(ins "Type":$resultType, "Value":$source,
+      "ArrayRef<int64_t>":$staticLow, "ArrayRef<int64_t>":$staticHigh,
+      "ValueRange":$low, "ValueRange":$high, CArg<"bool", "false">:$nofold,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a PadOp with all dynamic entries.
-    OpBuilder<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
-      CArg<"bool", "false">:$nofold,
+    OpBuilder<(ins "Type":$resultType, "Value":$source, "ValueRange":$low,
+      "ValueRange":$high, CArg<"bool", "false">:$nofold,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a PadOp with mixed static and dynamic entries and custom
     // result type. If the type passed is nullptr, it is inferred.
index 1297e87..c4f9fa8 100644 (file)
@@ -123,6 +123,11 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
                                TypeRange newResultTypes,
                                ValueRange newOperands);
 
+// Get the list of attributes associated with the op, ignoring
+// those with the provided name.
+SmallVector<NamedAttribute>
+getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
index f2da108..9d26e51 100644 (file)
@@ -2518,26 +2518,27 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
   return RankedTensorType::get(inferredShape, sourceType.getElementType());
 }
 
-void PadOp::build(OpBuilder &b, OperationState &result, Value source,
-                  ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
-                  ValueRange low, ValueRange high, bool nofold,
-                  ArrayRef<NamedAttribute> attrs) {
+void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                  Value source, ArrayRef<int64_t> staticLow,
+                  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
+                  bool nofold, ArrayRef<NamedAttribute> attrs) {
   auto sourceType = source.getType().cast<RankedTensorType>();
-  auto resultType = inferResultType(sourceType, staticLow, staticHigh);
+  if (!resultType)
+    resultType = inferResultType(sourceType, staticLow, staticHigh);
   build(b, result, resultType, source, low, high,
         b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
         nofold ? b.getUnitAttr() : UnitAttr());
   result.addAttributes(attrs);
 }
 
-void PadOp::build(OpBuilder &b, OperationState &result, Value source,
-                  ValueRange low, ValueRange high, bool nofold,
+void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                  Value source, ValueRange low, ValueRange high, bool nofold,
                   ArrayRef<NamedAttribute> attrs) {
   auto sourceType = source.getType().cast<RankedTensorType>();
   unsigned rank = sourceType.getRank();
   SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
-  build(b, result, source, staticVector, staticVector, low, high, nofold,
-        attrs);
+  build(b, result, resultType, source, staticVector, staticVector, low, high,
+        nofold, attrs);
 }
 
 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -2635,9 +2636,9 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
     } else {
       auto newOp = rewriter.create<PadOp>(
           padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
-          padTensorOp.getLow(), padTensorOp.getHigh(),
           padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
-          padTensorOp.getNofold());
+          padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
+          getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
       IRMapping mapper;
       padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
 
@@ -2667,9 +2668,10 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
 
     auto replacementOp = rewriter.create<PadOp>(
         padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
-        padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
-        padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
-        padTensorOp.getNofold());
+        padTensorOp.getSource(), padTensorOp.getStaticLow(),
+        padTensorOp.getStaticHigh(), padTensorOp.getLow(),
+        padTensorOp.getHigh(), padTensorOp.getNofold(),
+        getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
     replacementOp.getRegion().takeBody(padTensorOp.getRegion());
 
     rewriter.replaceOp(padTensorOp, replacementOp.getResult());
@@ -2827,7 +2829,8 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
         innerSliceOp.getMixedStrides());
     auto newPadOp = rewriter.create<PadOp>(
         padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
-        padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
+        padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
+        getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
     rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
                                 newPadOp.getRegion().begin());
     rewriter.replaceOp(padOp, newPadOp.getResult());
@@ -2916,8 +2919,9 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
     auto newResultType = RankedTensorType::get(
         newOutDims, padTensorOp.getType().getElementType());
     auto newOp = rewriter.create<PadOp>(
-        padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(),
-        padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold());
+        padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
+        padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
+        getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
 
     IRMapping mapper;
     padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
index b22f42c..49b49ef 100644 (file)
@@ -11,6 +11,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IRMapping.h"
+#include "llvm/ADT/StringSet.h"
 
 #include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
 
@@ -114,3 +115,16 @@ Operation *mlir::cloneWithoutRegions(OpBuilder &b, Operation *op,
     state.addRegion();
   return b.create(state);
 }
+
+SmallVector<NamedAttribute>
+mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
+  llvm::StringSet elidedAttrsSet;
+  elidedAttrsSet.insert(elidedAttrs.begin(), elidedAttrs.end());
+  SmallVector<NamedAttribute> attrs;
+  for (auto attr : op->getAttrs()) {
+    if (elidedAttrsSet.count(attr.getName()))
+      continue;
+    attrs.push_back(attr);
+  }
+  return attrs;
+}