Arguments<(ins AnyMemRef:$base,
VectorOfRankAndType<[1], [AnyInteger]>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
- Variadic<VectorOfRank<[1]>>:$pass_thru)>,
+ VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
let summary = "gathers elements from memory into a vector as defined by an index vector and mask";
The gather operation gathers elements from memory into a 1-D vector as
defined by a base and a 1-D index vector, but only if the corresponding
bit is set in a 1-D mask vector. Otherwise, the element is taken from a
- 1-D pass-through vector, if provided, or left undefined. Informally the
- semantics are:
+ 1-D pass-through vector. Informally the semantics are:
```
- if (!defined(pass_thru)) pass_thru = [undef, .., undef]
result[0] := mask[0] ? base[index[0]] : pass_thru[0]
result[1] := mask[1] ? base[index[1]] : pass_thru[1]
etc.
Example:
```mlir
- %g = vector.gather %base, %indices, %mask, %pass_thru
- : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+ %g = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
}];
let extraClassDeclaration = [{
return mask().getType().cast<VectorType>();
}
VectorType getPassThruVectorType() {
- return (llvm::size(pass_thru()) == 0)
- ? VectorType()
- : (*pass_thru().begin()).getType().cast<VectorType>();
+ return pass_thru().getType().cast<VectorType>();
}
VectorType getResultVectorType() {
return result().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
+ let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
+ "type($base) `,` type($indices) `,` type($mask) `,` type($pass_thru) `into` type($result)";
let hasCanonicalizer = 1;
}
Example:
```mlir
- vector.scatter %base, %indices, %mask, %value
- : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
```
}];
let extraClassDeclaration = [{
return value().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "$base `,` $indices `,` $mask `,` $value attr-dict `:` "
- "type($indices) `,` type($mask) `,` type($value) `into` type($base)";
+ let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
+ "type($base) `,` type($indices) `,` type($mask) `,` type($value)";
let hasCanonicalizer = 1;
}
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
-func @gather8(%base: memref<?xf32>,
- %indices: vector<8xi32>, %mask: vector<8xi1>) -> vector<8xf32> {
- %g = vector.gather %base, %indices, %mask
- : (memref<?xf32>, vector<8xi32>, vector<8xi1>) -> vector<8xf32>
- return %g : vector<8xf32>
-}
-
-func @gather_pass_thru8(%base: memref<?xf32>,
- %indices: vector<8xi32>, %mask: vector<8xi1>,
- %pass_thru: vector<8xf32>) -> vector<8xf32> {
- %g = vector.gather %base, %indices, %mask, %pass_thru
- : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> vector<8xf32>
+func @gather8(%base: memref<?xf32>, %indices: vector<8xi32>,
+ %mask: vector<8xi1>, %pass_thru: vector<8xf32>) -> vector<8xf32> {
+ %g = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
return %g : vector<8xf32>
}
// Gather tests.
//
- %g1 = call @gather8(%A, %idx, %all)
- : (memref<?xf32>, vector<8xi32>, vector<8xi1>)
+ %g1 = call @gather8(%A, %idx, %all, %pass)
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
-> (vector<8xf32>)
vector.print %g1 : vector<8xf32>
// CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 )
- %g2 = call @gather_pass_thru8(%A, %idx, %none, %pass)
+ %g2 = call @gather8(%A, %idx, %none, %pass)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
-> (vector<8xf32>)
vector.print %g2 : vector<8xf32>
// CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7 )
- %g3 = call @gather_pass_thru8(%A, %idx, %some, %pass)
+ %g3 = call @gather8(%A, %idx, %some, %pass)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
-> (vector<8xf32>)
vector.print %g3 : vector<8xf32>
// CHECK: ( 0, 6, 1, 3, -7, -7, -7, -7 )
- %g4 = call @gather_pass_thru8(%A, %idx, %more, %pass)
+ %g4 = call @gather8(%A, %idx, %more, %pass)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
-> (vector<8xf32>)
vector.print %g4 : vector<8xf32>
// CHECK: ( 0, 6, 1, 3, -7, -7, -7, 2 )
- %g5 = call @gather_pass_thru8(%A, %idx, %all, %pass)
+ %g5 = call @gather8(%A, %idx, %all, %pass)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
-> (vector<8xf32>)
vector.print %g5 : vector<8xf32>
func @scatter8(%base: memref<?xf32>,
%indices: vector<8xi32>,
%mask: vector<8xi1>, %value: vector<8xf32>) {
- vector.scatter %base, %indices, %mask, %value
- : vector<8xi32>, vector<8xi1>, vector<8xf32> into memref<?xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>
return
}
%cn = constant 8 : index
%f0 = constant 0.0 : f32
%mask = vector.constant_mask [4] : vector<4xi1>
+ %pass = vector.broadcast %f0 : f32 to vector<4xf32>
scf.for %i = %c0 to %cn step %c1 {
%aval = load %AVAL[%i] : memref<8xvector<4xf32>>
%aidx = load %AIDX[%i] : memref<8xvector<4xi32>>
- %0 = vector.gather %X, %aidx, %mask
- : (memref<?xf32>, vector<4xi32>, vector<4xi1>) -> vector<4xf32>
+ %0 = vector.gather %X[%aidx], %mask, %pass
+ : memref<?xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%1 = vector.contract #dot_trait %aval, %0, %f0 : vector<4xf32>, vector<4xf32> into f32
store %1, %B[%i] : memref<?xf32>
}
%c0 = constant 0 : index
%c1 = constant 1 : index
%cn = constant 4 : index
+ %f0 = constant 0.0 : f32
%mask = vector.constant_mask [8] : vector<8xi1>
+ %pass = vector.broadcast %f0 : f32 to vector<8xf32>
%b = load %B[%c0] : memref<1xvector<8xf32>>
%b_out = scf.for %k = %c0 to %cn step %c1 iter_args(%b_iter = %b) -> (vector<8xf32>) {
%aval = load %AVAL[%k] : memref<4xvector<8xf32>>
%aidx = load %AIDX[%k] : memref<4xvector<8xi32>>
- %0 = vector.gather %X, %aidx, %mask : (memref<?xf32>, vector<8xi32>, vector<8xi1>) -> vector<8xf32>
+ %0 = vector.gather %X[%aidx], %mask, %pass
+ : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
%b_new = vector.fma %aval, %0, %b_iter : vector<8xf32>
scf.yield %b_new : vector<8xf32>
}
return op.emitOpError("expected result dim to match indices dim");
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected result dim to match mask dim");
- if (llvm::size(op.pass_thru()) != 0) {
- VectorType passVType = op.getPassThruVectorType();
- if (resVType != passVType)
- return op.emitOpError("expected pass_thru of same type as result type");
- }
+ if (resVType != op.getPassThruVectorType())
+ return op.emitOpError("expected pass_thru of same type as result type");
return success();
}
// CHECK: llvm.return
func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
- %0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+ %0 = vector.gather %arg0[%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
return %0 : vector<3xf32>
}
// CHECK: llvm.return %[[G]] : !llvm.vec<3 x f32>
func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
- vector.scatter %arg0, %arg1, %arg2, %arg3 : vector<3xi32>, vector<3xi1>, vector<3xf32> into memref<?xf32>
+ vector.scatter %arg0[%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
return
}
// -----
-func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
// expected-error@+1 {{'vector.gather' op base and result element type should match}}
- %0 = vector.gather %base, %indices, %mask : (memref<?xf64>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
// -----
-func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
// expected-error@+1 {{'vector.gather' op result #0 must be of ranks 1, but got 'vector<2x16xf32>'}}
- %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<2x16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32>
}
// -----
-func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>, %mask: vector<16xi1>) {
+func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
// expected-error@+1 {{'vector.gather' op expected result dim to match indices dim}}
- %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<17xi32>, vector<16xi1>) -> vector<16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
// -----
-func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<17xi1>) {
+func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
// expected-error@+1 {{'vector.gather' op expected result dim to match mask dim}}
- %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<17xi1>) -> vector<16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
}
// -----
-func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
+func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
// expected-error@+1 {{'vector.gather' op expected pass_thru of same type as result type}}
- %0 = vector.gather %base, %indices, %mask, %pass_thru : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf64>) -> vector<16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf32>
}
// -----
-func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) {
// expected-error@+1 {{'vector.scatter' op base and value element type should match}}
- vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf64>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}
// -----
-func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<2x16xf32>) {
+func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<2x16xf32>) {
// expected-error@+1 {{'vector.scatter' op operand #3 must be of ranks 1, but got 'vector<2x16xf32>'}}
- vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<2x16xf32> into memref<?xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
}
// -----
-func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) {
// expected-error@+1 {{'vector.scatter' op expected value dim to match indices dim}}
- vector.scatter %base, %indices, %mask, %value : vector<17xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32>
}
// -----
-func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
+func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<17xi1>, %value: vector<16xf32>) {
// expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}}
- vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref<?xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32>
}
// -----
}
// CHECK-LABEL: @gather_and_scatter
-func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
- // CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
- %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
- // CHECK: %[[Y:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}}, %[[X]] : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
- %1 = vector.gather %base, %indices, %mask, %0 : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
- // CHECK: vector.scatter %{{.*}}, %{{.*}}, %{{.*}}, %[[Y]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
- vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ // CHECK: vector.scatter %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+ vector.scatter %base[%indices], %mask, %0 : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}
// CHECK-LABEL: @expand_and_compress
-func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
+func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = constant 0 : index
// CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
- %0 = vector.expandload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
vector.compressstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]], %[[A1]], %[[M]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-NEXT: return %[[G]] : vector<16xf32>
func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%mask = vector.constant_mask [16] : vector<16xi1>
- %ld = vector.gather %base, %indices, %mask, %pass_thru
- : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+ %ld = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
// CHECK-NEXT: return %[[A2]] : vector<16xf32>
func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%mask = vector.constant_mask [0] : vector<16xi1>
- %ld = vector.gather %base, %indices, %mask, %pass_thru
- : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+ %ld = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[M]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+// CHECK-NEXT: vector.scatter %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-NEXT: return
func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
%mask = vector.constant_mask [16] : vector<16xi1>
- vector.scatter %base, %indices, %mask, %value
- : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}
func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
%0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
%mask = vector.constant_mask [0] : vector<16xi1>
- vector.scatter %base, %indices, %mask, %value
- : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}