Add support for 1->N type mappings in the dialect conversion infrastructure. To suppo...
authorRiver Riddle <riverriddle@google.com>
Fri, 21 Jun 2019 16:29:46 +0000 (09:29 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 22 Jun 2019 16:16:06 +0000 (09:16 -0700)
PiperOrigin-RevId: 254411383

mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/TestDialect/TestOps.td
mlir/test/TestDialect/TestPatterns.cpp
mlir/test/Transforms/test-legalizer.mlir

index ede1120..599c7f1 100644 (file)
@@ -216,6 +216,19 @@ public:
   virtual LogicalResult convertSignatureArg(unsigned inputNo, Type type,
                                             NamedAttributeList attrs,
                                             SignatureConversion &result);
+
+  /// This hook allows for materializing a conversion from a set of types into
+  /// one result type by generating a cast operation of some kind. The generated
+  /// operation should produce one result, of 'resultType', with the provided
+  /// 'inputs' as operands. This hook must be overridden when a type conversion
+  /// results in more than one type.
+  virtual Operation *materializeConversion(PatternRewriter &rewriter,
+                                           Type resultType,
+                                           ArrayRef<Value *> inputs,
+                                           Location loc) {
+    llvm_unreachable("expected 'materializeConversion' to be overridden when "
+                     "generating 1->N type conversions");
+  }
 };
 
 /// This class describes a specific conversion target.
index 3cec1f8..1684655 100644 (file)
@@ -51,6 +51,12 @@ struct ArgConverter {
       if (it == argMapping.end())
         continue;
       for (auto *op : it->second) {
+        // If the operation exists within the parent block, like with 1->N cast
+        // operations, we don't need to drop them. They will be automatically
+        // cleaned up with the region is destroyed.
+        if (op->getBlock())
+          continue;
+
         op->dropAllDefinedValueUses();
         op->destroy();
       }
@@ -77,7 +83,13 @@ struct ArgConverter {
         auto *op = argOps[i];
         auto *arg = block->addArgument(op->getResult(0)->getType());
         op->getResult(0)->replaceAllUsesWith(arg);
-        op->destroy();
+
+        // If this was a 1->N value mapping it exists within the parent block so
+        // erase it instead of destroying.
+        if (op->getBlock())
+          op->erase();
+        else
+          op->destroy();
       }
     }
     argMapping.clear();
