Introduce the ability for "isolated from above" ops to introduce shadowing
authorChris Lattner <clattner@google.com>
Fri, 23 Aug 2019 17:35:24 +0000 (10:35 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Aug 2019 17:35:49 +0000 (10:35 -0700)
names for the basic block arguments in their body.

PiperOrigin-RevId: 265084627

mlir/include/mlir/IR/OpImplementation.h
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/parser.mlir
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestOps.td

index 99c1ff5..c4e87ce 100644 (file)
@@ -85,6 +85,13 @@ public:
   virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
                            bool printBlockTerminators = true) = 0;
 
+  /// Renumber the arguments for the specified region to the same names as the
+  /// SSA values in namesToUse.  This may only be used for IsolatedFromAbove
+  /// operations.  If any entry in namesToUse is null, the corresponding
+  /// argument name is left alone.
+  virtual void shadowRegionArgs(Region &region,
+                                ArrayRef<Value *> namesToUse) = 0;
+
   /// Prints an affine map of SSA ids, where SSA id names are used in place
   /// of dims/symbols.
   /// Operand values must come from single-result sources, and be valid
index 82f2c99..9da922c 100644 (file)
@@ -1244,6 +1244,12 @@ public:
     os.indent(currentIndent) << "}";
   }
 
+  /// Renumber the arguments for the specified region to the same names as the
+  /// SSA values in namesToUse.  This may only be used for IsolatedFromAbove
+  /// operations.  If any entry in namesToUse is null, the corresponding
+  /// argument name is left alone.
+  void shadowRegionArgs(Region &region, ArrayRef<Value *> namesToUse) override;
+
   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
                               ArrayRef<Value *> operands) override {
     AffineMap map = mapAttr.getValue();
@@ -1270,9 +1276,14 @@ protected:
   void numberValueID(Value *value);
   void numberValuesInRegion(Region &region);
   void numberValuesInBlock(Block &block);
-  void printValueID(Value *value, bool printResultNo = true) const;
+  void printValueID(Value *value, bool printResultNo = true) const {
+    printValueIDImpl(value, printResultNo, os);
+  }
 
 private:
+  void printValueIDImpl(Value *value, bool printResultNo,
+                        raw_ostream &stream) const;
+
   /// Uniques the given value name within the printer. If the given name
   /// conflicts, it is automatically renamed.
   StringRef uniqueValueName(StringRef name);
@@ -1491,7 +1502,8 @@ void OperationPrinter::print(Operation *op) {
   printTrailingLocation(op->getLoc());
 }
 
-void OperationPrinter::printValueID(Value *value, bool printResultNo) const {
+void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo,
+                                        raw_ostream &stream) const {
   int resultNo = -1;
   auto lookupValue = value;
 
@@ -1507,21 +1519,56 @@ void OperationPrinter::printValueID(Value *value, bool printResultNo) const {
 
   auto it = valueIDs.find(lookupValue);
   if (it == valueIDs.end()) {
-    os << "<<INVALID SSA VALUE>>";
+    stream << "<<INVALID SSA VALUE>>";
     return;
   }
 
-  os << '%';
+  stream << '%';
   if (it->second != nameSentinel) {
-    os << it->second;
+    stream << it->second;
   } else {
     auto nameIt = valueNames.find(lookupValue);
     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
-    os << nameIt->second;
+    stream << nameIt->second;
   }
 
   if (resultNo != -1 && printResultNo)
-    os << '#' << resultNo;
+    stream << '#' << resultNo;
+}
+
+/// Renumber the arguments for the specified region to the same names as the
+/// SSA values in namesToUse.  This may only be used for IsolatedFromAbove
+/// operations.  If any entry in namesToUse is null, the corresponding
+/// argument name is left alone.
+void OperationPrinter::shadowRegionArgs(Region &region,
+                                        ArrayRef<Value *> namesToUse) {
+  assert(!region.empty() && "cannot shadow arguments of an empty region");
+  assert(region.front().getNumArguments() == namesToUse.size() &&
+         "incorrect number of names passed in");
+  assert(region.getParentOp()->isKnownIsolatedFromAbove() &&
+         "only KnownIsolatedFromAbove ops can shadow names");
+
+  SmallVector<char, 16> nameStr;
+  for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
+    auto *nameToUse = namesToUse[i];
+    if (nameToUse == nullptr)
+      continue;
+
+    auto *nameToReplace = region.front().getArgument(i);
+
+    nameStr.clear();
+    llvm::raw_svector_ostream nameStream(nameStr);
+    printValueIDImpl(nameToUse, /*printResultNo=*/true, nameStream);
+
+    // Entry block arguments should already have a pretty "arg" name.
+    assert(valueIDs[nameToReplace] == nameSentinel);
+
+    // Use the name without the leading %.
+    auto name = StringRef(nameStream.str()).drop_front();
+
+    // Overwrite the name.
+    valueNames[nameToReplace] = name.copy(usedNameAllocator);
+  }
 }
 
 void OperationPrinter::printOperation(Operation *op) {
index 6f576a8..db4a096 100644 (file)
@@ -1055,13 +1055,25 @@ func @op_with_region_args() {
 // CHECK-LABEL: func @op_with_passthrough_region_args
 func @op_with_passthrough_region_args() {
   // CHECK: [[VAL:%.*]] = constant
-  // CHECK: "test.isolated_region"([[VAL]])
-  // CHECK-NEXT: ^{{.*}}([[ARG:%.*]]: index)
-  // CHECK-NEXT: "foo.consumer"([[ARG]]) : (index)
-
   %0 = constant 10 : index
+
+  // CHECK: test.isolated_region [[VAL]] {
+  // CHECK-NEXT: "foo.consumer"([[VAL]]) : (index)
+  // CHECK-NEXT: }
   test.isolated_region %0 {
     "foo.consumer"(%0) : (index) -> ()
   }
+
+  // CHECK: [[VAL:%.*]]:2 = "foo.op"
+  %result:2 = "foo.op"() : () -> (index, index)
+
+  // CHECK: test.isolated_region [[VAL]]#1 {
+  // CHECK-NEXT: "foo.consumer"([[VAL]]#1) : (index)
+  // CHECK-NEXT: }
+  test.isolated_region %result#1 {
+    "foo.consumer"(%result#1) : (index) -> ()
+  }
+
   return
 }
+
index 40faa0d..8b44b6c 100644 (file)
@@ -54,6 +54,13 @@ static ParseResult parseIsolatedRegionOp(OpAsmParser *parser,
                              /*enableNameShadowing=*/true);
 }
 
+static void print(OpAsmPrinter *p, IsolatedRegionOp op) {
+  *p << "test.isolated_region ";
+  p->printOperand(op.getOperand());
+  p->shadowRegionArgs(op.region(), op.getOperand());
+  p->printRegion(op.region(), /*printEntryBlockArgs=*/false);
+}
+
 //===----------------------------------------------------------------------===//
 // Test PolyForOp - parse list of region arguments.
 //===----------------------------------------------------------------------===//
index 55466b7..2926930 100644 (file)
@@ -704,6 +704,7 @@ def IsolatedRegionOp : TEST_Op<"isolated_region", [IsolatedFromAbove]> {
   let arguments = (ins Index:$input);
   let regions = (region SizedRegion<1>:$region);
   let parser = [{ return ::parse$cppClass(parser, result); }];
+  let printer = [{ return ::print(p, *this); }];
 }
 
 def PolyForOp : TEST_Op<"polyfor">