Support SPIR-V constant op to take DenseElementsAttr as input.
authorHanhan Wang <hanchung@google.com>
Tue, 19 Nov 2019 04:01:28 +0000 (20:01 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 19 Nov 2019 04:02:05 +0000 (20:02 -0800)
Iterates each element to build the array. This includes a little refactor to
combine bool/int/float into a function, since they are similar. The only
difference is calling different function in the end.

PiperOrigin-RevId: 281210288

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/constant.mlir
mlir/test/Dialect/SPIRV/structure-ops.mlir

index 8964963..4c9dd5b 100644 (file)
@@ -1079,12 +1079,10 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
   if (parser.parseAttribute(value, kValueAttrName, state.attributes))
     return failure();
 
-  Type type;
-  if (value.getType().isa<NoneType>()) {
+  Type type = value.getType();
+  if (type.isa<NoneType>() || type.isa<TensorType>()) {
     if (parser.parseColonType(type))
       return failure();
-  } else {
-    type = value.getType();
   }
 
   return parser.addTypeToList(type, state.types);
@@ -1108,14 +1106,46 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
   switch (value.getKind()) {
   case StandardAttributes::Bool:
   case StandardAttributes::Integer:
-  case StandardAttributes::Float:
-  case StandardAttributes::DenseElements:
-  case StandardAttributes::SparseElements: {
+  case StandardAttributes::Float: {
     if (valueType != opType)
       return constOp.emitOpError("result type (")
              << opType << ") does not match value type (" << valueType << ")";
     return success();
   } break;
+  case StandardAttributes::DenseElements:
+  case StandardAttributes::SparseElements: {
+    if (valueType == opType)
+      break;
+    auto arrayType = opType.dyn_cast<spirv::ArrayType>();
+    auto shapedType = valueType.dyn_cast<ShapedType>();
+    if (!arrayType) {
+      return constOp.emitOpError(
+          "must have spv.array result type for array value");
+    }
+
+    int numElements = arrayType.getNumElements();
+    auto opElemType = arrayType.getElementType();
+    while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
+      numElements *= t.getNumElements();
+      opElemType = t.getElementType();
+    }
+    if (!opElemType.isIntOrFloat()) {
+      return constOp.emitOpError("only support nested array result type");
+    }
+
+    auto valueElemType = shapedType.getElementType();
+    if (valueElemType != opElemType) {
+      return constOp.emitOpError("result element type (")
+             << opElemType << ") does not match value element type ("
+             << valueElemType << ")";
+    }
+
+    if (numElements != shapedType.getNumElements()) {
+      return constOp.emitOpError("result number of elements (")
+             << numElements << ") does not match value number of elements ("
+             << shapedType.getNumElements() << ")";
+    }
+  } break;
   case StandardAttributes::Array: {
     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
     if (!arrayType)
index 0ff79d9..ebe3ceb 100644 (file)
@@ -240,29 +240,21 @@ private:
   /// constants.
   uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
 
-  /// Prepares bool ElementsAttr serialization. This method updates `opcode`
-  /// with a proper OpConstant* instruction and pushes literal values for the
-  /// constant to `operands`.
-  LogicalResult prepareBoolVectorConstant(Location loc,
-                                          DenseIntElementsAttr elementsAttr,
-                                          spirv::Opcode &opcode,
-                                          SmallVectorImpl<uint32_t> &operands);
-
-  /// Prepares int ElementsAttr serialization. This method updates `opcode` with
-  /// a proper OpConstant* instruction and pushes literal values for the
-  /// constant to `operands`.
-  LogicalResult prepareIntVectorConstant(Location loc,
-                                         DenseIntElementsAttr elementsAttr,
-                                         spirv::Opcode &opcode,
-                                         SmallVectorImpl<uint32_t> &operands);
-
-  /// Prepares float ElementsAttr serialization. This method updates `opcode`
-  /// with a proper OpConstant* instruction and pushes literal values for the
-  /// constant to `operands`.
-  LogicalResult prepareFloatVectorConstant(Location loc,
-                                           DenseFPElementsAttr elementsAttr,
-                                           spirv::Opcode &opcode,
-                                           SmallVectorImpl<uint32_t> &operands);
+  /// Prepares array attribute serialization. This method emits corresponding
+  /// OpConstant* and returns the result <id> associated with it. Returns 0 if
+  /// failed.
+  uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr);
+
+  /// Prepares bool/int/float DenseElementsAttr serialization. This method
+  /// iterates the DenseElementsAttr to construct the constant array, and
+  /// returns the result <id>  associated with it. Returns 0 if failed. Note
+  /// that the size of `index` must match the rank.
+  /// TODO(hanchung): Consider to enhance splat elements cases. For splat cases,
+  /// we don't need to loop over all elements, especially when the splat value
+  /// is zero. We can use OpConstantNull when the value is zero.
+  uint32_t prepareDenseElementsConstant(Location loc, Type constType,
+                                        DenseElementsAttr valueAttr, int dim,
+                                        MutableArrayRef<uint64_t> index);
 
   /// Prepares scalar attribute serialization. This method emits corresponding
   /// OpConstant* and returns the result <id> associated with it. Returns 0 if
@@ -1064,6 +1056,7 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
   if (auto id = prepareConstantScalar(loc, valueAttr)) {
     return id;
   }
+
   // This is a composite literal. We need to handle each component separately
   // and then emit an OpConstantComposite for the whole.
 
@@ -1075,179 +1068,92 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
   if (failed(processType(loc, constType, typeID))) {
     return 0;
   }
-  auto resultID = getNextID();
-
-  spirv::Opcode opcode = spirv::Opcode::OpNop;
-  SmallVector<uint32_t, 4> operands;
-  operands.push_back(typeID);
-  operands.push_back(resultID);
 
-  if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
-    if (vectorAttr.getType().getElementType().isInteger(1)) {
-      if (failed(prepareBoolVectorConstant(loc, vectorAttr, opcode, operands)))
-        return 0;
-    } else if (failed(
-                   prepareIntVectorConstant(loc, vectorAttr, opcode, operands)))
-      return 0;
-  } else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
-    if (failed(prepareFloatVectorConstant(loc, vectorAttr, opcode, operands)))
-      return 0;
+  uint32_t resultID = 0;
+  if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
+    int rank = attr.getType().dyn_cast<ShapedType>().getRank();
+    SmallVector<uint64_t, 4> index(rank);
+    resultID = prepareDenseElementsConstant(loc, constType, attr,
+                                            /*dim=*/0, index);
   } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
-    opcode = spirv::Opcode::OpConstantComposite;
-    operands.reserve(arrayAttr.size() + 2);
+    resultID = prepareArrayConstant(loc, constType, arrayAttr);
+  }
 
