getPoolingInput<IndexedValueType>(op, indices.inputs);
}
-/// Emits the MLIR for the scalar part of the indexed generic op by:
-/// 1. Emitting load ops for each input and output view in order. This is
-/// achieved by applying the appropriate input or output map to the
-/// enclosing induction variables.
-/// 2. Emitting a call to `op.fun()` that takes as arguments the induction
-/// variables and the scalars from point 1. above.
-/// 3. Emitting store ops to store the results of 2. to the output views.
-///
-/// An example output may resemble:
-///
-/// ```
-/// scf.for %i = %c0 to %0 step %c1 {
-/// scf.for %j = %c0 to %1 step %c1 {
-/// scf.for %k = %c0 to %4 step %c1 {
-/// %11 = load %arg0[%i, %j] :
-/// memref<?x?xf32, stride_specification>
-/// %12 = load %arg1[%i, %j, %k] :
-/// memref<?x?x?xf32, stride_specification>
-/// %13 = load %arg2[%i, %k, %j] :
-/// memref<?x?x?xf32, stride_specification>
-/// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
-/// (index, index, index, f32, f32, f32) -> (f32, f32)
-/// store %14#0, %arg1[%i, %j, %k] :
-/// memref<?x?x?Xf32, stride_specification>
-/// store %14#1, %arg2[%i, %k, %j] :
-/// memref<?x?x?Xf32, stride_specification>
-/// }
-/// }
-/// }
-/// ```
-template <typename IndexedValueType>
-static void emitScalarImplementation(ArrayRef<Value> allIvs,
- IndexedGenericOp indexedGenericOp) {
- assert(indexedGenericOp.hasBufferSemantics() &&
- "expected linalg op with buffer semantics");
- auto &b = ScopedContext::getBuilderRef();
- auto loc = ScopedContext::getLocation();
- unsigned nInputs = indexedGenericOp.getNumInputs();
- unsigned nOutputs = indexedGenericOp.getNumOutputs();
- unsigned nLoops = allIvs.size();
- SmallVector<Value, 4> indexedValues;
- indexedValues.reserve(nLoops + nInputs + nOutputs);
- for (unsigned i = 0; i < nLoops; ++i)
- indexedValues.push_back(allIvs[i]);
-
- // TODO: Avoid the loads if the corresponding argument of the
- // region has no uses.
- // 1.a. Emit load from input views.
- for (unsigned i = 0; i < nInputs; ++i) {
- auto indexing = makeCanonicalAffineApplies(
- b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs);
- // Pass input i through IndexedValueType emits the proper load operation.
- indexedValues.push_back(
- IndexedValueType(indexedGenericOp.getInput(i))(indexing));
- }
- // 1.b. Emit load from output views.
- for (unsigned i = 0; i < nOutputs; ++i) {
- auto indexing = makeCanonicalAffineApplies(
- b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs);
- // Pass output i through IndexedValueType emits the proper load operation.
- indexedValues.push_back(
- IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing));
- }
-
- // TODO: When a region inliner exists, use it.
- // 2. Inline region, currently only works for a single basic block.
- // 3. Emit store.
- SmallVector<SmallVector<Value, 8>, 8> indexing;
- SmallVector<Value, 8> outputBuffers;
- for (unsigned i = 0; i < nOutputs; ++i) {
- indexing.push_back(makeCanonicalAffineApplies(
- b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
- outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i));
- }
- inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues,
- indexing, outputBuffers);
-}
-
template <typename LoopTy>
static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
OpBuilder &builder) {
assert(iterArgs.empty() && "unexpected iterArgs");
allIvs.append(ivs.begin(), ivs.end());
llvm::TypeSwitch<Operation *>(linalgOp)
- .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp,
- IndexedGenericOp, LinalgOp>([&](auto op) {
- emitScalarImplementation<IndexedValueTy>(allIvs, op);
- })
+ .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>(
+ [&](auto op) {
+ emitScalarImplementation<IndexedValueTy>(allIvs, op);
+ })
.Default([&](Operation *op) { assert(false && "unexpected op"); });
return scf::ValueVector{};
});
Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter,
LinalgOp linalgOp) {
+ // Convert indexed_generic ops to generic ops before lowering them to loops.
+ if (isa<IndexedGenericOp>(linalgOp))
+ return llvm::None;
+
Optional<LinalgLoops> loopOps =
linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation(), rewriter);
if (loopOps.hasValue())
// CHECKPARALLEL: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
// CHECKPARALLEL: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
-func @indexed_generic_region(
- %arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
- %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
- %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
- linalg.indexed_generic #trait4
- ins(%arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>)
- outs(%arg1, %arg2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
- memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
- ^bb0(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
- %result_1 = mulf %a, %b : f32
-
- %ij = addi %i, %j : index
- %ijk = addi %ij, %k : index
- %ijk_int = index_cast %ijk : index to i32
- %ijk_float = sitofp %ijk_int : i32 to f32
-
- %result_2 = addf %c, %ijk_float : f32
- linalg.yield %result_1, %result_2 : f32, f32
- }
- return
-}
-
-// CHECKLOOP-LABEL: @indexed_generic_region
-// CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
-// CHECKLOOP: scf.for %[[j:.*]] = {{.*}}
-// CHECKLOOP: scf.for %[[k:.*]] = {{.*}}
-// CHECKLOOP: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]]
-// CHECKLOOP: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]]
-// CHECKLOOP: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]]
-// CHECKLOOP: %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32
-// CHECKLOOP: %[[ij:.*]] = addi %[[i]], %[[j]] : index
-// CHECKLOOP: %[[ijk:.*]] = addi %[[ij]], %[[k]] : index
-// CHECKLOOP: %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32
-// CHECKLOOP: %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32
-// CHECKLOOP: %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
-// CHECKLOOP: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
-// CHECKLOOP: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
-
-// CHECKPARALLEL-LABEL: @indexed_generic_region
-// CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]])
-// CHECKPARALLEL: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]]
-// CHECKPARALLEL: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]]
-// CHECKPARALLEL: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]]
-// CHECKPARALLEL: %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32
-// CHECKPARALLEL: %[[ij:.*]] = addi %[[i]], %[[j]] : index
-// CHECKPARALLEL: %[[ijk:.*]] = addi %[[ij]], %[[k]] : index
-// CHECKPARALLEL: %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32
-// CHECKPARALLEL: %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32
-// CHECKPARALLEL: %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
-// CHECKPARALLEL: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
-// CHECKPARALLEL: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
-
// -----
#broadcast_access = [
// CHECKPARALLEL: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
// CHECKPARALLEL: store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
-func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
-{
- linalg.indexed_generic #trait_broadcast
- ins(%arg0 : memref<i32>)
- outs(%arg1 : memref<3x4xi32>) {
- ^bb(%i: index, %j: index, %a: i32, %b: i32) :
- %ij = addi %i, %j : index
- %ij_int = index_cast %ij : index to i32
- %result = addi %a, %ij_int : i32
- linalg.yield %result : i32
- }
- return
-}
-
-// CHECKLOOP-LABEL: @indexed_generic_op_zero_rank
-// CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
-// CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
-// CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
-// CHECKLOOP: scf.for %[[j:.*]] = {{.*}}
-// CHECKLOOP: %[[a:.*]] = memref.load %[[ARG0]][
-// CHECKLOOP: %[[ij:.*]] = addi %[[i]], %[[j]] : index
-// CHECKLOOP: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
-// CHECKLOOP: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
-// CHECKLOOP: store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
-
-// CHECKPARALLEL-LABEL: @indexed_generic_op_zero_rank
-// CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
-// CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
-// CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]])
-// CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][
-// CHECKPARALLEL: %[[ij:.*]] = addi %[[i]], %[[j]] : index
-// CHECKPARALLEL: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
-// CHECKPARALLEL: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
-// CHECKPARALLEL: store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
-
#reduce_1D_access = [
affine_map<(i) -> (i)>,
affine_map<(i) -> ()>
// CHECKPARALLEL: %[[e:.*]] = addf %[[a]], %[[d]]
// CHECKPARALLEL: store %[[e]], %[[ARG2]][]
-func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
- %arg1: memref<f32>,
- %arg2: memref<f32>)
-{
- linalg.indexed_generic #trait_reduce_init_1D
- ins(%arg0, %arg1 : memref<?xf32>, memref<f32>)
- outs(%arg2 : memref<f32>) {
- ^bb(%i : index, %a: f32, %b: f32, %c: f32) :
- %0 = constant 0 : index
- %1 = cmpi eq, %0, %i : index
- %2 = select %1, %b, %c : f32
- %3 = addf %a, %2 : f32
- linalg.yield %3 : f32
- }
- return
-}
-// CHECKLOOP-LABEL: @indexed_generic_op_1D_reduce
-// CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
-// CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
-// CHECKLOOP-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
-// CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
-// CHECKLOOP: %[[a:.*]] = memref.load %[[ARG0]][%[[i]]]
-// CHECKLOOP: %[[b:.*]] = memref.load %[[ARG1]][]
-// CHECKLOOP: %[[c:.*]] = memref.load %[[ARG2]][]
-// CHECKLOOP: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
-// CHECKLOOP: %[[e:.*]] = addf %[[a]], %[[d]]
-// CHECKLOOP: store %[[e]], %[[ARG2]][]
-
-// CHECKPARALLEL-LABEL: @indexed_generic_op_1D_reduce
-// CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
-// CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
-// CHECKPARALLEL-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
-// CHECKPARALLEL: scf.for %[[i:.*]] = {{.*}}
-// CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][%[[i]]]
-// CHECKPARALLEL: %[[b:.*]] = memref.load %[[ARG1]][]
-// CHECKPARALLEL: %[[c:.*]] = memref.load %[[ARG2]][]
-// CHECKPARALLEL: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
-// CHECKPARALLEL: %[[e:.*]] = addf %[[a]], %[[d]]
-// CHECKPARALLEL: store %[[e]], %[[ARG2]][]
-
#trait_const_fill = {
args_in = 0,
args_out = 1,