The `memref.global` operation declares or defines a named global memref
variable. The backing memory for the variable is allocated statically and is
described by the type of the variable (which should be a statically shaped
- memref type). The operation is a declaration if no `inital_value` is
+ memref type). The operation is a declaration if no `initial_value` is
specified, else it is a definition. The `initial_value` can either be a unit
attribute to represent a definition of an uninitialized global variable, or
an elements attribute to represent the definition of a global variable with
bool isUninitialized() {
return !isExternal() && initial_value().getValue().isa<UnitAttr>();
}
+ /// Returns the constant initial value if the memref.global is a constant,
+ /// or null otherwise.
+ ElementsAttr getConstantInitValue();
}];
let hasVerifier = 1;
}
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
-#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/InliningUtils.h"
-#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
/// load(memrefcast) -> load
if (succeeded(foldMemRefCast(*this)))
return getResult();
- return OpFoldResult();
+
+ // Fold load from a global constant memref.
+ auto getGlobalOp = memref().getDefiningOp<memref::GetGlobalOp>();
+ if (!getGlobalOp)
+ return {};
+ // Get to the memref.global defining the symbol.
+ auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
+ if (!symbolTableOp)
+ return {};
+ auto global = dyn_cast_or_null<memref::GlobalOp>(
+ SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.nameAttr()));
+ if (!global)
+ return {};
+ if (auto cstAttr =
+ global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>()) {
+ // We can fold only if we know the indices.
+ if (!getAffineMap().isConstant())
+ return {};
+ auto indices = llvm::to_vector<4>(
+ llvm::map_range(getAffineMap().getConstantResults(),
+ [](int64_t v) -> uint64_t { return v; }));
+ return cstAttr.getValues<Attribute>()[indices];
+ }
+ return {};
}
//===----------------------------------------------------------------------===//
return success();
}
+ElementsAttr GlobalOp::getConstantInitValue() {
+ auto initVal = initial_value();
+ if (constant() && initVal.hasValue())
+ return initVal.getValue().cast<ElementsAttr>();
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// GetGlobalOp
//===----------------------------------------------------------------------===//
return %0, %1: index, index
}
+
+// -----
+
+module {
+ memref.global "private" constant @__constant_1x5x1xf32 : memref<1x5x1xf32> = dense<[[[6.250000e-02], [2.500000e-01], [3.750000e-01], [2.500000e-01], [6.250000e-02]]]>
+ // CHECK-LABEL: func @fold_const_init_global_memref
+ func @fold_const_init_global_memref() -> (f32, f32) {
+ %m = memref.get_global @__constant_1x5x1xf32 : memref<1x5x1xf32>
+ %v0 = affine.load %m[0, 0, 0] : memref<1x5x1xf32>
+ %v1 = affine.load %m[0, 1, 0] : memref<1x5x1xf32>
+ return %v0, %v1 : f32, f32
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 6.250000e-02 : f32
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 2.500000e-01 : f32
+ // CHECK-NEXT: return
+ // CHECK-SAME: %[[C0]]
+ // CHECK-SAME: %[[C1]]
+ }
+}