[mlir][Memref] NFC - Addresult pretty printing to MemrefOps
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 30 Sep 2022 11:07:43 +0000 (04:07 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 4 Oct 2022 07:05:16 +0000 (00:05 -0700)
Differential Revision: https://reviews.llvm.org/D134968

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

index 494d9ff..1381b34 100644 (file)
@@ -17,6 +17,7 @@ include "mlir/Interfaces/CopyOpInterface.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 /// A TypeAttr for memref types.
@@ -135,7 +136,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
 // AllocOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, []> {
+def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
   let summary = "memory allocation operation";
   let description = [{
     The `alloc` operation allocates a region of memory, as specified by its
@@ -276,7 +278,8 @@ def MemRef_ReallocOp : MemRef_Op<"realloc"> {
 // AllocaOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
+def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource,[
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
   let summary = "stack memory allocation operation";
   let description = [{
     The `alloca` operation allocates memory on the stack, to be automatically
@@ -398,10 +401,12 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
 //===----------------------------------------------------------------------===//
 
 def MemRef_CastOp : MemRef_Op<"cast", [
-      NoSideEffect, SameOperandsAndResultShape,
       DeclareOpInterfaceMethods<CastOpInterface>,
-      ViewLikeOpInterface,
-      MemRefsNormalizable
+      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      MemRefsNormalizable,
+      NoSideEffect,
+      SameOperandsAndResultShape,
+      ViewLikeOpInterface
     ]> {
   let summary = "memref cast operation";
   let description = [{
@@ -477,8 +482,8 @@ def MemRef_CastOp : MemRef_Op<"cast", [
 // CopyOp
 //===----------------------------------------------------------------------===//
 
-def CopyOp : MemRef_Op<"copy",
-    [CopyOpInterface, SameOperandsElementType, SameOperandsShape]> {
+def CopyOp : MemRef_Op<"copy", [CopyOpInterface, SameOperandsElementType, 
+    SameOperandsShape]> {
 
   let description = [{
     Copies the data from the source to the destination memref.
@@ -536,8 +541,11 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
 // DimOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable,
-                                     ShapedDimOpInterface]> {
+def MemRef_DimOp : MemRef_Op<"dim", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    MemRefsNormalizable,
+    NoSideEffect,
+    ShapedDimOpInterface]> {
   let summary = "dimension index operation";
   let description = [{
     The `dim` operation takes a memref and a dimension operand of type `index`.
@@ -798,8 +806,11 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
 // ExtractAlignedPointerAsIndexOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_ExtractAlignedPointerAsIndexOp : MemRef_Op<"extract_aligned_pointer_as_index",
-    [NoSideEffect, SameVariadicResultSize]> {
+def MemRef_ExtractAlignedPointerAsIndexOp : 
+  MemRef_Op<"extract_aligned_pointer_as_index", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    NoSideEffect, 
+    SameVariadicResultSize]> {
   let summary = "Extracts a memref's underlying aligned pointer as an index";
   let description = [{
     Extracts the underlying aligned pointer as an index.
@@ -836,8 +847,10 @@ def MemRef_ExtractAlignedPointerAsIndexOp : MemRef_Op<"extract_aligned_pointer_a
 // ExtractStridedMetadataOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata",
-    [NoSideEffect, SameVariadicResultSize]> {
+def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    NoSideEffect, 
+    SameVariadicResultSize]> {
   let summary = "Extracts a buffer base with offset and strides";
   let description = [{
     Extracts a base buffer, offset and strides. This op allows additional layers
@@ -1193,8 +1206,12 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
 
 def MemRef_ReinterpretCastOp
   : MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
-      NoSideEffect, AttrSizedOperandSegments, ViewLikeOpInterface,
-      OffsetSizeAndStrideOpInterface, MemRefsNormalizable
+      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      AttrSizedOperandSegments, 
+      MemRefsNormalizable,
+      NoSideEffect, 
+      OffsetSizeAndStrideOpInterface, 
+      ViewLikeOpInterface
     ]> {
   let summary = "memref reinterpret cast operation";
   let description = [{
@@ -1313,7 +1330,9 @@ def MemRef_RankOp : MemRef_Op<"rank", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 
 def MemRef_ReshapeOp: MemRef_Op<"reshape", [
-    ViewLikeOpInterface, NoSideEffect]>  {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    NoSideEffect,
+    ViewLikeOpInterface]>  {
   let summary = "memref reshape operation";
   let description = [{
     The `reshape` operation converts a memref from one type to an
@@ -1412,7 +1431,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
   let hasVerifier = 1;
 }
 
-def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
+def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
     The `memref.expand_shape` op produces a new view with a higher rank whose
@@ -1491,7 +1511,8 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
   let hasVerifier = 1;
 }
 
-def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
+def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
   let summary = "operation to produce a memref with a smaller rank.";
   let description = [{
     The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1656,8 +1677,11 @@ def MemRef_StoreOp : MemRef_Op<"store",
 //===----------------------------------------------------------------------===//
 
 def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
-    DeclareOpInterfaceMethods<ViewLikeOpInterface>, NoSideEffect,
-    AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<ViewLikeOpInterface>,
+    AttrSizedOperandSegments,
+    OffsetSizeAndStrideOpInterface,
+    NoSideEffect
   ]> {
   let summary = "memref subview operation";
   let description = [{
@@ -1940,7 +1964,9 @@ def TensorStoreOp : MemRef_Op<"tensor_store",
 // TransposeOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_TransposeOp : MemRef_Op<"transpose", [NoSideEffect]>,
+def MemRef_TransposeOp : MemRef_Op<"transpose", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    NoSideEffect]>,
     Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
     Results<(outs AnyStridedMemRef)> {
   let summary = "`transpose` produces a new strided memref (metadata-only)";
@@ -1975,7 +2001,9 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [NoSideEffect]>,
 //===----------------------------------------------------------------------===//
 
 def MemRef_ViewOp : MemRef_Op<"view", [
-    DeclareOpInterfaceMethods<ViewLikeOpInterface>, NoSideEffect]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<ViewLikeOpInterface>, 
+    NoSideEffect]> {
   let summary = "memref view operation";
   let description = [{
     The "view" operation extracts an N-D contiguous memref with empty layout map
index 99d115a..26814ab 100644 (file)
@@ -113,6 +113,16 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
 // AllocOp / AllocaOp
 //===----------------------------------------------------------------------===//
 
+void AllocOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "alloc");
+}
+
+void AllocaOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "alloca");
+}
+
 template <typename AllocLikeOp>
 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
@@ -522,6 +532,10 @@ LogicalResult AssumeAlignmentOp::verify() {
 // CastOp
 //===----------------------------------------------------------------------===//
 
+void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "cast");
+}
+
 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
 /// source memref. This is useful to to fold a memref.cast into a consuming op
 /// and implement canonicalization patterns for ops in different dialects that
@@ -782,6 +796,10 @@ LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
 // DimOp
 //===----------------------------------------------------------------------===//
 
+void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "dim");
+}
+
 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
                   int64_t index) {
   auto loc = result.location;
@@ -1210,6 +1228,31 @@ LogicalResult DmaWaitOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
+// ExtractAlignedPointerAsIndexOp
+//===----------------------------------------------------------------------===//
+
+void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "intptr");
+}
+
+//===----------------------------------------------------------------------===//
+// ExtractStridedMetadataOp
+//===----------------------------------------------------------------------===//
+
+void ExtractStridedMetadataOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getBaseBuffer(), "base_buffer");
+  setNameFn(getOffset(), "offset");
+  // For multi-result to work properly with pretty names and packed syntax `x:3`
+  // we can only give a pretty name to the first value in the pack.
+  if (!getSizes().empty()) {
+    setNameFn(getSizes().front(), "sizes");
+    setNameFn(getStrides().front(), "strides");
+  }
+}
+
+//===----------------------------------------------------------------------===//
 // GenericAtomicRMWOp
 //===----------------------------------------------------------------------===//
 
@@ -1508,6 +1551,11 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
 // ReinterpretCastOp
 //===----------------------------------------------------------------------===//
 
+void ReinterpretCastOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "reinterpret_cast");
+}
+
 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
 /// `staticSizes` and `staticStrides` are automatically filled with
 /// source-memref-rank sentinel values that encode dynamic entries.
@@ -1709,6 +1757,16 @@ void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // Reassociative reshape ops
 //===----------------------------------------------------------------------===//
 
+void CollapseShapeOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "collapse_shape");
+}
+
+void ExpandShapeOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "expand_shape");
+}
+
 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
 /// result and operand. Layout maps are verified separately.
 ///
@@ -2120,6 +2178,11 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
 // ReshapeOp
 //===----------------------------------------------------------------------===//
 
+void ReshapeOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "reshape");
+}
+
 LogicalResult ReshapeOp::verify() {
   Type operandType = getSource().getType();
   Type resultType = getResult().getType();
@@ -2170,6 +2233,11 @@ LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
 // SubViewOp
 //===----------------------------------------------------------------------===//
 
+void SubViewOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "subview");
+}
+
 /// A subview result type can be fully inferred from the source type and the
 /// static representation of offsets, sizes and strides. Special sentinels
 /// encode the dynamic case.
@@ -2735,6 +2803,11 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
 // TransposeOp
 //===----------------------------------------------------------------------===//
 
+void TransposeOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "transpose");
+}
+
 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
 static MemRefType inferTransposeResultType(MemRefType memRefType,
                                            AffineMap permutationMap) {
@@ -2826,6 +2899,10 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
 // ViewOp
 //===----------------------------------------------------------------------===//
 
+void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "view");
+}
+
 LogicalResult ViewOp::verify() {
   auto baseType = getOperand(0).getType().cast<MemRefType>();
   auto viewType = getType();