(De)serialize float scalar spv.constant
authorLei Zhang <antiagainst@google.com>
Mon, 22 Jul 2019 13:01:34 +0000 (06:01 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 22 Jul 2019 13:02:32 +0000 (06:02 -0700)
This CL adds support for float scalar spv.constant in (de)serialization.

PiperOrigin-RevId: 259311776

mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/SPIRV/Serialization/constant.mlir

index 33bd50e..60a44c4 100644 (file)
@@ -443,8 +443,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
       floatTy = opBuilder.getF64Type();
       break;
     default:
-      return emitError(unknownLoc, "unsupported bitwdith ")
-             << operands[1] << " with OpTypeFloat";
+      return emitError(unknownLoc, "unsupported OpTypeFloat bitwdith: ")
+             << operands[1];
     }
     typeMap[operands[0]] = floatTy;
   } break;
@@ -556,6 +556,31 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
 
     auto attr = opBuilder.getIntegerAttr(intType, value);
     op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr);
+  } else if (auto floatType = resultType.dyn_cast<FloatType>()) {
+    auto bitwidth = floatType.getWidth();
+    if (failed(checkOperandSizeForBitwidth(bitwidth))) {
+      return failure();
+    }
+
+    APFloat value(0.f);
+    if (floatType.isF64()) {
+      // Double values are represented with two SPIR-V words. According to
+      // SPIR-V spec: "When the type’s bit width is larger than one word, the
+      // literal’s low-order words appear first."
+      struct DoubleWord {
+        uint32_t word1;
+        uint32_t word2;
+      } words = {operands[2], operands[3]};
+      value = APFloat(llvm::bit_cast<double>(words));
+    } else if (floatType.isF32()) {
+      value = APFloat(llvm::bit_cast<float>(operands[2]));
+    } else if (floatType.isF16()) {
+      APInt data(16, operands[2]);
+      value = APFloat(APFloat::IEEEhalf(), data);
+    }
+
+    auto attr = opBuilder.getFloatAttr(floatType, value);
+    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr);
   } else {
     return emitError(unknownLoc, "OpConstant can only generate values of "
                                  "scalar integer or floating-point type");
@@ -564,6 +589,7 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
   valueMap[operands[1]] = op.getResult();
   return success();
 }
+
 LogicalResult Deserializer::processConstantBool(bool isTrue,
                                                 ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
index ce440b9..0d0434b 100644 (file)
@@ -173,6 +173,8 @@ private:
 
   uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr);
 
+  uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr);
+
   //===--------------------------------------------------------------------===//
   // Operations
   //===--------------------------------------------------------------------===//
@@ -488,6 +490,9 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
 
 uint32_t Serializer::prepareConstant(Location loc, Type constType,
                                      Attribute valueAttr) {
+  if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
+    return prepareConstantFp(loc, floatAttr);
+  }
   if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
     return prepareConstantInt(loc, intAttr);
   }
@@ -575,6 +580,50 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
   return constIDMap[intAttr] = resultID;
 }
 
+uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) {
+  if (auto id = findConstantID(floatAttr)) {
+    return id;
+  }
+
+  // Process the type for this float literal
+  uint32_t typeID = 0;
+  if (failed(processType(loc, floatAttr.getType(), typeID))) {
+    return 0;
+  }
+
+  auto resultID = getNextID();
+  APFloat value = floatAttr.getValue();
+  APInt intValue = value.bitcastToAPInt();
+
+  if (&value.getSemantics() == &APFloat::IEEEsingle()) {
+    uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
+    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
+                          {typeID, resultID, word});
+  } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
+    struct DoubleWord {
+      uint32_t word1;
+      uint32_t word2;
+    } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
+    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
+                          {typeID, resultID, words.word1, words.word2});
+  } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
+    uint32_t word =
+        static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
+    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
+                          {typeID, resultID, word});
+  } else {
+    std::string valueStr;
+    llvm::raw_string_ostream rss(valueStr);
+    value.print(rss);
+
+    emitError(loc, "cannot serialize ")
+        << floatAttr.getType() << "-typed float literal: " << rss.str();
+    return 0;
+  }
+
+  return constIDMap[floatAttr] = resultID;
+}
+
 //===----------------------------------------------------------------------===//
 // Operation
 //===----------------------------------------------------------------------===//
index fda6d05..3ac9b57 100644 (file)
@@ -32,6 +32,33 @@ func @spirv_module() -> () {
     %9 = spv.constant -32768 : i16 // -2^15
     // CHECK: spv.constant 32767 : i16
     %10 = spv.constant 32767 : i16 //  2^15 - 1
+
+    // float
+    // CHECK: spv.constant 0.000000e+00 : f32
+    %11 = spv.constant 0. : f32
+    // CHECK: spv.constant 1.000000e+00 : f32
+    %12 = spv.constant 1. : f32
+    // CHECK: spv.constant -0.000000e+00 : f32
+    %13 = spv.constant -0. : f32
+    // CHECK: spv.constant -1.000000e+00 : f32
+    %14 = spv.constant -1. : f32
+    // CHECK: spv.constant 7.500000e-01 : f32
+    %15 = spv.constant 0.75 : f32
+    // CHECK: spv.constant -2.500000e-01 : f32
+    %16 = spv.constant -0.25 : f32
+
+    // double
+    // TODO(antiagainst): test range boundary values
+    // CHECK: spv.constant 1.024000e+03 : f64
+    %17 = spv.constant 1024. : f64
+    // CHECK: spv.constant -1.024000e+03 : f64
+    %18 = spv.constant -1024. : f64
+
+    // half
+    // CHECK: spv.constant 5.120000e+02 : f16
+    %19 = spv.constant 512. : f16
+    // CHECK: spv.constant -5.120000e+02 : f16
+    %20 = spv.constant -512. : f16
   }
   return
 }