From b4baccc2a760ea13901f201e6ca326284254d205 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Sun, 13 Jun 2021 13:45:33 -0700 Subject: [PATCH] Introduce tensor.insert op to Tensor dialect. Add `tensor.insert` op to make `tensor.extract`/`tensor.insert` work in pairs for `scalar` domain. Like `subtensor`/`subtensor_insert` work in pairs in `tensor` domain, and `vector.transfer_read`/`vector.transfer_write` work in pairs in `vector` domain. Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D104139 --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 51 ++++++++++++++++++++++++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 22 ++++++++++ mlir/test/Dialect/Tensor/canonicalize.mlir | 13 ++++++ mlir/test/Dialect/Tensor/invalid.mlir | 8 ++++ mlir/test/Dialect/Tensor/ops.mlir | 13 ++++++ 5 files changed, 107 insertions(+) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 17141da..6b06099 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -184,6 +184,57 @@ def Tensor_GenerateOp : Tensor_Op<"generate", } //===----------------------------------------------------------------------===// +// InsertOp +//===----------------------------------------------------------------------===// + +def Tensor_InsertOp : Tensor_Op<"insert", + [NoSideEffect, + TypesMatchWith<"result type matches type of dest", + "dest", "result", + "$_self.cast()">, + TypesMatchWith<"scalar type matches element type of dest", + "dest", "scalar", + "$_self.cast().getElementType()">]> { + let summary = "element insertion operation"; + let description = [{ + The `tensor.insert` op writes a tensor into a tensor `dest`as specified by + the operation's indices. + + It returns a copy of `dest` with the proper subtensor updated with the value + of `scalar`. + + The arity of indices must match the rank of the tensor `dest` (i.e., if a + tensor is of rank 3, then 3 indices are required for the extract. The + indices should all be of `index` type. + + Example: + + ```mlir + %4 = tensor.insert %t into %dest[%1, %2] : tensor<4x4xi32> + %5 = tensor.insert %rt into %dest[%1, %2] : tensor + %6 = tensor.insert %ut into %dest[%1, %2] : tensor<*xi32> + ``` + }]; + + let arguments = (ins AnyType:$scalar, + AnyTensor:$dest, + Variadic:$indices); + let results = (outs AnyTensor:$result); + let assemblyFormat = [{ + $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest) + }]; + + let builders = [ + OpBuilder<(ins "Value":$scalar, "Value":$dest, + CArg<"ValueRange", "{}">:$indices), [{ + auto resType = dest.getType(); + build($_builder, $_state, resType, scalar, dest, indices); + }]>]; + + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 2c9680a..9b1592e 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -287,6 +287,28 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// +// InsertOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(InsertOp op) { + // Verify the # indices match if we have a ranked type. + if (auto destType = op.dest().getType().dyn_cast()) + if (destType.getRank() != static_cast(op.indices().size())) + return op.emitOpError("incorrect number of indices"); + return success(); +} + +OpFoldResult InsertOp::fold(ArrayRef operands) { + Attribute scalar = operands[0]; + Attribute dest = operands[1]; + if (scalar && dest) + if (auto splatDest = dest.dyn_cast()) + if (scalar == splatDest.getSplatValue()) + return dest; + return {}; +} + +//===----------------------------------------------------------------------===// // GenerateOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 478117b..e4f5cc7 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -96,6 +96,19 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) { // ----- +// CHECK-LABEL: func @fold_insert +func @fold_insert(%arg0 : index) -> (tensor<4xf32>) { + // Fold an insert into a splat. + // CHECK-DAG: %[[C4:.+]] = constant dense<4.{{0*}}e+00> : tensor<4xf32> + %0 = constant dense<4.0> : tensor<4xf32> + %1 = constant 4.0 : f32 + %ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32> + // CHECK-NEXT: return %[[C4]] + return %ins_1 : tensor<4xf32> +} + +// ----- + // CHECK-LABEL: func @extract_from_tensor.cast // CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32> func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 { diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 79fef8c..edbea9a 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -16,6 +16,14 @@ func @extract_too_many_indices(%arg0: tensor) { // ----- +func @insert_too_many_indices(%arg0: f32, %arg1: tensor) { + // expected-error@+1 {{incorrect number of indices}} + %0 = tensor.insert %arg0 into %arg1[] : tensor + return +} + +// ----- + func @tensor.from_elements_wrong_result_type() { // expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}} %c0 = constant 0 : i32 diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 450da06..a8bc699 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -22,6 +22,19 @@ func @extract(%arg0: tensor, %arg1: index) { return } +// CHECK-LABEL: func @insert( +// CHECK-SAME: %[[SCALAR:.*]]: f32 +// CHECK-SAME: %[[INDEX:.*]]: index +// CHECK-SAME: %[[DEST1:.*]]: tensor +// CHECK-SAME: %[[DEST2:.*]]: tensor<*xf32> +func @insert(%arg0: f32, %arg1: index, %arg2: tensor, %arg3: tensor<*xf32>) { + // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor + %0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor + // CHECK: tensor.insert %[[SCALAR]] into %[[DEST2]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<*xf32> + %1 = tensor.insert %arg0 into %arg3[%arg1, %arg1, %arg1] : tensor<*xf32> + return +} + // CHECK-LABEL: func @tensor.from_elements() { func @tensor.from_elements() { %c0 = "std.constant"() {value = 0: index} : () -> index -- 2.7.4