Support printing SSA ids in affine.load/store which do not have special names.
authorAndy Davis <andydavis@google.com>
Tue, 25 Jun 2019 17:29:53 +0000 (10:29 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 25 Jun 2019 17:30:26 +0000 (10:30 -0700)
PiperOrigin-RevId: 254997746

mlir/include/mlir/IR/OpImplementation.h
mlir/lib/IR/AsmPrinter.cpp
mlir/test/AffineOps/load-store.mlir

index c344a2207dace8c53058e1abaf761ac1edc18953..96e682f57b225f3e3905ed92dfee964aaffb3b94 100644 (file)
@@ -87,6 +87,8 @@ public:
 
   /// Prints an affine map of SSA ids, where SSA id names are used in place
   /// of dims/symbols.
+  /// Operand values must come from single-result sources, and be valid
+  /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
   virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
                                       ArrayRef<Value *> operands) = 0;
 
@@ -380,6 +382,8 @@ public:
   }
 
   /// Parses an affine map attribute where dims and symbols are SSA operands.
+  /// Operand values must come from single-result sources, and be valid
+  /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
   virtual ParseResult
   parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
                          StringRef attrName,
index efa1de3ef86c2848d7773dade5cbafdc59c54754..a519d75aa2d6256b17e7bf8c3757061e653dd56b 100644 (file)
@@ -347,8 +347,9 @@ public:
   void printLocation(LocationAttr loc);
 
   void printAffineMap(AffineMap map);
-  void printAffineExpr(AffineExpr expr, ArrayRef<StringRef> dimValueNames = {},
-                       ArrayRef<StringRef> symbolValueNames = {});
+  void printAffineExpr(
+      AffineExpr expr,
+      llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
   void printAffineConstraint(AffineExpr expr, bool isEq);
   void printIntegerSet(IntegerSet set);
 
@@ -370,10 +371,9 @@ protected:
     Weak,   // + and -
     Strong, // All other binary operators.
   };
-  void printAffineExprInternal(AffineExpr expr,
-                               BindingStrength enclosingTightness,
-                               ArrayRef<StringRef> dimValueNames = {},
-                               ArrayRef<StringRef> symbolValueNames = {});
+  void printAffineExprInternal(
+      AffineExpr expr, BindingStrength enclosingTightness,
+      llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
 };
 } // end anonymous namespace
 
@@ -921,30 +921,28 @@ void ModulePrinter::printType(Type type) {
 // Affine expressions and maps
 //===----------------------------------------------------------------------===//
 
-void ModulePrinter::printAffineExpr(AffineExpr expr,
-                                    ArrayRef<StringRef> dimValueNames,
-                                    ArrayRef<StringRef> symbolValueNames) {
-  printAffineExprInternal(expr, BindingStrength::Weak, dimValueNames,
-                          symbolValueNames);
+void ModulePrinter::printAffineExpr(
+    AffineExpr expr, llvm::function_ref<void(unsigned, bool)> printValueName) {
+  printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
 }
 
 void ModulePrinter::printAffineExprInternal(
     AffineExpr expr, BindingStrength enclosingTightness,
-    ArrayRef<StringRef> dimValueNames, ArrayRef<StringRef> symbolValueNames) {
+    llvm::function_ref<void(unsigned, bool)> printValueName) {
   const char *binopSpelling = nullptr;
   switch (expr.getKind()) {
   case AffineExprKind::SymbolId: {
     unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
-    if (pos < symbolValueNames.size())
-      os << "symbol(%" << symbolValueNames[pos] << ')';
+    if (printValueName)
+      printValueName(pos, /*isSymbol=*/true);
     else
       os << 's' << pos;
     return;
   }
   case AffineExprKind::DimId: {
     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
-    if (pos < dimValueNames.size())
-      os << '%' << dimValueNames[pos];
+    if (printValueName)
+      printValueName(pos, /*isSymbol=*/false);
     else
       os << 'd' << pos;
     return;
@@ -982,16 +980,14 @@ void ModulePrinter::printAffineExprInternal(
     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
     if (rhsConst && rhsConst.getValue() == -1) {
       os << "-";
-      printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames,
-                              symbolValueNames);
+      printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
       return;
     }
 
-    printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames,
-                            symbolValueNames);
+    printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
+
     os << binopSpelling;
-    printAffineExprInternal(rhsExpr, BindingStrength::Strong, dimValueNames,
-                            symbolValueNames);
+    printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
 
     if (enclosingTightness == BindingStrength::Strong)
       os << ')';
@@ -1009,15 +1005,15 @@ void ModulePrinter::printAffineExprInternal(
       AffineExpr rrhsExpr = rhs.getRHS();
       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
         if (rrhs.getValue() == -1) {
-          printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
-                                  symbolValueNames);
+          printAffineExprInternal(lhsExpr, BindingStrength::Weak,
+                                  printValueName);
           os << " - ";
           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
-                                    dimValueNames, symbolValueNames);
+                                    printValueName);
           } else {
             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
-                                    dimValueNames, symbolValueNames);
+                                    printValueName);
           }
 
           if (enclosingTightness == BindingStrength::Strong)
