[mlir][linalg] move isElementwise() to Linalg/Utils (NFC)
authorOkwan Kwon <okwan@google.com>
Wed, 22 Jun 2022 22:59:53 +0000 (15:59 -0700)
committerOkwan Kwon <okwan@google.com>
Thu, 23 Jun 2022 01:55:45 +0000 (18:55 -0700)
Differential Revision: https://reviews.llvm.org/D128398

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 5a65298..72b126e 100644 (file)
@@ -32,6 +32,15 @@ class LinalgDependenceGraph;
 // General utilities
 //===----------------------------------------------------------------------===//
 
+/// Check if all indexing maps are projected permutations.
+bool allIndexingsAreProjectedPermutation(LinalgOp op);
+
+/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
+bool hasOnlyScalarElementwiseOp(Region &r);
+
+/// Check if a LinalgOp is an element-wise operation.
+bool isElementwise(LinalgOp op);
+
 /// Check if `permutation` is a permutation of the range
 /// `[0, permutation.size())`.
 bool isPermutation(ArrayRef<int64_t> permutation);
index 794fe97..efbf051 100644 (file)
@@ -417,48 +417,6 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
                llvm::to_vector<4>(returnTypes), op->getAttrs())};
 }
 
-/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
-static bool hasOnlyScalarElementwiseOp(Region &r) {
-  if (!llvm::hasSingleElement(r))
-    return false;
-  for (Operation &op : r.front()) {
-    if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
-              linalg::IndexOp>(op) ||
-          OpTrait::hasElementwiseMappableTraits(&op)) ||
-        llvm::any_of(op.getResultTypes(),
-                     [](Type type) { return !type.isIntOrIndexOrFloat(); }))
-      return false;
-  }
-  return true;
-}
-
-/// Returns `true` if all indexing maps of the linalg op are projected
-/// permutations.
-static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
-  return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
-    return m.isProjectedPermutation(/*allowZeroInResults=*/true);
-  });
-}
-
-// Return true if the op is an element-wise linalg op.
-static bool isElementwise(Operation *op) {
-  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
-  if (!linalgOp)
-    return false;
-  if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
-    return false;
-
-  if (!allIndexingsAreProjectedPermutation(linalgOp))
-    return false;
-
-  // TODO: relax the restrictions on indexing map.
-  for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
-    if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation())
-      return false;
-  }
-  return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
-}
-
 /// Generic vectorization function that rewrites the body of a `linalgOp` into
 /// vector form. Generic vectorization proceeds as follows:
 ///   1. Verify the `linalgOp` has one non-empty region.
index 2af56d6..91d661c 100644 (file)
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRLinalgUtils
   MLIRAffineAnalysis
   MLIRAffineUtils
   MLIRArithmeticDialect
+  MLIRFuncDialect
   MLIRIR
   MLIRLinalgDialect
   MLIRSCFDialect
index 6e68686..1090119 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -141,6 +142,41 @@ static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
 namespace mlir {
 namespace linalg {
 
+bool allIndexingsAreProjectedPermutation(LinalgOp op) {
+  return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
+    return m.isProjectedPermutation(/*allowZeroInResults=*/true);
+  });
+}
+
+bool hasOnlyScalarElementwiseOp(Region &r) {
+  if (!llvm::hasSingleElement(r))
+    return false;
+  for (Operation &op : r.front()) {
+    if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
+              linalg::IndexOp>(op) ||
+          OpTrait::hasElementwiseMappableTraits(&op)) ||
+        llvm::any_of(op.getResultTypes(),
+                     [](Type type) { return !type.isIntOrIndexOrFloat(); }))
+      return false;
+  }
+  return true;
+}
+
+bool isElementwise(LinalgOp op) {
+  if (op.getNumLoops() != op.getNumParallelLoops())
+    return false;
+
+  if (!allIndexingsAreProjectedPermutation(op))
+    return false;
+
+  // TODO: relax the restrictions on indexing map.
+  for (OpOperand *opOperand : op.getOutputOperands()) {
+    if (!op.getTiedIndexingMap(opOperand).isPermutation())
+      return false;
+  }
+  return hasOnlyScalarElementwiseOp(op->getRegion(0));
+}
+
 bool isPermutation(ArrayRef<int64_t> permutation) {
   // Count the number of appearances for all indices.
   SmallVector<int64_t> indexCounts(permutation.size(), 0);
index f0813db..1b7cb25 100644 (file)
@@ -7472,6 +7472,7 @@ cc_library(
         ":ArithmeticDialect",
         ":ArithmeticUtils",
         ":DialectUtils",
+        ":FuncDialect",
         ":IR",
         ":LinalgAnalysis",
         ":LinalgDialect",