[VectorOps] minor cleanup: vector dialect "subscripts" are i32
authorAart Bik <ajcbik@google.com>
Thu, 19 Dec 2019 19:47:45 +0000 (11:47 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 19 Dec 2019 19:51:08 +0000 (11:51 -0800)
Introduces some centralized methods to move towards
consistent use of i32 as vector subscripts.

Note: sizes/strides/offsets attributes are still i64
PiperOrigin-RevId: 286434133

mlir/include/mlir/Dialect/VectorOps/VectorOps.h
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/Dialect/VectorOps/ops.mlir

index 06672c7..57e2d9e 100644 (file)
@@ -52,6 +52,13 @@ void populateVectorToVectorCanonicalizationPatterns(
 void populateVectorToVectorTransformationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context);
 
+/// Returns the integer type required for subscripts in the vector dialect.
+IntegerType getVectorSubscriptType(Builder &builder);
+
+/// Returns an integer array attribute containing the given values using
+/// the integer type required for subscripts in the vector dialect.
+ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int32_t> values);
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/VectorOps/VectorOps.h.inc"
 
index 401e424..98c0610 100644 (file)
@@ -271,7 +271,7 @@ def Vector_ExtractElementOp :
   Vector_Op<"extractelement", [NoSideEffect,
      PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVector:$vector, Index:$position)>,
