[mlir] Provide minimal Python bindings for the math dialect
authorAlex Zinenko <zinenko@google.com>
Thu, 10 Jun 2021 17:00:34 +0000 (19:00 +0200)
committerAlex Zinenko <zinenko@google.com>
Fri, 11 Jun 2021 11:21:26 +0000 (13:21 +0200)
Reviewed By: ulysseB

Differential Revision: https://reviews.llvm.org/D104045

mlir/python/mlir/dialects/CMakeLists.txt
mlir/python/mlir/dialects/MathOps.td [new file with mode: 0644]
mlir/python/mlir/dialects/math.py [new file with mode: 0644]
mlir/test/python/dialects/math.py [new file with mode: 0644]

index 5eeb6d6..3c04344 100644 (file)
@@ -25,6 +25,11 @@ add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps
   DEPENDS LinalgOdsGen)
 add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonLinalgOps)
 
+add_mlir_dialect_python_bindings(MLIRBindingsPythonMathOps
+  TD_FILE MathOps.td
+  DIALECT_NAME math)
+add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonMathOps)
+
 add_mlir_dialect_python_bindings(MLIRBindingsPythonMemRefOps
   TD_FILE MemRefOps.td
   DIALECT_NAME memref)
diff --git a/mlir/python/mlir/dialects/MathOps.td b/mlir/python/mlir/dialects/MathOps.td
new file mode 100644 (file)
index 0000000..03d1fde
--- /dev/null
@@ -0,0 +1,15 @@
+//===-- MathOps.td - Entry point for MathOps bindings ------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_MATH_OPS
+#define PYTHON_BINDINGS_MATH_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/Dialect/Math/IR/MathOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/math.py b/mlir/python/mlir/dialects/math.py
new file mode 100644 (file)
index 0000000..f082bf4
--- /dev/null
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._math_ops_gen import *
diff --git a/mlir/test/python/dialects/math.py b/mlir/test/python/dialects/math.py
new file mode 100644 (file)
index 0000000..73246e2
--- /dev/null
@@ -0,0 +1,26 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+import mlir.dialects.builtin as builtin
+import mlir.dialects.math as mlir_math
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+
+# CHECK-LABEL: TEST: testMathOps
+@run
+def testMathOps():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    with InsertionPoint(module.body):
+      @builtin.FuncOp.from_py_func(F32Type.get())
+      def emit_sqrt(arg):
+        return mlir_math.SqrtOp(F32Type.get(), arg)
+
+    # CHECK-LABEL: func @emit_sqrt(
+    # CHECK-SAME:                  %[[ARG:.*]]: f32) {
+    # CHECK:         math.sqrt %[[ARG]] : f32
+    # CHECK:         return
+    # CHECK:       }
+    print(module)