Simplify the emission of various diagnostics emitted by the different dialects...
authorRiver Riddle <riverriddle@google.com>
Mon, 6 May 2019 16:46:11 +0000 (09:46 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sat, 11 May 2019 02:22:24 +0000 (19:22 -0700)
--

PiperOrigin-RevId: 246842016

mlir/include/mlir/IR/Diagnostics.h
mlir/lib/Dialect/Traits.cpp
mlir/lib/IR/Diagnostics.cpp
mlir/lib/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Quantization/IR/QuantTypes.cpp
mlir/lib/Quantization/Utils/FakeQuantSupport.cpp
mlir/lib/VectorOps/VectorOps.cpp
mlir/test/Quantization/convert-fakequant-invalid.mlir

index 31ce96c..361f2c8 100644 (file)
@@ -56,6 +56,7 @@ public:
   /// Enum that represents the different kinds of diagnostic arguments
   /// supported.
   enum class DiagnosticArgumentKind {
+    Double,
     Integer,
     String,
     Type,
@@ -68,10 +69,10 @@ public:
   /// Returns the kind of this argument.
   DiagnosticArgumentKind getKind() const { return kind; }
 
-  /// Returns this argument as a string.
-  StringRef getAsString() const {
-    assert(getKind() == DiagnosticArgumentKind::String);
-    return stringVal;
+  /// Returns this argument as a double.
+  double getAsDouble() const {
+    assert(getKind() == DiagnosticArgumentKind::Double);
+    return doubleVal;
   }
 
   /// Returns this argument as a signed integer.
@@ -80,6 +81,12 @@ public:
     return static_cast<int64_t>(opaqueVal);
   }
 
+  /// Returns this argument as a string.
+  StringRef getAsString() const {
+    assert(getKind() == DiagnosticArgumentKind::String);
+    return stringVal;
+  }
+
   /// Returns this argument as a Type.
   Type getAsType() const;
 
@@ -92,6 +99,11 @@ public:
 private:
   friend class Diagnostic;
 
+  // Construct from a floating point number.
+  explicit DiagnosticArgument(double val)
+      : kind(DiagnosticArgumentKind::Double), doubleVal(val) {}
+  explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {}
+
   // Construct from a signed integer.
   explicit DiagnosticArgument(int64_t val)
       : kind(DiagnosticArgumentKind::Integer), opaqueVal(val) {}
@@ -115,6 +127,7 @@ private:
 
   /// The value of this argument.
   union {
+    double doubleVal;
     intptr_t opaqueVal;
     StringRef stringVal;
   };
index 40092b2..8988dcd 100644 (file)
@@ -203,10 +203,10 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
                           retType.isa<UnrankedTensorType>() ||
                           isSameShapedVectorOrTensor(retType, broadcastedType);
   if (!hasCompatRetType)
-    return op->emitOpError(
-        llvm::formatv("result type '{0}' does not have the same shape as the "
-                      "broadcasted type '{1}' computed from the operand types",
-                      retType, broadcastedType));
+    return op->emitOpError()
+           << "result type '" << retType
+           << "' does not have the same shape as the broadcasted type '"
+           << broadcastedType << "' computed from the operand types";
 
   return success();
 }
index 35ad35e..877be72 100644 (file)
@@ -44,6 +44,9 @@ Type DiagnosticArgument::getAsType() const {
 /// Outputs this argument to a stream.
 void DiagnosticArgument::print(raw_ostream &os) const {
   switch (kind) {
+  case DiagnosticArgumentKind::Double:
+    os << getAsDouble();
+    break;
   case DiagnosticArgumentKind::Integer:
     os << getAsInteger();
     break;
index a28b3b8..20d5463 100644 (file)
@@ -902,8 +902,8 @@ LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func,
                                                       NamedAttribute argAttr) {
   // Check that llvm.noalias is a boolean attribute.
   if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>())
-    return func->emitError(
-        "llvm.noalias argument attribute of non boolean type");
+    return func->emitError()
+           << "llvm.noalias argument attribute of non boolean type";
   return success();
 }
 
index 587bd6b..5f57ebc 100644 (file)
@@ -1133,11 +1133,8 @@ Type LLVMLowering::convertType(Type t) {
     return result;
 
   auto *mlirContext = llvmDialect->getContext();
-  std::string message;
-  llvm::raw_string_ostream os(message);
-  os << "unsupported type: ";
-  t.print(os);
-  mlirContext->emitError(UnknownLoc::get(mlirContext), os.str());
+  mlirContext->emitError(UnknownLoc::get(mlirContext))
+      << "unsupported type: " << t;
   return {};
 }
 
index d8f9b13..958218d 100644 (file)
@@ -167,24 +167,24 @@ LogicalResult mlir::SliceOp::verify() {
     return emitOpError("first operand must come from a ViewOp");
   unsigned rank = getBaseViewRank();
   if (llvm::size(getIndexings()) != rank) {
-    return emitOpError("requires at least a view operand followed by " +
-                       Twine(rank) + " indexings");
+    return emitOpError("requires at least a view operand followed by ")
+           << rank << " indexings";
   }
   unsigned index = 0;
   for (auto indexing : getIndexings()) {
     if (!indexing->getType().isa<RangeType>() &&
         !indexing->getType().isa<IndexType>()) {
-      return emitOpError(Twine(index) +
-                         "^th index must be of range or index type");
+      return emitOpError() << index
+                           << "^th index must be of range or index type";
     }
     if (indexing->getType().isa<IndexType>())
       --rank;
     ++index;
   }
   if (getRank() != rank) {
-    return emitOpError("the rank of the view must be the number of its range "
-                       "indices (" +
-                       Twine(rank) + ") but got: " + Twine(getRank()));
+    return emitOpError()
+           << "the rank of the view must be the number of its range indices ("
+           << rank << ") but got: " << getRank();
   }
   return success();
 }
@@ -296,13 +296,13 @@ LogicalResult mlir::ViewOp::verify() {
   unsigned index = 0;
   for (auto indexing : getIndexings()) {
     if (!indexing->getType().isa<RangeType>()) {
-      return emitOpError(Twine(index) + "^th index must be of range type");
+      return emitOpError() << index << "^th index must be of range type";
     }
     ++index;
   }
   if (getViewType().getRank() != index)
-    return emitOpError(
-        "the rank of the view must be the number of its indexings");
+    return emitOpError()
+           << "the rank of the view must be the number of its indexings";
   return success();
 }
 
index efe0924..88ba4ae 100644 (file)
@@ -50,8 +50,7 @@ LogicalResult QuantizedType::verifyConstructionInvariants(
   // Verify storage width.
   if (integralWidth == 0 || integralWidth > MaxStorageBits) {
     if (loc) {
-      context->emitError(*loc,
-                         "illegal storage type size: " + Twine(integralWidth));
+      context->emitError(*loc, "illegal storage type size: ") << integralWidth;
     }
     return failure();
   }
@@ -67,9 +66,8 @@ LogicalResult QuantizedType::verifyConstructionInvariants(
       storageTypeMin < defaultIntegerMin ||
       storageTypeMax > defaultIntegerMax) {
     if (loc) {
-      context->emitError(*loc, "illegal storage min and storage max: (" +
-                                   Twine(storageTypeMin) + ":" +
-                                   Twine(storageTypeMax) + ")");
+      context->emitError(*loc, "illegal storage min and storage max: (")
+          << storageTypeMin << ":" << storageTypeMax << ")";
     }
     return failure();
   }
@@ -313,8 +311,7 @@ LogicalResult UniformQuantizedType::verifyConstructionInvariants(
   // Verify scale.
   if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
     if (loc) {
-      context->emitError(*loc,
-                         "illegal scale: " + Twine(std::to_string(scale)));
+      context->emitError(*loc) << "illegal scale: " << scale;
     }
     return failure();
   }
@@ -383,9 +380,8 @@ LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
   // Ensure that the number of scales and zeroPoints match.
   if (scales.size() != zeroPoints.size()) {
     if (loc) {
-      context->emitError(*loc, "illegal number of scales and zeroPoints: " +
-                                   Twine(scales.size()) + ", " +
-                                   Twine(zeroPoints.size()));
+      context->emitError(*loc, "illegal number of scales and zeroPoints: ")
+          << scales.size() << ", " << zeroPoints.size();
     }
     return failure();
   }
@@ -394,8 +390,7 @@ LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
   for (double scale : scales) {
     if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
       if (loc) {
-        context->emitError(*loc,
-                           "illegal scale: " + Twine(std::to_string(scale)));
+        context->emitError(*loc) << "illegal scale: " << scale;
       }
       return failure();
     }
index 3c61e48..26e88b9 100644 (file)
@@ -44,8 +44,7 @@ UniformQuantizedType mlir::quant::fakeQuantAttrsToType(Location loc,
     qmin = -32768;
     qmax = 32767;
   } else {
-    ctx->emitError(loc,
-                   "unsupported FakeQuant number of bits: " + Twine(numBits));
+    ctx->emitError(loc, "unsupported FakeQuant number of bits: ") << numBits;
     return nullptr;
   }
 
@@ -56,9 +55,8 @@ UniformQuantizedType mlir::quant::fakeQuantAttrsToType(Location loc,
 
   // Range must straddle zero.
   if (rmin > 0.0 || rmax < 0.0) {
-    return (ctx->emitError(loc, "FakeQuant range must straddle zero: [" +
-                                    Twine(std::to_string(rmin)) + "," +
-                                    Twine(std::to_string(rmax)) + "]"),
+    return (ctx->emitError(loc, "FakeQuant range must straddle zero: [")
+                << rmin << "," << rmax << "]",
             nullptr);
   }
 
index 263a0b3..6f41693 100644 (file)
@@ -195,9 +195,9 @@ LogicalResult VectorTransferReadOp::verify() {
                                  (optionalPaddingValue ? 1 : 0);
   // Checks on the actual operands and their types.
   if (getNumOperands() != expectedNumOperands) {
-    return emitOpError("expects " + Twine(expectedNumOperands) + " operands " +
-                       "(of which " + Twine(memrefType.getRank()) +
-                       " indices)");
+    return emitOpError("expects ")
+           << expectedNumOperands << " operands (of which "
+           << memrefType.getRank() << " indices)";
   }
   // Consistency of padding value with vector type.
   if (optionalPaddingValue) {
@@ -221,8 +221,8 @@ LogicalResult VectorTransferReadOp::verify() {
     ++numIndices;
   }
   if (numIndices != memrefType.getRank()) {
-    return emitOpError("requires at least a memref operand followed by " +
-                       Twine(memrefType.getRank()) + " indices");
+    return emitOpError("requires at least a memref operand followed by ")
+           << memrefType.getRank() << " indices";
   }
 
   // Consistency of AffineMap attribute.
@@ -242,9 +242,8 @@ LogicalResult VectorTransferReadOp::verify() {
   }
   if (permutationMap.getNumResults() != vectorType.getRank()) {
     return emitOpError("requires a permutation_map with result dims of the "
-                       "same rank as the vector type (" +
-                       Twine(permutationMap.getNumResults()) + " vs " +
-                       Twine(vectorType.getRank()));
+                       "same rank as the vector type (")
+           << permutationMap.getNumResults() << " vs " << vectorType.getRank();
   }
   return verifyPermutationMap(permutationMap,
                               [this](Twine t) { return emitOpError(t); });
@@ -340,9 +339,9 @@ LogicalResult VectorTransferWriteOp::verify() {
       Offsets::FirstIndexOffset + memrefType.getRank();
   // Checks on the actual operands and their types.
   if (getNumOperands() != expectedNumOperands) {
-    return emitOpError("expects " + Twine(expectedNumOperands) + " operands " +
-                       "(of which " + Twine(memrefType.getRank()) +
-                       " indices)");
+    return emitOpError() << "expects " << expectedNumOperands
+                         << " operands (of which " << memrefType.getRank()
+                         << " indices)";
   }
   // Consistency of indices types.
   unsigned numIndices = 0;
@@ -354,8 +353,8 @@ LogicalResult VectorTransferWriteOp::verify() {
     numIndices++;
   }
   if (numIndices != memrefType.getRank()) {
-    return emitOpError("requires at least a memref operand followed by " +
-                       Twine(memrefType.getRank()) + " indices");
+    return emitOpError("requires at least a memref operand followed by ")
+           << memrefType.getRank() << " indices";
   }
 
   // Consistency of AffineMap attribute.
@@ -375,9 +374,8 @@ LogicalResult VectorTransferWriteOp::verify() {
   }
   if (permutationMap.getNumResults() != vectorType.getRank()) {
     return emitOpError("requires a permutation_map with result dims of the "
-                       "same rank as the vector type (" +
-                       Twine(permutationMap.getNumResults()) + " vs " +
-                       Twine(vectorType.getRank()));
+                       "same rank as the vector type (")
+           << permutationMap.getNumResults() << " vs " << vectorType.getRank();
   }
   return verifyPermutationMap(permutationMap,
                               [this](Twine t) { return emitOpError(t); });
index e2ba76a..b108033 100644 (file)
@@ -4,7 +4,7 @@
 // Verify that a mismatched range errors.
 func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
 ^bb0(%arg0: tensor<8x4x3xf32>):
-  // expected-error@+1 {{FakeQuant range must straddle zero: [1.100000,1.500000]}}
+  // expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.500000e+00]}}
   %0 = "quant.const_fake_quant"(%arg0) {
     min: 1.1 : f32, max: 1.5 : f32, num_bits: 8
   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
@@ -15,7 +15,7 @@ func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
 // Verify that a valid range errors.
 func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
 ^bb0(%arg0: tensor<8x4x3xf32>):
-  // expected-error@+1 {{FakeQuant range must straddle zero: [1.100000,1.000000}}
+  // expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.000000e+00}}
   %0 = "quant.const_fake_quant"(%arg0) {
     min: 1.1 : f32, max: 1.0 : f32, num_bits: 8
   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>