-    auto elementType = constType.cast<spirv::ArrayType>().getElementType();
-    for (Attribute elementAttr : arrayAttr)
-      if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
-        operands.push_back(elementID);
-      } else {
-        return 0;
-      }
-  } else {
+  if (resultID == 0) {
     emitError(loc, "cannot serialize attribute: ") << valueAttr;
     return 0;
   }
 
-  encodeInstructionInto(typesGlobalValues, opcode, operands);
   constIDMap[valueAttr] = resultID;
   return resultID;
 }
 
-LogicalResult Serializer::prepareBoolVectorConstant(
-    Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
-    SmallVectorImpl<uint32_t> &operands) {
-  auto type = elementsAttr.getType();
-  assert(type.hasRank() && type.getRank() == 1 &&
-         "spv.constant should have verified only vector literal uses "
-         "ElementsAttr");
-  assert(type.getElementType().isInteger(1) && "must be bool ElementsAttr");
-  auto count = type.getNumElements();
-
-  // Operands for constructing the SPIR-V OpConstant* instruction
-  operands.reserve(count + 2);
-
-  // For splat cases, we don't need to loop over all elements, especially when
-  // the splat value is zero.
-  if (elementsAttr.isSplat()) {
-    // We can use OpConstantNull if this bool ElementsAttr is splatting false.
-    if (!elementsAttr.getSplatValue<bool>()) {
-      opcode = spirv::Opcode::OpConstantNull;
-      return success();
-    }
-
-    if (auto id =
-            prepareConstantBool(loc, elementsAttr.getSplatValue<BoolAttr>())) {
-      opcode = spirv::Opcode::OpConstantComposite;
-      operands.append(count, id);
-      return success();
-    }
-
-    return failure();
+uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
+                                          ArrayAttr attr) {
+  uint32_t typeID = 0;
+  if (failed(processType(loc, constType, typeID))) {
+    return 0;
   }
 
-  // Otherwise, we need to process each element and compose them with
-  // OpConstantComposite.
-  opcode = spirv::Opcode::OpConstantComposite;
-  for (auto boolAttr : elementsAttr.getValues<BoolAttr>()) {
-    // We are constructing an BoolAttr for each value here. But given that
-    // we only use ElementsAttr for vectors with no more than 4 elements, it
-    // should be fine here.
-    if (auto elementID = prepareConstantBool(loc, boolAttr)) {
+  uint32_t resultID = getNextID();
+  SmallVector<uint32_t, 4> operands = {typeID, resultID};
+  operands.reserve(attr.size() + 2);
+  auto elementType = constType.cast<spirv::ArrayType>().getElementType();
+  for (Attribute elementAttr : attr) {
+    if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
       operands.push_back(elementID);
     } else {
-      return failure();
-    }
-  }
-  return success();
-}
-
-LogicalResult Serializer::prepareIntVectorConstant(
-    Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
-    SmallVectorImpl<uint32_t> &operands) {
-  auto type = elementsAttr.getType();
-  assert(type.hasRank() && type.getRank() == 1 &&
-         "spv.constant should have verified only vector literal uses "
-         "ElementsAttr");
-  assert(!type.getElementType().isInteger(1) &&
-         "must be non-bool ElementsAttr");
-  auto count = type.getNumElements();
-
-  // Operands for constructing the SPIR-V OpConstant* instruction
-  operands.reserve(count + 2);
-
-  // For splat cases, we don't need to loop over all elements, especially when
-  // the splat value is zero.
-  if (elementsAttr.isSplat()) {
-    auto splatAttr = elementsAttr.getSplatValue<IntegerAttr>();
-
-    // We can use OpConstantNull if this int ElementsAttr is splatting 0.
-    if (splatAttr.getValue().isNullValue()) {
-      opcode = spirv::Opcode::OpConstantNull;
-      return success();
-    }
-
-    if (auto id = prepareConstantInt(loc, splatAttr)) {
-      opcode = spirv::Opcode::OpConstantComposite;
-      operands.append(count, id);
-      return success();
+      return 0;
     }
-    return failure();
   }
+  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
+  encodeInstructionInto(typesGlobalValues, opcode, operands);
 
-  // Otherwise, we need to process each element and compose them with
-  // OpConstantComposite.
-  opcode = spirv::Opcode::OpConstantComposite;
-  for (auto intAttr : elementsAttr.getValues<IntegerAttr>()) {
-    // We are constructing an IntegerAttr for each value here. But given that
-    // we only use ElementsAttr for vectors with no more than 4 elements, it
-    // should be fine here.
-    // TODO(antiagainst): revisit this if special extensions enabling large
-    // vectors are supported.
-    if (auto elementID = prepareConstantInt(loc, intAttr)) {
-      operands.push_back(elementID);
-    } else {
-      return failure();
-    }
-  }
-  return success();
+  return resultID;
 }
 
-LogicalResult Serializer::prepareFloatVectorConstant(
-    Location loc, DenseFPElementsAttr elementsAttr, spirv::Opcode &opcode,
-    SmallVectorImpl<uint32_t> &operands) {
-  auto type = elementsAttr.getType();
-  assert(type.hasRank() && type.getRank() == 1 &&
-         "spv.constant should have verified only vector literal uses "
-         "ElementsAttr");
-  auto count = type.getNumElements();
-
-  operands.reserve(count + 2);
-
-  if (elementsAttr.isSplat()) {
-    FloatAttr splatAttr = elementsAttr.getSplatValue<FloatAttr>();
-    if (splatAttr.getValue().isZero()) {
-      opcode = spirv::Opcode::OpConstantNull;
-      return success();
+// TODO(hanchung): Turn the below function into iterative function, instead of
+// recursive function.
+uint32_t
+Serializer::prepareDenseElementsConstant(Location loc, Type constType,
+                                         DenseElementsAttr valueAttr, int dim,
+                                         MutableArrayRef<uint64_t> index) {
+  auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
+  assert(dim <= shapedType.getRank());
+  if (shapedType.getRank() == dim) {
+    if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
+      return attr.getType().getElementType().isInteger(1)
+                 ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index))
+                 : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index));
     }
-
-    if (auto id = prepareConstantFp(loc, splatAttr)) {
-      opcode = spirv::Opcode::OpConstantComposite;
-      operands.append(count, id);
-      return success();
+    if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
+      return prepareConstantFp(loc, attr.getValue<FloatAttr>(index));
     }
+    return 0;
+  }
 
-    return failure();
+  uint32_t typeID = 0;
+  if (failed(processType(loc, constType, typeID))) {
+    return 0;
   }
 
-  opcode = spirv::Opcode::OpConstantComposite;
-  for (auto fpAttr : elementsAttr.getValues<FloatAttr>()) {
-    if (auto elementID = prepareConstantFp(loc, fpAttr)) {
+  uint32_t resultID = getNextID();
+  SmallVector<uint32_t, 4> operands = {typeID, resultID};
+  operands.reserve(shapedType.getDimSize(dim) + 2);
+  auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
+  for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
+    index[dim] = i;
+    if (auto elementID = prepareDenseElementsConstant(
+            loc, elementType, valueAttr, dim + 1, index)) {
       operands.push_back(elementID);
     } else {
-      return failure();
+      return 0;
     }
   }
-  return success();
+  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
+  encodeInstructionInto(typesGlobalValues, opcode, operands);
+
+  return resultID;
 }
 
 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
index 9531209..50005be 100644 (file)
@@ -178,4 +178,18 @@ spv.module "Logical" "GLSL450" {
     %4 = spv.IAdd %0, %3 : i32
     spv.Return
   }
+
+  // CHECK-LABEL: @multi_dimensions_const
+  func @multi_dimensions_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) {
+    // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]], {{\[}}[7 : i32, 8 : i32, 9 : i32], [10 : i32, 11 : i32, 12 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
+    %0 = spv.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
+    spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
+  }
+
+  // CHECK-LABEL: @multi_dimensions_splat_const
+  func @multi_dimensions_splat_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) {
+    // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]], {{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
+    %0 = spv.constant dense<1> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
+    spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
+  }
 }
