[mlir][python] Add fused location
authorJacques Pienaar <jpienaar@google.com>
Sat, 11 Dec 2021 18:12:29 +0000 (10:12 -0800)
committerJacques Pienaar <jpienaar@google.com>
Sat, 11 Dec 2021 18:16:13 +0000 (10:16 -0800)
mlir/lib/Bindings/Python/IRCore.cpp
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
mlir/test/python/ir/location.py

index cd57552..3640a15 100644 (file)
@@ -47,6 +47,9 @@ static const char kContextGetCallSiteLocationDocstring[] =
 static const char kContextGetFileLocationDocstring[] =
     R"(Gets a Location representing a file, line and column)";
 
+static const char kContextGetFusedLocationDocstring[] =
+    R"(Gets a Location representing a fused location with optional metadata)";
+
 static const char kContextGetNameLocationDocString[] =
     R"(Gets a Location representing a named location with optional child location)";
 
@@ -2198,6 +2201,23 @@ void mlir::python::populateIRCore(py::module &m) {
           py::arg("filename"), py::arg("line"), py::arg("col"),
           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
       .def_static(
+          "fused",
+          [](const std::vector<PyLocation> &pyLocations, llvm::Optional<PyAttribute> metadata,
+             DefaultingPyMlirContext context) {
+            if (pyLocations.empty())
+              throw py::value_error("No locations provided");
+            llvm::SmallVector<MlirLocation, 4> locations;
+            locations.reserve(pyLocations.size());
+            for (auto &pyLocation : pyLocations)
+              locations.push_back(pyLocation.get());
+            MlirLocation location = mlirLocationFusedGet(
+                context->get(), locations.size(), locations.data(),
+                metadata ? metadata->get() : MlirAttribute{0});
+            return PyLocation(context->getRef(), location);
+          },
+          py::arg("locations"), py::arg("metadata") = py::none(),
+          py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
+      .def_static(
           "name",
           [](std::string name, llvm::Optional<PyLocation> childLoc,
              DefaultingPyMlirContext context) {
index e1a84dd..e61e34a 100644 (file)
@@ -658,6 +658,8 @@ class Location:
     @staticmethod
     def file(filename: str, line: int, col: int, context: Optional["Context"] = None) -> "Location": ...
     @staticmethod
+    def fused(locations: Sequence["Location"], metadata: Optional["Attribute"] = None, context: Optional["Context"] = None) -> "Location": ...
+    @staticmethod
     def name(name: str, childLoc: Optional["Location"] = None, context: Optional["Context"] = None) -> "Location": ...
     @staticmethod
     def unknown(context: Optional["Context"] = None) -> Any: ...
index 7bc7b28..1c13c48 100644 (file)
@@ -75,6 +75,27 @@ def testCallSite():
 run(testCallSite)
 
 
+# CHECK-LABEL: TEST: testFused
+def testFused():
+  with Context() as ctx:
+    loc = Location.fused(
+        [Location.name("apple"), Location.name("banana")])
+    attr = Attribute.parse('"sauteed"')
+    loc_attr = Location.fused([Location.name("carrot"),
+                               Location.name("potatoes")], attr)
+  ctx = None
+  # CHECK: file str: loc(fused["apple", "banana"])
+  print("file str:", str(loc))
+  # CHECK: file repr: loc(fused["apple", "banana"])
+  print("file repr:", repr(loc))
+  # CHECK: file str: loc(fused<"sauteed">["carrot", "potatoes"])
+  print("file str:", str(loc_attr))
+  # CHECK: file repr: loc(fused<"sauteed">["carrot", "potatoes"])
+  print("file repr:", repr(loc_attr))
+
+run(testFused)
+
+
 # CHECK-LABEL: TEST: testLocationCapsule
 def testLocationCapsule():
   with Context() as ctx: