namespace {
+// Iteration graph sorting.
+enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 };
+
// Code generation.
struct CodeGen {
CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
/// order yields innermost unit-stride access with better spatial locality.
static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
std::vector<unsigned> &topSort,
- bool sparseOnly) {
+ unsigned mask) {
// Set up an n x n from/to adjacency matrix of the iteration graph
// for the implicit loop indices i_0 .. i_n-1.
unsigned n = op.getNumLoops();
auto map = op.getTiedIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
assert(map.getNumDims() == n);
- // Skip dense tensor constraints when sparse only is requested.
- if (sparseOnly && !enc)
+ // Skip dense tensor constraints when not requested.
+ if (!(mask & SortMask::kIncludeDense) && !enc)
continue;
// Each tensor expression and optional dimension ordering (row-major
// by default) puts an ordering constraint on the loop indices. For
unsigned t = map.getDimPosition(perm(enc, d));
adjM[f][t] = true;
}
+ // Push unrelated loops into sparse iteration space, so these
+ // will be skipped more often.
+ if (mask & SortMask::kIncludeUndef) {
+ unsigned tensor = t->getOperandNumber();
+ for (unsigned i = 0; i < n; i++)
+ if (merger.isDim(tensor, i, Dim::kSparse))
+ for (unsigned j = 0; j < n; j++)
+ if (merger.isDim(tensor, j, Dim::kUndef))
+ adjM[i][j] = true;
+ }
}
// Topologically sort the iteration graph to determine loop order.
// This assumes that higher-level passes have already put the
// tensors in each tensor expression in a feasible order.
std::vector<unsigned> topSort;
- if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
- !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
+ if (!computeIterationGraph(merger, op, topSort,
+ SortMask::kIncludeUndef |
+ SortMask::kIncludeDense) &&
+ !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) &&
+ !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) &&
+ !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly))
return failure();
// Builds the tensor expression for the Linalg operation in SSA form.
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] {
// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] {
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_22]]] : memref<?x?xf32>
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
-// CHECK: %[[VAL_25:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index
-// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_25]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] {
-// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref<?x?xf32>
-// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_27]]] : memref<?xf32>
-// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_28]]] : memref<?x?xf32>
-// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32
-// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_30]], %[[VAL_32]] : f32
-// CHECK: %[[VAL_34:.*]] = addf %[[VAL_29]], %[[VAL_33]] : f32
-// CHECK: memref.store %[[VAL_34]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref<?x?xf32>
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: %[[VAL_23:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] {
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
+// CHECK: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (f32) {
+// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref<?x?xf32>
+// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_30]], %[[VAL_26]]] : memref<?x?xf32>
+// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_32]], %[[VAL_33]] : f32
+// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_27]], %[[VAL_34]] : f32
+// CHECK: %[[VAL_36:.*]] = addf %[[VAL_31]], %[[VAL_35]] : f32
+// CHECK: scf.yield %[[VAL_36]] : f32
// CHECK: }
+// CHECK: memref.store %[[VAL_37:.*]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
// CHECK: }
// CHECK: }
-// CHECK: %[[VAL_35:.*]] = memref.tensor_load %[[VAL_17]] : memref<?x?xf32>
-// CHECK: return %[[VAL_35]] : tensor<?x?xf32>
+// CHECK: %[[VAL_38:.*]] = memref.tensor_load %[[VAL_17]] : memref<?x?xf32>
+// CHECK: return %[[VAL_38]] : tensor<?x?xf32>
// CHECK: }
func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
%arga: tensor<?x?xf32>,