[lldb] Accept negative indexes in __getitem__
authorDave Lee <davelee.com@gmail.com>
Fri, 3 Feb 2023 16:45:44 +0000 (08:45 -0800)
committerDave Lee <davelee.com@gmail.com>
Wed, 8 Feb 2023 18:46:26 +0000 (10:46 -0800)
To the Python bindings, add support for Python-like negative indexes.

While was using `script`, I tried to access a thread's bottom frame with
`thread.frame[-1]`, but that failed. This change updates the `__getitem__`
implementations to support negative indexes as one would expect in Python.

Differential Revision: https://reviews.llvm.org/D143282

13 files changed:
lldb/bindings/interface/SBBreakpoint.i
lldb/bindings/interface/SBInstructionList.i
lldb/bindings/interface/SBModule.i
lldb/bindings/interface/SBProcess.i
lldb/bindings/interface/SBSymbolContextList.i
lldb/bindings/interface/SBTarget.i
lldb/bindings/interface/SBThread.i
lldb/bindings/interface/SBTypeCategory.i
lldb/bindings/interface/SBTypeEnumMember.i
lldb/bindings/interface/SBValue.i
lldb/bindings/interface/SBValueList.i
lldb/test/API/python_api/breakpoint/TestBreakpointAPI.py
lldb/test/API/python_api/thread/TestThreadAPI.py

index a704830..a61874d 100644 (file)
@@ -273,8 +273,11 @@ public:
                 return 0
 
             def __getitem__(self, key):
-                if type(key) is int and key < len(self):
-                    return self.sbbreakpoint.GetLocationAtIndex(key)
+                if isinstance(key, int):
+                    count = len(self)
+                    if -count <= key < count:
+                        key %= count
+                        return self.sbbreakpoint.GetLocationAtIndex(key)
                 return None
 
         def get_locations_access_object(self):
index b51c037..e80452e 100644 (file)
@@ -83,7 +83,9 @@ public:
             '''Access instructions by integer index for array access or by lldb.SBAddress to find an instruction that matches a section offset address object.'''
             if type(key) is int:
                 # Find an instruction by index
-                if key < len(self):
+                count = len(self)
+                if -count <= key < count:
+                    key %= count
                     return self.GetInstructionAtIndex(key)
             elif type(key) is SBAddress:
                 # Find an instruction using a lldb.SBAddress object
index de476f7..f181d96 100644 (file)
@@ -415,7 +415,8 @@ public:
             def __getitem__(self, key):
                 count = len(self)
                 if type(key) is int:
-                    if key < count:
+                    if -count <= key < count:
+                        key %= count
                         return self.sbmodule.GetSymbolAtIndex(key)
                 elif type(key) is str:
                     matches = []
@@ -476,7 +477,8 @@ public:
             def __getitem__(self, key):
                 count = len(self)
                 if type(key) is int:
-                    if key < count:
+                    if -count <= key < count:
+                        key %= count
                         return self.sbmodule.GetSectionAtIndex(key)
                 elif type(key) is str:
                     for idx in range(count):
@@ -511,7 +513,8 @@ public:
             def __getitem__(self, key):
                 count = len(self)
                 if type(key) is int:
-                    if key < count:
+                    if -count <= key < count:
+                        key %= count
                         return self.sbmodule.GetCompileUnitAtIndex(key)
                 elif type(key) is str:
                     is_full_path = key[0] == '/'
index 0ef5584..01da427 100644 (file)
@@ -487,8 +487,11 @@ public:
                 return 0
 
             def __getitem__(self, key):
-                if type(key) is int and key < len(self):
-                    return self.sbprocess.GetThreadAtIndex(key)
+                if isinstance(key, int):
+                    count = len(self)
+                    if -count <= key < count:
+                        key %= count
+                        return self.sbprocess.GetThreadAtIndex(key)
                 return None
 
         def get_threads_access_object(self):
index 14566b3..7bbeed7 100644 (file)
@@ -74,8 +74,9 @@ public:
 
         def __getitem__(self, key):
             count = len(self)
-            if type(key) is int:
-                if key < count:
+            if isinstance(key, int):
+                if -count <= key < count:
+                    key %= count
                     return self.GetContextAtIndex(key)
                 else:
                     raise IndexError
index e887762..6529f8f 100644 (file)
@@ -1001,7 +1001,8 @@ public:
             def __getitem__(self, key):
                 num_modules = self.sbtarget.GetNumModules()
                 if type(key) is int:
-                    if key < num_modules:
+                    if -num_modules <= key < num_modules:
+                        key %= num_modules
                         return self.sbtarget.GetModuleAtIndex(key)
                 elif type(key) is str:
                     if key.find('/') == -1:
