[mlir][linalg] expose convolution dimension classifier
authorAlex Zinenko <zinenko@google.com>
Tue, 7 Feb 2023 17:29:21 +0000 (17:29 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 14 Feb 2023 10:12:01 +0000 (10:12 +0000)
Make available through functions in the `linalg::detail` namespace the
classification of Linalg op dimensions as different kinds (batch, image,
channel, etc) of convolution dimensions. This is useful for identifying
which dimensions to target with transformations.

Reviewed By: chelini

Differential Revision: https://reviews.llvm.org/D143584

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

index d8d59b5..cb93e8a 100644 (file)
@@ -42,6 +42,31 @@ bool isaContractionOpInterface(LinalgOp linalgOp);
 
 namespace detail {
 
+/// Result of matching a Linalg generic against the predicates of it being a
+/// convolution.
+enum class MatchConvolutionResult;
+
+/// Positions of a Linalg op loops that correspond to different kinds of a
+/// convolution dimension.
+struct ConvolutionDimensions {
+  SmallVector<unsigned, 2> batch;
+  SmallVector<unsigned, 2> outputImage;
+  SmallVector<unsigned, 2> outputChannel;
+  SmallVector<unsigned, 2> filterLoop;
+  SmallVector<unsigned, 2> inputChannel;
+  SmallVector<unsigned, 2> depth;
+};
+
+/// Checks whether `op` conforms to ConvolutionOpInterface and populates
+/// `dimensions` with indexes of the different kinds of dimensions when present.
+MatchConvolutionResult
+isConvolutionInterfaceImpl(Operation *op,
+                           ConvolutionDimensions *dimensions = nullptr);
+
+/// Returns the error message corresponding to the convolution checking return
+/// code.
+StringRef getMatchConvolutionMessage(MatchConvolutionResult res);
+
 /// Verify that `op` conforms to ContractionOpInterface.
 LogicalResult verifyContractionInterface(Operation *op);
 
index e5e0bdd..a5c6dc6 100644 (file)
@@ -268,6 +268,7 @@ static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) {
   return preservedDims;
 }
 
+namespace mlir::linalg::detail {
 enum class MatchConvolutionResult {
   Success = 0,
   NotLinalgOp,
@@ -278,8 +279,11 @@ enum class MatchConvolutionResult {
   OutputDimsNotParallel,
   NonOutputDimNotReduction
 };
+} // namespace mlir::linalg::detail
 
-static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
+mlir::linalg::detail::MatchConvolutionResult
+mlir::linalg::detail::isConvolutionInterfaceImpl(
+    Operation *op, ConvolutionDimensions *dimensions) {
   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
   if (!linalgOp)
     return MatchConvolutionResult::NotLinalgOp;
@@ -307,7 +311,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
   llvm::SmallDenseSet<unsigned> outputDims =
       getPreservedDims(indexingMaps.back());
   llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
-  // Make sure all loops are charecterized as one of:
+  // Make sure all loops are characterized as one of:
   // - Batch loop : present in output, as non-convolved in input, not present in
   //   filter.
   // - Output image dimension : present in output, convolved dims in input, not
@@ -329,6 +333,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
+      if (dimensions)
+        dimensions->batch.push_back(outputDim);
       continue;
     }
     if (inputExprWalker.convolvedDims.count(outputDim) &&
@@ -337,6 +343,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
+      if (dimensions)
+        dimensions->outputImage.push_back(outputDim);
       continue;
     }
     if (!inputExprWalker.convolvedDims.count(outputDim) &&
@@ -346,6 +354,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
+      if (dimensions)
+        dimensions->outputChannel.push_back(outputDim);
       continue;
     }
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -354,6 +364,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
+      if (dimensions)
+        dimensions->depth.push_back(outputDim);
       continue;
     }
     return MatchConvolutionResult::NonConvolutionLoop;
@@ -363,7 +375,10 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (outputDims.count(filterDim) &&
         !inputExprWalker.unConvolvedDims.count(filterDim) &&
         !inputExprWalker.convolvedDims.count(filterDim)) {
-      // Output channel dimension. THis is already seen, continue;
+      // Output channel dimension. This is already seen, continue;
+      assert((!dimensions ||
+              llvm::is_contained(dimensions->outputChannel, filterDim)) &&
+             "expected output channel to have been found from output dims");
       continue;
     }
     if (inputExprWalker.convolvedDims.count(filterDim) &&
@@ -374,6 +389,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
       allLoopDims.insert(filterDim);
+      if (dimensions)
+        dimensions->filterLoop.push_back(filterDim);
       continue;
     }
     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
@@ -384,11 +401,16 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
       allLoopDims.insert(filterDim);
+      if (dimensions)
+        dimensions->inputChannel.push_back(filterDim);
       continue;
     }
     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
         outputDims.count(filterDim)) {
       // Depthwise loop. Already seen.
+      assert(
+          (!dimensions || llvm::is_contained(dimensions->depth, filterDim)) &&
+          "expected depthwise dimension to have been found from output dims");
       continue;
     }
     return MatchConvolutionResult::NonConvolutionLoop;
