}
}];
let verifier = [{ return ::verify(*this); }];
+
+ let hasFolder = 1;
}
def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
}
}];
let verifier = [{ return ::verify(*this); }];
+
+ let hasFolder = 1;
}
def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
StringAttr::get(getReductionIteratorTypeName(), ctx), ctx);
}
}];
+
+ let hasFolder = 1;
}
def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
return ArrayAttr::get(iters, ctx);
}
}];
+
+ let hasFolder = 1;
}
def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
return ArrayAttr::get(iters, ctx);
}
}];
+
+ let hasFolder = 1;
}
def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
.cast<IntegerAttr>().getValue().getSExtValue();
}
}];
+
let verifier = [{ return ::verify(*this); }];
+
+ let hasFolder = 1;
}
def LinalgOperand: Type<
tensor SSA values are expected to be useful and will be added in the near
future.
}];
+
let verifier = [{ return ::verify(*this); }];
+
+ let hasFolder = 1;
}
def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
tensor SSA values are expected to be useful and will be added in the near
future.
}];
+
let verifier = [{ return ::verify(*this); }];
+
+ let hasFolder = 1;
}
#endif // LINALG_STRUCTURED_OPS
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
using namespace mlir;
using namespace mlir::linalg;
+/// Determines whether it is possible to fold it away in the parent Linalg op:
+///
+/// ```mlir
+/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
+/// %2 = linalg.slice %1 ... : memref<?x?xf32> ...
+/// // or
+/// %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
+/// to memref<?x?xf32>
+/// linalg.generic(%1 ...) : memref<?x?xf32> ...
+/// ```
+///
+/// into
+///
+/// ```mlir
+/// %2 = linalg.slice %0 ... : memref<8x16xf32> ...
+/// // or
+/// linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
+/// ```
+///
+static bool canFold(MemRefCastOp castOp) {
+ MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
+ MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
+
+ // If we don't have MemRefType as source and destination, bail out.
+ if (!sourceType || !resultType)
+ return false;
+
+ // If resultType has a map, it needs to be the same as the source type to
+ // canonicalize.
+ if (!resultType.getAffineMaps().empty() &&
+ sourceType.getAffineMaps() != resultType.getAffineMaps())
+ return false;
+
+ // Ensure that:
+ // 1. source is static
+ // 2. source and target have the same rank (will be extended when needed)
+ // 3. if result is partially static, ensure sizes match.
+ if (!sourceType.hasStaticShape() ||
+ sourceType.getRank() != resultType.getRank())
+ return false;
+
+ for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
+ auto sourceSize = std::get<0>(it);
+ auto resultSize = std::get<1>(it);
+ if (ShapedType::isDynamic(resultSize))
+ continue;
+ if (sourceSize != resultSize)
+ return false;
+ }
+
+ // If source has a map, it can only canonicalize if it is the canonical
+ // strided layout map.
+ if (sourceType.getAffineMaps().empty())
+ return true;
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto res = getStridesAndOffset(sourceType, strides, offset);
+ (void)res;
+ assert(succeeded(res));
+ auto stridedMap =
+ makeStridedLinearLayoutMap(strides, offset, castOp.getContext());
+ AffineMap sourceMap = sourceType.getAffineMaps().front();
+ return sourceMap == stridedMap;
+}
+
+/// This is a common class used for patterns of the form
+/// ```
+/// someop(memrefcast) -> someop
+/// ```
+/// It folds the source of any memref_cast into the root operation directly.
+static LogicalResult foldMemRefCast(Operation *op) {
+ bool folded = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
+ if (castOp && canFold(castOp)) {
+ operand.set(castOp.getOperand());
+ folded = true;
+ }
+ }
+ return success(folded);
+}
+
///////////////////// Operations defined with Tablegen /////////////////////////
// For such operations that do not correspond to library calls (i.e. defined in
// LinalgOps.td), we define an overloaded `print` function and a
ArrayAttr mlir::linalg::MatvecOp::indexing_maps() {
return getIndexingMaps(getOperation());
}
+
+// TODO(ntv, rriddle): Consider making all this boilerplate easy to autogenerate
+// with Tablegen. This seems a desirable property in the context of OpInterfaces
+// where a Linalg "named" op **isa** LinalgOp.
+LogicalResult ConvOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
+LogicalResult CopyOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
+LogicalResult DotOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
+LogicalResult FillOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
+LogicalResult GenericOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
+LogicalResult IndexedGenericOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
+LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
+LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
+OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return {};
+}
+OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return {};
+}
+OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return {};
+}
--- /dev/null
+// RUN: mlir-opt %s -canonicalize | FileCheck %s
+
+// CHECK-LABEL: func @memref_cast(
+func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c8 = constant 8 : index
+ %c16 = constant 16 : index
+ %1 = alloc (%b) : memref<?xi8>
+ %2 = view %1[][] : memref<?xi8> to memref<16x16xf32>
+ %3 = memref_cast %2 : memref<16x16xf32> to memref<?x?xf32>
+ %r0 = linalg.range %c0:%c8:%c1 : !linalg.range
+
+ // CHECK: linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
+ %4 = linalg.slice %3[%r0, %r0] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
+
+ // CHECK: linalg.matmul{{.*}}: memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>
+ linalg.matmul(%3, %3, %3) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+ return %4: memref<?x?xf32>
+}