Python bindings: export index_cast
authorAlex Zinenko <zinenko@google.com>
Thu, 10 Oct 2019 17:25:46 +0000 (10:25 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 10 Oct 2019 17:27:04 +0000 (10:27 -0700)
We are now properly enforcing the absence of index elements in memrefs and
tensors. Instead, users are expected to store sized integers and cast them to
index type if necessary. Expose the respective operation to Python bindings.

PiperOrigin-RevId: 273985856

mlir/bindings/python/pybind.cpp
mlir/bindings/python/test/test_py2and3.py

index 3e4f5db..399adf4 100644 (file)
@@ -692,6 +692,11 @@ PYBIND11_MODULE(pybind, m) {
                             falseArguments);
         return PythonValueHandle(nullptr);
       });
+  m.def("index_cast",
+        [](PythonValueHandle element, PythonType type) -> PythonValueHandle {
+          return ValueHandle::create<IndexCastOp>(
+              element.value, Type::getFromOpaquePointer(type.type));
+        });
   m.def("select",
         [](PythonValueHandle condition, PythonValueHandle trueValue,
            PythonValueHandle falseValue) -> PythonValueHandle {
index c658c94..12013b1 100644 (file)
@@ -297,6 +297,15 @@ class EdscTest:
     #       CHECK: func @foo_0()
     #       CHECK: %{{.*}} = constant 0 : index
 
+  def testIndexCast(self):
+    self.setUp()
+    with self.module.function_context("testIndexCast", [], []):
+      index = E.constant_index(0)
+      E.index_cast(index, self.module.make_scalar_type("i", 32))
+    printWithCurrentFunctionName(str(self.module))
+    # CHECK-LABEL: testIndexCast
+    #       CHECK: index_cast %{{.*}} : index to i32
+
   def testIndexedValue(self):
     self.setUp()
     memrefType = self.module.make_memref_type(self.f32Type, [10, 42])