}
//===----------------------------------------------------------------------===//
+// InsertOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_InsertOp : Tensor_Op<"insert",
+ [NoSideEffect,
+ TypesMatchWith<"result type matches type of dest",
+ "dest", "result",
+ "$_self.cast<ShapedType>()">,
+ TypesMatchWith<"scalar type matches element type of dest",
+ "dest", "scalar",
+ "$_self.cast<ShapedType>().getElementType()">]> {
+ let summary = "element insertion operation";
+ let description = [{
+ The `tensor.insert` op writes a tensor into a tensor `dest`as specified by
+ the operation's indices.
+
+ It returns a copy of `dest` with the proper subtensor updated with the value
+ of `scalar`.
+
+ The arity of indices must match the rank of the tensor `dest` (i.e., if a
+ tensor is of rank 3, then 3 indices are required for the extract. The
+ indices should all be of `index` type.
+
+ Example:
+
+ ```mlir
+ %4 = tensor.insert %t into %dest[%1, %2] : tensor<4x4xi32>
+ %5 = tensor.insert %rt into %dest[%1, %2] : tensor<?x?xi32>
+ %6 = tensor.insert %ut into %dest[%1, %2] : tensor<*xi32>
+ ```
+ }];
+
+ let arguments = (ins AnyType:$scalar,
+ AnyTensor:$dest,
+ Variadic<Index>:$indices);
+ let results = (outs AnyTensor:$result);
+ let assemblyFormat = [{
+ $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value":$scalar, "Value":$dest,
+ CArg<"ValueRange", "{}">:$indices), [{
+ auto resType = dest.getType();
+ build($_builder, $_state, resType, scalar, dest, indices);
+ }]>];
+
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
}
//===----------------------------------------------------------------------===//
+// InsertOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(InsertOp op) {
+ // Verify the # indices match if we have a ranked type.
+ if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
+ if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
+ return op.emitOpError("incorrect number of indices");
+ return success();
+}
+
+OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
+ Attribute scalar = operands[0];
+ Attribute dest = operands[1];
+ if (scalar && dest)
+ if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
+ if (scalar == splatDest.getSplatValue())
+ return dest;
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
// -----
+// CHECK-LABEL: func @fold_insert
+func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
+ // Fold an insert into a splat.
+ // CHECK-DAG: %[[C4:.+]] = constant dense<4.{{0*}}e+00> : tensor<4xf32>
+ %0 = constant dense<4.0> : tensor<4xf32>
+ %1 = constant 4.0 : f32
+ %ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32>
+ // CHECK-NEXT: return %[[C4]]
+ return %ins_1 : tensor<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.cast
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {
// -----
+func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+ // expected-error@+1 {{incorrect number of indices}}
+ %0 = tensor.insert %arg0 into %arg1[] : tensor<?xf32>
+ return
+}
+
+// -----
+
func @tensor.from_elements_wrong_result_type() {
// expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
%c0 = constant 0 : i32
return
}
+// CHECK-LABEL: func @insert(
+// CHECK-SAME: %[[SCALAR:.*]]: f32
+// CHECK-SAME: %[[INDEX:.*]]: index
+// CHECK-SAME: %[[DEST1:.*]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[DEST2:.*]]: tensor<*xf32>
+func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>, %arg3: tensor<*xf32>) {
+ // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
+ %0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
+ // CHECK: tensor.insert %[[SCALAR]] into %[[DEST2]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<*xf32>
+ %1 = tensor.insert %arg0 into %arg3[%arg1, %arg1, %arg1] : tensor<*xf32>
+ return
+}
+
// CHECK-LABEL: func @tensor.from_elements() {
func @tensor.from_elements() {
%c0 = "std.constant"() {value = 0: index} : () -> index