def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
O[None] = TypeFn.cast_signed(U, value)
+@linalg_structured_op
+def fill_rank_zero_poly(I=TensorDef(T1), O=TensorDef(U, output=True)):
+ O[None] = TypeFn.cast_signed(U, I[None])
with Context() as ctx, Location.unknown():
module = Module.create()
# CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
# CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
# CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+ # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()>
+ # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
# CHECK-LABEL: @test_fill_0d
# CHECK: linalg.generic
def test_fill_2d(value, init_result):
return fill_poly(value, outs=[init_result])
+ # CHECK-LABEL: @test_fill_rank_zero_3d
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32))
+ def test_fill_rank_zero_3d(input, init_result):
+ return fill_rank_zero_poly(input, outs=[init_result])
print(module)
%v1 = arith.constant 1.0 : f32
%v2 = arith.constant 2.0 : f32
- %lhs = memref.alloc() : memref<4x8xf32>
+ %lhs = memref.alloc() : memref<f32>
%rhs = memref.alloc() : memref<4x8xf32>
%O0 = memref.alloc() : memref<4x8xf32>
%O1 = memref.alloc() : memref<4x8xf32>
- linalg.fill(%v1, %lhs) : f32, memref<4x8xf32>
+ linalg.fill(%v1, %lhs) : f32, memref<f32>
linalg.fill(%v2, %rhs) : f32, memref<4x8xf32>
linalg.fill(%v0, %O0) : f32, memref<4x8xf32>
linalg.fill(%v0, %O1) : f32, memref<4x8xf32>
call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
- (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
+ (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
- (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
+ (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
%c0 = arith.constant 0 : index
%res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(
- MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+ MemRefType.get((), f32), MemRefType.get((4, 8), f32),
MemRefType.get((4, 8), f32))
def elemwise_exp_add_on_buffers(lhs, rhs, out):
linalg.elemwise_unary(lhs, outs=[out])
linalg.elemwise_binary(out, rhs, outs=[out])
@builtin.FuncOp.from_py_func(
- MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+ MemRefType.get((), f32), MemRefType.get((4, 8), f32),
MemRefType.get((4, 8), f32))
def elemwise_log_mul_on_buffers(lhs, rhs, out):
linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(
- MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+ MemRefType.get((), f32), MemRefType.get((4, 8), f32),
MemRefType.get((4, 8), f32))
def elemwise_exp_add_on_buffers(lhs, rhs, out):
linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
@builtin.FuncOp.from_py_func(
- MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+ MemRefType.get((), f32), MemRefType.get((4, 8), f32),
MemRefType.get((4, 8), f32))
def elemwise_log_mul_on_buffers(lhs, rhs, out):
linalg.elemwise_unary(