From 51afca640c2968c51a92dfb89e67e10bdcb98216 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 23 Nov 2022 18:57:02 -0500 Subject: [PATCH] [mlir][vector] Simplify fold pattern for ExtractOp(constant). NFC. Use helper functions. Reuse array element attributes. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D138609 --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 30 +++++------------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index bd96ee7..22d7bdc 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1623,7 +1623,6 @@ public: return failure(); auto vecTy = sourceVector.getType().cast(); - Type elemTy = vecTy.getElementType(); ArrayAttr positions = extractOp.getPosition(); if (vecTy.isScalable()) return failure(); @@ -1631,36 +1630,17 @@ public: // constants. if (vecTy.getRank() != static_cast(positions.size())) return failure(); - // TODO: Handle more element types, e.g., complex values. - if (!elemTy.isIntOrIndexOrFloat()) - return failure(); // The splat case is handled by `ExtractOpSplatConstantFolder`. auto dense = vectorCst.dyn_cast(); if (!dense || dense.isSplat()) return failure(); - // Calculate the flattened position. - int64_t elemPosition = 0; - int64_t innerElems = 1; - for (auto [dimSize, positionInDim] : - llvm::reverse(llvm::zip(vecTy.getShape(), positions))) { - int64_t positionVal = positionInDim.cast().getInt(); - elemPosition += positionVal * innerElems; - innerElems *= dimSize; - } - - Attribute newAttr; - if (vecTy.getElementType().isIntOrIndex()) { - auto values = to_vector(dense.getValues()); - newAttr = IntegerAttr::get(extractOp.getType(), values[elemPosition]); - } else if (vecTy.getElementType().isa()) { - auto values = to_vector(dense.getValues()); - newAttr = FloatAttr::get(extractOp.getType(), values[elemPosition]); - } - assert(newAttr && "Unhandled case"); - - rewriter.replaceOpWithNewOp(extractOp, newAttr); + // Calculate the linearized position. + int64_t elemPosition = + linearize(getI64SubArray(positions), computeStrides(vecTy.getShape())); + Attribute elementValue = *(dense.value_begin() + elemPosition); + rewriter.replaceOpWithNewOp(extractOp, elementValue); return success(); } }; -- 2.7.4