index 1e46bd6..8317e17 100644 (file)
@@ -434,8 +434,11 @@ public:
                 return 0
 
             def __getitem__(self, key):
-                if type(key) is int and key < self.sbthread.GetNumFrames():
-                    return self.sbthread.GetFrameAtIndex(key)
+                if isinstance(key, int):
+                    count = len(self)
+                    if -count <= key < count:
+                        key %= count
+                        return self.sbthread.GetFrameAtIndex(key)
                 return None
 
         def get_frames_access_object(self):
index b762bf8..f8af390 100644 (file)
@@ -147,7 +147,8 @@ namespace lldb {
                 def __getitem__(self, key):
                     num_items = len(self)
                     if type(key) is int:
-                        if key < num_items:
+                        if -num_items <= key < num_items:
+                            key %= num_items
                             return self.get_at_index_function(self.sbcategory,key)
                     elif type(key) is str:
                         return self.get_by_name_function(self.sbcategory,SBTypeNameSpecifier(key))
index b419010..986cb87 100644 (file)
@@ -121,7 +121,8 @@ public:
         def __getitem__(self, key):
           num_elements = self.GetSize()
           if type(key) is int:
-              if key < num_elements:
+              if -num_elements <= key < num_elements:
+                  key %= num_elements
                   return self.GetTypeEnumMemberAtIndex(key)
           elif type(key) is str:
               for idx in range(num_elements):
index bc66a4a..335667f 100644 (file)
@@ -459,8 +459,11 @@ public:
                 return 0
 
             def __getitem__(self, key):
-                if type(key) is int and key < len(self):
-                    return self.sbvalue.GetChildAtIndex(key)
+                if isinstance(key, int):
+                    count = len(self)
+                    if -count <= key < count:
+                        key %= count
+                        return self.sbvalue.GetChildAtIndex(key)
                 return None
 
         def get_child_access_object(self):
index e03b5c6..4488b5d 100644 (file)
@@ -146,7 +146,8 @@ public:
             # Access with "int" to get Nth item in the list
             #------------------------------------------------------------
             if type(key) is int:
-                if key < count:
+                if -count <= key < count:
+                    key %= count
                     return self.GetValueAtIndex(key)
             #------------------------------------------------------------
             # Access with "str" to get values by name
index 014065e..da6c895 100644 (file)
@@ -62,6 +62,9 @@ class BreakpointAPITestCase(TestBase):
         location = breakpoint.GetLocationAtIndex(0)
         self.assertTrue(location.IsValid())
 
+        # Test negative index access.
+        self.assertTrue(breakpoint.location[-1].IsValid())
+
         # Make sure the breakpoint's target is right:
         self.assertEqual(target, breakpoint.GetTarget(), "Breakpoint reports its target correctly")
         
index 7bdcf36..b26d168 100644 (file)
@@ -48,6 +48,11 @@ class ThreadAPITestCase(TestBase):
         self.setTearDownCleanup(dictionary=d)
         self.step_over_3_times(self.exe_name)
 
+    def test_negative_indexing(self):
+        """Test SBThread.frame with negative indexes."""
+        self.build()
+        self.validate_negative_indexing()
+
     def setUp(self):
         # Call super's setUp().
         TestBase.setUp(self)
@@ -269,3 +274,29 @@ class ThreadAPITestCase(TestBase):
         thread.RunToAddress(start_addr)
         self.runCmd("process status")
         #self.runCmd("thread backtrace")
+
+    def validate_negative_indexing(self):
+        exe = self.getBuildArtifact("a.out")
+
+        target = self.dbg.CreateTarget(exe)
+        self.assertTrue(target, VALID_TARGET)
+
+        breakpoint = target.BreakpointCreateByLocation(
+            "main.cpp", self.break_line)
+        self.assertTrue(breakpoint, VALID_BREAKPOINT)
+        self.runCmd("breakpoint list")
+
+        # Launch the process, and do not stop at the entry point.
+        process = target.LaunchSimple(
+            None, None, self.get_process_working_directory())
+
+        thread = get_stopped_thread(process, lldb.eStopReasonBreakpoint)
+        self.assertTrue(
+            thread.IsValid(),
+            "There should be a thread stopped due to breakpoint")
+        self.runCmd("process status")
+
+        pos_range = range(thread.num_frames)
+        neg_range = range(thread.num_frames, 0, -1)
+        for pos, neg in zip(pos_range, neg_range):
+            self.assertEqual(thread.frame[pos].idx, thread.frame[-neg].idx)