PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
+ SliceOpOperandAdaptor adaptor(operands);
auto sliceOp = cast<SliceOp>(op);
auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
auto viewType = sliceOp.getBaseViewType();
// Declare the view descriptor and insert data ptr.
Value *desc = undef(viewDescriptorTy);
desc = insertvalue(viewDescriptorTy, desc,
- getViewPtr(viewType, operands[0]), pos(0));
+ getViewPtr(viewType, adaptor.view()), pos(0));
// TODO(ntv): extract sizes and emit asserts.
SmallVector<Value *, 4> strides(viewType.getRank());
- for (int dim = 0, e = viewType.getRank(); dim < e; ++dim) {
- strides[dim] = extractvalue(int64Ty, operands[0], pos({3, dim}));
+ for (int i = 0, e = viewType.getRank(); i < e; ++i) {
+ strides[i] = extractvalue(int64Ty, adaptor.view(), pos({3, i}));
}
// Compute and insert base offset.
- Value *baseOffset = extractvalue(int64Ty, operands[0], pos(1));
- for (int j = 0, e = viewType.getRank(); j < e; ++j) {
- Value *indexing = operands[1 + j];
+ Value *baseOffset = extractvalue(int64Ty, adaptor.view(), pos(1));
+ for (int i = 0, e = viewType.getRank(); i < e; ++i) {
+ Value *indexing = adaptor.indexings()[i];
Value *min =
- sliceOp.indexing(j)->getType().isa<RangeType>()
+ sliceOp.indexing(i)->getType().isa<RangeType>()
? static_cast<Value *>(extractvalue(int64Ty, indexing, pos(0)))
: indexing;
- Value *product = mul(min, strides[j]);
+ Value *product = mul(min, strides[i]);
baseOffset = add(baseOffset, product);
}
desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
- // Compute and insert view sizes (max - min along the range). Skip the
- // non-range operands as they will be projected away from the view.
- int i = 0, j = 0;
- for (Value *index : sliceOp.indexings()) {
- if (!index->getType().isa<RangeType>()) {
- ++j;
- continue;
+ // Compute and insert view sizes (max - min along the range) and strides.
+ // Skip the non-range operands as they will be projected away from the view.
+ int numNewDims = 0;
+ for (auto en : llvm::enumerate(sliceOp.indexings())) {
+ Value *indexing = en.value();
+ if (indexing->getType().isa<RangeType>()) {
+ int i = en.index();
+ Value *rangeDescriptor = adaptor.indexings()[i];
+ Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+ Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+ Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+ Value *size = sub(max, min);
+ Value *stride = mul(strides[i], step);
+ desc = insertvalue(viewDescriptorTy, desc, size, pos({2, numNewDims}));
+ desc =
+ insertvalue(viewDescriptorTy, desc, stride, pos({3, numNewDims}));
+ ++numNewDims;
}
-
- Value *rangeDescriptor = operands[1 + j];
- Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
- Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
- Value *size = sub(max, min);
-
- desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
- ++i;
- ++j;
- }
-
- // Compute and insert view strides. Step over the strides that correspond
- // to non-range operands as they are projected away from the view.
- i = 0;
- for (int j = 0, e = strides.size(); j < e; ++j) {
- if (!sliceOp.indexing(j)->getType().isa<RangeType>())
- continue;
- Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
- Value *stride = mul(strides[j], step);
- desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
- ++i;
}
rewriter.replaceOp(op, desc);
// CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: %{{.*}} = llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: %{{.*}} = llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
-// CHECK-NEXT: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
-// CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: %{{.*}} = llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
+// CHECK-NEXT: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
func @dot(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<f32>) {