From 894641e974e55164481fe3e7c90ef3cf5af89cf9 Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Mon, 25 Jul 2022 10:28:23 -0600 Subject: [PATCH] Revert "[mlir][Arithmetic] Add `arith.delinearize_index` operation" This reverts commit 535b507ba58e8b5f604d53ffc961be1456d229a7. --- .../mlir/Dialect/Arithmetic/IR/ArithmeticOps.td | 43 --------------------- mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h | 16 -------- mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp | 30 --------------- mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt | 1 - .../Dialect/Arithmetic/Transforms/CMakeLists.txt | 1 - .../Dialect/Arithmetic/Transforms/ExpandOps.cpp | 24 +----------- mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp | 45 ---------------------- mlir/test/Dialect/Arithmetic/expand-ops.mlir | 21 ---------- mlir/test/Dialect/Arithmetic/invalid.mlir | 16 -------- mlir/test/Dialect/Arithmetic/ops.mlir | 6 --- 10 files changed, 2 insertions(+), 201 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td index 2bdff67..75710d60 100644 --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -1219,47 +1219,4 @@ def SelectOp : Arith_Op<"select", [ let hasCustomAssemblyFormat = 1; } -//===----------------------------------------------------------------------===// -// DelinearizeIndexOp -//===----------------------------------------------------------------------===// - -def DelinearizeIndexOp : Op { - let summary = "delinearize an index"; - let description = [{ - The `arith.delinearize_index` operation takes a single index value and - calculates the multi-index according to the given basis. - - Example: - - ``` - %indices:3 = arith.delinearize_index %linear_index (%1, %2, %3) : index, index, index - ``` - - In the above example, `%indices:3` conceptually holds the following: - - ``` - %v1 = arith.muli %1, %2 : index - %indices#0 = floorDiv(%linear_index , %v1) - %indices#1 = floorDiv(remander(%linear_index , %v1), %3) - %indices#2 = remainder(remainder(%linear_idnex, %v1), %3) - ``` - }]; - - let arguments = (ins Index:$linear_index, Variadic:$basis); - let results = (outs Variadic:$multi_index); - - let assemblyFormat = [{ - $linear_index `(` $basis `)` attr-dict `:` type($multi_index) - }]; - - let builders = [ - OpBuilder<(ins "Value":$linear_index, "ArrayRef":$basis)> - ]; - - let hasVerifier = 1; -} - - - #endif // ARITHMETIC_OPS diff --git a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h index a59e108..924de08 100644 --- a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h @@ -108,22 +108,6 @@ private: OpBuilder &b; Location loc; }; - -/// Holds the result of (div a, b) and (mod a, b) -struct DivModValue { - Value quotient; - Value remainder; -}; - -/// Create IR to calculate (div a, b) and (mod a, b) -DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs); - -/// Generate the IR to delinearize `linearIndex` given the `basis` and return -/// the multi-index. -FailureOr> delinearizeIndex(OpBuilder &b, Location loc, - Value linearIndex, - ArrayRef dimSizes); - } // namespace mlir #endif // MLIR_DIALECT_ARITHMETIC_UTILS_UTILS_H diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp index f2a88e2..4d0d50e 100644 --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -10,7 +10,6 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/CommonFolders.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" @@ -2042,35 +2041,6 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, } //===----------------------------------------------------------------------===// -// DelinearizeIndexOp -//===----------------------------------------------------------------------===// - -void arith::DelinearizeIndexOp::build(OpBuilder &builder, - OperationState &result, - Value linear_index, - ArrayRef basis) { - result.addTypes(SmallVector(basis.size(), builder.getIndexType())); - result.addOperands(linear_index); - SmallVector basisValues = - llvm::to_vector(llvm::map_range(basis, [&](OpFoldResult ofr) -> Value { - Optional staticDim = getConstantIntValue(ofr); - if (staticDim.has_value()) - return builder.create(result.location, - *staticDim); - return ofr.dyn_cast(); - })); - result.addOperands(basisValues); -} - -LogicalResult arith::DelinearizeIndexOp::verify() { - if (getBasis().empty()) - return emitOpError("basis should not be empty"); - if (getNumResults() != getBasis().size()) - return emitOpError("should return an index for each basis element"); - return success(); -} - -//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt index 5ca47d9..e23504b 100644 --- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt @@ -15,7 +15,6 @@ add_mlir_dialect_library(MLIRArithmeticDialect LINK_LIBS PUBLIC MLIRDialect - MLIRDialectUtils MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt index eead11d..f140715 100644 --- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -13,7 +13,6 @@ add_mlir_dialect_library(MLIRArithmeticTransforms LINK_LIBS PUBLIC MLIRAnalysis MLIRArithmeticDialect - MLIRArithmeticUtils MLIRBufferizationDialect MLIRBufferizationTransforms MLIRInferIntRangeInterface diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp index 7c48b6f..afe7aab 100644 --- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp @@ -9,7 +9,6 @@ #include "PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" -#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -189,23 +188,6 @@ public: } }; -/// Lowers `arith.delinearize_index` into a sequence of division and remainder -/// operations. -struct LowerDelinearizeIndexOps - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(arith::DelinearizeIndexOp op, - PatternRewriter &rewriter) const override { - FailureOr> multiIndex = - delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(), - llvm::to_vector(op.getBasis())); - if (failed(multiIndex)) - return failure(); - rewriter.replaceOp(op, *multiIndex); - return success(); - } -}; - struct ArithmeticExpandOpsPass : public ArithmeticExpandOpsBase { void runOnOperation() override { @@ -225,8 +207,7 @@ struct ArithmeticExpandOpsPass arith::MaxUIOp, arith::MinFOp, arith::MinSIOp, - arith::MinUIOp, - arith::DelinearizeIndexOp + arith::MinUIOp >(); // clang-format on if (failed(applyPartialConversion(getOperation(), target, @@ -249,8 +230,7 @@ void mlir::arith::populateArithmeticExpandOpsPatterns( MaxMinIOpConverter, MaxMinIOpConverter, MaxMinIOpConverter, - MaxMinIOpConverter, - LowerDelinearizeIndexOps + MaxMinIOpConverter >(patterns.getContext()); // clang-format on } diff --git a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp index 4058886..b568891 100644 --- a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/IR/OpDefinition.h" #include "llvm/ADT/SmallBitVector.h" using namespace mlir; @@ -116,47 +115,3 @@ Value ArithBuilder::slt(Value lhs, Value rhs) { Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { return b.create(loc, cmp, lhs, rhs); } - -DivModValue mlir::getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs) { - DivModValue result; - result.quotient = b.create(loc, lhs, rhs); - result.remainder = b.create(loc, lhs, rhs); - return result; -} - -/// Create IR that computes the product of all elements in the set. -static FailureOr getIndexProduct(OpBuilder &b, Location loc, - ArrayRef set) { - if (set.empty()) - return failure(); - OpFoldResult result = set[0]; - for (unsigned i = 1; i < set.size(); i++) - result = b.createOrFold( - loc, getValueOrCreateConstantIndexOp(b, loc, result), set[i]); - return result; -} - -FailureOr> mlir::delinearizeIndex(OpBuilder &b, Location loc, - Value linearIndex, - ArrayRef dimSizes) { - unsigned numDims = dimSizes.size(); - - SmallVector divisors; - for (unsigned i = 1; i < numDims; i++) { - ArrayRef slice(dimSizes.begin() + i, dimSizes.end()); - FailureOr prod = getIndexProduct(b, loc, slice); - if (failed(prod)) - return failure(); - divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod)); - } - - SmallVector results; - Value residual = linearIndex; - for (Value divisor : divisors) { - DivModValue divMod = getDivMod(b, loc, residual, divisor); - results.push_back(divMod.quotient); - residual = divMod.remainder; - } - results.push_back(residual); - return results; -} diff --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir index 1da12d3..e0990fd 100644 --- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir +++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir @@ -230,24 +230,3 @@ func.func @minui(%a: i32, %b: i32) -> i32 { } // CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32 - -// ----- - -// CHECK-LABEL: @static_basis -// CHECK-SAME: (%[[IDX:.+]]: index) -// CHECK: arith.constant -// CHECK: arith.constant -// CHECK-DAG: %[[c224:.+]] = arith.constant 224 : index -// CHECK-DAG: %[[c50176:.+]] = arith.constant 50176 : index -// CHECK: %[[N:.+]] = arith.divui %[[IDX]], %[[c50176]] : index -// CHECK: %[[RES:.+]] = arith.remui %[[IDX]], %[[c50176]] : index -// CHECK: %[[P:.+]] = arith.divui %[[RES]], %[[c224]] : index -// CHECK: %[[Q:.+]] = arith.remui %[[RES]], %[[c224]] : index -// CHECK: return %[[N]], %[[P]], %[[Q]] -func.func @static_basis(%linear_index: index) -> (index, index, index) { - %b0 = arith.constant 16 : index - %b1 = arith.constant 224 : index - %b2 = arith.constant 224 : index - %1:3 = arith.delinearize_index %linear_index (%b0, %b1, %b2) : index, index, index - return %1#0, %1#1, %1#2 : index, index, index -} diff --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir index 201b757..19c427b 100644 --- a/mlir/test/Dialect/Arithmetic/invalid.mlir +++ b/mlir/test/Dialect/Arithmetic/invalid.mlir @@ -721,19 +721,3 @@ func.func @func() { %x = arith.constant 1 : i32 } - -// ----- - -func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) { - // expected-error@+1 {{'arith.delinearize_index' op should return an index for each basis element}} - %1 = arith.delinearize_index %idx (%basis0, %basis1) : index - return -} - -// ----- - -func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) { - // expected-error@+1 {{'arith.delinearize_index' op basis should not be empty}} - arith.delinearize_index %idx () : index - return -} diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir index fe241ba..61f9a2d 100644 --- a/mlir/test/Dialect/Arithmetic/ops.mlir +++ b/mlir/test/Dialect/Arithmetic/ops.mlir @@ -958,9 +958,3 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>, %min_unsigned = arith.minui %i1, %i2 : i32 return } - -// CHECK-LABEL: func @delinearize -func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) -> (index, index) { - %1:2 = arith.delinearize_index %idx (%basis0, %basis1) : index, index - return %1#0, %1#1 : index, index -} -- 2.7.4