@@ -97,8 +109,14 @@ struct ArgConverter {
         auto *op = argOps[i];
 
         // Handle the case of a 1->N value mapping.
-        if (op->getNumOperands() > 1)
-          llvm_unreachable("1->N argument mappings are currently not handled");
+        if (op->getNumOperands() > 1) {
+          // If all of the uses were removed, we can drop this op. Otherwise,
+          // keep the operation alive and let the user handle any remaining
+          // usages.
+          if (op->use_empty())
+            op->erase();
+          continue;
+        }
 
         // Handle the case where this argument had a direct mapping.
         if (op->getNumOperands() == 1) {
@@ -132,7 +150,8 @@ struct ArgConverter {
   }
 
   /// Converts the signature of the given entry block.
-  void convertSignature(Block *block,
+  void convertSignature(Block *block, PatternRewriter &rewriter,
+                        TypeConverter &converter,
                         TypeConverter::SignatureConversion &signatureConversion,
                         BlockAndValueMapping &mapping) {
     unsigned origArgCount = block->getNumArguments();
@@ -146,13 +165,15 @@ struct ArgConverter {
     // Remap each of the original arguments as determined by the signature
     // conversion.
     auto &newArgMapping = argMapping[block];
+    rewriter.setInsertionPointToStart(block);
     for (unsigned i = 0; i != origArgCount; ++i) {
       ArrayRef<Value *> remappedValues;
       if (auto inputMap = signatureConversion.getInputMapping(i))
         remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size);
 
       BlockArgument *arg = block->getArgument(i);
-      newArgMapping.push_back(convertArgument(arg, remappedValues, mapping));
+      newArgMapping.push_back(
+          convertArgument(arg, remappedValues, rewriter, converter, mapping));
     }
 
     // Erase all of the original arguments.
@@ -161,7 +182,8 @@ struct ArgConverter {
   }
 
   /// Converts the arguments of the given block.
-  LogicalResult convertArguments(Block *block, TypeConverter &converter,
+  LogicalResult convertArguments(Block *block, PatternRewriter &rewriter,
+                                 TypeConverter &converter,
                                  BlockAndValueMapping &mapping) {
     unsigned origArgCount = block->getNumArguments();
     if (origArgCount == 0)
@@ -178,10 +200,11 @@ struct ArgConverter {
 
     // Remap all of the original argument values.
     auto &newArgMapping = argMapping[block];
+    rewriter.setInsertionPointToStart(block);
     for (unsigned i = 0; i != origArgCount; ++i) {
       SmallVector<Value *, 1> newArgs(block->addArguments(newArgTypes[i]));
-      newArgMapping.push_back(
-          convertArgument(block->getArgument(i), newArgs, mapping));
+      newArgMapping.push_back(convertArgument(block->getArgument(i), newArgs,
+                                              rewriter, converter, mapping));
     }
 
     // Erase all of the original arguments.
@@ -195,6 +218,8 @@ struct ArgConverter {
   /// to perform the conversion.
   Operation *convertArgument(BlockArgument *origArg,
                              ArrayRef<Value *> newValues,
+                             PatternRewriter &rewriter,
+                             TypeConverter &converter,
                              BlockAndValueMapping &mapping) {
     // Handle the cases of 1->0 or 1->1 mappings.
     if (newValues.size() < 2) {
@@ -209,7 +234,15 @@ struct ArgConverter {
         mapping.map(cast->getResult(0), newValues[0]);
       return cast;
     }
-    llvm_unreachable("1->N argument mappings are currently not handled");
+
+    // Otherwise, this is a 1->N mapping. Call into the provided type converter
+    // to pack the new values.
+    auto *cast = converter.materializeConversion(rewriter, origArg->getType(),
+                                                 newValues, loc);
+    assert(cast->getNumResults() == 1 &&
+           cast->getNumOperands() == newValues.size());
+    origArg->replaceAllUsesWith(cast->getResult(0));
+    return cast;
   }
 
   /// A utility function used to create a conversion cast operation with the
@@ -874,10 +907,11 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
   // types.
   if (typeConverter) {
     for (Block &block :
-         llvm::drop_begin(region.getBlocks(), convertEntryTypes ? 0 : 1))
-      if (failed(rewriter.argConverter.convertArguments(&block, *typeConverter,
-                                                        rewriter.mapping)))
+         llvm::drop_begin(region.getBlocks(), convertEntryTypes ? 0 : 1)) {
+      if (failed(rewriter.argConverter.convertArguments(
+              &block, rewriter, *typeConverter, rewriter.mapping)))
         return failure();
+    }
   }
 
   // Store the number of blocks before conversion (new blocks may be added due
@@ -909,8 +943,9 @@ LogicalResult FunctionConverter::convertFunction(
 
   // Update the signature of the entry block.
   if (signatureConversion) {
-    rewriter.argConverter.convertSignature(
-        &f->getBody().front(), *signatureConversion, rewriter.mapping);
+    rewriter.argConverter.convertSignature(&f->getBody().front(), rewriter,
+                                           *typeConverter, *signatureConversion,
+                                           rewriter.mapping);
   }
 
   // Rewrite the function body.
index 0286284..dfc9a0c 100644 (file)
@@ -227,4 +227,13 @@ def : Pat<(ILLegalOpD), (LegalOpA Test_LegalizerEnum_Failure)>;
 def : Pat<(ILLegalOpC), (ILLegalOpE), [], (addBenefit 10)>;
 def : Pat<(ILLegalOpE), (LegalOpA Test_LegalizerEnum_Success)>;
 
+//===----------------------------------------------------------------------===//
+// Test Type Legalization
+//===----------------------------------------------------------------------===//
+
+def TestReturnOp : TEST_Op<"return", [Terminator]>,
+  Arguments<(ins Variadic<AnyType>:$inputs)>;
+def TestCastOp : TEST_Op<"cast">,
+  Arguments<(ins Variadic<AnyType>:$inputs)>, Results<(outs AnyType:$res)>;
+
 #endif // TEST_OPS
index d4e5f79..f323d7f 100644 (file)
@@ -49,6 +49,7 @@ static mlir::PassRegistration<TestPatternDriver>
 //===----------------------------------------------------------------------===//
 // Legalization Driver.
 //===----------------------------------------------------------------------===//
+
 namespace {
 /// This pattern is a simple pattern that inlines the first region of a given
 /// operation into the parent region.
@@ -77,6 +78,29 @@ struct TestDropOp : public ConversionPattern {
     return matchSuccess();
   }
 };
+/// This pattern handles the case of a split return value.
+struct TestSplitReturnType : public ConversionPattern {
+  TestSplitReturnType(MLIRContext *ctx)
+      : ConversionPattern("test.return", 1, ctx) {}
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const final {
+    // Check for a return of F32.
+    if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32())
+      return matchFailure();
+
+    // Check if the first operation is a cast operation, if it is we use the
+    // results directly.
+    auto *defOp = operands[0]->getDefiningOp();
+    if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
+      SmallVector<Value *, 2> returnOperands(packerOp.getOperands());
+      rewriter.replaceOpWithNewOp<TestReturnOp>(op, returnOperands);
+      return matchSuccess();
+    }
+
+    // Otherwise, fail to match.
+    return matchFailure();
+  }
+};
 } // namespace
 
 namespace {
@@ -94,10 +118,35 @@ struct TestTypeConverter : public TypeConverter {
       return success();
     }
 
+    // Split F32 into F16,F16.
+    if (t.isF32()) {
+      results.assign(2, FloatType::getF16(t.getContext()));
+      return success();
+    }
+
     // Otherwise, convert the type directly.
     results.push_back(t);
     return success();
   }
+
+  /// Override the hook to materialize a conversion. This is necessary because
+  /// we generate 1->N type mappings.
+  Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
+                                   ArrayRef<Value *> inputs, Location loc) {
+    return rewriter.create<TestCastOp>(loc, resultType, inputs);
+  }
+};
+
+struct TestConversionTarget : public ConversionTarget {
+  TestConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
+    addLegalOp<LegalOpA>();
+    addDynamicallyLegalOp<TestReturnOp>();
+  }
+  bool isDynamicallyLegal(Operation *op) const final {
+    // Don't allow F32 operands.
+    return llvm::none_of(op->getOperandTypes(),
+                         [](Type type) { return type.isF32(); });
+  }
 };
 
 struct TestLegalizePatternDriver
@@ -105,12 +154,11 @@ struct TestLegalizePatternDriver
   void runOnModule() override {
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
-    RewriteListBuilder<TestRegionRewriteBlockMovement, TestDropOp>::build(
-        patterns, &getContext());
+    RewriteListBuilder<TestRegionRewriteBlockMovement, TestDropOp,
+                       TestSplitReturnType>::build(patterns, &getContext());
 
     TestTypeConverter converter;
-    ConversionTarget target(getContext());
-    target.addLegalOp<LegalOpA>();
+    TestConversionTarget target(getContext());
     if (failed(applyConversionPatterns(getModule(), target, converter,
                                        std::move(patterns))))
       signalPassFailure();
index b71e149..449fba9 100644 (file)
@@ -23,6 +23,19 @@ func @remap_input_1_to_1(%arg0: i64) -> i64 {
  return %arg0 : i64
 }
 
+// CHECK-LABEL: func @remap_input_1_to_N(%arg0: f16, %arg1: f16) -> (f16, f16)
+func @remap_input_1_to_N(%arg0: f32) -> f32 {
+ // CHECK-NEXT: "test.return"(%arg0, %arg1) : (f16, f16) -> ()
+ "test.return"(%arg0) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16)
+func @remap_input_1_to_N_remaining_use(%arg0: f32) {
+  // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32
+  // CHECK-NEXT: "work"([[CAST]]) : (f32) -> ()
+  "work"(%arg0) : (f32) -> ()
+}
+
 // CHECK-LABEL: func @remap_multi(%arg0: f64, %arg1: f64) -> (f64, f64)
 func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
  // CHECK-NEXT: return %arg0, %arg1 : f64, f64
@@ -44,11 +57,12 @@ func @remap_nested() {
 // CHECK-LABEL: func @remap_moved_region_args
 func @remap_moved_region_args() {
   // CHECK-NEXT: return
-  // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64):
-  // CHECK-NEXT: "work"{{.*}} : (f64, f64)
+  // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
+  // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
+  // CHECK-NEXT: "work"{{.*}} : (f64, f64, f32)
   "test.region"() ({
-    ^bb1(%i0: i64, %unused: i16, %i1: i64):
-      "work"(%i0, %i1) : (i64, i64) -> ()
+    ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
+      "work"(%i0, %i1, %2) : (i64, i64, f32) -> ()
   }) : () -> ()
   return
 }
@@ -58,8 +72,8 @@ func @remap_drop_region() {
   // CHECK-NEXT: return
   // CHECK-NEXT: }
   "test.drop_op"() ({
-    ^bb1(%i0: i64, %unused: i16, %i1: i64):
-      "work"(%i0, %i1) : (i64, i64) -> ()
+    ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
+      "work"(%i0, %i1, %2) : (i64, i64, f32) -> ()
   }) : () -> ()
   return
 }