[mlir][Linalg] Allow specifiying zero-rank shaped type operands to linalg.generic...
authorMaheshRavishankar <ravishankarm@google.com>
Tue, 18 Feb 2020 17:50:47 +0000 (09:50 -0800)
committerMaheshRavishankar <ravishankarm@google.com>
Tue, 18 Feb 2020 21:23:28 +0000 (13:23 -0800)
Fixing a bug where using a zero-rank shaped type operand to
linalg.generic ops hit an unrelated assert. This also meant that
lowering the operation to loops was not supported. Adding roundtrip
tests and lowering to loops test for zero-rank shaped type operand
with fixes to make the test pass.

Differential Revision: https://reviews.llvm.org/D74638

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir

index 9adc20c..a2fe018 100644 (file)
@@ -361,11 +361,10 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
       if (!cst || cst.getValue() != 0)
         return op.emitOpError("expected indexing_map #")
                << idx << " to be 0 to match 0-D view: " << view;
-    }
-
-    if (m.getNumResults() != view.getRank())
+    } else if (m.getNumResults() != view.getRank()) {
       return op.emitOpError("expected indexing_map #")
              << idx << " results to match view rank: " << view;
+    }
   }
 
   auto concatMap = concatAffineMaps(indexingMaps);
index a160ccd..4d3ee89 100644 (file)
@@ -238,9 +238,14 @@ public:
 
     // 1.a. Emit std_load from input views.
     for (unsigned i = 0; i < nInputs; ++i) {
-      ValueHandleArray indexing(makeCanonicalAffineApplies(
-          b, loc, genericOp.getInputIndexingMap(i), allIvs));
-      indexedValues[i] = std_load(genericOp.getInput(i), indexing);
+      Value input = genericOp.getInput(i);
+      if (!input.getType().cast<ShapedType>().getRank()) {
+        indexedValues[i] = std_load(input);
+      } else {
+        ValueHandleArray indexing(makeCanonicalAffineApplies(
+            b, loc, genericOp.getInputIndexingMap(i), allIvs));
+        indexedValues[i] = std_load(input, indexing);
+      }
     }
 
     // 1.b. Emit std_load from output views.
index 53e6949..bbc88e9 100644 (file)
@@ -351,12 +351,12 @@ AffineMap mlir::inversePermutation(AffineMap map) {
 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
   unsigned numResults = 0;
   for (auto m : maps)
-    numResults += m ? m.getNumResults() : 0;
+    numResults += (m && !m.isSingleConstant()) ? m.getNumResults() : 0;
   unsigned numDims = 0;
   SmallVector<AffineExpr, 8> results;
   results.reserve(numResults);
   for (auto m : maps) {
-    if (!m)
+    if (!m || m.isSingleConstant())
       continue;
     assert(m.getNumSymbols() == 0 && "expected map without symbols");
     results.append(m.getResults().begin(), m.getResults().end());
index 260f602..3d8d7f5 100644 (file)
@@ -356,3 +356,36 @@ func @indexed_generic_region(
 // CHECK:       %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
 // CHECK:       store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
 // CHECK:       store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
+
+// -----
+
+
+#broadcast_access = [
+  affine_map<(i, j) -> (0)>,
+  affine_map<(i,j) -> (i,j)>
+]
+
+#trait_broadcast = {
+  args_in = 1,
+  args_out = 1,
+  indexing_maps = #broadcast_access,
+  iterator_types = ["parallel", "parallel"],
+  library_call = "some_broadcast_external_fn"
+}
+
+func @generic_op_zero_rank(%arg0 : memref<f32>, %arg1:  memref<3x4xf32>)
+{
+  linalg.generic #trait_broadcast %arg0, %arg1 {
+    ^bb(%a: f32, %b : f32) :
+      linalg.yield %a : f32
+  } : memref<f32>, memref<3x4xf32>
+  return
+}
+
+// CHECK-LABEL: @generic_op_zero_rank
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xf32>
+// CHECK: loop.for %[[i:.*]] = {{.*}}
+// CHECK:   loop.for %[[j:.*]] = {{.*}}
+// CHECK:     %[[a:.*]] = load %[[ARG0]][]
+// CHECK:     store %[[a]], %[[ARG1]][%[[i]], %[[j]]]
index 4b81dce..ec28510 100644 (file)
@@ -345,6 +345,30 @@ func @indexed_generic_with_tensor_input_and_output(
 
 // -----
 
+#broadcast_access = [
+  affine_map<(i, j) -> (0)>,
+  affine_map<(i,j) -> (i,j)>
+]
+
+#trait_broadcast = {
+  args_in = 1,
+  args_out = 1,
+  indexing_maps = #broadcast_access,
+  iterator_types = ["parallel", "parallel"],
+  library_call = "some_broadcast_external_fn"
+}
+
+func @generic_op_zero_rank(%arg0 : tensor<f32>) ->  (tensor<3x4xf32>)
+{
+  %0 = linalg.generic #trait_broadcast %arg0 {
+    ^bb(%a: f32) :
+      linalg.yield %a : f32
+  } : tensor<f32> -> tensor<3x4xf32>
+  return %0 : tensor<3x4xf32>
+}
+
+// -----
+
 // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>