From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Tue, 24 Jan 2023 21:23:52 +0000 (-0800) Subject: [mlir][sparse] (re)introducing getRankedTensorType/getMemrefType X-Git-Tag: upstream/17.0.6~19598 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9916ab03f19dc50c688b8567ac0d30b4a6615f9d;p=platform%2Fupstream%2Fllvm.git [mlir][sparse] (re)introducing getRankedTensorType/getMemrefType The bulk of D142074 seems to have gotten overwritten due to some sort of merge conflict (afaict there's no record of it having been reverted intentionally). So this commit redoes those changes. In addition to the original changes, this commit also: * moves the definition of `getRankedTensorType` (from `Transforms/CodegenUtils.h` to `IR/SparseTensor.h`), so that it can be used by `IR/SparseTensorDialect.cpp`. * adds `getMemRefType` as another abbreviation. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D142503 --- diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index c7c0826..777a5b4 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -38,6 +38,18 @@ namespace mlir { namespace sparse_tensor { +/// Convenience method to abbreviate casting `getType()`. +template +inline RankedTensorType getRankedTensorType(T t) { + return t.getType().template cast(); +} + +/// Convenience method to abbreviate casting `getType()`. +template +inline MemRefType getMemRefType(T t) { + return t.getType().template cast(); +} + /// Convenience method to get a sparse encoding attribute from a type. /// Returns null-attribute for any type without an encoding. SparseTensorEncodingAttr getSparseTensorEncoding(Type type); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index f2495da..364c7e7 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -29,6 +29,15 @@ using namespace mlir; using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// +// Additional convenience methods. +//===----------------------------------------------------------------------===// + +template +static inline int64_t getTypeRank(T t) { + return getRankedTensorType(t).getRank(); +} + +//===----------------------------------------------------------------------===// // TensorDialect Attribute Methods. //===----------------------------------------------------------------------===// @@ -525,12 +534,11 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind, //===----------------------------------------------------------------------===// static LogicalResult isInBounds(uint64_t dim, Value tensor) { - return success(dim < - (uint64_t)tensor.getType().cast().getRank()); + return success(dim < static_cast(getTypeRank(tensor))); } static LogicalResult isMatchingWidth(Value result, unsigned width) { - const Type etp = result.getType().cast().getElementType(); + const Type etp = getMemRefType(result).getElementType(); return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); } @@ -562,8 +570,7 @@ static LogicalResult verifySparsifierGetterSetter( } LogicalResult NewOp::verify() { - if (getExpandSymmetry() && - getResult().getType().cast().getRank() != 2) + if (getExpandSymmetry() && getTypeRank(getResult()) != 2) return emitOpError("expand_symmetry can only be used for 2D tensors"); return success(); } @@ -624,8 +631,8 @@ LogicalResult ToIndicesBufferOp::verify() { } LogicalResult ToValuesOp::verify() { - RankedTensorType ttp = getTensor().getType().cast(); - MemRefType mtp = getResult().getType().cast(); + auto ttp = getRankedTensorType(getTensor()); + auto mtp = getMemRefType(getResult()); if (ttp.getElementType() != mtp.getElementType()) return emitError("unexpected mismatch in element types"); return success(); @@ -754,7 +761,7 @@ LogicalResult UnaryOp::verify() { } LogicalResult ConcatenateOp::verify() { - auto dstTp = getType().cast(); + auto dstTp = getRankedTensorType(*this); uint64_t concatDim = getDimension().getZExtValue(); unsigned rank = dstTp.getRank(); @@ -775,8 +782,7 @@ LogicalResult ConcatenateOp::verify() { concatDim)); for (size_t i = 0, e = getInputs().size(); i < e; i++) { - Value input = getInputs()[i]; - auto inputRank = input.getType().cast().getRank(); + const auto inputRank = getTypeRank(getInputs()[i]); if (inputRank != rank) return emitError( llvm::formatv("The input tensor ${0} has a different rank (rank={1}) " @@ -785,15 +791,13 @@ LogicalResult ConcatenateOp::verify() { } for (unsigned i = 0; i < rank; i++) { - auto dstDim = dstTp.getShape()[i]; + const auto dstDim = dstTp.getShape()[i]; if (i == concatDim) { if (!ShapedType::isDynamic(dstDim)) { + // If we reach here, all inputs should have static shapes. unsigned sumDim = 0; - for (auto src : getInputs()) { - // If we reach here, all inputs should have static shapes. - auto d = src.getType().cast().getShape()[i]; - sumDim += d; - } + for (auto src : getInputs()) + sumDim += getRankedTensorType(src).getShape()[i]; // If all dimension are statically known, the sum of all the input // dimensions should be equal to the output dimension. if (sumDim != dstDim) @@ -804,7 +808,7 @@ LogicalResult ConcatenateOp::verify() { } else { int64_t prev = dstDim; for (auto src : getInputs()) { - auto d = src.getType().cast().getShape()[i]; + const auto d = getRankedTensorType(src).getShape()[i]; if (!ShapedType::isDynamic(prev) && d != prev) return emitError("All dimensions (expect for the concatenating one) " "should be equal."); @@ -817,8 +821,7 @@ LogicalResult ConcatenateOp::verify() { } LogicalResult InsertOp::verify() { - RankedTensorType ttp = getTensor().getType().cast(); - if (ttp.getRank() != static_cast(getIndices().size())) + if (getTypeRank(getTensor()) != static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); return success(); } @@ -838,8 +841,7 @@ LogicalResult PushBackOp::verify() { } LogicalResult CompressOp::verify() { - RankedTensorType ttp = getTensor().getType().cast(); - if (ttp.getRank() != 1 + static_cast(getIndices().size())) + if (getTypeRank(getTensor()) != 1 + static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); return success(); } @@ -860,7 +862,7 @@ void ForeachOp::build( // Builds foreach body. if (!bodyBuilder) return; - auto rtp = tensor.getType().cast(); + auto rtp = getRankedTensorType(tensor); int64_t rank = rtp.getRank(); SmallVector blockArgTypes; @@ -886,7 +888,7 @@ void ForeachOp::build( } LogicalResult ForeachOp::verify() { - auto t = getTensor().getType().cast(); + auto t = getRankedTensorType(getTensor()); auto args = getBody()->getArguments(); if (static_cast(t.getRank()) + 1 + getInitArgs().size() != @@ -944,11 +946,11 @@ LogicalResult SortOp::verify() { auto n = getN().getDefiningOp(); - Type xtp = getXs().front().getType().cast().getElementType(); + Type xtp = getMemRefType(getXs().front()).getElementType(); auto checkTypes = [&](ValueRange operands, bool checkEleType = true) -> LogicalResult { for (Value opnd : operands) { - MemRefType mtp = opnd.getType().cast(); + auto mtp = getMemRefType(opnd); int64_t dim = mtp.getShape()[0]; // We can't check the size of dynamic dimension at compile-time, but all // xs and ys should have a dimension not less than n at runtime. @@ -986,7 +988,7 @@ LogicalResult SortCooOp::verify() { } auto checkDim = [&](Value v, uint64_t min, const char *message) { - MemRefType tp = v.getType().cast(); + auto tp = getMemRefType(v); int64_t dim = tp.getShape()[0]; if (!ShapedType::isDynamic(dim) && dim < (int64_t)min) { emitError(llvm::formatv("{0} got {1} < {2}", message, dim, min)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index a73d627..cf2f127 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -558,7 +558,7 @@ Value sparse_tensor::reshapeValuesToLevels( idxBuffer = builder.create( loc, MemRefType::get({rank}, builder.getIndexType()), idxBuffer); SmallVector shape(rank, ShapedType::kDynamic); - Type elemTp = valuesBuffer.getType().cast().getElementType(); + Type elemTp = getMemRefType(valuesBuffer).getElementType(); return builder.create(loc, MemRefType::get(shape, elemTp), valuesBuffer, idxBuffer); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index 8d8b0f8..b07991e 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -78,11 +78,6 @@ StringRef primaryTypeFunctionSuffix(Type elemTp); // Misc code generators and utilities. //===----------------------------------------------------------------------===// -template -inline RankedTensorType getRankedTensorType(T t) { - return t.getType().template cast(); -} - /// Generates a 1-valued attribute of the given type. This supports /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, /// for unsupported types we raise `llvm_unreachable` rather than diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index fc9476c..73b5bd4 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -53,16 +53,15 @@ static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, uint64_t nx, uint64_t ny, bool isCoo, ValueRange operands) { - nameOstream - << namePrefix << nx << "_" - << operands[xStartIdx].getType().cast().getElementType(); + nameOstream << namePrefix << nx << "_" + << getMemRefType(operands[xStartIdx]).getElementType(); if (isCoo) nameOstream << "_coo_" << ny; uint64_t yBufferOffset = isCoo ? 1 : nx; for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) - nameOstream << "_" << v.getType().cast().getElementType(); + nameOstream << "_" << getMemRefType(v).getElementType(); } /// Looks up a function that is appropriate for the given operands being @@ -719,7 +718,7 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, // Convert `values` to have dynamic shape and append them to `operands`. for (Value v : xys) { - auto mtp = v.getType().cast(); + auto mtp = getMemRefType(v); if (!mtp.isDynamicDim(0)) { auto newMtp = MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index aaeb041..074a8de 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -505,8 +505,8 @@ static LogicalResult genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) { Location loc = op.getLoc(); - auto srcTp = op.getSrc().getType().template cast(); - auto dstTp = op.getResult().getType().template cast(); + auto srcTp = getRankedTensorType(op.getSrc()); + auto dstTp = getRankedTensorType(op.getResult()); auto encSrc = getSparseTensorEncoding(srcTp); auto encDst = getSparseTensorEncoding(dstTp); if (!encDst || !encSrc) @@ -888,8 +888,8 @@ public: matchAndRewrite(ConvertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - Type resType = op.getType(); - Type srcType = op.getSource().getType(); + auto resType = getRankedTensorType(op); + auto srcType = getRankedTensorType(op.getSource()); auto encDst = getSparseTensorEncoding(resType); auto encSrc = getSparseTensorEncoding(srcType); Value src = adaptor.getOperands()[0]; @@ -953,10 +953,8 @@ public: // dst[elem.indices] = elem.value; // } // delete iter; - RankedTensorType dstTensorTp = resType.cast(); - RankedTensorType srcTensorTp = srcType.cast(); - unsigned rank = dstTensorTp.getRank(); - Type elemTp = dstTensorTp.getElementType(); + const unsigned rank = resType.getRank(); + const Type elemTp = resType.getElementType(); // Fabricate a no-permutation encoding for NewCallParams // The pointer/index types must be those of `src`. // The dimLevelTypes aren't actually used by Action::kToIterator. @@ -965,16 +963,16 @@ public: SmallVector(rank, DimLevelType::Dense), AffineMap(), AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); SmallVector dimSizes = - getDimSizes(rewriter, loc, encSrc, srcTensorTp, src); + getDimSizes(rewriter, loc, encSrc, srcType, src); Value iter = NewCallParams(rewriter, loc) - .genBuffers(encDst, dimSizes, dstTensorTp) + .genBuffers(encDst, dimSizes, resType) .genNewCall(Action::kToIterator, src); Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); Block *insertionBlock = rewriter.getInsertionBlock(); // TODO: Dense buffers should be allocated/deallocated via the callback // in BufferizationOptions. - Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, dimSizes); + Value dst = allocDenseTensor(rewriter, loc, resType, dimSizes); SmallVector noArgs; SmallVector noTypes; auto whileOp = rewriter.create(loc, noTypes, noArgs); @@ -1192,7 +1190,7 @@ public: // index order. All values are passed by reference through stack // allocated memrefs. Location loc = op->getLoc(); - auto tp = op.getTensor().getType().cast(); + auto tp = getRankedTensorType(op.getTensor()); auto elemTp = tp.getElementType(); unsigned rank = tp.getRank(); auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); @@ -1217,8 +1215,7 @@ public: matchAndRewrite(ExpandOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - RankedTensorType srcType = - op.getTensor().getType().cast(); + auto srcType = getRankedTensorType(op.getTensor()); Type eltType = srcType.getElementType(); Type boolType = rewriter.getIntegerType(1); Type idxType = rewriter.getIndexType(); @@ -1272,7 +1269,7 @@ public: Value added = adaptor.getAdded(); Value count = adaptor.getCount(); Value tensor = adaptor.getTensor(); - auto tp = op.getTensor().getType().cast(); + auto tp = getRankedTensorType(op.getTensor()); Type elemTp = tp.getElementType(); unsigned rank = tp.getRank(); auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); @@ -1326,7 +1323,7 @@ public: // a[ adjustForOffset(elem.indices) ] = elem.value // return a Location loc = op.getLoc(); - auto dstTp = op.getType().cast(); + auto dstTp = getRankedTensorType(op); auto encDst = getSparseTensorEncoding(dstTp); Type elemTp = dstTp.getElementType(); uint64_t concatDim = op.getDimension().getZExtValue(); @@ -1381,7 +1378,7 @@ public: for (auto it : llvm::zip(op.getInputs(), adaptor.getInputs())) { Value orignalOp = std::get<0>(it); // Input (with encoding) from Op Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor - RankedTensorType srcTp = orignalOp.getType().cast(); + auto srcTp = getRankedTensorType(orignalOp); auto encSrc = getSparseTensorEncoding(srcTp); if (encSrc) { genSparseCOOIterationLoop( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index 4b92540..bc05137 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -69,7 +69,7 @@ static VectorType vectorType(VL vl, Type etp) { /// Constructs vector type from pointer. static VectorType vectorType(VL vl, Value ptr) { - return vectorType(vl, ptr.getType().cast().getElementType()); + return vectorType(vl, getMemRefType(ptr).getElementType()); } /// Constructs vector iteration mask.