From 7c5ecc8b7e1bcd1b02eafeba9bbf3d5bc50d72c5 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 18 Nov 2021 05:41:25 +0000 Subject: [PATCH] [mlir][vector] Insert/extract element can accept index `vector::InsertElementOp` and `vector::ExtractElementOp` have had their `position` operand changed to accept `AnySignlessIntegerOrIndex` for better operability with operations that use `index`, such as affine loops. LLVM's `extractelement` and `insertelement` can also accept `i64`, so lowering directly to these operations without explicitly inserting casts is allowed. SPIRV's equivalent ops can also accept `i64`. Reviewed By: nicolasvasilache, jpienaar Differential Revision: https://reviews.llvm.org/D114139 --- mlir/include/mlir/Dialect/Vector/VectorOps.td | 6 +- mlir/include/mlir/IR/OpBase.td | 10 +++- mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 12 +--- .../Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 5 ++ .../SparseTensor/Transforms/Sparsification.cpp | 6 +- .../Vector/VectorMultiDimReductionTransforms.cpp | 5 +- mlir/lib/Dialect/Vector/VectorOps.cpp | 14 ----- .../Conversion/VectorToLLVM/vector-to-llvm.mlir | 29 ++++++++++ .../test/Conversion/VectorToSCF/vector-to-scf.mlir | 12 ++-- mlir/test/Conversion/VectorToSPIRV/simple.mlir | 16 ++++++ mlir/test/Dialect/SparseTensor/sparse_vector.mlir | 6 +- .../Dialect/SparseTensor/sparse_vector_chain.mlir | 3 +- .../Vector/vector-multi-reduction-lowering.mlir | 64 +++++++++++----------- 13 files changed, 108 insertions(+), 80 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 06bb0a5..bbd45b7 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -482,7 +482,7 @@ def Vector_ExtractElementOp : TypesMatchWith<"result type matches element type of vector operand", "vector", "result", "$_self.cast().getElementType()">]>, - Arguments<(ins AnyVector:$vector, AnySignlessInteger:$position)>, + Arguments<(ins AnyVector:$vector, AnySignlessIntegerOrIndex:$position)>, Results<(outs AnyType:$result)> { let summary = "extractelement operation"; let description = [{ @@ -504,7 +504,6 @@ def Vector_ExtractElementOp : }]; let builders = [ - OpBuilder<(ins "Value":$source, "int64_t":$position)>, OpBuilder<(ins "Value":$source, "Value":$position)> ]; let extraClassDeclaration = [{ @@ -658,7 +657,7 @@ def Vector_InsertElementOp : "$_self.cast().getElementType()">, AllTypesMatch<["dest", "result"]>]>, Arguments<(ins AnyType:$source, AnyVector:$dest, - AnySignlessInteger:$position)>, + AnySignlessIntegerOrIndex:$position)>, Results<(outs AnyVector:$result)> { let summary = "insertelement operation"; let description = [{ @@ -683,7 +682,6 @@ def Vector_InsertElementOp : }]; let builders = [ - OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>, OpBuilder<(ins "Value":$source, "Value":$dest, "Value":$position)> ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 1398bd7..e088d037 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -502,6 +502,10 @@ def Index : Type()">, "index", "::mlir::IndexType">, BuildableType<"$_builder.getIndexType()">; +// Any signless integer type or index type. +def AnySignlessIntegerOrIndex : Type, + "signless integer or index">; + // Floating point types. // Any float type irrespective of its width. @@ -823,9 +827,9 @@ def BoolLike : TypeConstraint.predicate, // Type constraint for signless-integer-like types: signless integers, indices, // vectors of signless integers or indices, tensors of signless integers. def SignlessIntegerLike : TypeConstraint.predicate, - TensorOf<[AnySignlessInteger, Index]>.predicate]>, + AnySignlessIntegerOrIndex.predicate, + VectorOf<[AnySignlessIntegerOrIndex]>.predicate, + TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>, "signless-integer-like">; // Type constraint for float-like types: floats, vectors or tensors thereof. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 2fd4959..6d2c91f 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -123,9 +123,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { return Value(); Location loc = xferOp.getLoc(); - Value ivI32 = b.create( - loc, IntegerType::get(b.getContext(), 32), iv); - return b.create(loc, xferOp.mask(), ivI32); + return b.create(loc, xferOp.mask(), iv); } /// Helper function TransferOpConversion and TransferOp1dConversion. @@ -1115,8 +1113,6 @@ struct Strategy1d { ValueRange loopState) { SmallVector indices; auto dim = get1dMemrefIndices(b, xferOp, iv, indices); - Value ivI32 = b.create( - loc, IntegerType::get(b.getContext(), 32), iv); auto vec = loopState[0]; // In case of out-of-bounds access, leave `vec` as is (was initialized with @@ -1126,7 +1122,7 @@ struct Strategy1d { /*inBoundsCase=*/ [&](OpBuilder &b, Location loc) { Value val = b.create(loc, xferOp.source(), indices); - return b.create(loc, val, vec, ivI32); + return b.create(loc, val, vec, iv); }, /*outOfBoundsCase=*/ [&](OpBuilder & /*b*/, Location loc) { return vec; }); @@ -1148,15 +1144,13 @@ struct Strategy1d { ValueRange /*loopState*/) { SmallVector indices; auto dim = get1dMemrefIndices(b, xferOp, iv, indices); - Value ivI32 = b.create( - loc, IntegerType::get(b.getContext(), 32), iv); // Nothing to do in case of out-of-bounds access. generateInBoundsCheck( b, xferOp, iv, dim, /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { auto val = - b.create(loc, xferOp.vector(), ivI32); + b.create(loc, xferOp.vector(), iv); b.create(loc, val, xferOp.source(), indices); }); b.create(loc); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 99e7a66..7c63fce 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -880,6 +880,11 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { valueTypes.append(op->operand_type_begin(), op->operand_type_end()); valueTypes.append(op->result_type_begin(), op->result_type_end()); + // Ensure that all types have been converted to SPIRV types. + if (llvm::any_of(valueTypes, + [](Type t) { return !t.isa(); })) + return false; + // Special treatment for global variables, whose type requirements are // conveyed by type attributes. if (auto globalVar = dyn_cast(op)) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 31d3ee5..676c86e 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -391,7 +391,8 @@ static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, // Initialize reduction vector to: | 0 | .. | 0 | r | Attribute zero = rewriter.getZeroAttr(vtp); Value vec = rewriter.create(loc, vtp, zero); - return rewriter.create(loc, r, vec, 0); + return rewriter.create( + loc, r, vec, rewriter.create(loc, 0)); } case kProduct: { // Initialize reduction vector to: | 1 | .. | 1 | r | @@ -403,7 +404,8 @@ static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, one = rewriter.getIntegerAttr(etp, 1); Value vec = rewriter.create( loc, vtp, DenseElementsAttr::get(vtp, one)); - return rewriter.create(loc, r, vec, 0); + return rewriter.create( + loc, r, vec, rewriter.create(loc, 0)); } case kAnd: case kOr: diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp index 84102f0..0f1c401 100644 --- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp @@ -347,8 +347,9 @@ struct TwoDimMultiReductionToReduction loc, getElementTypeOrSelf(multiReductionOp.getDestType()), rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v, ValueRange{}); - result = rewriter.create(loc, reducedValue, - result, i); + result = rewriter.create( + loc, reducedValue, result, + rewriter.create(loc, i)); } rewriter.replaceOp(multiReductionOp, result); return success(); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index b03c4ec..3ca8fa0 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -838,13 +838,6 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, result.addTypes(source.getType().cast().getElementType()); } -void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, - Value source, int64_t position) { - Value pos = - builder.create(result.location, position, 32); - build(builder, result, source, pos); -} - static LogicalResult verify(vector::ExtractElementOp op) { VectorType vectorType = op.getVectorType(); if (vectorType.getRank() != 1) @@ -1505,13 +1498,6 @@ void InsertElementOp::build(OpBuilder &builder, OperationState &result, result.addTypes(dest.getType()); } -void InsertElementOp::build(OpBuilder &builder, OperationState &result, - Value source, Value dest, int64_t position) { - Value pos = - builder.create(result.location, position, 32); - build(builder, result, source, dest, pos); -} - static LogicalResult verify(InsertElementOp op) { auto dstVectorType = op.getDestVectorType(); if (dstVectorType.getRank() != 1) diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 7f77516..d5d8509 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -431,6 +431,20 @@ func @extract_element(%arg0: vector<16xf32>) -> f32 { // ----- +func @extract_element_index(%arg0: vector<16xf32>) -> f32 { + %0 = arith.constant 15 : index + %1 = vector.extractelement %arg0[%0 : index]: vector<16xf32> + return %1 : f32 +} +// CHECK-LABEL: @extract_element_index( +// CHECK-SAME: %[[A:.*]]: vector<16xf32>) +// CHECK: %[[c:.*]] = arith.constant 15 : index +// CHECK: %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64 +// CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[i]] : i64] : vector<16xf32> +// CHECK: return %[[x]] : f32 + +// ----- + func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 { %0 = vector.extract %arg0[15]: vector<16xf32> return %0 : f32 @@ -502,6 +516,21 @@ func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { // ----- +func @insert_element_index(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { + %0 = arith.constant 3 : index + %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<4xf32> + return %1 : vector<4xf32> +} +// CHECK-LABEL: @insert_element_index( +// CHECK-SAME: %[[A:.*]]: f32, +// CHECK-SAME: %[[B:.*]]: vector<4xf32>) +// CHECK: %[[c:.*]] = arith.constant 3 : index +// CHECK: %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64 +// CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[i]] : i64] : vector<4xf32> +// CHECK: return %[[x]] : vector<4xf32> + +// ----- + func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { %0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32> return %0 : vector<4xf32> diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir index 9279c47..08b3ffb 100644 --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -8,16 +8,14 @@ func @vector_transfer_ops_0d(%M: memref) { // CHECK: %[[V0:.*]] = arith.constant dense<0{{.*}}> : vector<1xf32> // CHECK: %[[R0:.*]] = scf.for %[[I:.*]] = {{.*}} iter_args(%[[V0_ITER:.*]] = %[[V0]]) -> (vector<1xf32>) { -// CHECK: %[[IDX:.*]] = arith.index_cast %[[I]] : index to i32 // CHECK: %[[S:.*]] = memref.load %[[MEM]][] : memref -// CHECK: %[[R_ITER:.*]] = vector.insertelement %[[S]], %[[V0_ITER]][%[[IDX]] : i32] : vector<1xf32> +// CHECK: %[[R_ITER:.*]] = vector.insertelement %[[S]], %[[V0_ITER]][%[[I]] : index] : vector<1xf32> // CHECK: scf.yield %[[R_ITER]] : vector<1xf32> %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} : memref, vector<1xf32> // CHECK: scf.for %[[J:.*]] = %{{.*}} -// CHECK: %[[JDX:.*]] = arith.index_cast %[[J]] : index to i32 -// CHECK: %[[SS:.*]] = vector.extractelement %[[R0]][%[[JDX]] : i32] : vector<1xf32> +// CHECK: %[[SS:.*]] = vector.extractelement %[[R0]][%[[J]] : index] : vector<1xf32> // CHECK: memref.store %[[SS]], %[[MEM]][] : memref vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} : vector<1xf32>, memref @@ -107,10 +105,9 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK: scf.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] { // CHECK: %[[VEC:.*]] = scf.for %[[I6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {{.*}} -> (vector<3xf32>) { // CHECK: %[[L0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]]) - // CHECK: %[[VIDX:.*]] = arith.index_cast %[[I6]] // CHECK: scf.if {{.*}} -> (vector<3xf32>) { // CHECK-NEXT: %[[SCAL:.*]] = memref.load %{{.*}}[%[[L0]], %[[I1]], %[[I2]], %[[L3]]] : memref - // CHECK-NEXT: %[[RVEC:.*]] = vector.insertelement %[[SCAL]], %{{.*}}[%[[VIDX]] : i32] : vector<3xf32> + // CHECK-NEXT: %[[RVEC:.*]] = vector.insertelement %[[SCAL]], %{{.*}}[%[[I6]] : index] : vector<3xf32> // CHECK-NEXT: scf.yield // CHECK-NEXT: } else { // CHECK-NEXT: scf.yield @@ -181,9 +178,8 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK: %[[VEC:.*]] = memref.load %[[VECTOR_VIEW2]][%[[I4]], %[[I5]]] : memref<5x4xvector<3xf32>> // CHECK: scf.for %[[I6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] { // CHECK: %[[S0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]]) - // CHECK: %[[VIDX:.*]] = arith.index_cast %[[I6]] // CHECK: scf.if - // CHECK: %[[SCAL:.*]] = vector.extractelement %[[VEC]][%[[VIDX]] : i32] : vector<3xf32> + // CHECK: %[[SCAL:.*]] = vector.extractelement %[[VEC]][%[[I6]] : index] : vector<3xf32> // CHECK: memref.store %[[SCAL]], {{.*}}[%[[S0]], %[[S1]], %[[I2]], %[[S3]]] : memref // CHECK: } // CHECK: } diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir index 4a471a4..a253fc7 100644 --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -80,6 +80,14 @@ func @extract_element(%arg0 : vector<4xf32>, %id : i32) { // ----- +func @extract_element_index(%arg0 : vector<4xf32>, %id : index) { +// expected-error @+1 {{failed to legalize operation 'vector.extractelement'}} + %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> + spv.ReturnValue %0: f32 +} + +// ----- + func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) { // expected-error @+1 {{failed to legalize operation 'vector.extractelement'}} %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32> @@ -110,6 +118,14 @@ func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) { // ----- +func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) { +// expected-error @+1 {{failed to legalize operation 'vector.insertelement'}} + %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> + spv.ReturnValue %0: vector<4xf32> +} + +// ----- + func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) { // expected-error @+1 {{failed to legalize operation 'vector.insertelement'}} %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32> diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir index 3752cdb..7f27922 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir @@ -210,12 +210,11 @@ func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %ar // // CHECK-VEC1-LABEL: func @reduction_d // CHECK-VEC1-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC1-DAG: %[[i0:.*]] = arith.constant 0 : i32 // CHECK-VEC1-DAG: %[[c16:.*]] = arith.constant 16 : index // CHECK-VEC1-DAG: %[[c1024:.*]] = arith.constant 1024 : index // CHECK-VEC1-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> // CHECK-VEC1: %[[l:.*]] = memref.load %{{.*}}[] : memref -// CHECK-VEC1: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[i0]] : i32] : vector<16xf32> +// CHECK-VEC1: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32> // CHECK-VEC1: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) { // CHECK-VEC1: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> // CHECK-VEC1: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> @@ -228,12 +227,11 @@ func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %ar // // CHECK-VEC2-LABEL: func @reduction_d // CHECK-VEC2-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC2-DAG: %[[i0:.*]] = arith.constant 0 : i32 // CHECK-VEC2-DAG: %[[c16:.*]] = arith.constant 16 : index // CHECK-VEC2-DAG: %[[c1024:.*]] = arith.constant 1024 : index // CHECK-VEC2-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> // CHECK-VEC2: %[[l:.*]] = memref.load %{{.*}}[] : memref -// CHECK-VEC2: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[i0]] : i32] : vector<16xf32> +// CHECK-VEC2: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32> // CHECK-VEC2: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) { // CHECK-VEC2: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> // CHECK-VEC2: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir index ab69441..030eff2 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir @@ -23,7 +23,6 @@ // CHECK-SAME: %[[VAL_2:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor { // CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<8xf64> // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 64 : index // CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index @@ -89,7 +88,7 @@ // CHECK: %[[VAL_63:.*]] = select %[[VAL_61]], %[[VAL_62]], %[[VAL_34]] : index // CHECK: scf.yield %[[VAL_60]], %[[VAL_63]], %[[VAL_64:.*]] : index, index, f64 // CHECK: } -// CHECK: %[[VAL_65:.*]] = vector.insertelement %[[VAL_66:.*]]#2, %[[VAL_3]]{{\[}}%[[VAL_5]] : i32] : vector<8xf64> +// CHECK: %[[VAL_65:.*]] = vector.insertelement %[[VAL_66:.*]]#2, %[[VAL_3]]{{\[}}%[[VAL_6]] : index] : vector<8xf64> // CHECK: %[[VAL_67:.*]] = scf.for %[[VAL_68:.*]] = %[[VAL_66]]#0 to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_69:.*]] = %[[VAL_65]]) -> (vector<8xf64>) { // CHECK: %[[VAL_70:.*]] = affine.min #map(%[[VAL_22]], %[[VAL_68]]) // CHECK: %[[VAL_71:.*]] = vector.create_mask %[[VAL_70]] : vector<8xi1> diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir index 42a90f9..7e6c171 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -7,14 +7,14 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> { // CHECK-LABEL: func @vector_multi_reduction // CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32> // CHECK: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32> -// CHECK: %[[C0:.+]] = arith.constant 0 : i32 -// CHECK: %[[C1:.+]] = arith.constant 1 : i32 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0] // CHECK: %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<4xf32> into f32 -// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<2xf32> +// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32> // CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1] // CHECK: %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<4xf32> into f32 -// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<2xf32> +// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 { @@ -36,31 +36,31 @@ func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> { // CHECK-LABEL: func @vector_reduction_inner // CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32> // CHECK: %[[FLAT_RESULT_VEC_0:.+]] = arith.constant dense<0> : vector<6xi32> -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32 -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : i32 -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i32 -// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index // CHECK: %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32> // CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<6x20xi32> // CHECK: %[[V0R:.+]] = vector.reduction "add", %[[V0]] : vector<20xi32> into i32 -// CHECK: %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : i32] : vector<6xi32> +// CHECK: %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : index] : vector<6xi32> // CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<6x20xi32> // CHECK: %[[V1R:.+]] = vector.reduction "add", %[[V1]] : vector<20xi32> into i32 -// CHECK: %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : i32] : vector<6xi32> +// CHECK: %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : index] : vector<6xi32> // CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<6x20xi32> // CHECK: %[[V2R:.+]] = vector.reduction "add", %[[V2]] : vector<20xi32> into i32 -// CHECK: %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : i32] : vector<6xi32> +// CHECK: %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : index] : vector<6xi32> // CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<6x20xi32> // CHECK: %[[V3R:.+]] = vector.reduction "add", %[[V3]] : vector<20xi32> into i32 -// CHECK: %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : i32] : vector<6xi32> +// CHECK: %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : index] : vector<6xi32> // CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<6x20xi32> // CHECK: %[[V4R:.+]] = vector.reduction "add", %[[V4]] : vector<20xi32> into i32 -// CHECK: %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : i32] : vector<6xi32> +// CHECK: %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : index] : vector<6xi32> /// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<6x20xi32> // CHECK: %[[V5R:.+]] = vector.reduction "add", %[[V5]] : vector<20xi32> into i32 -// CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : i32] : vector<6xi32> +// CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : index] : vector<6xi32> // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32> // CHECK: return %[[RESULT]] @@ -84,38 +84,38 @@ func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf3 // CHECK-LABEL: func @vector_multi_reduction_ordering // CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32> // CHECK: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<8xf32> -// CHECK: %[[C0:.+]] = arith.constant 0 : i32 -// CHECK: %[[C1:.+]] = arith.constant 1 : i32 -// CHECK: %[[C2:.+]] = arith.constant 2 : i32 -// CHECK: %[[C3:.+]] = arith.constant 3 : i32 -// CHECK: %[[C4:.+]] = arith.constant 4 : i32 -// CHECK: %[[C5:.+]] = arith.constant 5 : i32 -// CHECK: %[[C6:.+]] = arith.constant 6 : i32 -// CHECK: %[[C7:.+]] = arith.constant 7 : i32 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[C3:.+]] = arith.constant 3 : index +// CHECK: %[[C4:.+]] = arith.constant 4 : index +// CHECK: %[[C5:.+]] = arith.constant 5 : index +// CHECK: %[[C6:.+]] = arith.constant 6 : index +// CHECK: %[[C7:.+]] = arith.constant 7 : index // CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0] // CHECK: %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<3xf32> into f32 -// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<8xf32> +// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<8xf32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1] // CHECK: %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<3xf32> into f32 -// CHECK: %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<8xf32> +// CHECK: %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<8xf32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2] // CHECK: %[[RV2:.+]] = vector.reduction "mul", %[[V2]] : vector<3xf32> into f32 -// CHECK: %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : i32] : vector<8xf32> +// CHECK: %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : index] : vector<8xf32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3] // CHECK: %[[RV3:.+]] = vector.reduction "mul", %[[V3]] : vector<3xf32> into f32 -// CHECK: %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : i32] : vector<8xf32> +// CHECK: %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : index] : vector<8xf32> // CHECK: %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0] // CHECK: %[[RV4:.+]] = vector.reduction "mul", %[[V4]] : vector<3xf32> into f32 -// CHECK: %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : i32] : vector<8xf32> +// CHECK: %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : index] : vector<8xf32> // CHECK: %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1] // CHECK: %[[RV5:.+]] = vector.reduction "mul", %[[V5]] : vector<3xf32> into f32 -// CHECK: %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : i32] : vector<8xf32> +// CHECK: %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : index] : vector<8xf32> // CHECK: %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2] // CHECK: %[[RV6:.+]] = vector.reduction "mul", %[[V6]] : vector<3xf32> into f32 -// CHECK: %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : i32] : vector<8xf32> +// CHECK: %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : index] : vector<8xf32> // CHECK: %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3] // CHECK: %[[RV7:.+]] = vector.reduction "mul", %[[V7]] : vector<3xf32> into f32 -// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : i32] : vector<8xf32> +// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32> // CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32> // CHECK: return %[[RESHAPED_VEC]] -- 2.7.4