[mlir][sparse] adjust output shape inference to new tensor abstraction
authorAart Bik <ajcbik@google.com>
Tue, 5 Jan 2021 21:29:28 +0000 (13:29 -0800)
committerAart Bik <ajcbik@google.com>
Tue, 5 Jan 2021 23:31:39 +0000 (15:31 -0800)
Nicolas changed the tensor abstraction so that every output has
its own shape definition. This simplifies the "inference" that
was used in the sparse compiler.

Reviewed By: penpornk

Differential Revision: https://reviews.llvm.org/D94119

mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/test/Dialect/Linalg/sparse_2d.mlir

index a6b7277..ed81d5e 100644 (file)
@@ -538,15 +538,8 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
       // Find lower and upper bound in current dimension.
       Value up;
       if (shape[d] == TensorType::kDynamicSize) {
-        // For the output tensor, we may need to infer the upper bound.
-        // For all others, we look at the incoming argument.
-        if (t == numInputs && !op.getNumInitTensors()) {
-          up = codegen.sizes[i];
-          assert(up); // TODO: what else?
-        } else {
-          Value arg = t < numInputs ? op.getInput(t) : op.getInitTensors()[0];
-          up = rewriter.create<DimOp>(loc, arg, d);
-        }
+        Value arg = t < numInputs ? op.getInput(t) : op.getOutput(0);
+        up = rewriter.create<DimOp>(loc, arg, d);
         args.push_back(up);
       } else {
         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
index 6612a72..9bb68ca 100644 (file)
@@ -1139,19 +1139,19 @@ func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor<f32>) -> tensor<f32>
 // CHECK:           %[[VAL_2:.*]] = constant 999 : index
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64>
+// CHECK:           %[[VAL_5:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
 // CHECK:           %[[VAL_6:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64>
-// CHECK:           %[[VAL_9:.*]] = alloca(%[[VAL_2]]) : memref<?xf64>
-// CHECK:           %[[VAL_10:.*]] = alloca(%[[VAL_5]], %[[VAL_8]]) : memref<?x?xf64>
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] {
-// CHECK:             %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref<?xf64>
+// CHECK:           %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64>
+// CHECK:           %[[VAL_9:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64>
+// CHECK:           %[[VAL_10:.*]] = alloca(%[[VAL_8]], %[[VAL_9]]) : memref<?x?xf64>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_12:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<?xindex>
 // CHECK:             %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK:             %[[VAL_14:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_13]]] : memref<?xindex>
 // CHECK:             scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_4]] {
-// CHECK:               %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
-// CHECK:               %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xf64>
+// CHECK:               %[[VAL_16:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK:               %[[VAL_17:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xf64>
 // CHECK:               %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_1]] : f64
 // CHECK:               store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref<?x?xf64>
 // CHECK:             }