[mlir][ods] Make Attr/Type def accessors match the dialect
authorMogball <jeffniu22@gmail.com>
Tue, 14 Jun 2022 05:04:56 +0000 (05:04 +0000)
committerMogball <jeffniu22@gmail.com>
Tue, 14 Jun 2022 05:13:24 +0000 (05:13 +0000)
The generated attribute and type def accessors are changed to match the setting on the dialect. Most importantly, "prefixed" will now correctly convert snake case to camel case (e.g. `weight_zp` -> `getWeightZp`)

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D127688

mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h

index c2aafca0831b0af481daef569040f40c3698d80d..0a23e0fed56cc8fe692c6fa80509a17bfc4facb4 100644 (file)
@@ -58,6 +58,9 @@ public:
   /// Get the parameter name.
   StringRef getName() const;
 
+  /// Get the parameter accessor name.
+  std::string getAccessorName() const;
+
   /// If specified, get the custom allocator code for this parameter.
   Optional<StringRef> getAllocator() const;
 
index b1d0b696009da4172a0148d1d3aef3f28f5163e8..6a14e6a0f4d477980ad3f8253e3e51c7d9f82717 100644 (file)
@@ -147,8 +147,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
       cast<tosa::NegateOp>(op).quantization_info()) {
     auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
     int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
-    int64_t inZp = quantizationInfo.getValue().getInput_zp();
-    int64_t outZp = quantizationInfo.getValue().getOutput_zp();
+    int64_t inZp = quantizationInfo.getValue().getInputZp();
+    int64_t outZp = quantizationInfo.getValue().getOutputZp();
 
     // Compute the maximum value that can occur in the intermediate buffer.
     int64_t zpAdd = inZp + outZp;
@@ -1847,7 +1847,7 @@ public:
       } else if (elementTy.isa<IntegerType>() && !padOp.quantization_info()) {
         constantAttr = rewriter.getIntegerAttr(elementTy, 0);
       } else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
-        int64_t value = padOp.quantization_info().getValue().getInput_zp();
+        int64_t value = padOp.quantization_info().getValue().getInputZp();
         constantAttr = rewriter.getIntegerAttr(elementTy, value);
       }
       if (constantAttr)
index 2154cd98204e0597bffca93719b4c7aa8af92a30..bd3eb0feca647ebf45b5fc508e41cc0f34f06c62 100644 (file)
@@ -202,7 +202,7 @@ public:
     if (isQuantized) {
       auto quantizationInfo =
           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
-      int64_t iZp = quantizationInfo.getInput_zp();
+      int64_t iZp = quantizationInfo.getInputZp();
 
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
@@ -274,8 +274,8 @@ public:
     if (isQuantized) {
       auto quantizationInfo =
           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
-      auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp());
-      auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp());
+      auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
+      auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
 
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -366,8 +366,8 @@ public:
     if (isQuantized) {
       auto quantizationInfo =
           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
-      iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp());
-      kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp());
+      iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
+      kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
     }
 
     auto weightShape = weightTy.getShape();
@@ -378,7 +378,7 @@ public:
     if (isQuantized) {
       auto quantizationInfo =
           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
-      int64_t iZp = quantizationInfo.getInput_zp();
+      int64_t iZp = quantizationInfo.getInputZp();
 
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
@@ -542,9 +542,9 @@ public:
 
     auto quantizationInfo = op.quantization_info().getValue();
     auto aZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getA_zp()));
+        loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
     auto bZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getB_zp()));
+        loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
     rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
         op, TypeRange{op.getType()},
         ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
@@ -652,9 +652,9 @@ public:
 
     auto quantizationInfo = op.quantization_info().getValue();
     auto inputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp()));
+        loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
     auto outputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp()));
+        loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
     Value matmul =
         rewriter
             .create<linalg::QuantizedMatmulOp>(
@@ -892,8 +892,7 @@ public:
             if (op.quantization_info()) {
               auto quantizationInfo = op.quantization_info().getValue();
               auto inputZp = rewriter.create<arith::ConstantOp>(
-                  loc,
-                  b.getIntegerAttr(accETy, quantizationInfo.getInput_zp()));
+                  loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
               Value offset =
                   rewriter.create<arith::MulIOp>(loc, accETy, countI, inputZp);
               poolVal =
@@ -930,7 +929,7 @@ public:
               auto quantizationInfo = op.quantization_info().getValue();
               auto outputZp = rewriter.create<arith::ConstantOp>(
                   loc, b.getIntegerAttr(scaled.getType(),
-                                        quantizationInfo.getOutput_zp()));
+                                        quantizationInfo.getOutputZp()));
               scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
                            .getResult();
             }
index 17f3412999f3f1eaab28fe5c59bd334591f79864..b588afa10647625f4f4d4b218077e8efda6d35d4 100644 (file)
@@ -145,7 +145,7 @@ spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
 
 DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) {
   if (auto entryPoint = spirv::lookupEntryPointABI(op))
-    return entryPoint.getLocal_size();
+    return entryPoint.getLocalSize();
 
   return {};
 }
index 3a01557f5c9de55cb28d4ce9b5c6b798bab8f514..3ea2224a4408ac4e787f5fc1ec943286dba0457f 100644 (file)
@@ -135,7 +135,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
       funcOp.getLoc(), executionModel.getValue(), funcOp, interfaceVars);
 
   // Specifies the spv.ExecutionModeOp.
