namespace intrinsics {
using linalg_fill = OperationBuilder<linalg::FillOp>;
-using linalg_reshape = OperationBuilder<linalg::ReshapeOp>;
+using linalg_reshape = ValueBuilder<linalg::ReshapeOp>;
using linalg_yield = OperationBuilder<linalg::YieldOp>;
} // namespace intrinsics
```
}];
- let builders = [OpBuilder<
- "Builder *b, OperationState &result, Value view, "
- "ArrayAttr reassociation, ArrayRef<NamedAttribute> attrs = {}">];
+ let builders = [
+ // Builder for a contracting reshape whose result type is computed from
+ // `view` and `reassociation`.
+ OpBuilder<"Builder *b, OperationState &result, Value view, "
+ "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+ "ArrayRef<NamedAttribute> attrs = {}">,
+ // Builder for a reshape whose result type is passed explicitly. This may be
+ // either a contracting or expanding reshape.
+ OpBuilder<"Builder *b, OperationState &result, Type resultType, Value view,"
+ "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+ "ArrayRef<NamedAttribute> attrs = {}">];
let extraClassDeclaration = [{
static StringRef getReassociationAttrName() { return "reassociation"; }
[](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }, attrs);
}
-void mlir::linalg::ReshapeOp::build(Builder *b, OperationState &result,
- Value view, ArrayAttr reassociation,
- ArrayRef<NamedAttribute> attrs) {
- auto maps = getAffineMaps(reassociation);
+template <typename AffineExprTy>
+unsigned getMaxPosOfType(ArrayRef<ArrayRef<AffineExpr>> exprArrays) {
+ unsigned pos = 0;
+ for (auto exprs : exprArrays) {
+ for (auto expr : exprs) {
+ expr.walk([&pos](AffineExpr e) {
+ if (auto d = e.dyn_cast<AffineExprTy>())
+ pos = std::max(pos, d.getPosition());
+ });
+ }
+ }
+ return pos;
+}
+
+static SmallVector<AffineMap, 4>
+getSymbolLessAffineMaps(ArrayRef<ArrayRef<AffineExpr>> reassociation) {
+ unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
+ unsigned maxSym = getMaxPosOfType<AffineSymbolExpr>(reassociation);
+ assert(maxSym == 0 && "Expected symbol-less expressions");
+ SmallVector<AffineMap, 4> maps;
+ maps.reserve(reassociation.size());
+ for (auto exprs : reassociation)
+ maps.push_back(AffineMap::get(maxDim + 1, 0, exprs));
+ return maps;
+}
+
+void mlir::linalg::ReshapeOp::build(
+ Builder *b, OperationState &result, Value view,
+ ArrayRef<ArrayRef<AffineExpr>> reassociation,
+ ArrayRef<NamedAttribute> attrs) {
+ auto maps = getSymbolLessAffineMaps(reassociation);
auto memRefType = view.getType().cast<MemRefType>();
auto resultType = computeReshapeCollapsedType(memRefType, maps);
build(b, result, resultType, view, attrs);
- result.addAttribute(ReshapeOp::getReassociationAttrName(), reassociation);
+ result.addAttribute(ReshapeOp::getReassociationAttrName(),
+ b->getAffineMapArrayAttr(maps));
+}
+
+void mlir::linalg::ReshapeOp::build(
+ Builder *b, OperationState &result, Type resultType, Value view,
+ ArrayRef<ArrayRef<AffineExpr>> reassociation,
+ ArrayRef<NamedAttribute> attrs) {
+ auto maps = getSymbolLessAffineMaps(reassociation);
+ build(b, result, resultType, view, attrs);
+ result.addAttribute(ReshapeOp::getReassociationAttrName(),
+ b->getAffineMapArrayAttr(maps));
}
static void print(OpAsmPrinter &p, ReshapeOp op) {
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/EDSC/Builders.h"
f.erase();
}
+// clang-format off
+// CHECK-LABEL: func @linalg_metadata_ops
+// CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<4x8x16xf32> into memref<32x16xf32>
+// CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<32x16xf32> into memref<4x8x16xf32>
+// clang-format on
+TEST_FUNC(linalg_metadata_ops) {
+ using namespace edsc;
+ using namespace edsc::intrinsics;
+
+ auto f32Type = FloatType::getF32(&globalContext());
+ auto memrefType = MemRefType::get({4, 8, 16}, f32Type, {}, 0);
+ auto f = makeFunction("linalg_metadata_ops", {}, {memrefType});
+
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ AffineExpr i, j, k;
+ bindDims(&globalContext(), i, j, k);
+ ValueHandle v(f.getArgument(0));
+ auto reshaped = linalg_reshape(v, ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
+ linalg_reshape(memrefType, reshaped,
+ ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
+
+ f.print(llvm::outs());
+ f.erase();
+}
+
int main() {
RUN_TESTS();
return 0;