From 91f27d025b3c92f7b57efb409960128cfd9a4108 Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Tue, 25 Jun 2019 10:29:53 -0700 Subject: [PATCH] Support printing SSA ids in affine.load/store which do not have special names. PiperOrigin-RevId: 254997746 --- mlir/include/mlir/IR/OpImplementation.h | 4 ++ mlir/lib/IR/AsmPrinter.cpp | 84 ++++++++++++++++----------------- mlir/test/AffineOps/load-store.mlir | 17 +++++++ 3 files changed, 61 insertions(+), 44 deletions(-) diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index c344a22..96e682f 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -87,6 +87,8 @@ public: /// Prints an affine map of SSA ids, where SSA id names are used in place /// of dims/symbols. + /// Operand values must come from single-result sources, and be valid + /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ArrayRef operands) = 0; @@ -380,6 +382,8 @@ public: } /// Parses an affine map attribute where dims and symbols are SSA operands. + /// Operand values must come from single-result sources, and be valid + /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl &operands, Attribute &map, StringRef attrName, diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index efa1de3..a519d75 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -347,8 +347,9 @@ public: void printLocation(LocationAttr loc); void printAffineMap(AffineMap map); - void printAffineExpr(AffineExpr expr, ArrayRef dimValueNames = {}, - ArrayRef symbolValueNames = {}); + void printAffineExpr( + AffineExpr expr, + llvm::function_ref printValueName = nullptr); void printAffineConstraint(AffineExpr expr, bool isEq); void printIntegerSet(IntegerSet set); @@ -370,10 +371,9 @@ protected: Weak, // + and - Strong, // All other binary operators. }; - void printAffineExprInternal(AffineExpr expr, - BindingStrength enclosingTightness, - ArrayRef dimValueNames = {}, - ArrayRef symbolValueNames = {}); + void printAffineExprInternal( + AffineExpr expr, BindingStrength enclosingTightness, + llvm::function_ref printValueName = nullptr); }; } // end anonymous namespace @@ -921,30 +921,28 @@ void ModulePrinter::printType(Type type) { // Affine expressions and maps //===----------------------------------------------------------------------===// -void ModulePrinter::printAffineExpr(AffineExpr expr, - ArrayRef dimValueNames, - ArrayRef symbolValueNames) { - printAffineExprInternal(expr, BindingStrength::Weak, dimValueNames, - symbolValueNames); +void ModulePrinter::printAffineExpr( + AffineExpr expr, llvm::function_ref printValueName) { + printAffineExprInternal(expr, BindingStrength::Weak, printValueName); } void ModulePrinter::printAffineExprInternal( AffineExpr expr, BindingStrength enclosingTightness, - ArrayRef dimValueNames, ArrayRef symbolValueNames) { + llvm::function_ref printValueName) { const char *binopSpelling = nullptr; switch (expr.getKind()) { case AffineExprKind::SymbolId: { unsigned pos = expr.cast().getPosition(); - if (pos < symbolValueNames.size()) - os << "symbol(%" << symbolValueNames[pos] << ')'; + if (printValueName) + printValueName(pos, /*isSymbol=*/true); else os << 's' << pos; return; } case AffineExprKind::DimId: { unsigned pos = expr.cast().getPosition(); - if (pos < dimValueNames.size()) - os << '%' << dimValueNames[pos]; + if (printValueName) + printValueName(pos, /*isSymbol=*/false); else os << 'd' << pos; return; @@ -982,16 +980,14 @@ void ModulePrinter::printAffineExprInternal( auto rhsConst = rhsExpr.dyn_cast(); if (rhsConst && rhsConst.getValue() == -1) { os << "-"; - printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames, - symbolValueNames); + printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); return; } - printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames, - symbolValueNames); + printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); + os << binopSpelling; - printAffineExprInternal(rhsExpr, BindingStrength::Strong, dimValueNames, - symbolValueNames); + printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); if (enclosingTightness == BindingStrength::Strong) os << ')'; @@ -1009,15 +1005,15 @@ void ModulePrinter::printAffineExprInternal( AffineExpr rrhsExpr = rhs.getRHS(); if (auto rrhs = rrhsExpr.dyn_cast()) { if (rrhs.getValue() == -1) { - printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames, - symbolValueNames); + printAffineExprInternal(lhsExpr, BindingStrength::Weak, + printValueName); os << " - "; if (rhs.getLHS().getKind() == AffineExprKind::Add) { printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, - dimValueNames, symbolValueNames); + printValueName); } else { printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, - dimValueNames, symbolValueNames); + printValueName); } if (enclosingTightness == BindingStrength::Strong) @@ -1026,11 +1022,11 @@ void ModulePrinter::printAffineExprInternal( } if (rrhs.getValue() < -1) { - printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames, - symbolValueNames); + printAffineExprInternal(lhsExpr, BindingStrength::Weak, + printValueName); os << " - "; printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, - dimValueNames, symbolValueNames); + printValueName); os << " * " << -rrhs.getValue(); if (enclosingTightness == BindingStrength::Strong) os << ')'; @@ -1043,8 +1039,7 @@ void ModulePrinter::printAffineExprInternal( // Pretty print addition to a negative number as a subtraction. if (auto rhsConst = rhsExpr.dyn_cast()) { if (rhsConst.getValue() < 0) { - printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames, - symbolValueNames); + printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); os << " - " << -rhsConst.getValue(); if (enclosingTightness == BindingStrength::Strong) os << ')'; @@ -1052,11 +1047,10 @@ void ModulePrinter::printAffineExprInternal( } } - printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames, - symbolValueNames); + printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); + os << " + "; - printAffineExprInternal(rhsExpr, BindingStrength::Weak, dimValueNames, - symbolValueNames); + printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); if (enclosingTightness == BindingStrength::Strong) os << ')'; @@ -1242,16 +1236,18 @@ public: ArrayRef operands) { AffineMap map = mapAttr.getValue(); unsigned numDims = map.getNumDims(); - SmallVector dimValueNames; - SmallVector symbolValueNames; - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (i < numDims) - dimValueNames.push_back(valueNames[operands[i]]); - else - symbolValueNames.push_back(valueNames[operands[i]]); - } + auto printValueName = [&](unsigned pos, bool isSymbol) { + unsigned index = isSymbol ? numDims + pos : pos; + assert(index < operands.size()); + if (isSymbol) + os << "symbol("; + printValueID(operands[index]); + if (isSymbol) + os << ')'; + }; + interleaveComma(map.getResults(), [&](AffineExpr expr) { - printAffineExpr(expr, dimValueNames, symbolValueNames); + printAffineExpr(expr, printValueName); }); } diff --git a/mlir/test/AffineOps/load-store.mlir b/mlir/test/AffineOps/load-store.mlir index fdb32f8..5f7ce36 100644 --- a/mlir/test/AffineOps/load-store.mlir +++ b/mlir/test/AffineOps/load-store.mlir @@ -165,4 +165,21 @@ func @test6(%arg0 : index, %arg1 : index) { } } return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0 + 1) + +// Test with operands without special SSA name. +func @test7() { + %0 = alloc() : memref<10xf32> + affine.for %i0 = 0 to 10 { + %1 = affine.apply (d1) -> (d1 + 1)(%i0) + %2 = affine.load %0[%1] : memref<10xf32> + affine.store %2, %0[%1] : memref<10xf32> +// CHECK: affine.load %0[%1] : memref<10xf32> +// CHECK: affine.store %2, %0[%1] : memref<10xf32> + } + return } \ No newline at end of file -- 2.7.4