return LLVM::LLVMPointerType::get(converted);
}
-
// Function types are converted to LLVM Function types by recursively converting
// argument and result types. If MLIR Function has zero results, the LLVM
// Function has one VoidType result. If MLIR Function has more than one result,
auto result = getStridesAndOffset(type, strides, offset);
(void)result;
assert(succeeded(result) && "unexpected failure in stride computation");
- assert(offset != MemRefType::getDynamicStrideOrOffset() &&
+ assert(!MemRefType::isDynamicStrideOrOffset(offset) &&
"expected static offset");
- assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) &&
- "expected static strides");
+ assert(!llvm::any_of(strides, [](int64_t stride) {
+ return MemRefType::isDynamicStrideOrOffset(stride);
+ }) && "expected static strides");
auto convertedType = typeConverter.convertType(type);
assert(convertedType && "unexpected failure in memref type conversion");
Value index;
if (offset != 0) // Skip if offset is zero.
- index = offset == MemRefType::getDynamicStrideOrOffset()
+ index = MemRefType::isDynamicStrideOrOffset(offset)
? memRefDescriptor.offset(rewriter, loc)
: createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
Value increment = indices[i];
if (strides[i] != 1) { // Skip if stride is 1.
- Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
+ Value stride = MemRefType::isDynamicStrideOrOffset(strides[i])
? memRefDescriptor.stride(rewriter, loc, i)
: createIndexConstant(rewriter, loc, strides[i]);
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
extracted);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
- // Copy the buffer pointer from the old descriptor to the new one.
+ // Copy the aligned pointer from the old descriptor to the new one.
extracted = sourceMemRef.alignedPtr(rewriter, loc);
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc,
ArrayRef<int64_t> strides, Value nextSize,
Value runningStride, unsigned idx) const {
assert(idx < strides.size());
- if (strides[idx] != MemRefType::getDynamicStrideOrOffset())
+ if (!MemRefType::isDynamicStrideOrOffset(strides[idx]))
return createIndexConstant(rewriter, loc, strides[idx]);
if (nextSize)
return runningStride
/// This function is slight variant of `is subsequence` algorithm where
/// not matching dimension must be 1.
static SubViewVerificationResult isRankReducedType(Type originalType,
- Type reducedType) {
- if (originalType == reducedType)
+ Type candidateReducedType) {
+ if (originalType == candidateReducedType)
return SubViewVerificationResult::Success;
if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
return SubViewVerificationResult::Success;
if (originalType.isa<RankedTensorType>() &&
- !reducedType.isa<RankedTensorType>())
+ !candidateReducedType.isa<RankedTensorType>())
return SubViewVerificationResult::Success;
- if (originalType.isa<MemRefType>() && !reducedType.isa<MemRefType>())
+ if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
return SubViewVerificationResult::Success;
ShapedType originalShapedType = originalType.cast<ShapedType>();
- ShapedType reducedShapedType = reducedType.cast<ShapedType>();
+ ShapedType candidateReducedShapedType =
+ candidateReducedType.cast<ShapedType>();
// Rank and size logic is valid for all ShapedTypes.
ArrayRef<int64_t> originalShape = originalShapedType.getShape();
- ArrayRef<int64_t> reducedShape = reducedShapedType.getShape();
+ ArrayRef<int64_t> candidateReducedShape =
+ candidateReducedShapedType.getShape();
unsigned originalRank = originalShape.size(),
- reducedRank = reducedShape.size();
- if (reducedRank > originalRank)
+ candidateReducedRank = candidateReducedShape.size();
+ if (candidateReducedRank > originalRank)
return SubViewVerificationResult::RankTooLarge;
- auto optionalMask = computeRankReductionMask(originalShape, reducedShape);
+ auto optionalMask =
+ computeRankReductionMask(originalShape, candidateReducedShape);
// Sizes cannot be matched in case empty vector is returned.
if (!optionalMask.hasValue())
// Strided layout logic is relevant for MemRefType only.
MemRefType original = originalType.cast<MemRefType>();
- MemRefType reduced = reducedType.cast<MemRefType>();
+ MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
MLIRContext *c = original.getContext();
- int64_t originalOffset, reducedOffset;
- SmallVector<int64_t, 4> originalStrides, reducedStrides, keepStrides;
+ int64_t originalOffset, candidateReducedOffset;
+ SmallVector<int64_t, 4> originalStrides, candidateReducedStrides, keepStrides;
SmallVector<bool, 4> keepMask = optionalMask.getValue();
getStridesAndOffset(original, originalStrides, originalOffset);
- getStridesAndOffset(reduced, reducedStrides, reducedOffset);
+ getStridesAndOffset(candidateReduced, candidateReducedStrides,
+ candidateReducedOffset);
// Filter strides based on the mask and check that they are the same
- // as reduced ones.
- unsigned reducedIdx = 0;
+ // as candidateReduced ones.
+ unsigned candidateReducedIdx = 0;
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
if (keepMask[originalIdx]) {
- if (originalStrides[originalIdx] != reducedStrides[reducedIdx++])
+ if (originalStrides[originalIdx] !=
+ candidateReducedStrides[candidateReducedIdx++])
return SubViewVerificationResult::StrideMismatch;
keepStrides.push_back(originalStrides[originalIdx]);
}
}
- if (original.getElementType() != reduced.getElementType())
+ if (original.getElementType() != candidateReduced.getElementType())
return SubViewVerificationResult::ElemTypeMismatch;
- if (original.getMemorySpace() != reduced.getMemorySpace())
+ if (original.getMemorySpace() != candidateReduced.getMemorySpace())
return SubViewVerificationResult::MemSpaceMismatch;
+ // reducedMap is obtained by projecting away the dimensions inferred from
+ // matching the 1's positions in candidateReducedType.
auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c);
- if (!reduced.getAffineMaps().empty() &&
- reducedMap != reduced.getAffineMaps().front())
+
+ MemRefType expectedReducedType = MemRefType::get(
+ candidateReduced.getShape(), candidateReduced.getElementType(),
+ reducedMap, candidateReduced.getMemorySpace());
+ expectedReducedType = canonicalizeStridedLayout(expectedReducedType);
+
+ if (expectedReducedType != canonicalizeStridedLayout(candidateReduced))
return SubViewVerificationResult::AffineMapMismatch;
return SubViewVerificationResult::Success;