Add Python bindings for affine expressions with binary operators.
authorMLIR Team <no-reply@google.com>
Tue, 3 Dec 2019 18:11:40 +0000 (10:11 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Dec 2019 18:12:11 +0000 (10:12 -0800)
PiperOrigin-RevId: 283569325

mlir/bindings/python/pybind.cpp
mlir/bindings/python/test/test_py2and3.py

index 90b24cdbf4c70183dd3485e124a363863a7d0da0..b1be0d2133679b0c0b01d0be7b62711ebc9e5398 100644 (file)
@@ -207,6 +207,9 @@ struct PythonMLIRModule {
   // Creates an affine symbol expression.
   PythonAffineExpr affineSymbolExpr(unsigned position);
 
+  // Creates an affine dimension expression.
+  PythonAffineExpr affineDimExpr(unsigned position);
+
   // Creates a single constant result affine map.
   PythonAffineMap affineConstantMap(int64_t value);
 
@@ -565,6 +568,8 @@ struct PythonAffineExpr {
   operator AffineExpr() const { return affine_expr; }
   operator AffineExpr &() { return affine_expr; }
 
+  AffineExpr get() const { return affine_expr; }
+
   std::string str() const {
     std::string res;
     llvm::raw_string_ostream os(res);
@@ -724,6 +729,10 @@ PythonAffineExpr PythonMLIRModule::affineSymbolExpr(unsigned position) {
   return PythonAffineExpr(getAffineSymbolExpr(position, &mlirContext));
 }
 
+PythonAffineExpr PythonMLIRModule::affineDimExpr(unsigned position) {
+  return PythonAffineExpr(getAffineDimExpr(position, &mlirContext));
+}
+
 PythonAffineMap PythonMLIRModule::affineConstantMap(int64_t value) {
   return PythonAffineMap(AffineMap::getConstantMap(value, &mlirContext));
 }
@@ -937,6 +946,8 @@ PYBIND11_MODULE(pybind, m) {
            "Returns an affine constant expression.")
       .def("affine_symbol_expr", &PythonMLIRModule::affineSymbolExpr,
            "Returns an affine symbol expression.")
+      .def("affine_dim_expr", &PythonMLIRModule::affineDimExpr,
+           "Returns an affine dim expression.")
       .def("affine_constant_map", &PythonMLIRModule::affineConstantMap,
            "Returns an affine map with single constant result.")
       .def("affine_map", &PythonMLIRModule::affineMap, "Returns an affine map.",
@@ -1054,6 +1065,58 @@ PYBIND11_MODULE(pybind, m) {
   py::class_<PythonAffineExpr>(m, "AffineExpr",
                                "A wrapper around mlir::AffineExpr")
       .def(py::init<PythonAffineExpr>())
+      .def("__add__",
+           [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() + rhs);
+           })
+      .def("__add__",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() + rhs.get());
+           })
+      .def("__neg__",
+           [](PythonAffineExpr lhs) -> PythonAffineExpr {
+             return PythonAffineExpr(-lhs.get());
+           })
+      .def("__sub__",
+           [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() - rhs);
+           })
+      .def("__sub__",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() - rhs.get());
+           })
+      .def("__mul__",
+           [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() * rhs);
+           })
+      .def("__mul__",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() * rhs.get());
+           })
+      .def("__floordiv__",
+           [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get().floorDiv(rhs));
+           })
+      .def("__floordiv__",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get().floorDiv(rhs.get()));
+           })
+      .def("ceildiv",
+           [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get().ceilDiv(rhs));
+           })
+      .def("ceildiv",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get().ceilDiv(rhs.get()));
+           })
+      .def("__mod__",
+           [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() % rhs);
+           })
+      .def("__mod__",
+           [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
+             return PythonAffineExpr(lhs.get() % rhs.get());
+           })
       .def("__str__", &PythonAffineExpr::str);
 
   py::class_<PythonAffineMap>(m, "AffineMap",
index cd4d7bfcecda54a5fbd469392546d35538866d9d..678e5023173c9f6986eaad9f9aca670c9902d6e8 100644 (file)
@@ -289,22 +289,30 @@ class EdscTest:
     self.setUp()
     a1 = self.module.affine_constant_expr(23)
     a2 = self.module.affine_constant_expr(44)
+    a3 = self.module.affine_dim_expr(1)
     s0 = self.module.affine_symbol_expr(0)
     aMap1 = self.module.affine_map(2, 0, [a1, a2, s0])
     aMap2 = self.module.affine_constant_map(42)
+    aMap3 = self.module.affine_map(
+        2, 0,
+        [a1 + a2 * a3, a1 // a3 % a2,
+         a1.ceildiv(a2), a1 - 2, a2 * 2, -a3])
+
     affineAttr1 = self.module.affineMapAttr(aMap1)
     affineAttr2 = self.module.affineMapAttr(aMap2)
+    affineAttr3 = self.module.affineMapAttr(aMap3)
 
     t = self.module.make_memref_type(self.f32Type, [10])
     t_with_attr = t({
         "affine_attr_1": affineAttr1,
-        "affine_attr_2": affineAttr2
+        "affine_attr_2": affineAttr2,
+        "affine_attr_3": affineAttr3,
     })
 
     f = self.module.declare_function("foo", [t, t_with_attr], [])
     printWithCurrentFunctionName(str(self.module))
     # CHECK-LABEL: testFunctionDeclarationWithAffineAttr
-    #       CHECK:  func @foo(memref<10xf32>, memref<10xf32> {affine_attr_1 = (d0, d1) -> (23, 44, s0), affine_attr_2 = () -> (42)})
+    #       CHECK:  func @foo(memref<10xf32>, memref<10xf32> {affine_attr_1 = (d0, d1) -> (23, 44, s0), affine_attr_2 = () -> (42), affine_attr_3 = (d0, d1) -> (d1 * 44 + 23, (23 floordiv d1) mod 44, 1, 21, 88, -d1)})
 
   def testFunctionDeclarationWithArrayAttr(self):
     self.setUp()