if (!sourceElementTy || !targetDescTy)
return matchFailure();
- // Early exit for 0-D and operands lesser than `rank` corner cases.
+ // Currently, only rank > 0 and full or no operands are supported. Fail to
+ // convert otherwise.
unsigned rank = sourceMemRefType.getRank();
- if (viewMemRefType.getRank() == 0 || rank != dynamicOffsets.size() ||
- rank != dynamicSizes.size() || rank != dynamicStrides.size())
+ if (viewMemRefType.getRank() == 0 || (rank != dynamicOffsets.size()) ||
+ (!dynamicSizes.empty() && rank != dynamicSizes.size()) ||
+ (!dynamicStrides.empty() && rank != dynamicStrides.size()))
return matchFailure();
int64_t offset;
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
+ // Fill in missing dynamic sizes.
+ auto llvmIndexType = lowering.convertType(rewriter.getIndexType());
+ if (dynamicSizes.empty()) {
+ dynamicSizes.reserve(viewMemRefType.getRank());
+ auto shape = viewMemRefType.getShape();
+ for (auto extent : shape) {
+ dynamicSizes.push_back(rewriter.create<LLVM::ConstantOp>(
+ loc, llvmIndexType, rewriter.getI64IntegerAttr(extent)));
+ }
+ }
+
// Offset.
Value *baseOffset = sourceMemRef.offset(rewriter, loc);
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
// Update sizes and strides.
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]);
- targetMemRef.setStride(rewriter, loc, i,
- rewriter.create<LLVM::MulOp>(
- loc, dynamicStrides[i], strideValues[i]));
+ Value *newStride;
+ if (dynamicStrides.empty())
+ newStride = rewriter.create<LLVM::ConstantOp>(
+ loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
+ else
+ newStride = rewriter.create<LLVM::MulOp>(loc, dynamicStrides[i],
+ strideValues[i]);
+ targetMemRef.setStride(rewriter, loc, i, newStride);
}
rewriter.replaceOp(op, {targetMemRef});
memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
return
}
+
+// CHECK-LABEL: func @subview_const_size(
+// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
+func @subview_const_size(%0 : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>, %arg0 : index, %arg1 : index, %arg2 : index) {
+ // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+ // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
+ // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64)
+ // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64
+ // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i64
+ // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64
+ // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i64
+ // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64
+ // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64
+ // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %1 = subview %0[%arg0, %arg1][][%arg0, %arg1] :
+ memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref<4x2xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
+ return
+}
+
+// CHECK-LABEL: func @subview_const_stride(
+// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
+func @subview_const_stride(%0 : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>, %arg0 : index, %arg1 : index, %arg2 : index) {
+ // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+ // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64
+ // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i64
+ // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64
+ // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i64
+ // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64)
+ // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST2]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
+ // CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %1 = subview %0[%arg0, %arg1][%arg0, %arg1][] :
+ memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 * 2 + s0)>
+ return
+}