return preservedDims;
}
+namespace mlir::linalg::detail {
enum class MatchConvolutionResult {
Success = 0,
NotLinalgOp,
OutputDimsNotParallel,
NonOutputDimNotReduction
};
+} // namespace mlir::linalg::detail
-static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
+mlir::linalg::detail::MatchConvolutionResult
+mlir::linalg::detail::isConvolutionInterfaceImpl(
+ Operation *op, ConvolutionDimensions *dimensions) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return MatchConvolutionResult::NotLinalgOp;
llvm::SmallDenseSet<unsigned> outputDims =
getPreservedDims(indexingMaps.back());
llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
- // Make sure all loops are charecterized as one of:
+ // Make sure all loops are characterized as one of:
// - Batch loop : present in output, as non-convolved in input, not present in
// filter.
// - Output image dimension : present in output, convolved dims in input, not
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
+ if (dimensions)
+ dimensions->batch.push_back(outputDim);
continue;
}
if (inputExprWalker.convolvedDims.count(outputDim) &&
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
+ if (dimensions)
+ dimensions->outputImage.push_back(outputDim);
continue;
}
if (!inputExprWalker.convolvedDims.count(outputDim) &&
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
+ if (dimensions)
+ dimensions->outputChannel.push_back(outputDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
+ if (dimensions)
+ dimensions->depth.push_back(outputDim);
continue;
}
return MatchConvolutionResult::NonConvolutionLoop;
if (outputDims.count(filterDim) &&
!inputExprWalker.unConvolvedDims.count(filterDim) &&
!inputExprWalker.convolvedDims.count(filterDim)) {
- // Output channel dimension. THis is already seen, continue;
+ // Output channel dimension. This is already seen, continue;
+ assert((!dimensions ||
+ llvm::is_contained(dimensions->outputChannel, filterDim)) &&
+ "expected output channel to have been found from output dims");
continue;
}
if (inputExprWalker.convolvedDims.count(filterDim) &&
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
allLoopDims.insert(filterDim);
+ if (dimensions)
+ dimensions->filterLoop.push_back(filterDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
allLoopDims.insert(filterDim);
+ if (dimensions)
+ dimensions->inputChannel.push_back(filterDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
outputDims.count(filterDim)) {
// Depthwise loop. Already seen.
+ assert(
+ (!dimensions || llvm::is_contained(dimensions->depth, filterDim)) &&
+ "expected depthwise dimension to have been found from output dims");
continue;
}
return MatchConvolutionResult::NonConvolutionLoop;
if (allLoopDims.size() != linalgOp.getNumLoops())
return MatchConvolutionResult::NonConvolutionLoop;
+ if (dimensions) {
+ assert(dimensions->batch.size() + dimensions->outputImage.size() +
+ dimensions->outputChannel.size() +
+ dimensions->filterLoop.size() +
+ dimensions->inputChannel.size() + dimensions->depth.size() ==
+ linalgOp.getNumLoops() &&
+ "expected all loops to be classified");
+ }
+
return MatchConvolutionResult::Success;
}
-LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
- auto res = isConvolutionInterfaceImpl(op);
- if (res == MatchConvolutionResult::NotLinalgOp)
- return op->emitError("expected a LinalgOp");
- if (res == MatchConvolutionResult::WrongNumOperands)
- return op->emitError("expected op with 2 inputs and 1 output");
- if (res == MatchConvolutionResult::WrongInputIndexingMap)
- return op->emitError("unexpected input index map for convolutions");
- if (res == MatchConvolutionResult::NotProjectedPermutations) {
- return op->emitError(
- "expected output/filter indexing maps to be projected permutations");
- }
- if (res == MatchConvolutionResult::NonConvolutionLoop) {
- return op->emitError("unexpected loop dimension for convolution op");
- }
- if (res == MatchConvolutionResult::OutputDimsNotParallel) {
- return op->emitError(
- "expected all iterators used to access outputs to be parallel");
- }
- if (res == MatchConvolutionResult::NonOutputDimNotReduction) {
- return op->emitError(
- "expected all iterators not used to access outputs to be reduction");
+StringRef
+mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
+ switch (res) {
+ case MatchConvolutionResult::NotLinalgOp:
+ return "expected a LinalgOp";
+ case MatchConvolutionResult::WrongNumOperands:
+ return "expected op with 2 inputs and 1 output";
+ case MatchConvolutionResult::WrongInputIndexingMap:
+ return "unexpected input index map for convolutions";
+ case MatchConvolutionResult::NotProjectedPermutations:
+ return "expected output/filter indexing maps to be projected permutations";
+ case MatchConvolutionResult::NonConvolutionLoop:
+ return "unexpected loop dimension for convolution op";
+ case MatchConvolutionResult::OutputDimsNotParallel:
+ return "expected all iterators used to access outputs to be parallel";
+ case MatchConvolutionResult::NonOutputDimNotReduction:
+ return "expected all iterators not used to access outputs to be reduction";
+ case MatchConvolutionResult::Success:
+ return "";
}
+ llvm_unreachable("unhandled MatchConvolutionResult case");
+}
+
+LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
+ MatchConvolutionResult res = isConvolutionInterfaceImpl(op);
+ if (res != MatchConvolutionResult::Success)
+ return op->emitError(getMatchConvolutionMessage(res));
return success();
}