+    Arguments<(ins AnyVector:$vector, I32:$position)>,
     Results<(outs AnyType)> {
   let summary = "extractelement operation";
   let description = [{
@@ -379,7 +379,7 @@ def Vector_InsertElementOp :
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      PredOpTrait<"dest operand and result have same type",
                  TCresIsSameAsOpBase<0, 1>>]>,
-     Arguments<(ins AnyType:$source, AnyVector:$dest, Index:$position)>,
+     Arguments<(ins AnyType:$source, AnyVector:$dest, I32:$position)>,
      Results<(outs AnyVector)> {
   let summary = "insertelement operation";
   let description = [{
index 4ed0902..ff28334 100644 (file)
@@ -58,6 +58,15 @@ Operation *VectorOpsDialect::materializeConstant(OpBuilder &builder,
   return builder.create<ConstantOp>(loc, type, value);
 }
 
+IntegerType vector::getVectorSubscriptType(Builder &builder) {
+  return builder.getIntegerType(32);
+}
+
+ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
+                                         ArrayRef<int32_t> values) {
+  return builder.getI32ArrayAttr(values);
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOp
 //===----------------------------------------------------------------------===//
@@ -397,7 +406,7 @@ static Type inferExtractOpResultType(VectorType vectorType,
 void vector::ExtractOp::build(Builder *builder, OperationState &result,
                               Value *source, ArrayRef<int32_t> position) {
   result.addOperands(source);
-  auto positionAttr = builder->getI32ArrayAttr(position);
+  auto positionAttr = getVectorSubscriptAttr(*builder, position);
   result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
                                            positionAttr));
   result.addAttribute(getPositionAttrName(), positionAttr);
@@ -641,7 +650,7 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser,
 void ShuffleOp::build(Builder *builder, OperationState &result, Value *v1,
                       Value *v2, ArrayRef<int32_t> mask) {
   result.addOperands({v1, v2});
-  auto maskAttr = builder->getI32ArrayAttr(mask);
+  auto maskAttr = getVectorSubscriptAttr(*builder, mask);
   result.addTypes(v1->getType());
   result.addAttribute(getMaskAttrName(), maskAttr);
 }
@@ -765,7 +774,7 @@ static LogicalResult verify(InsertElementOp op) {
 void InsertOp::build(Builder *builder, OperationState &result, Value *source,
                      Value *dest, ArrayRef<int32_t> position) {
   result.addOperands({source, dest});
-  auto positionAttr = builder->getI32ArrayAttr(position);
+  auto positionAttr = getVectorSubscriptAttr(*builder, position);
   result.addTypes(dest->getType());
   result.addAttribute(getPositionAttrName(), positionAttr);
 }
index d3b1d40..3700217 100644 (file)
@@ -281,13 +281,13 @@ func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
 //       CHECK:   llvm.return %[[i3]] : !llvm<"[3 x <4 x float>]">
 
 func @extract_element(%arg0: vector<16xf32>) -> f32 {
-  %0 = constant 15 : index
-  %1 = vector.extractelement %arg0[%0 : index]: vector<16xf32>
+  %0 = constant 15 : i32
+  %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32>
   return %1 : f32
 }
 // CHECK-LABEL: extract_element(%arg0: !llvm<"<16 x float>">)
-//       CHECK:   %[[c:.*]] = llvm.mlir.constant(15 : index) : !llvm.i64
-//       CHECK:   %[[x:.*]] = llvm.extractelement %arg0[%[[c]] : !llvm.i64] : !llvm<"<16 x float>">
+//       CHECK:   %[[c:.*]] = llvm.mlir.constant(15 : i32) : !llvm.i32
+//       CHECK:   %[[x:.*]] = llvm.extractelement %arg0[%[[c]] : !llvm.i32] : !llvm<"<16 x float>">
 //       CHECK:   llvm.return %[[x]] : !llvm.float
 
 func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
@@ -326,13 +326,13 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
 //       CHECK:   llvm.return {{.*}} : !llvm.float
 
 func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
-  %0 = constant 3 : index
-  %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<4xf32>
+  %0 = constant 3 : i32
+  %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32>
   return %1 : vector<4xf32>
 }
 // CHECK-LABEL: insert_element(%arg0: !llvm.float, %arg1: !llvm<"<4 x float>">)
-//       CHECK:   %[[c:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64
-//       CHECK:   %[[x:.*]] = llvm.insertelement %arg0, %arg1[%[[c]] : !llvm.i64] : !llvm<"<4 x float>">
+//       CHECK:   %[[c:.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
+//       CHECK:   %[[x:.*]] = llvm.insertelement %arg0, %arg1[%[[c]] : !llvm.i32] : !llvm<"<4 x float>">
 //       CHECK:   llvm.return %[[x]] : !llvm<"<4 x float>">
 
 func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
index 7e8fce9..d79c035 100644 (file)
@@ -67,9 +67,9 @@ func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
 // -----
 
 func @extract_element(%arg0: vector<4x4xf32>) {
-  %c = constant 3 : index
+  %c = constant 3 : i32
   // expected-error@+1 {{'vector.extractelement' op expected 1-D vector}}
-  %1 = vector.extractelement %arg0[%c : index] : vector<4x4xf32>
+  %1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32>
 }
 
 // -----
@@ -124,9 +124,9 @@ func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
 // -----
 
 func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) {
-  %c = constant 3 : index
+  %c = constant 3 : i32
   // expected-error@+1 {{'vector.insertelement' op expected 1-D vector}}
-  %0 = vector.insertelement %arg0, %arg1[%c : index] : vector<4x4xf32>
+  %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32>
 }
 
 // -----
index b43c675..06d5728 100644 (file)
@@ -55,10 +55,10 @@ func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
 
 // CHECK-LABEL: @extract_element
 func @extract_element(%a: vector<16xf32>) -> f32 {
-  // CHECK:      %[[C15:.*]] = constant 15 : index
-  %c = constant 15 : index
-  // CHECK-NEXT: vector.extractelement %{{.*}}[%[[C15]] : index] : vector<16xf32>
-  %1 = vector.extractelement %a[%c : index] : vector<16xf32>
+  // CHECK:      %[[C15:.*]] = constant 15 : i32
+  %c = constant 15 : i32
+  // CHECK-NEXT: vector.extractelement %{{.*}}[%[[C15]] : i32] : vector<16xf32>
+  %1 = vector.extractelement %a[%c : i32] : vector<16xf32>
   return %1 : f32
 }
 
@@ -75,10 +75,10 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f
 
 // CHECK-LABEL: @insert_element
 func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> {
-  // CHECK:      %[[C15:.*]] = constant 15 : index
-  %c = constant 15 : index
-  // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[%[[C15]] : index] : vector<16xf32>
-  %1 = vector.insertelement %a, %b[%c : index] : vector<16xf32>
+  // CHECK:      %[[C15:.*]] = constant 15 : i32
+  %c = constant 15 : i32
+  // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[%[[C15]] : i32] : vector<16xf32>
+  %1 = vector.insertelement %a, %b[%c : i32] : vector<16xf32>
   return %1 : vector<16xf32>
 }