Simplify API uses of `getContext()` (NFC)
authorMehdi Amini <aminim@google.com>
Wed, 27 Mar 2019 20:57:02 +0000 (13:57 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 30 Mar 2019 00:47:11 +0000 (17:47 -0700)
The Pass base class is providing a convenience getContext() accessor.

PiperOrigin-RevId: 240634961

mlir/lib/IR/Function.cpp
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Transforms/Canonicalizer.cpp
mlir/lib/Transforms/MaterializeVectors.cpp
mlir/lib/Transforms/SimplifyAffineStructures.cpp
mlir/lib/Transforms/StripDebugInfo.cpp

index fa9328f..d8cc4ed 100644 (file)
@@ -156,7 +156,7 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
   dest->setAttrs(newAttrs.takeVector());
 
   // Clone the body.
-  body.cloneInto(&dest->body, mapper, dest->getContext());
+  body.cloneInto(&dest->body, mapper, getContext());
 }
 
 /// Create a deep copy of this function and all of its blocks, remapping
index d5f430f..b029cb5 100644 (file)
@@ -159,7 +159,6 @@ Type TypeConverter::convertIntegerType(IntegerType type) {
 }
 
 Type TypeConverter::convertFloatType(FloatType type) {
-  MLIRContext *context = type.getContext();
   switch (type.getKind()) {
   case mlir::StandardTypes::F32:
     return wrap(builder.getFloatTy());
@@ -168,8 +167,8 @@ Type TypeConverter::convertFloatType(FloatType type) {
   case mlir::StandardTypes::F16:
     return wrap(builder.getHalfTy());
   case mlir::StandardTypes::BF16:
-    return context->emitError(UnknownLoc::get(context),
-                              "unsupported type: BF16"),
+    return mlirContext->emitError(UnknownLoc::get(mlirContext),
+                                  "unsupported type: BF16"),
            Type();
   default:
     llvm_unreachable("non-float type in convertFloatType");
@@ -236,11 +235,11 @@ FunctionType TypeConverter::convertFunctionSignatureType(FunctionType type) {
 
   // If function does not return anything, return immediately.
   if (type.getNumResults() == 0)
-    return FunctionType::get(argTypes, {}, type.getContext());
+    return FunctionType::get(argTypes, {}, mlirContext);
 
   // Otherwise pack the result types into a struct.
   if (auto result = getPackedResultType(type.getResults()))
-    return FunctionType::get(argTypes, {result}, type.getContext());
+    return FunctionType::get(argTypes, {result}, mlirContext);
 
   return {};
 }
@@ -271,9 +270,8 @@ Type TypeConverter::convertMemRefType(MemRefType type) {
 // Convert a 1D vector type to an LLVM vector type.
 Type TypeConverter::convertVectorType(VectorType type) {
   if (type.getRank() != 1) {
-    MLIRContext *context = type.getContext();
-    context->emitError(UnknownLoc::get(context),
-                       "only 1D vectors are supported");
+    mlirContext->emitError(UnknownLoc::get(mlirContext),
+                           "only 1D vectors are supported");
     return {};
   }
 
@@ -300,12 +298,11 @@ Type TypeConverter::convertType(Type type) {
   if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
     return llvmType;
 
-  MLIRContext *context = type.getContext();
   std::string message;
   llvm::raw_string_ostream os(message);
   os << "unsupported type: ";
   type.print(os);
-  context->emitError(UnknownLoc::get(context), os.str());
+  mlirContext->emitError(UnknownLoc::get(mlirContext), os.str());
   return {};
 }
 
index 5457975..8a2002c 100644 (file)
@@ -45,7 +45,7 @@ void Canonicalizer::runOnFunction() {
   // TODO: Instead of adding all known patterns from the whole system lazily add
   // and cache the canonicalization patterns for ops we see in practice when
   // building the worklist.  For now, we just grab everything.
-  auto *context = func.getContext();
+  auto *context = &getContext();
   for (auto *op : context->getRegisteredOperations())
     op->getCanonicalizationPatterns(patterns, context);
 
index c151085..2a877c4 100644 (file)
@@ -745,7 +745,7 @@ void MaterializeVectorsPass::runOnFunction() {
   // Get the hardware vector type.
   // TODO(ntv): get elemental type from super-vector type rather than force f32.
   auto subVectorType =
-      VectorType::get(state.hwVectorSize, FloatType::getF32(f->getContext()));
+      VectorType::get(state.hwVectorSize, FloatType::getF32(&getContext()));
 
   // Capture terminators; i.e. vector_transfer_write ops involving a strict
   // super-vector of subVectorType.
index 7a9cfc7..4ff5367 100644 (file)
@@ -71,7 +71,7 @@ struct SimplifyAffineStructures
     FlatAffineConstraints fac(set);
     if (fac.isEmpty())
       return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
-                                     set.getContext());
+                                     &getContext());
     return set;
   }
 
index bc2b4b9..9d6b7a0 100644 (file)
@@ -30,7 +30,7 @@ struct StripDebugInfo : public FunctionPass<StripDebugInfo> {
 
 void StripDebugInfo::runOnFunction() {
   Function &func = getFunction();
-  UnknownLoc unknownLoc = UnknownLoc::get(func.getContext());
+  UnknownLoc unknownLoc = UnknownLoc::get(&getContext());
 
   // Strip the debug info from the function and its instructions.
   func.setLoc(unknownLoc);