From 657f68b1f2fd38deb63c23d8f46d12b7fd357e63 Mon Sep 17 00:00:00 2001 From: Quentin Colombet Date: Tue, 11 Oct 2022 20:53:25 +0000 Subject: [PATCH] [NFC][mlir][MemRef] Make use of InferTypeOpInterface The `InferTypeOpInterface` generates builders for things it can infer the types. Thanks to that interface we can: - Eliminate a builder for DimOp, and - Describe how to infer the result types of `extract_strided_metadata` from its source, and get a simpler builder as a result NFC Differential Revision: https://reviews.llvm.org/D135734 --- mlir/include/mlir/Dialect/MemRef/IR/MemRef.h | 1 + mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td | 5 ++-- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 32 ++++++++++++++++++---- .../Transforms/SimplifyExtractStridedMetadata.cpp | 12 ++------ 4 files changed, 32 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h index bd99cf2..6538abc 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -15,6 +15,7 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 54394da..ba8fe81 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ShapedOpInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -581,7 +582,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [ let builders = [ OpBuilder<(ins "Value":$source, "int64_t":$index)>, - OpBuilder<(ins "Value":$source, "Value":$index)> ]; let extraClassDeclaration = [{ @@ -853,7 +853,8 @@ def MemRef_ExtractAlignedPointerAsIndexOp : def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [ DeclareOpInterfaceMethods, Pure, - SameVariadicResultSize]> { + SameVariadicResultSize, + DeclareOpInterfaceMethods]> { let summary = "Extracts a buffer base with offset and strides"; let description = [{ Extracts a base buffer, offset and strides. This op allows additional layers diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index fbc1ead..ab7311b 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -807,12 +807,6 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source, build(builder, result, source, indexValue); } -void DimOp::build(OpBuilder &builder, OperationState &result, Value source, - Value index) { - auto indexTy = builder.getIndexType(); - build(builder, result, indexTy, source, index); -} - Optional DimOp::getConstantIndex() { if (auto constantOp = getIndex().getDefiningOp()) return constantOp.getValue().cast().getInt(); @@ -1254,6 +1248,32 @@ void ExtractAlignedPointerAsIndexOp::getAsmResultNames( // ExtractStridedMetadataOp //===----------------------------------------------------------------------===// +/// The number and type of the results are inferred from the +/// shape of the source. +LogicalResult ExtractStridedMetadataOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, regions); + auto sourceType = extractAdaptor.getSource().getType().dyn_cast(); + if (!sourceType) + return failure(); + + unsigned sourceRank = sourceType.getRank(); + IndexType indexType = IndexType::get(context); + auto memrefType = + MemRefType::get({}, sourceType.getElementType(), + MemRefLayoutAttrInterface{}, sourceType.getMemorySpace()); + // Base. + inferredReturnTypes.push_back(memrefType); + // Offset. + inferredReturnTypes.push_back(indexType); + // Sizes and strides. + for (unsigned i = 0; i < sourceRank * 2; ++i) + inferredReturnTypes.push_back(indexType); + return success(); +} + void ExtractStridedMetadataOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getBaseBuffer(), "base_buffer"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp index 257a02b..2a8ffba 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -59,16 +59,12 @@ public: // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(subview(memref)). Location origLoc = op.getLoc(); - IndexType indexType = rewriter.getIndexType(); Value source = subview.getSource(); auto sourceType = source.getType().cast(); unsigned sourceRank = sourceType.getRank(); - SmallVector sizeStrideTypes(sourceRank, indexType); auto newExtractStridedMetadata = - rewriter.create( - origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, - sizeStrideTypes, source); + rewriter.create(origLoc, source); SmallVector sourceStrides; int64_t sourceOffset; @@ -486,16 +482,12 @@ public: // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(reassociative_reshape_like(memref)). Location origLoc = op.getLoc(); - IndexType indexType = rewriter.getIndexType(); Value source = reshape.getSrc(); auto sourceType = source.getType().cast(); unsigned sourceRank = sourceType.getRank(); - SmallVector sizeStrideTypes(sourceRank, indexType); auto newExtractStridedMetadata = - rewriter.create( - origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, - sizeStrideTypes, source); + rewriter.create(origLoc, source); // Collect statically known information. SmallVector strides; -- 2.7.4