@@ -1026,11 +1022,11 @@ void ModulePrinter::printAffineExprInternal(
         }
 
         if (rrhs.getValue() < -1) {
-          printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
-                                  symbolValueNames);
+          printAffineExprInternal(lhsExpr, BindingStrength::Weak,
+                                  printValueName);
           os << " - ";
           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
-                                  dimValueNames, symbolValueNames);
+                                  printValueName);
           os << " * " << -rrhs.getValue();
           if (enclosingTightness == BindingStrength::Strong)
             os << ')';
@@ -1043,8 +1039,7 @@ void ModulePrinter::printAffineExprInternal(
   // Pretty print addition to a negative number as a subtraction.
   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
     if (rhsConst.getValue() < 0) {
-      printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
-                              symbolValueNames);
+      printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
       os << " - " << -rhsConst.getValue();
       if (enclosingTightness == BindingStrength::Strong)
         os << ')';
@@ -1052,11 +1047,10 @@ void ModulePrinter::printAffineExprInternal(
     }
   }
 
-  printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
-                          symbolValueNames);
+  printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
+
   os << " + ";
-  printAffineExprInternal(rhsExpr, BindingStrength::Weak, dimValueNames,
-                          symbolValueNames);
+  printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
 
   if (enclosingTightness == BindingStrength::Strong)
     os << ')';
@@ -1242,16 +1236,18 @@ public:
                               ArrayRef<Value *> operands) {
     AffineMap map = mapAttr.getValue();
     unsigned numDims = map.getNumDims();
-    SmallVector<StringRef, 2> dimValueNames;
-    SmallVector<StringRef, 1> symbolValueNames;
-    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
-      if (i < numDims)
-        dimValueNames.push_back(valueNames[operands[i]]);
-      else
-        symbolValueNames.push_back(valueNames[operands[i]]);
-    }
+    auto printValueName = [&](unsigned pos, bool isSymbol) {
+      unsigned index = isSymbol ? numDims + pos : pos;
+      assert(index < operands.size());
+      if (isSymbol)
+        os << "symbol(";
+      printValueID(operands[index]);
+      if (isSymbol)
+        os << ')';
+    };
+
     interleaveComma(map.getResults(), [&](AffineExpr expr) {
-      printAffineExpr(expr, dimValueNames, symbolValueNames);
+      printAffineExpr(expr, printValueName);
     });
   }
 
index fdb32f8d97bd9cccac5817be1c88ca3ab0f65ee5..5f7ce361b91430ffab8b7ffcf22082493f9428b1 100644 (file)
@@ -165,4 +165,21 @@ func @test6(%arg0 : index, %arg1 : index) {
     }
   }
   return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0 + 1)
+
+// Test with operands without special SSA name.
+func @test7() {
+  %0 = alloc() : memref<10xf32>
+  affine.for %i0 = 0 to 10 {
+    %1 = affine.apply (d1) -> (d1 + 1)(%i0)
+    %2 = affine.load %0[%1] : memref<10xf32>
+    affine.store %2, %0[%1] : memref<10xf32>
+// CHECK: affine.load %0[%1] : memref<10xf32>
+// CHECK: affine.store %2, %0[%1] : memref<10xf32>
+  }
+  return
 }
\ No newline at end of file