/// Get the parameter name.
StringRef getName() const;
+ /// Get the parameter accessor name.
+ std::string getAccessorName() const;
+
/// If specified, get the custom allocator code for this parameter.
Optional<StringRef> getAllocator() const;
cast<tosa::NegateOp>(op).quantization_info()) {
auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
- int64_t inZp = quantizationInfo.getValue().getInput_zp();
- int64_t outZp = quantizationInfo.getValue().getOutput_zp();
+ int64_t inZp = quantizationInfo.getValue().getInputZp();
+ int64_t outZp = quantizationInfo.getValue().getOutputZp();
// Compute the maximum value that can occur in the intermediate buffer.
int64_t zpAdd = inZp + outZp;
} else if (elementTy.isa<IntegerType>() && !padOp.quantization_info()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
} else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
- int64_t value = padOp.quantization_info().getValue().getInput_zp();
+ int64_t value = padOp.quantization_info().getValue().getInputZp();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
if (constantAttr)
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
- int64_t iZp = quantizationInfo.getInput_zp();
+ int64_t iZp = quantizationInfo.getInputZp();
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
- auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp());
- auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp());
+ auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
+ auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
- iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp());
- kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp());
+ iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
+ kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
}
auto weightShape = weightTy.getShape();
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
- int64_t iZp = quantizationInfo.getInput_zp();
+ int64_t iZp = quantizationInfo.getInputZp();
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
auto quantizationInfo = op.quantization_info().getValue();
auto aZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getA_zp()));
+ loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
auto bZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getB_zp()));
+ loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
auto quantizationInfo = op.quantization_info().getValue();
auto inputZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp()));
+ loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
auto outputZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp()));
+ loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
Value matmul =
rewriter
.create<linalg::QuantizedMatmulOp>(
if (op.quantization_info()) {
auto quantizationInfo = op.quantization_info().getValue();
auto inputZp = rewriter.create<arith::ConstantOp>(
- loc,
- b.getIntegerAttr(accETy, quantizationInfo.getInput_zp()));
+ loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
Value offset =
rewriter.create<arith::MulIOp>(loc, accETy, countI, inputZp);
poolVal =
auto quantizationInfo = op.quantization_info().getValue();
auto outputZp = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(scaled.getType(),
- quantizationInfo.getOutput_zp()));
+ quantizationInfo.getOutputZp()));
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
.getResult();
}
DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) {
if (auto entryPoint = spirv::lookupEntryPointABI(op))
- return entryPoint.getLocal_size();
+ return entryPoint.getLocalSize();
return {};
}
funcOp.getLoc(), executionModel.getValue(), funcOp, interfaceVars);
// Specifies the spv.ExecutionModeOp.
- auto localSizeAttr = entryPointAttr.getLocal_size();
+ auto localSizeAttr = entryPointAttr.getLocalSize();
if (localSizeAttr) {
auto values = localSizeAttr.getValues<int32_t>();
SmallVector<int32_t, 3> localSize(values);
} else if (elementTy.isa<IntegerType>() && !op.quantization_info()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
} else if (elementTy.isa<IntegerType>() && op.quantization_info()) {
- auto value = op.quantization_info().getValue().getInput_zp();
+ auto value = op.quantization_info().getValue().getInputZp();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
weight = createOpAndInfer<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal, nullptr,
- rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeight_zp()));
+ rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
} else {
weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
input = createOpAndInfer<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal, nullptr,
- rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInput_zp()));
+ rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
} else {
input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
UnrankedTensorType::get(inputETy),
return def->getArgName(index)->getValue();
}
+std::string AttrOrTypeParameter::getAccessorName() const {
+ return "get" +
+ llvm::convertToCamelFromSnakeCase(getName(), /*capitalizeFirst=*/true);
+}
+
Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
return getDefValue<llvm::StringInit>("allocator");
}
//===----------------------------------------------------------------------===//
#include "AttrOrTypeFormatGen.h"
-#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
-#include "llvm/ADT/Sequence.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/TableGen/Error.h"
// Utility Functions
//===----------------------------------------------------------------------===//
-std::string mlir::tblgen::getParameterAccessorName(StringRef name) {
- assert(!name.empty() && "parameter has empty name");
- auto ret = "get" + name.str();
- ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name
- return ret;
-}
-
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
/// specified and can only find one dialect's defs, use that.
static void collectAllDefs(StringRef selectedDialect,
void DefGen::emitAccessors() {
for (auto ¶m : params) {
Method *m = defCls.addMethod(
- param.getCppAccessorType(), getParameterAccessorName(param.getName()),
+ param.getCppAccessorType(), param.getAccessorName(),
def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
// Generate accessor definitions only if we also generate the storage
// class. Otherwise, let the user define the exact accessor definition.
/// Generate the code to check whether the parameter should be printed.
MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
- std::string self = getParameterAccessorName(getName()) + "()";
+ std::string self = param.getAccessorName() + "()";
ctx.withSelf(self);
os << tgfmt("($_self", &ctx);
if (llvm::Optional<StringRef> defaultValue = getParam().getDefaultValue()) {
void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
MethodBody &os, bool skipGuard) {
const AttrOrTypeParameter ¶m = el->getParam();
- ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
+ ctx.withSelf(param.getAccessorName() + "()");
// Guard the printer on the presence of optional parameters and that they
// aren't equal to their default values (if they have one).
if (auto *ref = dyn_cast<RefDirective>(arg))
param = ref->getArg();
os << ",\n"
- << getParameterAccessorName(cast<ParameterElement>(param)->getName())
- << "()";
+ << cast<ParameterElement>(param)->getParam().getAccessorName() << "()";
}
os.unindent() << ");\n";
}
void generateAttrOrTypeFormat(const AttrOrTypeDef &def, MethodBody &parser,
MethodBody &printer);
-/// From the parameter name, get the name of the accessor function in camelcase.
-/// The first letter of the parameter is upper-cased and prefixed with "get".
-/// E.g. 'value' -> 'getValue'.
-std::string getParameterAccessorName(llvm::StringRef name);
-
} // namespace tblgen
} // namespace mlir