Python bindings: expose various Ops through declarative builders
authorAlex Zinenko <zinenko@google.com>
Thu, 14 Mar 2019 12:04:38 +0000 (05:04 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 30 Mar 2019 00:17:27 +0000 (17:17 -0700)
In particular, expose `cond_br`, `select` and `call` operations with syntax
similar to that of the previous emitter-based EDSC interface.  These are
provided for backwards-compatibility.  Ideally, we want them to be
Table-generated from the Op definitions when those definitions are declarative.

Additionally, expose the ability to construct any op given its canonical name,
which also exercises the construction of unregistered ops.

PiperOrigin-RevId: 238421583

mlir/bindings/python/pybind.cpp
mlir/bindings/python/test/test_py2and3.py

index fb48aa9..d03a8f8 100644 (file)
@@ -76,6 +76,17 @@ struct PythonValueHandle {
     return std::to_string(reinterpret_cast<intptr_t>(value.getValue()));
   }
 
+  PythonValueHandle call(const std::vector<PythonValueHandle> &args) {
+    assert(value.hasType() && value.getType().isa<FunctionType>() &&
+           "can only call function-typed values");
+
+    std::vector<Value *> argValues;
+    argValues.reserve(args.size());
+    for (auto arg : args)
+      argValues.push_back(arg.value.getValue());
+    return ValueHandle::create<CallIndirectOp>(value, argValues);
+  }
+
   mlir::edsc::ValueHandle value;
 };
 
@@ -958,6 +969,38 @@ PYBIND11_MODULE(pybind, m) {
         return PythonValueHandle(nullptr);
       },
       py::arg("dest"), py::arg("args") = std::vector<PythonValueHandle>());
+  m.def(
+      "cond_br",
+      [](PythonValueHandle condition, const PythonBlockHandle &trueDest,
+         const std::vector<PythonValueHandle> &trueArgs,
+         const PythonBlockHandle &falseDest,
+         const std::vector<PythonValueHandle> &falseArgs) -> PythonValueHandle {
+        std::vector<ValueHandle> trueArguments(trueArgs.begin(),
+                                               trueArgs.end());
+        std::vector<ValueHandle> falseArguments(falseArgs.begin(),
+                                                falseArgs.end());
+        intrinsics::COND_BR(condition, trueDest, trueArguments, falseDest,
+                            falseArguments);
+        return PythonValueHandle(nullptr);
+      });
+  m.def("select",
+        [](PythonValueHandle condition, PythonValueHandle trueValue,
+           PythonValueHandle falseValue) -> PythonValueHandle {
+          return ValueHandle::create<SelectOp>(condition.value, trueValue.value,
+                                               falseValue.value);
+        });
+  m.def("op",
+        [](const std::string &name,
+           const std::vector<PythonValueHandle> &operands,
+           const std::vector<PythonType> &resultTypes) -> PythonValueHandle {
+          std::vector<ValueHandle> operandHandles(operands.begin(),
+                                                  operands.end());
+          std::vector<Type> types;
+          types.reserve(resultTypes.size());
+          for (auto t : resultTypes)
+            types.push_back(Type::getFromOpaquePointer(t.type));
+          return ValueHandle::create(name, operandHandles, types);
+        });
 
   m.def("Max", [](const py::list &args) {
     SmallVector<edsc_expr_t, 8> owning;
@@ -1163,7 +1206,8 @@ PYBIND11_MODULE(pybind, m) {
                  -> PythonValueHandle { return lhs.value / rhs.value; })
         .def("__mod__",
              [](PythonValueHandle lhs, PythonValueHandle rhs)
-                 -> PythonValueHandle { return lhs.value % rhs.value; });
+                 -> PythonValueHandle { return lhs.value % rhs.value; })
+        .def("__call__", &PythonValueHandle::call);
   }
 
   py::class_<PythonBlockAppender>(
index b89e90c..779135f 100644 (file)
@@ -227,6 +227,18 @@ class EdscTest(unittest.TestCase):
     self.assertIn("^bb1(%0: index, %1: index):", code)
     self.assertIn("  br ^bb1(%1, %0 : index, index)", code)
 
+  def testCondBr(self):
+    with self.module.function_context("foo", [self.boolType], []) as fun:
+      with E.BlockContext() as blk1:
+        E.ret([])
+      with E.BlockContext([self.indexType]) as blk2:
+        E.ret([])
+      cst = E.constant_index(0)
+      E.cond_br(fun.arg(0), blk1, [], blk2, [cst])
+
+    code = str(fun)
+    self.assertIn("cond_br %arg0, ^bb1, ^bb2(%c0 : index)", code)
+
   def testRet(self):
     with self.module.function_context("foo", [],
                                       [self.indexType, self.indexType]) as fun:
@@ -238,6 +250,35 @@ class EdscTest(unittest.TestCase):
     self.assertIn("  %c0 = constant 0 : index", code)
     self.assertIn("  return %c42, %c0 : index, index", code)
 
+  def testSelectOp(self):
+    with self.module.function_context("foo", [self.boolType],
+                                      [self.i32Type]) as fun:
+      a = E.constant_int(42, 32)
+      b = E.constant_int(0, 32)
+      E.ret([E.select(fun.arg(0), a, b)])
+
+    code = str(fun)
+    self.assertIn("%0 = select %arg0, %c42_i32, %c0_i32 : i32", code)
+
+  def testCallOp(self):
+    callee = self.module.declare_function("sqrtf", [self.f32Type],
+                                          [self.f32Type])
+    with self.module.function_context("call", [self.f32Type], []) as fun:
+      funCst = E.constant_function(callee)
+      funCst([fun.arg(0)]) + E.constant_float(42., self.f32Type)
+
+    code = str(self.module)
+    self.assertIn("func @sqrtf(f32) -> f32", code)
+    self.assertIn("%f = constant @sqrtf : (f32) -> f32", code)
+    self.assertIn("%0 = call_indirect %f(%arg0) : (f32) -> f32", code)
+
+  def testCustom(self):
+    with self.module.function_context("custom", [self.indexType, self.f32Type],
+                                      []) as fun:
+      E.op("foo", [fun.arg(0)], [self.f32Type]) + fun.arg(1)
+    code = str(fun)
+    self.assertIn('%0 = "foo"(%arg0) : (index) -> f32', code)
+    self.assertIn("%1 = addf %0, %arg1 : f32", code)
 
   def testConstants(self):
     with self.module.function_context("constants", [], []) as fun: