NFC: Remove the 'context' parameter from OperationState.
authorRiver Riddle <riverriddle@google.com>
Sat, 22 Jun 2019 18:08:52 +0000 (11:08 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 22 Jun 2019 20:05:10 +0000 (13:05 -0700)
Now that Locations are Attributes they contain a direct reference to the MLIRContext, i.e. the context can be directly accessed from the given location instead of being explicitly passed in.

PiperOrigin-RevId: 254568329

13 files changed:
mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/EDSC/Builders.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp
mlir/lib/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Transforms/MaterializeVectors.cpp
mlir/lib/Transforms/Utils/Utils.cpp
mlir/lib/Transforms/Vectorize.cpp

index a9c044d..1abcc97 100644 (file)
@@ -231,7 +231,7 @@ private:
 
     // Build the MLIR operation from the name and the two operands. The return
     // type is always a generic array for binary operators.
-    mlir::OperationState result(&context, location, op_name);
+    mlir::OperationState result(location, op_name);
     result.types.push_back(getType(VarType{}));
     result.operands.push_back(L);
     result.operands.push_back(R);
@@ -253,7 +253,7 @@ private:
   bool mlirGen(ReturnExprAST &ret) {
     auto location = loc(ret.loc());
     // `return` takes an optional expression, we need to account for it here.
-    mlir::OperationState result(&context, location, "toy.return");
+    mlir::OperationState result(location, "toy.return");
     if (ret.getExpr().hasValue()) {
       auto *expr = mlirGen(*ret.getExpr().getValue());
       if (!expr)
@@ -306,7 +306,7 @@ private:
                      .cast<mlir::DenseElementsAttr>());
 
     // Build the MLIR op `toy.constant`, only boilerplate below.
-    mlir::OperationState result(&context, location, "toy.constant");
+    mlir::OperationState result(location, "toy.constant");
     result.types.push_back(type);
     result.attributes.push_back(dataAttribute);
     return builder->createOperation(result)->getResult(0);
@@ -348,7 +348,7 @@ private:
     }
     // builtin have their custom operation, this is a straightforward emission.
     if (callee == "transpose") {
-      mlir::OperationState result(&context, location, "toy.transpose");
+      mlir::OperationState result(location, "toy.transpose");
       result.types.push_back(getType(VarType{}));
       result.operands = std::move(operands);
       return builder->createOperation(result)->getResult(0);
@@ -356,7 +356,7 @@ private:
 
     // Calls to user-defined functions are mapped to a custom call that takes
     // the callee name as an attribute.
-    mlir::OperationState result(&context, location, "toy.generic_call");
+    mlir::OperationState result(location, "toy.generic_call");
     result.types.push_back(getType(VarType{}));
     result.operands = std::move(operands);
     auto calleeAttr = builder->getStringAttr(call.getCallee());
@@ -372,7 +372,7 @@ private:
     if (!arg)
       return false;
     auto location = loc(call.loc());
-    mlir::OperationState result(&context, location, "toy.print");
+    mlir::OperationState result(location, "toy.print");
     result.operands.push_back(arg);
     builder->createOperation(result);
     return true;
@@ -381,7 +381,7 @@ private:
   // Emit a constant for a single number (FIXME: semantic? broadcast?)
   mlir::Value *mlirGen(NumberExprAST &num) {
     auto location = loc(num.loc());
-    mlir::OperationState result(&context, location, "toy.constant");
+    mlir::OperationState result(location, "toy.constant");
     mlir::Type elementType = mlir::FloatType::getF64(&context);
     result.types.push_back(builder->getMemRefType({1}, elementType));
     auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
@@ -427,7 +427,7 @@ private:
       // with specific shape, we emit a "reshape" operation. It will get
       // optimized out later as needed.
       if (!vardecl.getType().shape.empty()) {
-        mlir::OperationState result(&context, location, "toy.reshape");
+        mlir::OperationState result(location, "toy.reshape");
         result.types.push_back(getType(vardecl.getType()));
         result.operands.push_back(value);
         value = builder->createOperation(result)->getResult(0);
index c04ca7a..ce35336 100644 (file)
@@ -297,7 +297,7 @@ public:
   /// Create an operation of specific op type at the current insertion point.
   template <typename OpTy, typename... Args>
   OpTy create(Location location, Args... args) {
-    OperationState state(getContext(), location, OpTy::getOperationName());
+    OperationState state(location, OpTy::getOperationName());
     OpTy::build(this, &state, args...);
     auto *op = createOperation(state);
     auto result = dyn_cast<OpTy>(op);
index 5ee3a69..52aa617 100644 (file)
@@ -896,7 +896,7 @@ void ensureRegionTerminator(
 template <typename OpTy>
 void ensureRegionTerminator(Region &region, Builder &builder, Location loc) {
   ensureRegionTerminator(region, loc, [&] {
-    OperationState state(loc->getContext(), loc, OpTy::getOperationName());
+    OperationState state(loc, OpTy::getOperationName());
     OpTy::build(&builder, &state);
     return Operation::create(state);
   });
index 035421f..bd99e38 100644 (file)
@@ -247,13 +247,12 @@ struct OperationState {
   bool resizableOperandList = false;
 
 public:
-  OperationState(MLIRContext *context, Location location, StringRef name);
+  OperationState(Location location, StringRef name);
 
-  OperationState(MLIRContext *context, Location location, OperationName name);
+  OperationState(Location location, OperationName name);
 
-  OperationState(MLIRContext *context, Location location, StringRef name,
-                 ArrayRef<Value *> operands, ArrayRef<Type> types,
-                 ArrayRef<NamedAttribute> attributes,
+  OperationState(Location location, StringRef name, ArrayRef<Value *> operands,
+                 ArrayRef<Type> types, ArrayRef<NamedAttribute> attributes,
                  ArrayRef<Block *> successors = {},
                  MutableArrayRef<std::unique_ptr<Region>> regions = {},
                  bool resizableOperandList = false);
@@ -270,7 +269,7 @@ public:
 
   /// Add an attribute with the specified name.
   void addAttribute(StringRef name, Attribute attr) {
-    addAttribute(Identifier::get(name, context), attr);
+    addAttribute(Identifier::get(name, getContext()), attr);
   }
 
   /// Add an attribute with the specified name.
@@ -299,6 +298,9 @@ public:
   void setOperandListToResizable(bool isResizable = true) {
     resizableOperandList = isResizable;
   }
+
+  /// Get the context held by this operation state.
+  MLIRContext *getContext() { return location->getContext(); }
 };
 
 namespace detail {
index 9efba12..149e554 100644 (file)
@@ -272,7 +272,7 @@ public:
   /// without verifying to see if it is valid.
   template <typename OpTy, typename... Args>
   OpTy create(Location location, Args... args) {
-    OperationState state(getContext(), location, OpTy::getOperationName());
+    OperationState state(location, OpTy::getOperationName());
     OpTy::build(this, &state, args...);
     auto *op = createOperation(state);
     auto result = dyn_cast<OpTy>(op);
@@ -285,7 +285,7 @@ public:
   /// and return null.
   template <typename OpTy, typename... Args>
   OpTy createChecked(Location location, Args... args) {
-    OperationState state(getContext(), location, OpTy::getOperationName());
+    OperationState state(location, OpTy::getOperationName());
     OpTy::build(this, &state, args...);
     auto *op = createOperation(state);
 
index af59d3a..43c5345 100644 (file)
@@ -117,8 +117,7 @@ OperationHandle OperationHandle::create(StringRef name,
                                         ArrayRef<ValueHandle> operands,
                                         ArrayRef<Type> resultTypes,
                                         ArrayRef<NamedAttribute> attributes) {
-  OperationState state(ScopedContext::getContext(),
-                       ScopedContext::getLocation(), name);
+  OperationState state(ScopedContext::getLocation(), name);
   SmallVector<Value *, 4> ops(operands.begin(), operands.end());
   state.addOperands(ops);
   state.addTypes(resultTypes);
index 0e306ee..7857f04 100644 (file)
@@ -29,22 +29,21 @@ using namespace mlir;
 // OperationState
 //===----------------------------------------------------------------------===//
 
-OperationState::OperationState(MLIRContext *context, Location location,
-                               StringRef name)
-    : context(context), location(location), name(name, context) {}
+OperationState::OperationState(Location location, StringRef name)
+    : context(location->getContext()), location(location),
+      name(name, location->getContext()) {}
 
-OperationState::OperationState(MLIRContext *context, Location location,
-                               OperationName name)
-    : context(context), location(location), name(name) {}
+OperationState::OperationState(Location location, OperationName name)
+    : context(location->getContext()), location(location), name(name) {}
 
-OperationState::OperationState(MLIRContext *context, Location location,
-                               StringRef name, ArrayRef<Value *> operands,
-                               ArrayRef<Type> types,
+OperationState::OperationState(Location location, StringRef name,
+                               ArrayRef<Value *> operands, ArrayRef<Type> types,
                                ArrayRef<NamedAttribute> attributes,
                                ArrayRef<Block *> successors,
                                MutableArrayRef<std::unique_ptr<Region>> regions,
                                bool resizableOperandList)
-    : context(context), location(location), name(name, context),
+    : context(location->getContext()), location(location),
+      name(name, location->getContext()),
       operands(operands.begin(), operands.end()),
       types(types.begin(), types.end()),
       attributes(attributes.begin(), attributes.end()),
index 088cf75..fb77f43 100644 (file)
@@ -2823,7 +2823,7 @@ Operation *OperationParser::parseGenericOperation() {
 
   consumeToken(Token::string);
 
-  OperationState result(builder.getContext(), srcLocation, name);
+  OperationState result(srcLocation, name);
 
   // Generic operations have a resizable operation list.
   result.setOperandListToResizable();
@@ -3325,7 +3325,7 @@ Operation *OperationParser::parseCustomOperation() {
   auto srcLocation = getEncodedSourceLocation(opLoc);
 
   // Have the op implementation take a crack and parsing this.
-  OperationState opState(builder.getContext(), srcLocation, opDefinition->name);
+  OperationState opState(srcLocation, opDefinition->name);
   CleanupOpStateRegions guard{opState};
   if (opAsmParser.parseOperation(opDefinition, &opState))
     return nullptr;
index e529927..95908b0 100644 (file)
@@ -42,8 +42,7 @@ Block *createOneBlockFunction(Builder builder, Module *module) {
   auto *block = new Block();
   fn->push_back(block);
 
-  OperationState state(builder.getContext(), builder.getUnknownLoc(),
-                       ReturnOp::getOperationName());
+  OperationState state(builder.getUnknownLoc(), ReturnOp::getOperationName());
   ReturnOp::build(&builder, &state);
   block->push_back(Operation::create(state));
 
index 9c6d00e..8341c88 100644 (file)
@@ -177,8 +177,7 @@ LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
 
 spirv::ModuleOp Deserializer::createModuleOp() {
   Builder builder(context);
-  OperationState state(context, unknownLoc,
-                       spirv::ModuleOp::getOperationName());
+  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
   // TODO(antiagainst): use target environment to select the version
   state.addAttribute("major_version", builder.getI32IntegerAttr(1));
   state.addAttribute("minor_version", builder.getI32IntegerAttr(0));
index 842feef..cd92198 100644 (file)
@@ -428,9 +428,8 @@ static Operation *instantiate(OpBuilder b, Operation *opInst,
 
   auto attrs = materializeAttributes(opInst, hwVectorType);
 
-  OperationState state(b.getContext(), opInst->getLoc(),
-                       opInst->getName().getStringRef(), operands,
-                       {hwVectorType}, attrs);
+  OperationState state(opInst->getLoc(), opInst->getName().getStringRef(),
+                       operands, {hwVectorType}, attrs);
   return b.createOperation(state);
 }
 
index 2e2bc08..876f44b 100644 (file)
@@ -114,8 +114,7 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
     unsigned memRefOperandPos = getMemRefOperandPos();
 
     // Construct the new operation using this memref.
-    OperationState state(opInst->getContext(), opInst->getLoc(),
-                         opInst->getName());
+    OperationState state(opInst->getLoc(), opInst->getName());
     state.setOperandListToResizable(opInst->hasResizableOperandsList());
     state.operands.reserve(opInst->getNumOperands() + extraIndices.size());
     // Insert the non-memref operands.
index 8a28517..39a05d8 100644 (file)
@@ -926,8 +926,7 @@ static Value *vectorizeConstant(Operation *op, ConstantOp constant, Type type) {
   auto attr = DenseElementsAttr::get(vectorType, constant.getValue());
   auto *constantOpInst = constant.getOperation();
 
-  OperationState state(b.getContext(), loc,
-                       constantOpInst->getName().getStringRef(), {},
+  OperationState state(loc, constantOpInst->getName().getStringRef(), {},
                        {vectorType}, {b.getNamedAttr("value", attr)});
 
   return b.createOperation(state)->getResult(0);
@@ -1055,9 +1054,9 @@ static Operation *vectorizeOneOperation(Operation *opInst,
   // TODO(ntv): Is it worth considering an Operation.clone operation which
   // changes the type so we can promote an Operation with less boilerplate?
   OpBuilder b(opInst);
-  OperationState newOp(b.getContext(), opInst->getLoc(),
-                       opInst->getName().getStringRef(), vectorOperands,
-                       vectorTypes, opInst->getAttrs(), /*successors=*/{},
+  OperationState newOp(opInst->getLoc(), opInst->getName().getStringRef(),
+                       vectorOperands, vectorTypes, opInst->getAttrs(),
+                       /*successors=*/{},
                        /*regions=*/{}, opInst->hasResizableOperandsList());
   return b.createOperation(newOp);
 }