```mlir
global @variable : tuple<i32, f32> {
- %0 = constant 45 : i32
- %1 = constant 100.0 : f32
+ %0 = arith.constant 45 : i32
+ %1 = arith.constant 100.0 : f32
%2 = fir.undefined tuple<i32, f32>
- %3 = constant 0 : index
+ %3 = arith.constant 0 : index
%4 = fir.insert_value %2, %0, %3 : (tuple<i32, f32>, i32, index) -> tuple<i32, f32>
- %5 = constant 1 : index
+ %5 = arith.constant 1 : index
%6 = fir.insert_value %4, %1, %5 : (tuple<i32, f32>, f32, index) -> tuple<i32, f32>
fir.has_value %6 : tuple<i32, f32>
}
by the calling routine. (In Fortran, these are called descriptors.)
```mlir
- %c1 = constant 1 : index
- %c10 = constant 10 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
%5 = ... : !fir.ref<!fir.array<10 x i32>>
%6 = fir.embox %5 : (!fir.ref<!fir.array<10 x i32>>) -> !fir.box<!fir.array<10 x i32>>
```
```
```mlir
%4 = ... : !fir.ref<!fir.array<10 x !fir.char<1>>>
- %5 = constant 10 : i32
+ %5 = arith.constant 10 : i32
%6 = fir.emboxchar %4, %5 : (!fir.ref<!fir.array<10 x !fir.char<1>>>, i32) -> !fir.boxchar<1>
```
`dim` is out of bounds.
```mlir
- %c1 = constant 0 : i32
+ %c1 = arith.constant 0 : i32
%52:3 = fir.box_dims %40, %c1 : (!fir.box<!fir.array<*:f64>>, i32) -> (index, index, index)
```
```mlir
%r = ... : !fir.ref<i64>
- %c_100 = constant 100 : index
+ %c_100 = arith.constant 100 : index
%d = fir.shape %c_100 : (index) -> !fir.shape<1>
%b = fir.embox %r(%d) : (!fir.ref<i64>, !fir.shape<1>) -> !fir.box<i64>
%a = fir.box_isarray %b : (!fir.box<i64>) -> i1 // true
%a = ... : !fir.array<10xtuple<i32, f32>>
%f = ... : f32
%o = ... : i32
- %c = constant 1 : i32
+ %c = arith.constant 1 : i32
%b = fir.insert_value %a, %f, %o, %c : (!fir.array<10x20xtuple<i32, f32>>, f32, i32, i32) -> !fir.array<10x20xtuple<i32, f32>>
```
}];
```mlir
%a = fir.undefined !fir.array<10x10xf32>
- %c = constant 3.0 : f32
+ %c = arith.constant 3.0 : f32
%1 = fir.insert_on_range %a, %c, [0 : index, 7 : index, 0 : index, 2 : index] : (!fir.array<10x10xf32>, f32) -> !fir.array<10x10xf32>
```
MLIR's `scf.for`.
```mlir
- %l = constant 0 : index
- %u = constant 9 : index
- %s = constant 1 : index
+ %l = arith.constant 0 : index
+ %u = arith.constant 9 : index
+ %s = arith.constant 1 : index
fir.do_loop %i = %l to %u step %s unordered {
%x = fir.convert %i : (index) -> i32
%v = fir.call @compute(%x) : (i32) -> f32
```mlir
fir.global @_QV_Mquark_Vvarble : tuple<i32, f32> {
- %1 = constant 1 : i32
- %2 = constant 2.0 : f32
+ %1 = arith.constant 1 : i32
+ %2 = arith.constant 2.0 : f32
%3 = fir.undefined tuple<i32, f32>
- %z = constant 0 : index
- %o = constant 1 : index
+ %z = arith.constant 0 : index
+ %o = arith.constant 1 : index
%4 = fir.insert_value %3, %1, %z : (tuple<i32, f32>, i32, index) -> tuple<i32, f32>
%5 = fir.insert_value %4, %2, %o : (tuple<i32, f32>, f32, index) -> tuple<i32, f32>
fir.has_value %5 : tuple<i32, f32>
- Bufferizes only `arith.constant` ops of `tensor` type.
- This is an example of setting up the legality so that only a subset of
- `std.constant` ops get bufferized.
+ `arith.constant` ops get bufferized.
- This is an example of a pass that is not split along dialect
subdivisions.
/// Mark all operations within the LLVM dialect are legal.
addLegalDialect<LLVMDialect>();
- /// Mark `std.constant` op is always legal on this target.
- addLegalOp<ConstantOp>();
+ /// Mark `arith.constant` op is always legal on this target.
+ addLegalOp<arith::ConstantOp>();
//--------------------------------------------------------------------------
// Marking an operation as dynamically legal.
'%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory
space 1 at indices [%k + 7, %l], would be specified as follows:
- %num_elements = constant 256
+ %num_elements = arith.constant 256
%idx = arith.constant 0 : index
%tag = memref.alloc() : memref<1xi32, 4>
affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx],
*** IR Dump Before CSE ***
func @simple_constant() -> (i32, i32) {
- %c1_i32 = constant 1 : i32
- %c1_i32_0 = constant 1 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %c1_i32_0 = arith.constant 1 : i32
return %c1_i32, %c1_i32_0 : i32, i32
}
```
*** IR Dump After CSE ***
func @simple_constant() -> (i32, i32) {
- %c1_i32 = constant 1 : i32
+ %c1_i32 = arith.constant 1 : i32
return %c1_i32, %c1_i32 : i32, i32
}
```
*** IR Dump After CSE ***
func @simple_constant() -> (i32, i32) {
- %c1_i32 = constant 1 : i32
+ %c1_i32 = arith.constant 1 : i32
return %c1_i32, %c1_i32 : i32, i32
}
```
*** IR Dump After BadPass Failed ***
func @simple_constant() -> (i32, i32) {
- %c1_i32 = constant 1 : i32
+ %c1_i32 = arith.constant 1 : i32
return %c1_i32, %c1_i32 : i32, i32
}
```
}
func @simple_constant() -> (i32, i32) {
- %c1_i32 = constant 1 : i32
- %c1_i32_0 = constant 1 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %c1_i32_0 = arith.constant 1 : i32
return %c1_i32, %c1_i32_0 : i32, i32
}
}
func @simple_constant() -> (i32, i32) {
- %c1_i32 = constant 1 : i32
+ %c1_i32 = arith.constant 1 : i32
return %c1_i32, %c1_i32 : i32, i32
}
```
[include "AffinePasses.md"]
+## `arith` Dialect Passes
+
+[include "ArithmeticPasses.md"]
+
## `gpu` Dialect Passes
[include "GPUPasses.md"]
return %y: i32
}
// CHECK-LABEL: func @test_subi_zero_cfg(%arg0: i32)
- // CHECK-NEXT: %c0_i32 = constant 0 : i32
+ // CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: return %c0
```
Example:
```mlir
- %0 = constant 2 : i32
+ %0 = arith.constant 2 : i32
// Apply the foo operation to %0
%1 = standalone.foo %0 : i32
```
// memref '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in
// memory space 1 at indices [%k + 7, %l], would be specified as follows:
//
-// %num_elements = constant 256
-// %idx = constant 0 : index
+// %num_elements = arith.constant 256
+// %idx = arith.constant 0 : index
// %tag = alloc() : memref<1xi32, 4>
// affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx],
// %num_elements :
```mlir
func @reduce(%buffer: memref<1024xf32>) -> (f32) {
// Initial sum set to 0.
- %sum_0 = constant 0.0 : f32
+ %sum_0 = arith.constant 0.0 : f32
// iter_args binds initial values to the loop's region arguments.
%sum = affine.for %i = 0 to 10 step 2
iter_args(%sum_iter = %sum_0) -> (f32) {
%2 = load %I[%i - 1, %j - 1] : memref<10x10xf32>
affine.yield %2
} else {
- %2 = constant 0.0 : f32
+ %2 = arith.constant 0.0 : f32
affine.yield %2 : f32
}
affine.store %1, %O[%i, %j] : memref<12x12xf32>
```mlir
func @store_load_affine_apply() -> memref<10x10xf32> {
- %cf7 = constant 7.0 : f32
+ %cf7 = arith.constant 7.0 : f32
%m = alloc() : memref<10x10xf32>
affine.for %i0 = 0 to 10 {
affine.for %i1 = 0 to 10 {
```mlir
module {
func @store_load_affine_apply() -> memref<10x10xf32> {
- %cst = constant 7.000000e+00 : f32
+ %cst = arith.constant 7.000000e+00 : f32
%0 = alloc() : memref<10x10xf32>
affine.for %arg0 = 0 to 10 {
affine.for %arg1 = 0 to 10 {
/// affine.for %arg2 = 0 to 64 {
/// affine.for %arg3 = 0 to 128 step 8 {
/// affine.for %arg4 = 0 to 512 step 4 {
-/// %cst = constant 0.000000e+00 : f32
+/// %cst = arith.constant 0.000000e+00 : f32
/// %0 = vector.transfer_read %arg0[%arg2, %arg3, %arg4], %cst : ...
/// vector.transfer_write %0, %arg1[%arg2, %arg3, %arg4] : ...
/// }
```mlir
// Always returns 4, can be constant folded:
- %c0 = constant 0 : index
+ %c0 = arith.constant 0 : index
%x = memref.dim %A, %c0 : memref<4 x ? x f32>
// Returns the dynamic dimension of %A.
- %c1 = constant 1 : index
+ %c1 = arith.constant 1 : index
%y = memref.dim %A, %c1 : memref<4 x ? x f32>
// Equivalent generic form:
space 1 at indices [%k, %l], would be specified as follows:
```mlir
- %num_elements = constant 256
- %idx = constant 0 : index
+ %num_elements = arith.constant 256
+ %idx = arith.constant 0 : index
%tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
memref<40 x 128 x f32>, (d0) -> (d0), 0>,
cond_br %cond, ^bb1, ^bb2
^bb1:
- %c1 = constant 1 : i64
+ %c1 = arith.constant 1 : i64
br ^bb3(%c1 : i64)
^bb2:
- %c2 = constant 2 : i64
+ %c2 = arith.constant 2 : i64
br ^bb3(%c2 : i64)
^bb3(%x : i64):
func @reduce(%buffer: memref<1024xf32>, %lb: index,
%ub: index, %step: index) -> (f32) {
// Initial sum set to 0.
- %sum_0 = constant 0.0 : f32
+ %sum_0 = arith.constant 0.0 : f32
// iter_args binds initial values to the loop's region arguments.
%sum = scf.for %iv = %lb to %ub step %step
iter_args(%sum_iter = %sum_0) -> (f32) {
```mlir
func @conditional_reduce(%buffer: memref<1024xf32>, %lb: index,
%ub: index, %step: index) -> (f32) {
- %sum_0 = constant 0.0 : f32
- %c0 = constant 0.0 : f32
+ %sum_0 = arith.constant 0.0 : f32
+ %c0 = arith.constant 0.0 : f32
%sum = scf.for %iv = %lb to %ub step %step
iter_args(%sum_iter = %sum_0) -> (f32) {
%t = load %buffer[%iv] : memref<1024xf32>
Example:
```mlir
- %init = constant 0.0 : f32
+ %init = arith.constant 0.0 : f32
scf.parallel (%iv) = (%lb) to (%ub) step (%step) init (%init) -> f32 {
%elem_to_reduce = load %buffer[%iv] : memref<100xf32>
scf.reduce(%elem_to_reduce) : f32 {
Example:
```mlir
- %operand = constant 1.0 : f32
+ %operand = arith.constant 1.0 : f32
scf.reduce(%operand) : f32 {
^bb0(%lhs : f32, %rhs: f32):
%res = arith.addf %lhs, %rhs : f32
```mlir
Before:
- %c1 = constant 1 : index
+ %c1 = arith.constant 1 : index
%0 = sparse_tensor.pointers %arg0, %c1
: tensor<8x8xf32, #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed" ],
}>> to memref<?xindex>
After:
- %c1 = constant 1 : index
+ %c1 = arith.constant 1 : index
%0 = call @sparsePointers(%arg0, %c1) : (!llvm.ptr<i8>, index) -> memref<?xindex>
```
}];
```mlir
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%current_value : f32):
- %c1 = constant 1.0 : f32
+ %c1 = arith.constant 1.0 : f32
%inc = arith.addf %c1, %current_value : f32
atomic_yield %inc : f32
}
Example:
```mlir
- // Integer constant
- %1 = constant 42 : i32
+ // Complex constant
+ %1 = constant [1.0 : f32, 1.0 : f32] : complex<f32>
// Reference to function @myfn.
- %3 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32>
+ %2 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32>
// Equivalent generic forms
- %1 = "std.constant"() {value = 42 : i32} : () -> i32
- %3 = "std.constant"() {value = @myfn}
+ %1 = "std.constant"() {value = [1.0 : f32, 1.0 : f32] : complex<f32>}
+ : () -> complex<f32>
+ %2 = "std.constant"() {value = @myfn}
: () -> ((tensor<16xf32>, f32) -> tensor<16xf32>)
```
```mlir
// Always returns 4, can be constant folded:
- %c0 = constant 0 : index
+ %c0 = arith.constant 0 : index
%x = tensor.dim %A, %c0 : tensor<4x?xf32>
// Returns the dynamic dimension of %A.
- %c1 = constant 1 : index
+ %c1 = arith.constant 1 : index
%y = tensor.dim %A, %c1 : memref<4x?xf32>
// Equivalent generic form:
Example:
```mlir
- %0 = constant 0.0 : f32
+ %0 = arith.constant 0.0 : f32
%1 = vector.broadcast %0 : f32 to vector<16xf32>
%2 = vector.broadcast %1 : vector<16xf32> to vector<4x16xf32>
```
Example:
```mlir
- %c = constant 15 : i32
+ %c = arith.constant 15 : i32
%1 = vector.extractelement %0[%c : i32]: vector<16xf32>
```
}];
Example:
```mlir
- %c = constant 15 : i32
- %f = constant 0.0f : f32
+ %c = arith.constant 15 : i32
+ %f = arith.constant 0.0f : f32
%1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
```
}];
memref<?x?x?x?xf32>
store %tmp[%i, %j, %k] : vector<3x4x5xf32>
}}}
- %c0 = constant 0 : index
+ %c0 = arith.constant 0 : index
%vec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
```
memref<?x?x?x?xf32>
store %tmp[%i, 0, %k] : vector<3x4x5xf32>
}}
- %c0 = constant 0 : index
+ %c0 = arith.constant 0 : index
%tmpvec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
%vec = broadcast %tmpvec, 1 : vector<3x4x5xf32>
```
```mlir
// Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32>
// and pad with %f0 to handle the boundary case:
- %f0 = constant 0.0f : f32
+ %f0 = arith.constant 0.0f : f32
for %i0 = 0 to %0 {
affine.for %i1 = 0 to %1 step 256 {
affine.for %i2 = 0 to %2 step 32 {
Example:
```mlir
- %0 = constant 0.0 : f32
+ %0 = arith.constant 0.0 : f32
%1 = vector.broadcast %0 : f32 to vector<4xf32>
vector.print %1 : vector<4xf32>
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to an output-size-unrolled sequence:
/// ```
-/// %out = constant ... : vector<MxNxelt_type>
+/// %out = arith.constant ... : vector<MxNxelt_type>
/// %bt = vector.transpose %b, [1, 0]
/// %aRow0 = vector.extract %a[0]
/// %btRow0 = vector.extract %bt[0]
/// The following MLIR snippet:
///
/// ```mlir
-/// %cst0 = constant 0 : index
+/// %cst0 = arith.constant 0 : index
/// affine.for %i0 = 0 to %0 {
/// %a0 = load %arg0[%cst0, %cst0] : memref<?x?xf32>
/// }
func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
%0 = alloc() : memref<10xf32>
%1 = alloc() : memref<10xf32>
- %cst = constant 0.000000e+00 : f32
+ %cst = arith.constant 0.000000e+00 : f32
affine.for %arg2 = 0 to 10 {
affine.store %cst, %0[%arg2] : memref<10xf32>
affine.store %cst, %1[%arg2] : memref<10xf32>
func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
%0 = alloc() : memref<1xf32>
%1 = alloc() : memref<1xf32>
- %cst = constant 0.000000e+00 : f32
+ %cst = arith.constant 0.000000e+00 : f32
affine.for %arg2 = 0 to 10 {
affine.store %cst, %0[0] : memref<1xf32>
affine.store %cst, %1[0] : memref<1xf32>
%0 = alloc() : memref<256xf32>
%1 = alloc() : memref<32xf32, 1>
%2 = alloc() : memref<1xf32>
- %c0 = constant 0 : index
- %c128 = constant 128 : index
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
affine.for %i0 = 0 to 8 {
affine.dma_start %0[%i0], %1[%i0], %2[%c0], %c128 : memref<256xf32>, memref<32xf32, 1>, memref<1xf32>
affine.dma_wait %2[%c0], %c128 : memref<1xf32>
```mlir
module {
func @pipelinedatatransfer() {
- %c8 = constant 8 : index
- %c0 = constant 0 : index
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
%0 = alloc() : memref<256xf32>
- %c0_0 = constant 0 : index
- %c128 = constant 128 : index
+ %c0_0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
%1 = alloc() : memref<2x32xf32, 1>
%2 = alloc() : memref<2x1xf32>
affine.dma_start %0[%c0], %1[%c0 mod 2, %c0], %2[%c0 mod 2, symbol(%c0_0)], %c128 : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
func @linearize(%arg0: memref<8x8xi32, #linear8>,
%arg1: memref<8x8xi32, #linear8>,
%arg2: memref<8x8xi32, #linear8>) {
- %c8 = constant 8 : index
- %c0 = constant 0 : index
- %c1 = constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
affine.for %arg3 = %c0 to %c8 {
affine.for %arg4 = %c0 to %c8 {
affine.for %arg5 = %c0 to %c8 {
func @linearize(%arg0: memref<64xi32>,
%arg1: memref<64xi32>,
%arg2: memref<64xi32>) {
- %c8 = constant 8 : index
- %c0 = constant 0 : index
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
affine.for %arg3 = %c0 to %c8 {
affine.for %arg4 = %c0 to %c8 {
affine.for %arg5 = %c0 to %c8 {
if (!srcType)
return failure();
- // std.constant should only have vector or tenor types.
+ // arith.constant should only have vector or tenor types.
assert((srcType.isa<VectorType, RankedTensorType>()));
auto dstType = getTypeConverter()->convertType(srcType);
// Bool type.
if (srcType.isInteger(1)) {
- // std.constant can use 0/1 instead of true/false for i1 values. We need to
- // handle that here.
+ // arith.constant can use 0/1 instead of true/false for i1 values. We need
+ // to handle that here.
auto dstAttr = convertBoolAttr(constOp.value(), rewriter);
if (!dstAttr)
return failure();
///
/// becomes
///
-/// %c0 = constant 0 : index
+/// %c0 = arith.constant 0 : index
/// %0 = dim %arg0, %c0 : tensor<?xindex>
/// %1 = dim %arg1, %c0 : tensor<?xindex>
/// %2 = arith.cmpi "eq", %0, %1 : index
using namespace mlir;
//===----------------------------------------------------------------------===//
-// Utility functions
-//===----------------------------------------------------------------------===//
-
-/// Converts the given `srcAttr` into a boolean attribute if it holds an
-/// integral value. Returns null attribute if conversion fails.
-static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
- if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
- return boolAttr;
- if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
- return builder.getBoolAttr(intAttr.getValue().getBoolValue());
- return BoolAttr();
-}
-
-/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
-/// Returns null attribute if conversion fails.
-static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
- Builder builder) {
- // If the source number uses less active bits than the target bitwidth, then
- // it should be safe to convert.
- if (srcAttr.getValue().isIntN(dstType.getWidth()))
- return builder.getIntegerAttr(dstType, srcAttr.getInt());
-
- // XXX: Try again by interpreting the source number as a signed value.
- // Although integers in the standard dialect are signless, they can represent
- // a signed number. It's the operation decides how to interpret. This is
- // dangerous, but it seems there is no good way of handling this if we still
- // want to change the bitwidth. Emit a message at least.
- if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
- auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
- LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
- << dstAttr << "' for type '" << dstType << "'\n");
- return dstAttr;
- }
-
- LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
- << "' illegal: cannot fit into target type '"
- << dstType << "'\n");
- return IntegerAttr();
-}
-
-/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
-/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
-static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
- Builder builder) {
- // Only support converting to float for now.
- if (!dstType.isF32())
- return FloatAttr();
-
- // Try to convert the source floating-point number to single precision.
- APFloat dstVal = srcAttr.getValue();
- bool losesInfo = false;
- APFloat::opStatus status =
- dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
- if (status != APFloat::opOK || losesInfo) {
- LLVM_DEBUG(llvm::dbgs()
- << srcAttr << " illegal: cannot fit into converted type '"
- << dstType << "'\n");
- return FloatAttr();
- }
-
- return builder.getF32FloatAttr(dstVal.convertToFloat());
-}
-
-//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
-/// Converts composite std.constant operation to spv.Constant.
-class ConstantCompositeOpPattern final
- : public OpConversionPattern<ConstantOp> {
-public:
- using OpConversionPattern<ConstantOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts scalar std.constant operation to spv.Constant.
-class ConstantScalarOpPattern final : public OpConversionPattern<ConstantOp> {
-public:
- using OpConversionPattern<ConstantOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
/// Converts std.return to spv.Return.
class ReturnOpPattern final : public OpConversionPattern<ReturnOp> {
public:
} // namespace
//===----------------------------------------------------------------------===//
-// ConstantOp with composite type.
-//===----------------------------------------------------------------------===//
-
-// TODO: This probably should be split into the vector case and tensor case,
-// so that the tensor case can be moved to TensorToSPIRV conversion. But,
-// std.constant is for the standard dialect though.
-LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
- ConstantOp constOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- auto srcType = constOp.getType().dyn_cast<ShapedType>();
- if (!srcType)
- return failure();
-
- // std.constant should only have vector or tenor types.
- assert((srcType.isa<VectorType, RankedTensorType>()));
-
- auto dstType = getTypeConverter()->convertType(srcType);
- if (!dstType)
- return failure();
-
- auto dstElementsAttr = constOp.value().dyn_cast<DenseElementsAttr>();
- ShapedType dstAttrType = dstElementsAttr.getType();
- if (!dstElementsAttr)
- return failure();
-
- // If the composite type has more than one dimensions, perform linearization.
- if (srcType.getRank() > 1) {
- if (srcType.isa<RankedTensorType>()) {
- dstAttrType = RankedTensorType::get(srcType.getNumElements(),
- srcType.getElementType());
- dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
- } else {
- // TODO: add support for large vectors.
- return failure();
- }
- }
-
- Type srcElemType = srcType.getElementType();
- Type dstElemType;
- // Tensor types are converted to SPIR-V array types; vector types are
- // converted to SPIR-V vector/array types.
- if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
- dstElemType = arrayType.getElementType();
- else
- dstElemType = dstType.cast<VectorType>().getElementType();
-
- // If the source and destination element types are different, perform
- // attribute conversion.
- if (srcElemType != dstElemType) {
- SmallVector<Attribute, 8> elements;
- if (srcElemType.isa<FloatType>()) {
- for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
- if (!dstAttr)
- return failure();
- elements.push_back(dstAttr);
- }
- } else if (srcElemType.isInteger(1)) {
- return failure();
- } else {
- for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
- IntegerAttr dstAttr = convertIntegerAttr(
- srcAttr, dstElemType.cast<IntegerType>(), rewriter);
- if (!dstAttr)
- return failure();
- elements.push_back(dstAttr);
- }
- }
-
- // Unfortunately, we cannot use dialect-specific types for element
- // attributes; element attributes only works with builtin types. So we need
- // to prepare another converted builtin types for the destination elements
- // attribute.
- if (dstAttrType.isa<RankedTensorType>())
- dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
- else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
-
- dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
- }
-
- rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
- dstElementsAttr);
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// ConstantOp with scalar type.
-//===----------------------------------------------------------------------===//
-
-LogicalResult ConstantScalarOpPattern::matchAndRewrite(
- ConstantOp constOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Type srcType = constOp.getType();
- if (!srcType.isIntOrIndexOrFloat())
- return failure();
-
- Type dstType = getTypeConverter()->convertType(srcType);
- if (!dstType)
- return failure();
-
- // Floating-point types.
- if (srcType.isa<FloatType>()) {
- auto srcAttr = constOp.value().cast<FloatAttr>();
- auto dstAttr = srcAttr;
-
- // Floating-point types not supported in the target environment are all
- // converted to float type.
- if (srcType != dstType) {
- dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
- if (!dstAttr)
- return failure();
- }
-
- rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
- return success();
- }
-
- // Bool type.
- if (srcType.isInteger(1)) {
- // std.constant can use 0/1 instead of true/false for i1 values. We need to
- // handle that here.
- auto dstAttr = convertBoolAttr(constOp.value(), rewriter);
- if (!dstAttr)
- return failure();
- rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
- return success();
- }
-
- // IndexType or IntegerType. Index values are converted to 32-bit integer
- // values when converting to SPIR-V.
- auto srcAttr = constOp.value().cast<IntegerAttr>();
- auto dstAttr =
- convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
- if (!dstAttr)
- return failure();
- rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
- return success();
-}
-
-//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
spirv::UnaryAndBinaryOpPattern<MinSIOp, spirv::GLSLSMinOp>,
spirv::UnaryAndBinaryOpPattern<MinUIOp, spirv::GLSLUMinOp>,
- // Constant patterns
- ConstantCompositeOpPattern, ConstantScalarOpPattern,
-
ReturnOpPattern, SelectOpPattern, SplatPattern>(typeConverter, context);
}
// each yielded value.
//
// %token, %result = async.execute -> !async.value<T> {
- // %0 = constant ... : T
+ // %0 = arith.constant ... : T
// async.yield %0 : T
// }
Value asyncToken; // token representing completion of the async region
/// defined as dynamic, but the size was defined using a `constant` op. For
/// example
///
-/// %c5 = constant 5: index
+/// %c5 = arith.constant 5: index
/// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
///
/// to
//
// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
//
-// %c1 = constant 1 : index
-// %c0 = constant 0 : index
-// %c25 = constant 25 : index
-// %c10 = constant 10 : index
+// %c1 = arith.constant 1 : index
+// %c0 = arith.constant 0 : index
+// %c25 = arith.constant 25 : index
+// %c10 = arith.constant 10 : index
// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
// scf.for %k = %c0 to operand_dim_0 step %c10 {
/// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp
/// is used in dynamicSizes of new AllocOp.
/// %0 = dim %arg0, %c1 : memref<4x?xf32>
-/// %c4 = constant 4 : index
+/// %c4 = arith.constant 4 : index
/// %1 = affine.apply #map1(%c4, %0)
/// %2 = affine.apply #map2(%c4, %0)
static void createNewDynamicSizes(MemRefType oldMemRefType,
MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
MlirOperationState constZeroState = mlirOperationStateGet(
- mlirStringRefCreateFromCString("std.constant"), loc);
+ mlirStringRefCreateFromCString("arith.constant"), loc);
mlirOperationStateAddResults(&constZeroState, 1, &indexType);
mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
MlirOperation constZero = mlirOperationCreate(&constZeroState);
+ if (!mlirOperationVerify(constZero)) {
+ fprintf(stderr, "ERROR: Expected operation to verify correctly\n");
+ return 5;
+ }
+
if (mlirOperationIsNull(constZero)) {
fprintf(stderr, "ERROR: Expected registered operation to be present\n");
- return 5;
+ return 6;
}
MlirTypeID registeredOpID = mlirOperationGetTypeID(constZero);
if (mlirTypeIDIsNull(registeredOpID)) {
fprintf(stderr,
"ERROR: Expected registered operation type id to be present\n");
- return 6;
+ return 7;
}
// Create an unregistered operation, which should not have a type id.
MlirOperation unregisteredOp = mlirOperationCreate(&opState);
if (mlirOperationIsNull(unregisteredOp)) {
fprintf(stderr, "ERROR: Expected unregistered operation to be present\n");
- return 7;
+ return 8;
}
MlirTypeID unregisteredOpID = mlirOperationGetTypeID(unregisteredOp);
if (!mlirTypeIDIsNull(unregisteredOpID)) {
fprintf(stderr,
"ERROR: Expected unregistered operation type id to be null\n");
- return 8;
+ return 9;
}
mlirOperationDestroy(constZero);