-  auto localSizeAttr = entryPointAttr.getLocal_size();
+  auto localSizeAttr = entryPointAttr.getLocalSize();
   if (localSizeAttr) {
     auto values = localSizeAttr.getValues<int32_t>();
     SmallVector<int32_t, 3> localSize(values);
index d84e5d218dc2a4a1dd1f42e68ec0a04753ad4eec..d7203c6afbe46f1c352ea371de040386dfea28e5 100644 (file)
@@ -347,7 +347,7 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
     } else if (elementTy.isa<IntegerType>() && !op.quantization_info()) {
       constantAttr = rewriter.getIntegerAttr(elementTy, 0);
     } else if (elementTy.isa<IntegerType>() && op.quantization_info()) {
-      auto value = op.quantization_info().getValue().getInput_zp();
+      auto value = op.quantization_info().getValue().getInputZp();
       constantAttr = rewriter.getIntegerAttr(elementTy, value);
     }
 
index 3389dda46e1b0d8821235a470380fd1d2b46505c..1db101280ef239d52b95d69cfe9e5ff6d1bf8465 100644 (file)
@@ -214,7 +214,7 @@ public:
       weight = createOpAndInfer<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
           weightPaddingVal, nullptr,
-          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeight_zp()));
+          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
 
     } else {
       weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
@@ -278,7 +278,7 @@ public:
       input = createOpAndInfer<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(inputETy), input,
           inputPaddingVal, nullptr,
-          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInput_zp()));
+          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
     } else {
       input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
                                             UnrankedTensorType::get(inputETy),
index 444db742bd32e5706f5d7f8d36fdc560b3f6d549..8467af0ee74c0fdad7c3c953b7b1d6c2453edcfa 100644 (file)
@@ -215,6 +215,11 @@ StringRef AttrOrTypeParameter::getName() const {
   return def->getArgName(index)->getValue();
 }
 
+std::string AttrOrTypeParameter::getAccessorName() const {
+  return "get" +
+         llvm::convertToCamelFromSnakeCase(getName(), /*capitalizeFirst=*/true);
+}
+
 Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
   return getDefValue<llvm::StringInit>("allocator");
 }
index 6e32b431bf823ca93c2d68c0276bb85664b8984d..6895cb15fcdb50773c179d843cd47310233d3eb4 100644 (file)
@@ -7,16 +7,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "AttrOrTypeFormatGen.h"
-#include "mlir/Support/LogicalResult.h"
 #include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/Class.h"
 #include "mlir/TableGen/CodeGenHelpers.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Interfaces.h"
-#include "llvm/ADT/Sequence.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/TableGen/Error.h"
@@ -31,13 +27,6 @@ using namespace mlir::tblgen;
 // Utility Functions
 //===----------------------------------------------------------------------===//
 
-std::string mlir::tblgen::getParameterAccessorName(StringRef name) {
-  assert(!name.empty() && "parameter has empty name");
-  auto ret = "get" + name.str();
-  ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name
-  return ret;
-}
-
 /// Find all the AttrOrTypeDef for the specified dialect. If no dialect
 /// specified and can only find one dialect's defs, use that.
 static void collectAllDefs(StringRef selectedDialect,
@@ -288,7 +277,7 @@ void DefGen::emitParserPrinter() {
 void DefGen::emitAccessors() {
   for (auto &param : params) {
     Method *m = defCls.addMethod(
-        param.getCppAccessorType(), getParameterAccessorName(param.getName()),
+        param.getCppAccessorType(), param.getAccessorName(),
         def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
     // Generate accessor definitions only if we also generate the storage
     // class. Otherwise, let the user define the exact accessor definition.
index e1b0774e2a2b819da52a27310bdf9d41af052f78..a31943790aadae3b3d26590636974056131a1b98 100644 (file)
@@ -58,7 +58,7 @@ public:
 
   /// Generate the code to check whether the parameter should be printed.
   MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
-    std::string self = getParameterAccessorName(getName()) + "()";
+    std::string self = param.getAccessorName() + "()";
     ctx.withSelf(self);
     os << tgfmt("($_self", &ctx);
     if (llvm::Optional<StringRef> defaultValue = getParam().getDefaultValue()) {
@@ -718,7 +718,7 @@ void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
 void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
                                    MethodBody &os, bool skipGuard) {
   const AttrOrTypeParameter &param = el->getParam();
-  ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
+  ctx.withSelf(param.getAccessorName() + "()");
 
   // Guard the printer on the presence of optional parameters and that they
   // aren't equal to their default values (if they have one).
@@ -812,8 +812,7 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
     if (auto *ref = dyn_cast<RefDirective>(arg))
       param = ref->getArg();
     os << ",\n"
-       << getParameterAccessorName(cast<ParameterElement>(param)->getName())
-       << "()";
+       << cast<ParameterElement>(param)->getParam().getAccessorName() << "()";
   }
   os.unindent() << ");\n";
 }
index c371aee268b421662d498c066ed703517a455f1e..d4711532a79bb30ab58b93a3cc86a426c1419b8f 100644 (file)
@@ -20,11 +20,6 @@ class AttrOrTypeDef;
 void generateAttrOrTypeFormat(const AttrOrTypeDef &def, MethodBody &parser,
                               MethodBody &printer);
 
-/// From the parameter name, get the name of the accessor function in camelcase.
-/// The first letter of the parameter is upper-cased and prefixed with "get".
-/// E.g. 'value' -> 'getValue'.
-std::string getParameterAccessorName(llvm::StringRef name);
-
 } // namespace tblgen
 } // namespace mlir