[python-bindings] Added support for getting a module's functions, iterating f/b over...
authorMichael Gottesman <mgottesman@apple.com>
Wed, 11 Sep 2013 00:52:47 +0000 (00:52 +0000)
committerMichael Gottesman <mgottesman@apple.com>
Wed, 11 Sep 2013 00:52:47 +0000 (00:52 +0000)
Tests are included as well.

llvm-svn: 190471

llvm/bindings/python/llvm/core.py
llvm/bindings/python/llvm/tests/test_core.py

index fa5486a..3da69d3 100644 (file)
@@ -23,6 +23,8 @@ __all__ = [
     "OpCode",
     "MemoryBuffer",
     "Module",
+    "Value",
+    "Function",
     "Context",
     "PassRegistry"
 ]
@@ -91,6 +93,18 @@ class MemoryBuffer(LLVMObject):
     def __len__(self):
         return lib.LLVMGetBufferSize(self)
 
+class Value(LLVMObject):
+    
+    def __init__(self, value):
+        LLVMObject.__init__(self, value)
+
+    @property
+    def name(self):
+        return lib.LLVMGetValueName(self)
+
+    def dump(self):
+        lib.LLVMDumpValue(self)
+
 class Module(LLVMObject):
     """Represents the top-level structure of an llvm program in an opaque object."""
 
@@ -124,6 +138,42 @@ class Module(LLVMObject):
     def dump(self):
         lib.LLVMDumpModule(self)
 
+    class __function_iterator(object):
+        def __init__(self, module, reverse=False):
+            self.module = module
+            self.reverse = reverse
+            if self.reverse:
+                self.function = self.module.last
+            else:
+                self.function = self.module.first
+        
+        def __iter__(self):
+            return self
+        
+        def next(self):
+            if not isinstance(self.function, Function):
+                raise StopIteration("")
+            result = self.function
+            if self.reverse:
+                self.function = self.function.prev
+            else:
+                self.function = self.function.next
+            return result
+    
+    def __iter__(self):
+        return Module.__function_iterator(self)
+
+    def __reversed__(self):
+        return Module.__function_iterator(self, reverse=True)
+
+    @property
+    def first(self):
+        return Function(lib.LLVMGetFirstFunction(self))
+
+    @property
+    def last(self):
+        return Function(lib.LLVMGetLastFunction(self))
+
     def print_module_to_file(self, filename):
         out = c_char_p(None)
         # Result is inverted so 0 means everything was ok.
@@ -131,6 +181,21 @@ class Module(LLVMObject):
         if result:
             raise RuntimeError("LLVM Error: %s" % out.value)
 
+class Function(Value):
+
+    def __init__(self, value):
+        Value.__init__(self, value)
+    
+    @property
+    def next(self):
+        f = lib.LLVMGetNextFunction(self)
+        return f and Function(f)
+    
+    @property
+    def prev(self):
+        f = lib.LLVMGetPreviousFunction(self)
+        return f and Function(f)
+    
 class Context(LLVMObject):
 
     def __init__(self, context=None):
@@ -241,6 +306,25 @@ def register_library(library):
                                               POINTER(c_char_p)]
     library.LLVMPrintModuleToFile.restype = bool
 
+    library.LLVMGetFirstFunction.argtypes = [Module]
+    library.LLVMGetFirstFunction.restype = c_object_p
+
+    library.LLVMGetLastFunction.argtypes = [Module]
+    library.LLVMGetLastFunction.restype = c_object_p
+
+    library.LLVMGetNextFunction.argtypes = [Function]
+    library.LLVMGetNextFunction.restype = c_object_p
+
+    library.LLVMGetPreviousFunction.argtypes = [Function]
+    library.LLVMGetPreviousFunction.restype = c_object_p
+
+    # Value declarations.
+    library.LLVMGetValueName.argtypes = [Value]
+    library.LLVMGetValueName.restype = c_char_p
+
+    library.LLVMDumpValue.argtypes = [Value]
+    library.LLVMDumpValue.restype = None
+
 def register_enumerations():
     for name, value in enumerations.OpCodes:
         OpCode.register(name, value)
index 07a574e..a1f79a4 100644 (file)
@@ -4,6 +4,7 @@ from ..core import MemoryBuffer
 from ..core import PassRegistry
 from ..core import Context
 from ..core import Module
+from ..bit_reader import parse_bitcode
 
 class TestCore(TestBase):
     def test_opcode(self):
@@ -61,3 +62,19 @@ class TestCore(TestBase):
         m.target = target
         m.print_module_to_file("test2.ll")
     
+    def test_module_function_iteration(self):
+        m = parse_bitcode(MemoryBuffer(filename=self.get_test_bc()))
+        i = 0
+        functions = ["f", "f2", "f3", "f4", "f5", "f6", "g1", "g2", "h1", "h2",
+                     "h3"]
+        # Forward
+        for f in m:
+            self.assertEqual(f.name, functions[i])
+            f.dump()
+            i += 1
+        # Backwards
+        for f in reversed(m):
+            i -= 1
+            self.assertEqual(f.name, functions[i])
+            f.dump()
+