[EDSC] Fix Stmt::operator= and allow DimOp in For loops
authorNicolas Vasilache <ntv@google.com>
Wed, 20 Feb 2019 08:41:42 +0000 (00:41 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 23:33:26 +0000 (16:33 -0700)
This CL fixes 2 recent issues with EDSCs:
1. the type of the LHS in Stmt::operator=(Expr rhs) should be the same as the (asserted unique) return type;
2. symbols coming from DimOp should be admissible as lower / upper bounds in For

The relevant tests are added.

PiperOrigin-RevId: 234750249

mlir/lib/EDSC/LowerEDSCTestPass.cpp
mlir/lib/EDSC/MLIREmitter.cpp
mlir/lib/EDSC/Types.cpp
mlir/test/EDSC/for-loops.mlir

index 25fdadd397d8c66128941c051e465a9b707b3d0a..c812ce90f0b7fb7fc45987dd154608543446d94c 100644 (file)
@@ -118,6 +118,37 @@ PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
     return success();
   }
 
+  // Inject an EDSC-constructed computation that assigns Stmt and uses the LHS.
+  if (f->getName().strref().contains("assignments")) {
+    FuncBuilder builder(f);
+    edsc::ScopedEDSCContext context;
+    edsc::MLIREmitter emitter(&builder, f->getLoc());
+
+    edsc::Expr zero = emitter.zero();
+    edsc::Expr one = emitter.one();
+    auto args = emitter.makeBoundFunctionArguments(f);
+    auto views = emitter.makeBoundMemRefViews(args.begin(), args.end());
+
+    Type indexType = builder.getIndexType();
+    edsc::Expr i(indexType);
+    edsc::Expr A = args[0], B = args[1], C = args[2];
+    edsc::Expr M = views[0].dim(0);
+    // clang-format off
+    using namespace edsc::op;
+    edsc::Stmt scalarA, scalarB, tmp;
+    auto block = edsc::block({
+      For(i, zero, M, one, {
+        scalarA = load(A, {i}),
+        scalarB = load(B, {i}),
+        tmp = scalarA * scalarB,
+        store(tmp, C, {i})
+      }),
+    });
+    // clang-format on
+
+    emitter.emitStmts(block.getBody());
+  }
+
   f->walk([](Instruction *op) {
     if (op->getName().getStringRef() == "print") {
       auto opName = op->getAttrOfType<StringAttr>("op");
index 38cb9927c99fee6c2498728b3f2b190bac074788..cbaca8201640eef8510b4ee27c0ee8f36eabca0b 100644 (file)
@@ -117,12 +117,14 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
       auto *lbDef = lb->getDefiningInst();
       (void)lbDef;
       assert((!lbDef || lbDef->isa<ConstantIndexOp>() ||
-              lbDef->isa<AffineApplyOp>() || lbDef->isa<AffineForOp>()) &&
+              lbDef->isa<AffineApplyOp>() || lbDef->isa<AffineForOp>() ||
+              lbDef->isa<DimOp>()) &&
              "lower bound expression does not have affine provenance");
       auto *ubDef = ub->getDefiningInst();
       (void)ubDef;
       assert((!ubDef || ubDef->isa<ConstantIndexOp>() ||
-              ubDef->isa<AffineApplyOp>() || ubDef->isa<AffineForOp>()) &&
+              ubDef->isa<AffineApplyOp>() || ubDef->isa<AffineForOp>() ||
+              ubDef->isa<DimOp>()) &&
              "upper bound expression does not have affine provenance");
 
       // Step must be a static constant.
index e09b8e571243df0f4ad0424bcf0be2f80a35f9fa..e43c35b5da1edd270660e0e51b6de27432ac224b 100644 (file)
@@ -760,7 +760,9 @@ edsc_stmt_t makeStmt(edsc_expr_t e) {
 }
 
 Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) {
-  Stmt res(Bindable(Expr(Type())), expr, {});
+  auto types = expr.getResultTypes();
+  assert(types.size() == 1 && "single result Expr expected in Stmt::operator=");
+  Stmt res(Bindable(Expr(types.front())), expr, {});
   std::swap(res.storage, this->storage);
   return *this;
 }
index 630e45c375c5c8a042549c0857ee498808a6c49c..eb6a661430c45a0f51316988d34c05c4675a31ad 100644 (file)
@@ -41,3 +41,27 @@ func @dynamic_for_func_args(%arg0 : index, %arg1 : index) {
 func @dynamic_for(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) {
   return
 }
+
+// These functions will be detected by the test pass that will insert an
+// EDSC-constructed 1-D pointwise-add loop with assignments to scalars before
+// the `return` instruction.
+//
+// CHECK-LABEL: func @assignments_1(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>) {
+// CHECK: for %[[iv:.*]] = 0 to 4 {
+// CHECK:   %[[a:.*]] = load %arg0[%[[iv]]] : memref<4xf32>
+// CHECK:   %[[b:.*]] = load %arg1[%[[iv]]] : memref<4xf32>
+// CHECK:   %[[tmp:.*]] = mulf %[[a]], %[[b]] : f32
+// CHECK:   store %[[tmp]], %arg2[%[[iv]]] : memref<4xf32>
+func @assignments_1(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>) {
+  return
+}
+
+// CHECK-LABEL: func @assignments_2(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
+// CHECK: for %[[iv:.*]] = {{.*}} to {{.*}} {
+// CHECK:   %[[a:.*]] = load %arg0[%[[iv]]] : memref<?xf32>
+// CHECK:   %[[b:.*]] = load %arg1[%[[iv]]] : memref<?xf32>
+// CHECK:   %[[tmp:.*]] = mulf %[[a]], %[[b]] : f32
+// CHECK:   store %[[tmp]], %arg2[%[[iv]]] : memref<?xf32>
+func @assignments_2(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
+  return
+}