From 9926ad17525a3424618b710a5bb882b392793e8d Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 21 Sep 2022 12:47:21 +0000 Subject: [PATCH] [mlir] Move getDimsOfType to StructuredOpsUtils.h. Summary: This change will bring all helpers that work with iterator types to one place. Currently getDimsOfType is is declared in Linalg.h, but not directly included by LinalgInterfaces. It worked so far only because all the places that include LinalgInterfaces.h also include Linalg.h directly or indirectly. Differential Revision: https://reviews.llvm.org/D134350 --- mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 5 ----- mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 9 ++++++--- mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h | 11 +++++++++++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 16 ---------------- 4 files changed, 17 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h index 4cf2e93..70e7fc9 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -69,11 +69,6 @@ AffineMap extractOrIdentityMap(Optional maybeMap, unsigned rank, SmallVector concat(ArrayRef a, ArrayRef b); -/// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. -/// Assumes `op` is a LinalgOp. -void getDimsOfType(Operation *op, StringRef iteratorTypeName, - SmallVectorImpl &res); - } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 513aedc..bf2509f 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -206,7 +206,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getDimsOfType($_op, getParallelIteratorTypeName(), res); + return findPositionsOfType($_op.iterator_types(), + getParallelIteratorTypeName(), res); }] >, InterfaceMethod< @@ -231,7 +232,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getDimsOfType($_op, getReductionIteratorTypeName(), res); + return findPositionsOfType($_op.iterator_types(), + getReductionIteratorTypeName(), res); }] >, InterfaceMethod< @@ -256,7 +258,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getDimsOfType($_op.getOperation(), getWindowIteratorTypeName(), res); + return findPositionsOfType($_op.iterator_types(), + getWindowIteratorTypeName(), res); }] >, InterfaceMethod< diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index 9dfde80..fd470c1 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -110,6 +110,17 @@ inline unsigned getNumIterators(ArrayAttr iteratorTypes) { return res; } +/// Return positions in `iteratorTypes` that match `iteratorTypeName`. +inline void findPositionsOfType(ArrayAttr iteratorTypes, + StringRef iteratorTypeName, + SmallVectorImpl &res) { + for (const auto &en : + llvm::enumerate(iteratorTypes.getAsValueRange())) { + if (en.value() == iteratorTypeName) + res.push_back(en.index()); + } +} + /// Helper StructuredGenerator class to manipulate and rewrite ops with /// `StructuredOpInterface`. This is templated for now because VectorOps do not /// yet implement the StructuredOpInterface itself. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 5190073..c088310 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1540,22 +1540,6 @@ LogicalResult IndexOp::verify() { #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" -/// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. -/// Assumes `op` is a LinalgOp. -void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, - SmallVectorImpl &res) { - if (!cast(op).iterator_types()) - return; - - unsigned dim = 0; - for (auto tn : - cast(op).iterator_types().getAsValueRange()) { - if (tn == iteratorTypeName) - res.push_back(dim); - ++dim; - } -} - AffineMap mlir::linalg::extractOrIdentityMap(Optional maybeMap, unsigned rank, MLIRContext *context) { -- 2.7.4