From 18a2f479bf475c7cb94b742b39290bef98e3d305 Mon Sep 17 00:00:00 2001 From: Vladislav Vinogradov Date: Tue, 23 Mar 2021 11:45:24 +0300 Subject: [PATCH] [mlir][NFC] Replace `getMemorySpaceAsInt` with `getMemorySpace` where possible Use new `MemRefType::getMemorySpace` method with generic Attribute in cases, where there is no specific logic around the memory space. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D99154 --- mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 8 +++----- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 20 +++++++++----------- mlir/lib/Dialect/Vector/VectorOps.cpp | 6 +++--- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 1c8e05b..f95193f 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -116,11 +116,9 @@ public: VectorType::get(vectorType.getShape().take_back(minorRank), vectorType.getElementType()); /// Memref of minor vector type is used for individual transfers. - memRefMinorVectorType = - MemRefType::get(majorVectorType.getShape(), minorVectorType, {}, - xferOp.getShapedType() - .template cast() - .getMemorySpaceAsInt()); + memRefMinorVectorType = MemRefType::get( + majorVectorType.getShape(), minorVectorType, {}, + xferOp.getShapedType().template cast().getMemorySpace()); } LogicalResult doReplace(); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e0e273d..546c43a 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -420,7 +420,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) return false; } - if (aT.getMemorySpaceAsInt() != bT.getMemorySpaceAsInt()) + if (aT.getMemorySpace() != bT.getMemorySpace()) return false; // They must have the same rank, and any specified dimensions must match. @@ -447,10 +447,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (aEltType != bEltType) return false; - auto aMemSpace = - (aT) ? aT.getMemorySpaceAsInt() : uaT.getMemorySpaceAsInt(); - auto bMemSpace = - (bT) ? bT.getMemorySpaceAsInt() : ubT.getMemorySpaceAsInt(); + auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); + auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); if (aMemSpace != bMemSpace) return false; @@ -1204,7 +1202,7 @@ static LogicalResult verify(ReinterpretCastOp op) { // The source and result memrefs should be in the same memory space. auto srcType = op.source().getType().cast(); auto resultType = op.getType().cast(); - if (srcType.getMemorySpaceAsInt() != resultType.getMemorySpaceAsInt()) + if (srcType.getMemorySpace() != resultType.getMemorySpace()) return op.emitError("different memory spaces specified for source type ") << srcType << " and result memref type " << resultType; if (srcType.getElementType() != resultType.getElementType()) @@ -1389,7 +1387,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType, staticSizes, sourceMemRefType.getElementType(), makeStridedLinearLayoutMap(targetStrides, targetOffset, sourceMemRefType.getContext()), - sourceMemRefType.getMemorySpaceAsInt()); + sourceMemRefType.getMemorySpace()); } Type SubViewOp::inferResultType(MemRefType sourceMemRefType, @@ -1435,7 +1433,7 @@ Type SubViewOp::inferRankReducedResultType( map = getProjectedMap(maps.front(), dimsToProject); inferredType = MemRefType::get(projectedShape, inferredType.getElementType(), map, - inferredType.getMemorySpaceAsInt()); + inferredType.getMemorySpace()); } return inferredType; } @@ -1613,7 +1611,7 @@ isRankReducedType(Type originalType, Type candidateReducedType, // Strided layout logic is relevant for MemRefType only. MemRefType original = originalType.cast(); MemRefType candidateReduced = candidateReducedType.cast(); - if (original.getMemorySpaceAsInt() != candidateReduced.getMemorySpaceAsInt()) + if (original.getMemorySpace() != candidateReduced.getMemorySpace()) return SubViewVerificationResult::MemSpaceMismatch; llvm::SmallDenseSet unusedDims = optionalUnusedDimsMask.getValue(); @@ -1687,7 +1685,7 @@ static LogicalResult verify(SubViewOp op) { MemRefType subViewType = op.getType(); // The base memref and the view memref should be in the same memory space. - if (baseType.getMemorySpaceAsInt() != subViewType.getMemorySpaceAsInt()) + if (baseType.getMemorySpace() != subViewType.getMemorySpace()) return op.emitError("different memory spaces specified for base memref " "type ") << baseType << " and subview memref type " << subViewType; @@ -1979,7 +1977,7 @@ static LogicalResult verify(ViewOp op) { return op.emitError("unsupported map for result memref type ") << viewType; // The base memref and the view memref should be in the same memory space. - if (baseType.getMemorySpaceAsInt() != viewType.getMemorySpaceAsInt()) + if (baseType.getMemorySpace() != viewType.getMemorySpace()) return op.emitError("different memory spaces specified for base memref " "type ") << baseType << " and view memref type " << viewType; diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index d1703ca..9079f99 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -3235,7 +3235,7 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result, VectorType::get(extractShape(memRefType), getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); result.addTypes( - MemRefType::get({}, vectorType, {}, memRefType.getMemorySpaceAsInt())); + MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace())); } static LogicalResult verify(TypeCastOp op) { @@ -3244,8 +3244,8 @@ static LogicalResult verify(TypeCastOp op) { return op.emitOpError("expects operand to be a memref with no layout"); if (!op.getResultMemRefType().getAffineMaps().empty()) return op.emitOpError("expects result to be a memref with no layout"); - if (op.getResultMemRefType().getMemorySpaceAsInt() != - op.getMemRefType().getMemorySpaceAsInt()) + if (op.getResultMemRefType().getMemorySpace() != + op.getMemRefType().getMemorySpace()) return op.emitOpError("expects result in same memory space"); auto sourceType = op.getMemRefType(); -- 2.7.4