[mlir][Linalg] Update ReshapeOp::build to be more idiomatic
authorNicolas Vasilache <ntv@google.com>
Mon, 13 Jan 2020 03:38:57 +0000 (22:38 -0500)
committerNicolas Vasilache <ntv@google.com>
Mon, 13 Jan 2020 15:56:07 +0000 (10:56 -0500)
Summary:
This diff makes it easier to create a `linalg.reshape` op
and adds an EDSC builder api test to exercise the new builders.

Reviewers: ftynse, jpienaar

Subscribers: mehdi_amini, rriddle, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72580

mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/EDSC/builder-api-test.cpp

index 42b286d..7777f5c 100644 (file)
@@ -17,7 +17,7 @@ namespace edsc {
 namespace intrinsics {
 
 using linalg_fill = OperationBuilder<linalg::FillOp>;
-using linalg_reshape = OperationBuilder<linalg::ReshapeOp>;
+using linalg_reshape = ValueBuilder<linalg::ReshapeOp>;
 using linalg_yield = OperationBuilder<linalg::YieldOp>;
 
 } // namespace intrinsics
index 8e1f40f..db15f51 100644 (file)
@@ -100,9 +100,17 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
     ```
   }];
 
-  let builders = [OpBuilder<
-    "Builder *b, OperationState &result, Value view, "
-    "ArrayAttr reassociation, ArrayRef<NamedAttribute> attrs = {}">];
+  let builders = [
+    // Builder for a contracting reshape whose result type is computed from
+    // `view` and `reassociation`.
+    OpBuilder<"Builder *b, OperationState &result, Value view, "
+    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+    "ArrayRef<NamedAttribute> attrs = {}">,
+    // Builder for a reshape whose result type is passed explicitly. This may be
+    // either a contracting or expanding reshape.
+    OpBuilder<"Builder *b, OperationState &result, Type resultType, Value view,"
+    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+    "ArrayRef<NamedAttribute> attrs = {}">];
 
   let extraClassDeclaration = [{
     static StringRef getReassociationAttrName() { return "reassociation"; }
index cf27a81..e244542 100644 (file)
@@ -465,14 +465,52 @@ static SmallVector<AffineMap, 4> getAffineMaps(ArrayAttr attrs) {
       [](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }, attrs);
 }
 
-void mlir::linalg::ReshapeOp::build(Builder *b, OperationState &result,
-                                    Value view, ArrayAttr reassociation,
-                                    ArrayRef<NamedAttribute> attrs) {
-  auto maps = getAffineMaps(reassociation);
+template <typename AffineExprTy>
+unsigned getMaxPosOfType(ArrayRef<ArrayRef<AffineExpr>> exprArrays) {
+  unsigned pos = 0;
+  for (auto exprs : exprArrays) {
+    for (auto expr : exprs) {
+      expr.walk([&pos](AffineExpr e) {
+        if (auto d = e.dyn_cast<AffineExprTy>())
+          pos = std::max(pos, d.getPosition());
+      });
+    }
+  }
+  return pos;
+}
+
+static SmallVector<AffineMap, 4>
+getSymbolLessAffineMaps(ArrayRef<ArrayRef<AffineExpr>> reassociation) {
+  unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
+  unsigned maxSym = getMaxPosOfType<AffineSymbolExpr>(reassociation);
+  assert(maxSym == 0 && "Expected symbol-less expressions");
+  SmallVector<AffineMap, 4> maps;
+  maps.reserve(reassociation.size());
+  for (auto exprs : reassociation)
+    maps.push_back(AffineMap::get(maxDim + 1, 0, exprs));
+  return maps;
+}
+
+void mlir::linalg::ReshapeOp::build(
+    Builder *b, OperationState &result, Value view,
+    ArrayRef<ArrayRef<AffineExpr>> reassociation,
+    ArrayRef<NamedAttribute> attrs) {
+  auto maps = getSymbolLessAffineMaps(reassociation);
   auto memRefType = view.getType().cast<MemRefType>();
   auto resultType = computeReshapeCollapsedType(memRefType, maps);
   build(b, result, resultType, view, attrs);
-  result.addAttribute(ReshapeOp::getReassociationAttrName(), reassociation);
+  result.addAttribute(ReshapeOp::getReassociationAttrName(),
+                      b->getAffineMapArrayAttr(maps));
+}
+
+void mlir::linalg::ReshapeOp::build(
+    Builder *b, OperationState &result, Type resultType, Value view,
+    ArrayRef<ArrayRef<AffineExpr>> reassociation,
+    ArrayRef<NamedAttribute> attrs) {
+  auto maps = getSymbolLessAffineMaps(reassociation);
+  build(b, result, resultType, view, attrs);
+  result.addAttribute(ReshapeOp::getReassociationAttrName(),
+                      b->getAffineMapArrayAttr(maps));
 }
 
 static void print(OpAsmPrinter &p, ReshapeOp op) {
index fcd5e37..7ddfe50 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/AffineOps/AffineOps.h"
 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
 #include "mlir/EDSC/Builders.h"
@@ -962,6 +963,32 @@ TEST_FUNC(linalg_dilated_conv_nhwc) {
   f.erase();
 }
 
+// clang-format off
+// CHECK-LABEL: func @linalg_metadata_ops
+//       CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<4x8x16xf32> into memref<32x16xf32>
+//       CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<32x16xf32> into memref<4x8x16xf32>
+// clang-format on
+TEST_FUNC(linalg_metadata_ops) {
+  using namespace edsc;
+  using namespace edsc::intrinsics;
+
+  auto f32Type = FloatType::getF32(&globalContext());
+  auto memrefType = MemRefType::get({4, 8, 16}, f32Type, {}, 0);
+  auto f = makeFunction("linalg_metadata_ops", {}, {memrefType});
+
+  OpBuilder builder(f.getBody());
+  ScopedContext scope(builder, f.getLoc());
+  AffineExpr i, j, k;
+  bindDims(&globalContext(), i, j, k);
+  ValueHandle v(f.getArgument(0));
+  auto reshaped = linalg_reshape(v, ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
+  linalg_reshape(memrefType, reshaped,
+                 ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
+
+  f.print(llvm::outs());
+  f.erase();
+}
+
 int main() {
   RUN_TESTS();
   return 0;