[mlir] Add MemRefTypeBuilder and refactor some MemRefType::get().
authorTim Shen <timshen@google.com>
Wed, 22 Jan 2020 21:46:11 +0000 (13:46 -0800)
committerTim Shen <timshen@google.com>
Fri, 31 Jan 2020 07:30:46 +0000 (23:30 -0800)
The refactored MemRefType::get() calls all intend to clone from another
memref type, with some modifications. In fact, some calls dropped memory space
during the cloning. Migrate them to the cloning API so that nothing gets
dropped if they are not explicitly listed.

It's close to NFC but not quite, as it helps with propagating memory spaces in
some places.

Differential Revision: https://reviews.llvm.org/D73296

mlir/include/mlir/IR/StandardTypes.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/lib/Transforms/PipelineDataTransfer.cpp
mlir/lib/Transforms/Utils/Utils.cpp

index 32c9a2d..e789ab7 100644 (file)
@@ -390,6 +390,52 @@ public:
 class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
                                          detail::MemRefTypeStorage> {
 public:
+  /// This is a builder type that keeps local references to arguments. Arguments
+  /// that are passed into the builder must out-live the builder.
+  class Builder {
+  public:
+    // Build from another MemRefType.
+    explicit Builder(MemRefType other)
+        : shape(other.getShape()), elementType(other.getElementType()),
+          affineMaps(other.getAffineMaps()),
+          memorySpace(other.getMemorySpace()) {}
+
+    // Build from scratch.
+    Builder(ArrayRef<int64_t> shape, Type elementType)
+        : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {
+    }
+
+    Builder &setShape(ArrayRef<int64_t> newShape) {
+      shape = newShape;
+      return *this;
+    }
+
+    Builder &setElementType(Type newElementType) {
+      elementType = newElementType;
+      return *this;
+    }
+
+    Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
+      affineMaps = newAffineMaps;
+      return *this;
+    }
+
+    Builder &setMemorySpace(unsigned newMemorySpace) {
+      memorySpace = newMemorySpace;
+      return *this;
+    }
+
+    operator MemRefType() {
+      return MemRefType::get(shape, elementType, affineMaps, memorySpace);
+    }
+
+  private:
+    ArrayRef<int64_t> shape;
+    Type elementType;
+    ArrayRef<AffineMap> affineMaps;
+    unsigned memorySpace;
+  };
+
   using Base::Base;
 
   /// Get or create a new MemRefType based on shape, element type, affine
index a2c4ec6..ab47dcf 100644 (file)
@@ -41,8 +41,7 @@ public:
     auto memref = type.dyn_cast<MemRefType>();
     if (memref &&
         memref.getMemorySpace() == gpu::GPUDialect::getPrivateAddressSpace()) {
-      type = MemRefType::get(memref.getShape(), memref.getElementType(),
-                             memref.getAffineMaps());
+      type = MemRefType::Builder(memref).setMemorySpace(0);
     }
 
     return LLVMTypeConverter::convertType(type);
index 414acce..0fa3590 100644 (file)
@@ -168,8 +168,8 @@ inline Type castElementType(Type t, Type newElementType) {
     case StandardTypes::Kind::UnrankedTensor:
       return UnrankedTensorType::get(newElementType);
     case StandardTypes::Kind::MemRef:
-      return MemRefType::get(st.getShape(), newElementType,
-                             st.cast<MemRefType>().getAffineMaps());
+      return MemRefType::Builder(st.cast<MemRefType>())
+          .setElementType(newElementType);
     }
   }
   assert(t.isIntOrFloat());
index 1cb62ea..fb18fbf 100644 (file)
@@ -480,7 +480,7 @@ computeReshapeCollapsedType(MemRefType type,
 
   // Early-exit: if `type` is contiguous, the result must be contiguous.
   if (canonicalizeStridedLayout(type).getAffineMaps().empty())
-    return MemRefType::get(newSizes, type.getElementType(), {});
+    return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
 
   // Convert back to int64_t because we don't have enough information to create
   // new strided layouts from AffineExpr only. This corresponds to a case where
@@ -499,7 +499,7 @@ computeReshapeCollapsedType(MemRefType type,
   auto layout =
       makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
   return canonicalizeStridedLayout(
-      MemRefType::get(newSizes, type.getElementType(), {layout}));
+      MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
 }
 
 /// Helper functions assert Attribute of the proper type in attr and returns the
@@ -613,11 +613,10 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
   unsigned rank = memRefType.getRank();
   // TODO(ntv): propagate static size and stride information when available.
   SmallVector<int64_t, 4> sizes(rank, -1); // -1 encodes dynamic size.
-  Type elementType = memRefType.getElementType();
-  result.addTypes({MemRefType::get(
-      sizes, elementType,
-      {makeStridedLinearLayoutMap(strides, offset, b->getContext())},
-      memRefType.getMemorySpace())});
+  result.addTypes({MemRefType::Builder(memRefType)
+                       .setShape(sizes)
+                       .setAffineMaps(makeStridedLinearLayoutMap(
+                           strides, offset, b->getContext()))});
 }
 
 static void print(OpAsmPrinter &p, SliceOp op) {
@@ -698,8 +697,8 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result,
   auto map = makeStridedLinearLayoutMap(strides, offset, b->getContext());
   map = permutationMap ? map.compose(permutationMap) : map;
   // Compute result type.
-  auto resultType = MemRefType::get(sizes, memRefType.getElementType(), map,
-                                    memRefType.getMemorySpace());
+  MemRefType resultType =
+      MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
 
   build(b, result, resultType, view, attrs);
   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
index fded608..b1e0096 100644 (file)
@@ -350,9 +350,8 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
     }
 
     // Create new memref type (which will have fewer dynamic dimensions).
