From 8d541a1fbe6d92a3fadf6d7d8e8209ed6c76e092 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 31 Dec 2020 14:46:08 -0800 Subject: [PATCH] [mlir][shape] Add shape.lib attribute Enable querying shape function library ops from the module. Currently supports singular or array of them (as long as array has all unique ops in mappings). The preferred canonical form would have one library, but given the invariant on the mapping, this can easily be achieved by a simple merging pass. Preferred the attribute approach vs naming convention as these could be added in multiple different ways. --- mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td | 1 + mlir/lib/Dialect/Shape/IR/Shape.cpp | 50 ++++++++++++ mlir/test/Analysis/test-shape-fn-report.mlir | 4 + mlir/test/Dialect/Shape/invalid.mlir | 92 ++++++++++++++++++++++ mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp | 54 ++++++++----- 5 files changed, 182 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td index a7868e7..1cccb59 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -37,6 +37,7 @@ def ShapeDialect : Dialect { let cppNamespace = "::mlir::shape"; let hasConstantMaterializer = 1; + let hasOperationAttrVerify = 1; } def Shape_ShapeType : DialectTypehasTrait()) + return op->emitError( + "shape.lib attribute may only be on op implementing SymbolTable"); + + if (auto symbolRef = attribute.second.dyn_cast()) { + auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); + if (!symbol) + return op->emitError("shape function library ") + << symbolRef << " not found"; + return isa(symbol) + ? success() + : op->emitError() + << symbolRef << " required to be shape function library"; + } + + if (auto arr = attribute.second.dyn_cast()) { + // Verify all entries are function libraries and mappings in libraries + // refer to unique ops. + DenseSet key; + for (auto it : arr) { + if (!it.isa()) + return op->emitError( + "only SymbolRefAttr allowed in shape.lib attribute array"); + + auto shapeFnLib = dyn_cast( + SymbolTable::lookupSymbolIn(op, it.cast())); + if (!shapeFnLib) + return op->emitError() + << it << " does not refer to FunctionLibraryOp"; + for (auto mapping : shapeFnLib.mapping()) { + if (!key.insert(mapping.first).second) { + return op->emitError("only one op to shape mapping allowed, found " + "multiple for `") + << mapping.first << "`"; + } + } + } + return success(); + } + + return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " + "allowed as shape.lib attribute"); + } + return success(); +} + //===----------------------------------------------------------------------===// // AnyOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Analysis/test-shape-fn-report.mlir b/mlir/test/Analysis/test-shape-fn-report.mlir index ad5c8e6..b015935 100644 --- a/mlir/test/Analysis/test-shape-fn-report.mlir +++ b/mlir/test/Analysis/test-shape-fn-report.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s --test-shape-function-report -verify-diagnostics +module attributes {shape.lib = [@shape_lib]} { + // expected-remark@+1 {{associated shape function: same_result_shape}} func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32> attributes {shape.function = @shape_lib::@same_result_shape} { @@ -20,3 +22,5 @@ shape.function_library @shape_lib { } mapping { test.same_operand_result_type = @same_result_shape } + +} diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir index eb0ae5a..d2f5af2 100644 --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -154,3 +154,95 @@ func @broadcast(%arg0 : !shape.shape, %arg1 : tensor) -> tensor -> tensor return %result : tensor } + +// ----- + +// Test using an unsupported shape.lib attribute type. + +// expected-error@+1 {{only SymbolRefAttr allowed in shape.lib attribute array}} +module attributes {shape.lib = [@shape_lib, "shape_lib"]} { + +shape.function_library @shape_lib { + // Test shape function that returns the shape of input arg as result shape. + func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape + return %0 : !shape.shape + } +} mapping { + test.same_operand_result_type = @same_result_shape +} + +} + +// ----- + +// Test that duplicate op to shape function mappings are flagged, this uses +// the same library twice for easy overlap. + +// expected-error@+1 {{only one op to shape mapping allowed}} +module attributes {shape.lib = [@shape_lib, @shape_lib]} { + +shape.function_library @shape_lib { + // Test shape function that returns the shape of input arg as result shape. + func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape + return %0 : !shape.shape + } +} mapping { + test.same_operand_result_type = @same_result_shape +} + +} + +// ----- + +// Test that duplicate op to shape function mappings are flagged (this is +// more an invariant of using the dictionary attribute here than anything +// specific to function library op). + +module attributes {shape.lib = [@shape_lib]} { + +shape.function_library @shape_lib { + // Test shape function that returns the shape of input arg as result shape. + func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape + return %0 : !shape.shape + } +} mapping { + // expected-error @+2 {{duplicate key}} + test.same_operand_result_type = @same_result_shape, + test.same_operand_result_type = @same_result_shape +} + +} + +// ----- + +// Test that op referred to by shape lib is a shape function library. + +// expected-error@+1 {{required to be shape function library}} +module attributes {shape.lib = @fn} { + +func @fn(%arg: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape + return %0 : !shape.shape +} + +} + +// ----- + +// Test that op referred to by shape lib is a shape function library. + +func @fn(%arg: !shape.value_shape) -> !shape.shape { + // expected-error@+1 {{SymbolTable}} + %0 = shape.shape_of %arg {shape.lib = @fn} : !shape.value_shape -> !shape.shape + return %0 : !shape.shape +} + +// ----- + +// Test that shape function library is defined. + +// expected-error@+1 {{@fn not found}} +module attributes {shape.lib = @fn} { } diff --git a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp index b7127c5..4477eb1 100644 --- a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp +++ b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp @@ -26,41 +26,57 @@ struct ReportShapeFnPass void ReportShapeFnPass::runOnOperation() { auto module = getOperation(); - // Lookup shape function library. - shape::FunctionLibraryOp shapeFnLib = nullptr; - for (auto lib : module.getOps()) { - if (shapeFnLib) { - lib.emitError("duplicate shape library op") - .attachNote(shapeFnLib.getLoc()) - << "previous mapping"; - return signalPassFailure(); - } - shapeFnLib = lib; - }; - // Report the shape function available to refine the op. auto shapeFnId = Identifier::get("shape.function", &getContext()); - auto remarkShapeFn = [&](Operation *op) { + auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) { if (op->isKnownTerminator()) - return; + return true; if (auto typeInterface = dyn_cast(op)) { op->emitRemark() << "implements InferType op interface"; - } else if (auto fn = shapeFnLib.getShapeFunction(op)) { + return true; + } + if (auto fn = shapeFnLib.getShapeFunction(op)) { op->emitRemark() << "associated shape function: " << fn.getName(); - } else if (auto symbol = op->getAttrOfType(shapeFnId)) { + return true; + } + if (auto symbol = op->getAttrOfType(shapeFnId)) { auto fn = cast(SymbolTable::lookupSymbolIn(module, symbol)); op->emitRemark() << "associated shape function: " << fn.getName(); - } else { - op->emitRemark() << "no associated way to refine shape"; + return true; } + return false; }; + // Lookup shape function library. + SmallVector libraries; + auto attr = module.getAttr("shape.lib"); + if (attr) { + auto lookup = [&](Attribute attr) { + return cast( + SymbolTable::lookupSymbolIn(module, attr.cast())); + }; + if (auto arrayAttr = attr.dyn_cast()) { + libraries.reserve(arrayAttr.size()); + for (auto attr : arrayAttr) + libraries.push_back(lookup(attr)); + } else { + libraries.reserve(1); + libraries.push_back(lookup(attr)); + } + } + module.getBodyRegion().walk([&](FuncOp func) { // Skip ops in the shape function library. if (isa(func->getParentOp())) return; - func.walk([&](Operation *op) { remarkShapeFn(op); }); + func.walk([&](Operation *op) { + bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) { + return remarkShapeFn(lib, op); + }); + if (!found) + op->emitRemark() << "no associated way to refine shape"; + }); }); } -- 2.7.4