[mlir][Vector] Thread 0-d vectors through InsertElementOp.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 23 Nov 2021 12:01:53 +0000 (12:01 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 23 Nov 2021 12:55:11 +0000 (12:55 +0000)
This revision makes concrete use of 0-d vectors to extend the semantics of
InsertElementOp.

Reviewed By: dcaballe, pifon2a

Differential Revision: https://reviews.llvm.org/D114388

mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

index d8ffdff..f274ff6 100644 (file)
@@ -666,16 +666,18 @@ def Vector_InsertElementOp :
                     "result", "source",
                     "$_self.cast<ShapedType>().getElementType()">,
      AllTypesMatch<["dest", "result"]>]>,
-     Arguments<(ins AnyType:$source, AnyVector:$dest,
-                    AnySignlessIntegerOrIndex:$position)>,
-     Results<(outs AnyVector:$result)> {
+     Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
+                    Optional<AnySignlessIntegerOrIndex>:$position)>,
+     Results<(outs AnyVectorOfAnyRank:$result)> {
   let summary = "insertelement operation";
   let description = [{
-    Takes a scalar source, an 1-D destination vector and a dynamic index
-    position and inserts the source into the destination at the proper
-    position.  Note that this instruction resembles vector.insert, but
-    is restricted to 1-D vectors and relaxed to dynamic indices. It is
-    meant to be closer to LLVM's version:
+    Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index
+    position and inserts the source into the destination at the proper position.
+
+    Note that this instruction resembles vector.insert, but is restricted to 0-D
+    and 1-D vectors and relaxed to dynamic indices. 
+
+    It is meant to be closer to LLVM's version:
     https://llvm.org/docs/LangRef.html#insertelement-instruction
 
     Example:
@@ -684,14 +686,18 @@ def Vector_InsertElementOp :
     %c = arith.constant 15 : i32
     %f = arith.constant 0.0f : f32
     %1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
+    %2 = vector.insertelement %f, %z[]: vector<f32>
     ```
   }];
   let assemblyFormat = [{
-    $source `,` $dest `[` $position `:` type($position) `]` attr-dict `:`
+    $source `,` $dest `[` ($position^ `:` type($position))? `]`  attr-dict `:`
     type($result)
   }];
 
   let builders = [
+    // 0-D builder.
+    OpBuilder<(ins "Value":$source, "Value":$dest)>,
+    // 1-D + position builder.
     OpBuilder<(ins "Value":$source, "Value":$dest, "Value":$position)>
   ];
   let extraClassDeclaration = [{
index c74eca5..108e664 100644 (file)
@@ -663,6 +663,17 @@ public:
     if (!llvmType)
       return failure();
 
+    if (vectorType.getRank() == 0) {
+      Location loc = insertEltOp.getLoc();
+      auto idxType = rewriter.getIndexType();
+      auto zero = rewriter.create<LLVM::ConstantOp>(
+          loc, typeConverter->convertType(idxType),
+          rewriter.getIntegerAttr(idxType, 0));
+      rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+          insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
+      return success();
+    }
+
     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
         adaptor.position());
index d8cd3c1..e4438fc 100644 (file)
@@ -1554,6 +1554,12 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
 //===----------------------------------------------------------------------===//
 
 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
+                            Value source, Value dest) {
+  result.addOperands({source, dest});
+  result.addTypes(dest.getType());
+}
+
+void InsertElementOp::build(OpBuilder &builder, OperationState &result,
                             Value source, Value dest, Value position) {
   result.addOperands({source, dest, position});
   result.addTypes(dest.getType());
@@ -1561,8 +1567,15 @@ void InsertElementOp::build(OpBuilder &builder, OperationState &result,
 
 static LogicalResult verify(InsertElementOp op) {
   auto dstVectorType = op.getDestVectorType();
+  if (dstVectorType.getRank() == 0) {
+    if (op.position())
+      return op.emitOpError("expected position to be empty with 0-D vector");
+    return success();
+  }
   if (dstVectorType.getRank() != 1)
-    return op.emitOpError("expected 1-D vector");
+    return op.emitOpError("unexpected >1 vector rank");
+  if (!op.position())
+    return op.emitOpError("expected position for 1-D vector");
   return success();
 }
 
index 9cce66c..033e2d8 100644 (file)
@@ -512,6 +512,19 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @insert_element_0d
+// CHECK-SAME: %[[A:.*]]: f32,
+func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
+  // CHECK: %[[B:.*]] =  builtin.unrealized_conversion_cast %{{.*}} :
+  // CHECK:   vector<f32> to vector<1xf32>
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C0]] : {{.*}}] : vector<1xf32>
+  %1 = vector.insertelement %a, %b[] : vector<f32>
+  return %1 : vector<f32>
+}
+
+// -----
+
 func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : i32
   %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32>
index c327bfe..593686a 100644 (file)
@@ -79,7 +79,7 @@ func @extract_element(%arg0: vector<f32>) {
 }
 
 // -----
-
 func @extract_element(%arg0: vector<4xf32>) {
   %c = arith.constant 3 : i32
   // expected-error@+1 {{expected position for 1-D vector}}
@@ -138,9 +138,25 @@ func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
+func @insert_element(%arg0: f32, %arg1: vector<f32>) {
+  %c = arith.constant 3 : i32
+  // expected-error@+1 {{expected position to be empty with 0-D vector}}
+  %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32>
+}
+
+// -----
+
+func @insert_element(%arg0: f32, %arg1: vector<4xf32>) {
+  %c = arith.constant 3 : i32
+  // expected-error@+1 {{expected position for 1-D vector}}
+  %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32>
+}
+
+// -----
+
 func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) {
   %c = arith.constant 3 : i32
-  // expected-error@+1 {{'vector.insertelement' op expected 1-D vector}}
+  // expected-error@+1 {{unexpected >1 vector rank}}
   %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32>
 }
 
index 3f7fe75..11b986f 100644 (file)
@@ -192,6 +192,13 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32
   return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
 }
 
+// CHECK-LABEL: @insert_element_0d
+func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
+  // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
+  %1 = vector.insertelement %a, %b[] : vector<f32>
+  return %1 : vector<f32>
+}
+
 // CHECK-LABEL: @insert_element
 func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> {
   // CHECK:      %[[C15:.*]] = arith.constant 15 : i32
index 0921bfc..b3052eb 100644 (file)
@@ -10,8 +10,15 @@ func @extract_element_0d(%a: vector<f32>) {
   return
 }
 
+func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) {
+  %1 = vector.insertelement %a, %b[] : vector<f32>
+  return %1: vector<f32>
+}
+
 func @entry() {
-  %1 = arith.constant dense<42.0> : vector<f32>
-  call  @extract_element_0d(%1) : (vector<f32>) -> ()
+  %0 = arith.constant 42.0 : f32
+  %1 = arith.constant dense<0.0> : vector<f32>
+  %2 = call  @insert_element_0d(%0, %1) : (f32, vector<f32>) -> (vector<f32>)
+  call  @extract_element_0d(%2) : (vector<f32>) -> ()
   return
 }