-    auto newMemRefType = MemRefType::get(
-        newShapeConstants, memrefType.getElementType(),
-        memrefType.getAffineMaps(), memrefType.getMemorySpace());
+    MemRefType newMemRefType =
+        MemRefType::Builder(memrefType).setShape(newShapeConstants);
     assert(static_cast<int64_t>(newOperands.size()) ==
            newMemRefType.getNumDynamicDims());
 
@@ -2453,9 +2452,9 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
                                      rewriter.getContext());
 
     // Create new memref type with constant folded dims and/or offset/strides.
-    auto newMemRefType =
-        MemRefType::get(newShapeConstants, memrefType.getElementType(), {map},
-                        memrefType.getMemorySpace());
+    MemRefType newMemRefType = MemRefType::Builder(memrefType)
+                                   .setShape(newShapeConstants)
+                                   .setAffineMaps({map});
     (void)dynamicOffsetOperandCount; // unused in opt mode
     assert(static_cast<int64_t>(newOperands.size()) ==
            dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims());
@@ -2509,7 +2508,6 @@ static Type inferSubViewResultType(MemRefType memRefType) {
   auto rank = memRefType.getRank();
   int64_t offset;
   SmallVector<int64_t, 4> strides;
-  Type elementType = memRefType.getElementType();
   auto res = getStridesAndOffset(memRefType, strides, offset);
   assert(succeeded(res) && "SubViewOp expected strided memref type");
   (void)res;
@@ -2524,8 +2522,9 @@ static Type inferSubViewResultType(MemRefType memRefType) {
   auto stridedLayout =
       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
   SmallVector<int64_t, 4> sizes(rank, ShapedType::kDynamicSize);
-  return MemRefType::get(sizes, elementType, stridedLayout,
-                         memRefType.getMemorySpace());
+  return MemRefType::Builder(memRefType)
+      .setShape(sizes)
+      .setAffineMaps(stridedLayout);
 }
 
 void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source,
@@ -2774,9 +2773,8 @@ public:
       assert(defOp);
       staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
     }
-    MemRefType newMemRefType = MemRefType::get(
-        staticShape, subViewType.getElementType(), subViewType.getAffineMaps(),
-        subViewType.getMemorySpace());
+    MemRefType newMemRefType =
+        MemRefType::Builder(subViewType).setShape(staticShape);
     auto newSubViewOp = rewriter.create<SubViewOp>(
         subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
         ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
@@ -2825,8 +2823,7 @@ public:
     AffineMap layoutMap = makeStridedLinearLayoutMap(
         staticStrides, resultOffset, rewriter.getContext());
     MemRefType newMemRefType =
-        MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
-                        layoutMap, subViewType.getMemorySpace());
+        MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
     auto newSubViewOp = rewriter.create<SubViewOp>(
         subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
         subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
@@ -2877,8 +2874,7 @@ public:
     AffineMap layoutMap = makeStridedLinearLayoutMap(
         resultStrides, staticOffset, rewriter.getContext());
     MemRefType newMemRefType =
-        MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
-                        layoutMap, subViewType.getMemorySpace());
+        MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
     auto newSubViewOp = rewriter.create<SubViewOp>(
         subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
         subViewOp.sizes(), subViewOp.strides(), newMemRefType);
index 6f9f6a8..bd12cff 100644 (file)
@@ -723,11 +723,9 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
   auto simplifiedLayoutExpr =
       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
   if (expr != simplifiedLayoutExpr)
-    return MemRefType::get(t.getShape(), t.getElementType(),
-                           {AffineMap::get(m.getNumDims(), m.getNumSymbols(),
-                                           {simplifiedLayoutExpr})});
-
-  return MemRefType::get(t.getShape(), t.getElementType(), {});
+    return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
+        m.getNumDims(), m.getNumSymbols(), {simplifiedLayoutExpr})});
+  return MemRefType::Builder(t).setAffineMaps({});
 }
 
 /// Return true if the layout for `t` is compatible with strided semantics.
index a9a41a6..58a2f17 100644 (file)
@@ -72,10 +72,9 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
     SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
     newShape[0] = 2;
     std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
-    auto newMemRefType =
-        MemRefType::get(newShape, oldMemRefType.getElementType(), {},
-                        oldMemRefType.getMemorySpace());
-    return newMemRefType;
+    return MemRefType::Builder(oldMemRefType)
+        .setShape(newShape)
+        .setAffineMaps({});
   };
 
   auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
index ca41ac2..5e82413 100644 (file)
@@ -445,8 +445,10 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
   auto oldMemRef = allocOp.getResult();
   SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands());
 
-  auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(),
-                                       b.getMultiDimIdentityMap(newRank));
+  MemRefType newMemRefType =
+      MemRefType::Builder(memrefType)
+          .setShape(newShape)
+          .setAffineMaps(b.getMultiDimIdentityMap(newRank));
   auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType);
 
   // Replace all uses of the old memref.