@@ -397,32 +419,45 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
   if (allLoopDims.size() != linalgOp.getNumLoops())
     return MatchConvolutionResult::NonConvolutionLoop;
 
+  if (dimensions) {
+    assert(dimensions->batch.size() + dimensions->outputImage.size() +
+                   dimensions->outputChannel.size() +
+                   dimensions->filterLoop.size() +
+                   dimensions->inputChannel.size() + dimensions->depth.size() ==
+               linalgOp.getNumLoops() &&
+           "expected all loops to be classified");
+  }
+
   return MatchConvolutionResult::Success;
 }
 
-LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
-  auto res = isConvolutionInterfaceImpl(op);
-  if (res == MatchConvolutionResult::NotLinalgOp)
-    return op->emitError("expected a LinalgOp");
-  if (res == MatchConvolutionResult::WrongNumOperands)
-    return op->emitError("expected op with 2 inputs and 1 output");
-  if (res == MatchConvolutionResult::WrongInputIndexingMap)
-    return op->emitError("unexpected input index map for convolutions");
-  if (res == MatchConvolutionResult::NotProjectedPermutations) {
-    return op->emitError(
-        "expected output/filter indexing maps to be projected permutations");
-  }
-  if (res == MatchConvolutionResult::NonConvolutionLoop) {
-    return op->emitError("unexpected loop dimension for convolution op");
-  }
-  if (res == MatchConvolutionResult::OutputDimsNotParallel) {
-    return op->emitError(
-        "expected all iterators used to access outputs to be parallel");
-  }
-  if (res == MatchConvolutionResult::NonOutputDimNotReduction) {
-    return op->emitError(
-        "expected all iterators not used to access outputs to be reduction");
+StringRef
+mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
+  switch (res) {
+  case MatchConvolutionResult::NotLinalgOp:
+    return "expected a LinalgOp";
+  case MatchConvolutionResult::WrongNumOperands:
+    return "expected op with 2 inputs and 1 output";
+  case MatchConvolutionResult::WrongInputIndexingMap:
+    return "unexpected input index map for convolutions";
+  case MatchConvolutionResult::NotProjectedPermutations:
+    return "expected output/filter indexing maps to be projected permutations";
+  case MatchConvolutionResult::NonConvolutionLoop:
+    return "unexpected loop dimension for convolution op";
+  case MatchConvolutionResult::OutputDimsNotParallel:
+    return "expected all iterators used to access outputs to be parallel";
+  case MatchConvolutionResult::NonOutputDimNotReduction:
+    return "expected all iterators not used to access outputs to be reduction";
+  case MatchConvolutionResult::Success:
+    return "";
   }
+  llvm_unreachable("unhandled MatchConvolutionResult case");
+}
+
+LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
+  MatchConvolutionResult res = isConvolutionInterfaceImpl(op);
+  if (res != MatchConvolutionResult::Success)
+    return op->emitError(getMatchConvolutionMessage(res));
   return success();
 }