index 2bbb03f..8fe03f4 100644 (file)
@@ -48,12 +48,20 @@ func @const() -> () {
   // CHECK: %2 = spv.constant 5.000000e-01 : f32
   // CHECK: %3 = spv.constant dense<[2, 3]> : vector<2xi32>
   // CHECK: %4 = spv.constant [dense<3.000000e+00> : vector<2xf32>] : !spv.array<1 x vector<2xf32>>
+  // CHECK: %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
+  // CHECK: %6 = spv.constant dense<1.000000e+00> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
+  // CHECK: %7 = spv.constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
+  // CHECK: %8 = spv.constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
 
   %0 = spv.constant true
   %1 = spv.constant 42 : i32
   %2 = spv.constant 0.5 : f32
   %3 = spv.constant dense<[2, 3]> : vector<2xi32>
   %4 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
+  %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
+  %6 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
+  %7 = spv.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
+  %8 = spv.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
   return
 }
 
@@ -83,11 +91,33 @@ func @array_constant() -> () {
 
 // -----
 
+func @non_nested_array_constant() -> () {
+  // expected-error @+1 {{only support nested array result type}}
+  %0 = spv.constant dense<3.0> : tensor<2x2xf32> : !spv.array<2xvector<2xf32>>
+  return
+}
+
+// -----
+
 func @value_result_type_mismatch() -> () {
-  // expected-error @+1 {{result type ('vector<4xi32>') does not match value type ('tensor<4xi32>')}}
+  // expected-error @+1 {{must have spv.array result type for array value}}
   %0 = "spv.constant"() {value = dense<0> : tensor<4xi32>} : () -> (vector<4xi32>)
 }
 
+// -----
+
+func @value_result_type_mismatch() -> () {
+  // expected-error @+1 {{result element type ('i32') does not match value element type ('f32')}}
+  %0 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
+}
+
+// -----
+
+func @value_result_num_elements_mismatch() -> () {
+  // expected-error @+1 {{result number of elements (6) does not match value number of elements (4)}}
+  %0 = spv.constant dense<1.0> : tensor<2x2xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
+  return
+}
 
 // -----