[spirv] Add a canonicalizer for BitcastOp.
authorDenis Khalikov <khalikov.denis@huawei.com>
Mon, 18 Nov 2019 20:36:16 +0000 (12:36 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 18 Nov 2019 20:37:00 +0000 (12:37 -0800)
Convert chained `spirv::BitcastOp` operations into
one `spirv::BitcastOp` operation.

Closes tensorflow/mlir#238

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/238 from denis0x0D:sandbox/canon_bitcast 4352ed4f81b959ec92f849c599e733b62a99c010
PiperOrigin-RevId: 281129234

mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/canonicalize.mlir

index 245a224..1798b9d 100644 (file)
@@ -98,6 +98,8 @@ def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> {
 
   let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
   let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
index 5cda907..8964963 100644 (file)
@@ -652,8 +652,8 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
 
 namespace {
 
-// Combine chained `spirv::AccessChainOp` operations into one
-// `spirv::AccessChainOp` operation.
+/// Combines chained `spirv::AccessChainOp` operations into one
+/// `spirv::AccessChainOp` operation.
 struct CombineChainedAccessChain
     : public OpRewritePattern<spirv::AccessChainOp> {
   using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
@@ -678,7 +678,7 @@ struct CombineChainedAccessChain
     return matchSuccess();
   }
 };
-} // namespace
+} // end anonymous namespace
 
 void spirv::AccessChainOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
@@ -771,6 +771,35 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
   return success();
 }
 
+namespace {
+
+/// Converts chained `spirv::BitcastOp` operations into one
+/// `spirv::BitcastOp` operation.
+struct ConvertChainedBitcast : public OpRewritePattern<spirv::BitcastOp> {
+  using OpRewritePattern<spirv::BitcastOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(spirv::BitcastOp bitcastOp,
+                                     PatternRewriter &rewriter) const override {
+    auto parentBitcastOp = dyn_cast_or_null<spirv::BitcastOp>(
+        bitcastOp.operand()->getDefiningOp());
+
+    if (!parentBitcastOp) {
+      return matchFailure();
+    }
+
+    rewriter.replaceOpWithNewOp<spirv::BitcastOp>(
+        /*valuesToRemoveIfDead=*/{parentBitcastOp.result()}, bitcastOp,
+        bitcastOp.result()->getType(), parentBitcastOp.operand());
+    return matchSuccess();
+  }
+};
+} // end anonymous namespace
+
+void spirv::BitcastOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<ConvertChainedBitcast>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // spv.BitFieldInsert
 //===----------------------------------------------------------------------===//
@@ -2278,7 +2307,7 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
 
   return matchSuccess();
 }
-} // namespace
+} // end anonymous namespace
 
 void spirv::SelectionOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
index 02d8645..87be892 100644 (file)
@@ -135,6 +135,34 @@ func @deduplicate_composite_constant() -> (!spv.array<1 x vector<2xi32>>, !spv.a
 // -----
 
 //===----------------------------------------------------------------------===//
+// spv.Bitcast
+//===----------------------------------------------------------------------===//
+
+func @convert_bitcast_full(%arg0 : vector<2xf32>) -> f64 {
+  // CHECK: %[[RESULT:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
+  // CHECK-NEXT: spv.ReturnValue %[[RESULT]]
+  %0 = spv.Bitcast %arg0 : vector<2xf32> to vector<2xi32>
+  %1 = spv.Bitcast %0 : vector<2xi32> to i64
+  %2 = spv.Bitcast %1 : i64 to f64
+  spv.ReturnValue %2 : f64
+}
+
+// -----
+
+func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spv.ptr<i64, Uniform>) -> f64 {
+  // CHECK: %[[RESULT_0:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to i64
+  // CHECK-NEXT: %[[RESULT_1:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
+  // CHECK-NEXT: spv.Store {{".*"}} {{%.*}}, %[[RESULT_0]]
+  // CHECK-NEXT: spv.ReturnValue %[[RESULT_1]]
+  %0 = spv.Bitcast %arg0 : vector<2xf32> to i64
+  %1 = spv.Bitcast %0 : i64 to f64
+  spv.Store "Uniform" %arg1, %0 : i64
+  spv.ReturnValue %1 : f64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // spv.selection
 //===----------------------------------------------------------------------===//