From: Alex Zinenko Date: Tue, 2 Nov 2021 13:15:25 +0000 (+0100) Subject: [mlir][python] improve usability of Python affine construct bindings X-Git-Tag: upstream/15.0.7~26852 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=fc7594cc4aa5e652fe61f278a13e865141797245;p=platform%2Fupstream%2Fllvm.git [mlir][python] improve usability of Python affine construct bindings - Provide the operator overloads for constructing (semi-)affine expressions in Python by combining existing expressions with constants. - Make AffineExpr, AffineMap and IntegerSet hashable in Python. - Expose the AffineExpr composition functionality. Reviewed By: gysit, aoyal Differential Revision: https://reviews.llvm.org/D113010 --- diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h index 5516f29..14e951d 100644 --- a/mlir/include/mlir-c/AffineExpr.h +++ b/mlir/include/mlir-c/AffineExpr.h @@ -39,6 +39,8 @@ DEFINE_C_API_STRUCT(MlirAffineExpr, const void); #undef DEFINE_C_API_STRUCT +struct MlirAffineMap; + /// Gets the context that owns the affine expression. MLIR_CAPI_EXPORTED MlirContext mlirAffineExprGetContext(MlirAffineExpr affineExpr); @@ -86,6 +88,10 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsMultipleOf(MlirAffineExpr affineExpr, MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr, intptr_t position); +/// Composes the given map with the given expression. +MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose( + MlirAffineExpr affineExpr, struct MlirAffineMap affineMap); + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 50a96c8c..da80cda 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -205,6 +205,18 @@ public: return PyAffineAddExpr(lhs.getContext(), expr); } + static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineAddExpr(lhs.getContext(), expr); + } + + static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineAddExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineAddExpr::get); } @@ -222,6 +234,18 @@ public: return PyAffineMulExpr(lhs.getContext(), expr); } + static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineMulExpr(lhs.getContext(), expr); + } + + static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineMulExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineMulExpr::get); } @@ -239,6 +263,18 @@ public: return PyAffineModExpr(lhs.getContext(), expr); } + static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineModExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineModExpr(lhs.getContext(), expr); + } + + static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineModExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineModExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineModExpr::get); } @@ -256,6 +292,18 @@ public: return PyAffineFloorDivExpr(lhs.getContext(), expr); } + static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineFloorDivExpr(lhs.getContext(), expr); + } + + static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineFloorDivExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineFloorDivExpr::get); } @@ -273,6 +321,18 @@ public: return PyAffineCeilDivExpr(lhs.getContext(), expr); } + static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineCeilDivExpr(lhs.getContext(), expr); + } + + static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineCeilDivExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineCeilDivExpr::get); } @@ -435,17 +495,19 @@ void mlir::python::populateIRAffine(py::module &m) { .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) - .def("__add__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineAddExpr::get(self, other); - }) - .def("__mul__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineMulExpr::get(self, other); - }) - .def("__mod__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineModExpr::get(self, other); + .def("__add__", &PyAffineAddExpr::get) + .def("__add__", &PyAffineAddExpr::getRHSConstant) + .def("__radd__", &PyAffineAddExpr::getRHSConstant) + .def("__mul__", &PyAffineMulExpr::get) + .def("__mul__", &PyAffineMulExpr::getRHSConstant) + .def("__rmul__", &PyAffineMulExpr::getRHSConstant) + .def("__mod__", &PyAffineModExpr::get) + .def("__mod__", &PyAffineModExpr::getRHSConstant) + .def("__rmod__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineModExpr::get( + PyAffineConstantExpr::get(other, *self.getContext().get()), + self); }) .def("__sub__", [](PyAffineExpr &self, PyAffineExpr &other) { @@ -454,6 +516,17 @@ void mlir::python::populateIRAffine(py::module &m) { return PyAffineAddExpr::get(self, PyAffineMulExpr::get(negOne, other)); }) + .def("__sub__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineAddExpr::get( + self, + PyAffineConstantExpr::get(-other, *self.getContext().get())); + }) + .def("__rsub__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineAddExpr::getLHSConstant( + other, PyAffineMulExpr::getLHSConstant(-1, self)); + }) .def("__eq__", [](PyAffineExpr &self, PyAffineExpr &other) { return self == other; }) .def("__eq__", @@ -474,24 +547,63 @@ void mlir::python::populateIRAffine(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) + .def("__hash__", + [](PyAffineExpr &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def_property_readonly( "context", [](PyAffineExpr &self) { return self.getContext().getObject(); }) + .def("compose", + [](PyAffineExpr &self, PyAffineMap &other) { + return PyAffineExpr(self.getContext(), + mlirAffineExprCompose(self, other)); + }) .def_static( "get_add", &PyAffineAddExpr::get, "Gets an affine expression containing a sum of two expressions.") + .def_static("get_add", &PyAffineAddExpr::getLHSConstant, + "Gets an affine expression containing a sum of a constant " + "and another expression.") + .def_static("get_add", &PyAffineAddExpr::getRHSConstant, + "Gets an affine expression containing a sum of an expression " + "and a constant.") .def_static( "get_mul", &PyAffineMulExpr::get, "Gets an affine expression containing a product of two expressions.") + .def_static("get_mul", &PyAffineMulExpr::getLHSConstant, + "Gets an affine expression containing a product of a " + "constant and another expression.") + .def_static("get_mul", &PyAffineMulExpr::getRHSConstant, + "Gets an affine expression containing a product of an " + "expression and a constant.") .def_static("get_mod", &PyAffineModExpr::get, "Gets an affine expression containing the modulo of dividing " "one expression by another.") + .def_static("get_mod", &PyAffineModExpr::getLHSConstant, + "Gets a semi-affine expression containing the modulo of " + "dividing a constant by an expression.") + .def_static("get_mod", &PyAffineModExpr::getRHSConstant, + "Gets an affine expression containing the module of dividing" + "an expression by a constant.") .def_static("get_floor_div", &PyAffineFloorDivExpr::get, "Gets an affine expression containing the rounded-down " "result of dividing one expression by another.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant, + "Gets a semi-affine expression containing the rounded-down " + "result of dividing a constant by an expression.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant, + "Gets an affine expression containing the rounded-down " + "result of dividing an expression by a constant.") .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, "Gets an affine expression containing the rounded-up result " "of dividing one expression by another.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant, + "Gets a semi-affine expression containing the rounded-up " + "result of dividing a constant by an expression.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, + "Gets an affine expression containing the rounded-up result " + "of dividing an expression by a constant.") .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), py::arg("context") = py::none(), "Gets a constant affine expression with the given value.") @@ -542,6 +654,10 @@ void mlir::python::populateIRAffine(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) + .def("__hash__", + [](PyAffineMap &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def_static("compress_unused_symbols", [](py::list affineMaps, DefaultingPyMlirContext context) { SmallVector maps; @@ -714,6 +830,10 @@ void mlir::python::populateIRAffine(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) + .def("__hash__", + [](PyIntegerSet &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def_property_readonly( "context", [](PyIntegerSet &self) { return self.getContext().getObject(); }) diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp index 2d8bc3c..5b25ab5 100644 --- a/mlir/lib/CAPI/IR/AffineExpr.cpp +++ b/mlir/lib/CAPI/IR/AffineExpr.cpp @@ -56,6 +56,11 @@ bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr, return unwrap(affineExpr).isFunctionOfDim(position); } +MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr, + MlirAffineMap affineMap) { + return wrap(unwrap(affineExpr).compose(unwrap(affineMap))); +} + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index ef55537..1056f65 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1393,6 +1393,13 @@ int affineMapFromExprs(MlirContext ctx) { if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr)) return 3; + MlirAffineExpr affineDim2Expr = mlirAffineDimExprGet(ctx, 1); + MlirAffineExpr composed = mlirAffineExprCompose(affineDim2Expr, map); + // CHECK: s1 + mlirAffineExprDump(composed); + if (!mlirAffineExprEqual(composed, affineSymbolExpr)) + return 4; + return 0; } diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py index 1844668..9854b49 100644 --- a/mlir/test/python/ir/affine_expr.py +++ b/mlir/test/python/ir/affine_expr.py @@ -137,6 +137,14 @@ def testAffineAddExpr(): # CHECK: d1 + d2 print(d12op) + d1cst_op = d1 + 2 + # CHECK: d1 + 2 + print(d1cst_op) + + d1cst_op2 = 2 + d1 + # CHECK: d1 + 2 + print(d1cst_op2) + assert d12 == d12op assert d12.lhs == d1 assert d12.rhs == d2 @@ -156,7 +164,16 @@ def testAffineMulExpr(): op = d1 * c2 print(op) + # CHECK: d1 * 2 + op_cst = d1 * 2 + print(op_cst) + + # CHECK: d1 * 2 + op_cst2 = 2 * d1 + print(op_cst2) + assert expr == op + assert expr == op_cst assert expr.lhs == d1 assert expr.rhs == c2 @@ -175,10 +192,32 @@ def testAffineModExpr(): op = d1 % c2 print(op) + # CHECK: d1 mod 2 + op_cst = d1 % 2 + print(op_cst) + + # CHECK: 2 mod d1 + print(2 % d1) + assert expr == op + assert expr == op_cst assert expr.lhs == d1 assert expr.rhs == c2 + expr2 = AffineExpr.get_mod(c2, d1) + expr3 = AffineExpr.get_mod(2, d1) + expr4 = AffineExpr.get_mod(d1, 2) + + # CHECK: 2 mod d1 + print(expr2) + # CHECK: 2 mod d1 + print(expr3) + # CHECK: d1 mod 2 + print(expr4) + + assert expr2 == expr3 + assert expr4 == expr + # CHECK-LABEL: TEST: testAffineFloorDivExpr @run @@ -193,6 +232,20 @@ def testAffineFloorDivExpr(): assert expr.lhs == d1 assert expr.rhs == c2 + expr2 = AffineExpr.get_floor_div(c2, d1) + expr3 = AffineExpr.get_floor_div(2, d1) + expr4 = AffineExpr.get_floor_div(d1, 2) + + # CHECK: 2 floordiv d1 + print(expr2) + # CHECK: 2 floordiv d1 + print(expr3) + # CHECK: d1 floordiv 2 + print(expr4) + + assert expr2 == expr3 + assert expr4 == expr + # CHECK-LABEL: TEST: testAffineCeilDivExpr @run @@ -207,6 +260,20 @@ def testAffineCeilDivExpr(): assert expr.lhs == d1 assert expr.rhs == c2 + expr2 = AffineExpr.get_ceil_div(c2, d1) + expr3 = AffineExpr.get_ceil_div(2, d1) + expr4 = AffineExpr.get_ceil_div(d1, 2) + + # CHECK: 2 ceildiv d1 + print(expr2) + # CHECK: 2 ceildiv d1 + print(expr3) + # CHECK: d1 ceildiv 2 + print(expr4) + + assert expr2 == expr3 + assert expr4 == expr + # CHECK-LABEL: TEST: testAffineExprSub @run @@ -225,6 +292,15 @@ def testAffineExprSub(): # CHECK: -1 print(rhs.rhs) + # CHECK: d1 - 42 + print(d1 - 42) + # CHECK: -d1 + 42 + print(42 - d1) + + c42 = AffineConstantExpr.get(42) + assert d1 - 42 == d1 - c42 + assert 42 - d1 == c42 - d1 + # CHECK-LABEL: TEST: testClassHierarchy @run def testClassHierarchy(): @@ -289,3 +365,38 @@ def testIsInstance(): print(AffineMulExpr.isinstance(mul)) # CHECK: False print(AffineAddExpr.isinstance(mul)) + + +# CHECK-LABEL: TEST: testCompose +@run +def testCompose(): + with Context(): + # d0 + d2. + expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2)) + + # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2) + map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1)) + map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0)) + map3 = AffineAddExpr.get( + AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)), + AffineDimExpr.get(2)) + map = AffineMap.get(3, 2, [map1, map2, map3]) + + # CHECK: d0 + s1 + d0 + d1 + d2 + print(expr.compose(map)) + + +# CHECK-LABEL: TEST: testHash +@run +def testHash(): + with Context(): + d0 = AffineDimExpr.get(0) + s1 = AffineSymbolExpr.get(1) + assert hash(d0) == hash(AffineDimExpr.get(0)) + assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1)) + + dictionary = dict() + dictionary[d0] = 0 + dictionary[s1] = 1 + assert d0 in dictionary + assert s1 in dictionary diff --git a/mlir/test/python/ir/affine_map.py b/mlir/test/python/ir/affine_map.py index da5d230..52c7261 100644 --- a/mlir/test/python/ir/affine_map.py +++ b/mlir/test/python/ir/affine_map.py @@ -9,9 +9,11 @@ def run(f): f() gc.collect() assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testAffineMapCapsule +@run def testAffineMapCapsule(): with Context() as ctx: am1 = AffineMap.get_empty(ctx) @@ -23,10 +25,8 @@ def testAffineMapCapsule(): assert am2.context is ctx -run(testAffineMapCapsule) - - # CHECK-LABEL: TEST: testAffineMapGet +@run def testAffineMapGet(): with Context() as ctx: d0 = AffineDimExpr.get(0) @@ -100,10 +100,8 @@ def testAffineMapGet(): print(e) -run(testAffineMapGet) - - # CHECK-LABEL: TEST: testAffineMapDerive +@run def testAffineMapDerive(): with Context() as ctx: map5 = AffineMap.get_identity(5) @@ -121,10 +119,8 @@ def testAffineMapDerive(): print(map34) -run(testAffineMapDerive) - - # CHECK-LABEL: TEST: testAffineMapProperties +@run def testAffineMapProperties(): with Context(): d0 = AffineDimExpr.get(0) @@ -147,10 +143,8 @@ def testAffineMapProperties(): print(map3.is_projected_permutation) -run(testAffineMapProperties) - - # CHECK-LABEL: TEST: testAffineMapExprs +@run def testAffineMapExprs(): with Context(): d0 = AffineDimExpr.get(0) @@ -181,10 +175,8 @@ def testAffineMapExprs(): assert list(map3.results) == [d2, d0, d1] -run(testAffineMapExprs) - - # CHECK-LABEL: TEST: testCompressUnusedSymbols +@run def testCompressUnusedSymbols(): with Context() as ctx: d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), @@ -210,10 +202,8 @@ def testCompressUnusedSymbols(): print(compressed_maps) -run(testCompressUnusedSymbols) - - # CHECK-LABEL: TEST: testReplace +@run def testReplace(): with Context() as ctx: d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), @@ -236,4 +226,16 @@ def testReplace(): print(replace3) -run(testReplace) +# CHECK-LABEL: TEST: testHash +@run +def testHash(): + with Context(): + d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1) + m1 = AffineMap.get(2, 0, [d0, d1]) + m2 = AffineMap.get(2, 0, [d1, d0]) + assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1])) + + dictionary = dict() + dictionary[m1] = 1 + dictionary[m2] = 2 + assert m1 in dictionary diff --git a/mlir/test/python/ir/integer_set.py b/mlir/test/python/ir/integer_set.py index bdec8af..b916d9ab3 100644 --- a/mlir/test/python/ir/integer_set.py +++ b/mlir/test/python/ir/integer_set.py @@ -8,9 +8,11 @@ def run(f): f() gc.collect() assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testIntegerSetCapsule +@run def testIntegerSetCapsule(): with Context() as ctx: is1 = IntegerSet.get_empty(1, 1, ctx) @@ -21,10 +23,9 @@ def testIntegerSetCapsule(): assert is1 == is2 assert is2.context is ctx -run(testIntegerSetCapsule) - # CHECK-LABEL: TEST: testIntegerSetGet +@run def testIntegerSetGet(): with Context(): d0 = AffineDimExpr.get(0) @@ -92,10 +93,9 @@ def testIntegerSetGet(): # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols print(e) -run(testIntegerSetGet) - # CHECK-LABEL: TEST: testIntegerSetProperties +@run def testIntegerSetProperties(): with Context(): d0 = AffineDimExpr.get(0) @@ -125,4 +125,17 @@ def testIntegerSetProperties(): print(cstr.expr, end='') print(" == 0" if cstr.is_eq else " >= 0") -run(testIntegerSetProperties) + +# CHECK_LABEL: TEST: testHash +@run +def testHash(): + with Context(): + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + set = IntegerSet.get(2, 0, [d0 + d1], [True]) + + assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True])) + + dictionary = dict() + dictionary[set] = 42 + assert set in dictionary