candidateLayout = getStridedLinearLayoutMap(candidateReduced);
else
candidateLayout = candidateReduced.getAffineMaps().front();
- if (inferredType != candidateLayout) {
+ assert(inferredType.getNumResults() == 1 &&
+ candidateLayout.getNumResults() == 1);
+ if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
+ inferredType.getNumDims() != candidateLayout.getNumDims()) {
+ if (errMsg) {
+ llvm::raw_string_ostream os(*errMsg);
+ os << "inferred type: " << inferredType;
+ }
+ return SubViewVerificationResult::AffineMapMismatch;
+ }
+ // Check that the difference of the affine maps simplifies to 0.
+ AffineExpr diffExpr =
+ inferredType.getResult(0) - candidateLayout.getResult(0);
+ diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(),
+ inferredType.getNumSymbols());
+ auto cst = diffExpr.dyn_cast<AffineConstantExpr>();
+ if (!(cst && cst.getValue() == 0)) {
if (errMsg) {
llvm::raw_string_ostream os(*errMsg);
os << "inferred type: " << inferredType;
/// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
/// the cast source operand type and the SubViewOp static information. This
/// is the resulting type if the MemRefCastOp were folded.
- Type resultType = SubViewOp::inferResultType(
- castOp.source().getType().cast<MemRefType>(),
- extractFromI64ArrayAttr(subViewOp.static_offsets()),
- extractFromI64ArrayAttr(subViewOp.static_sizes()),
- extractFromI64ArrayAttr(subViewOp.static_strides()));
+ auto resultType = SubViewOp::inferResultType(
+ castOp.source().getType().cast<MemRefType>(),
+ extractFromI64ArrayAttr(subViewOp.static_offsets()),
+ extractFromI64ArrayAttr(subViewOp.static_sizes()),
+ extractFromI64ArrayAttr(subViewOp.static_strides()))
+ .cast<MemRefType>();
+ uint32_t rankDiff =
+ subViewOp.getSourceType().getRank() - subViewOp.getType().getRank();
+ if (rankDiff > 0) {
+ auto shape = resultType.getShape();
+ auto projectedShape = shape.drop_front(rankDiff);
+ AffineMap map;
+ auto maps = resultType.getAffineMaps();
+ if (!maps.empty() && maps.front()) {
+ auto optionalUnusedDimsMask =
+ computeRankReductionMask(shape, projectedShape);
+ llvm::SmallDenseSet<unsigned> dimsToProject =
+ optionalUnusedDimsMask.getValue();
+ map = getProjectedMap(maps.front(), dimsToProject);
+ }
+ resultType = MemRefType::get(projectedShape, resultType.getElementType(),
+ map, resultType.getMemorySpace());
+ }
Value newSubView = rewriter.create<SubViewOp>(
subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
%1 = tensor_to_memref %0 : memref<?x?x16x32xi8>
return %1 : memref<?x?x16x32xi8>
}
+
+// CHECK-LABEL: func @subview_of_memcast
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
+// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+// CHECK: %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
+// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}>
+func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
+ memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
+ %0 = memref_cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+ %1 = subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] :
+ memref<?x?x16x32xi8> to
+ memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+ return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+}