Write a pass to annotate constant operands on FIR ops. This works
authorEric Schweitz <eschweitz@nvidia.com>
Wed, 2 Feb 2022 20:46:21 +0000 (12:46 -0800)
committerEric Schweitz <eschweitz@nvidia.com>
Mon, 14 Mar 2022 18:14:44 +0000 (11:14 -0700)
around the feature in MLIR's canonicalizer, which considers the semantics
of constants differently based on how they are packaged and not their
values and use.  Add test.

Reviewed By: clementval

Differential Revision: https://reviews.llvm.org/D121625

flang/include/flang/Optimizer/Transforms/Passes.h
flang/include/flang/Optimizer/Transforms/Passes.td
flang/lib/Optimizer/Transforms/AnnotateConstant.cpp [new file with mode: 0644]
flang/lib/Optimizer/Transforms/CMakeLists.txt
flang/test/Fir/annotate-constant.fir [new file with mode: 0644]

index 4c13572..7a28f85 100644 (file)
@@ -37,6 +37,7 @@ std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
 std::unique_ptr<mlir::Pass> createMemoryAllocationPass();
 std::unique_ptr<mlir::Pass>
 createMemoryAllocationPass(bool dynOnHeap, std::size_t maxStackSize);
+std::unique_ptr<mlir::Pass> createAnnotateConstantOperandsPass();
 
 // declarative passes
 #define GEN_PASS_REGISTRATION
index d81cb36..ca466ce 100644 (file)
@@ -74,6 +74,23 @@ def AffineDialectDemotion : Pass<"demote-affine", "::mlir::FuncOp"> {
   ];
 }
 
+def AnnotateConstantOperands : Pass<"annotate-constant"> {
+  let summary = "Annotate constant operands to all FIR operations";
+  let description = [{
+    The MLIR canonicalizer makes a distinction between constants based on how
+    they are packaged in the IR. A constant value is wrapped in an Attr and that
+    Attr can be attached to an Op. There is a distinguished Op, ConstantOp, that
+    merely has one of these Attr attached.
+
+    The MLIR canonicalizer treats constants referenced by an Op and constants
+    referenced through a ConstantOp as having distinct semantics. This pass
+    eliminates that distinction, so hashconsing of Ops, basic blocks, etc.
+    behaves as one would expect.
+  }];
+  let constructor = "::fir::createAnnotateConstantOperandsPass()";
+  let dependentDialects = [ "fir::FIROpsDialect" ];
+}
+
 def ArrayValueCopy : Pass<"array-value-copy", "::mlir::FuncOp"> {
   let summary = "Convert array value operations to memory operations.";
   let description = [{
@@ -91,6 +108,7 @@ def ArrayValueCopy : Pass<"array-value-copy", "::mlir::FuncOp"> {
     This pass is required before code gen to the LLVM IR dialect.
   }];
   let constructor = "::fir::createArrayValueCopyPass()";
+  let dependentDialects = [ "fir::FIROpsDialect" ];
 }
 
 def CharacterConversion : Pass<"character-conversion"> {
diff --git a/flang/lib/Optimizer/Transforms/AnnotateConstant.cpp b/flang/lib/Optimizer/Transforms/AnnotateConstant.cpp
new file mode 100644 (file)
index 0000000..0437883
--- /dev/null
@@ -0,0 +1,55 @@
+//===-- AnnotateConstant.cpp ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+
+#define DEBUG_TYPE "flang-annotate-constant"
+
+using namespace fir;
+
+namespace {
+struct AnnotateConstantOperands
+    : AnnotateConstantOperandsBase<AnnotateConstantOperands> {
+  void runOnOperation() override {
+    auto *context = &getContext();
+    mlir::Dialect *firDialect = context->getLoadedDialect("fir");
+    getOperation()->walk([&](mlir::Operation *op) {
+      // We filter out other dialects even though they may undergo merging of
+      // non-equal constant values by the canonicalizer as well.
+      if (op->getDialect() == firDialect) {
+        llvm::SmallVector<mlir::Attribute> attrs;
+        bool hasOneOrMoreConstOpnd = false;
+        for (mlir::Value opnd : op->getOperands()) {
+          if (auto constOp = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(
+                  opnd.getDefiningOp())) {
+            attrs.push_back(constOp.getValue());
+            hasOneOrMoreConstOpnd = true;
+          } else if (auto addrOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
+                         opnd.getDefiningOp())) {
+            attrs.push_back(addrOp.getSymbol());
+            hasOneOrMoreConstOpnd = true;
+          } else {
+            attrs.push_back(mlir::UnitAttr::get(context));
+          }
+        }
+        if (hasOneOrMoreConstOpnd)
+          op->setAttr("canonicalize_constant_operands",
+                      mlir::ArrayAttr::get(context, attrs));
+      }
+    });
+  }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createAnnotateConstantOperandsPass() {
+  return std::make_unique<AnnotateConstantOperands>();
+}
index 30337a1..9c20db0 100644 (file)
@@ -2,6 +2,7 @@ add_flang_library(FIRTransforms
   AbstractResult.cpp
   AffinePromotion.cpp
   AffineDemotion.cpp
+  AnnotateConstant.cpp
   CharacterConversion.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp
diff --git a/flang/test/Fir/annotate-constant.fir b/flang/test/Fir/annotate-constant.fir
new file mode 100644 (file)
index 0000000..7c7e0e9
--- /dev/null
@@ -0,0 +1,9 @@
+// RUN: fir-opt -annotate-constant %s | FileCheck %s
+
+// CHECK-LABEL: func @annotate_test() -> !fir.ref<!fir.array<?xi32>> {
+func @annotate_test() -> !fir.ref<!fir.array<?xi32>> {
+  %1 = arith.constant 5 : index
+  // CHECK: %[[a:.*]] = fir.alloca !fir.array<?xi32>, %{{.*}} {canonicalize_constant_operands = [5 : index]}
+  %2 = fir.alloca !fir.array<?xi32>, %1
+  return %2 : !fir.ref<!fir.array<?xi32>>
+}