Adding gst-python package
[platform/upstream/gst-python.git] / testsuite / common.py
1 # -*- Mode: Python; py-indent-offset: 4 -*-
2 # vim: tabstop=4 shiftwidth=4 expandtab
3 #
4 # Copyright (C) 2015 Thibault Saunier <thibault.saunier@collabora.com>
5 #
6 # This program is free software; you can redistribute it and/or
7 # modify it under the terms of the GNU Lesser General Public
8 # License as published by the Free Software Foundation; either
9 # version 2.1 of the License, or (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # Lesser General Public License for more details.
15 #
16 # You should have received a copy of the GNU Lesser General Public
17 # License along with this program; if not, write to the
18 # Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
19 # Boston, MA 02110-1301, USA.
20 # This program is free software; you can redistribute it and/or modify
21 # it under the terms of the GNU General Public License as published by
22 # the Free Software Foundation; either version 3, or (at your option)
23 # any later version.
24 """
25 A collection of objects to use for testing
26
27 Copyied from pitivi
28 """
29
30 import os
31 import gc
32 import unittest
33 import gi.overrides
34 gi.overrides
35
36 from gi.repository import Gst
37
38
39 detect_leaks = os.environ.get("TEST_DETECT_LEAKS", "1") not in ("0", "")
40
41
42 class TestCase(unittest.TestCase):
43     _tracked_types = (Gst.MiniObject, Gst.Element, Gst.Pad, Gst.Caps)
44
45     def gctrack(self):
46         self.gccollect()
47         self._tracked = []
48         for obj in gc.get_objects():
49             if not isinstance(obj, self._tracked_types):
50                 continue
51
52             self._tracked.append(obj)
53
54     def gccollect(self):
55         ret = 0
56         while True:
57             c = gc.collect()
58             ret += c
59             if c == 0:
60                 break
61         return ret
62
63     def gcverify(self):
64         leaked = []
65         for obj in gc.get_objects():
66             if not isinstance(obj, self._tracked_types) or \
67                     obj in self._tracked:
68                 continue
69
70             leaked.append(obj)
71
72         # we collect again here to get rid of temporary objects created in the
73         # above loop
74         self.gccollect()
75
76         for elt in leaked:
77             print(elt)
78             for i in gc.get_referrers(elt):
79                 print("   ", i)
80
81         self.assertFalse(leaked, leaked)
82         del self._tracked
83
84     def setUp(self):
85         self._num_failures = len(getattr(self._result, 'failures', []))
86         self._num_errors = len(getattr(self._result, 'errors', []))
87         if detect_leaks:
88             self.gctrack()
89
90     def tearDown(self):
91         # don't barf gc info all over the console if we have already failed a
92         # test case
93         if (self._num_failures < len(getattr(self._result, 'failures', []))
94            or self._num_errors < len(getattr(self._result, 'failures', []))):
95             return
96         if detect_leaks:
97             self.gccollect()
98             self.gcverify()
99
100     # override run() to save a reference to the test result object
101     def run(self, result=None):
102         if not result:
103             result = self.defaultTestResult()
104         self._result = result
105         unittest.TestCase.run(self, result)
106
107
108 class SignalMonitor(object):
109
110     def __init__(self, obj, *signals):
111         self.signals = signals
112         self.connectToObj(obj)
113
114     def connectToObj(self, obj):
115         self.obj = obj
116         for signal in self.signals:
117             obj.connect(signal, self._signalCb, signal)
118             setattr(self, self._getSignalCounterName(signal), 0)
119             setattr(self, self._getSignalCollectName(signal), [])
120
121     def disconnectFromObj(self, obj):
122         obj.disconnect_by_func(self._signalCb)
123         del self.obj
124
125     def _getSignalCounterName(self, signal):
126         field = '%s_count' % signal.replace('-', '_')
127         return field
128
129     def _getSignalCollectName(self, signal):
130         field = '%s_collect' % signal.replace('-', '_')
131         return field
132
133     def _signalCb(self, obj, *args):
134         name = args[-1]
135         field = self._getSignalCounterName(name)
136         setattr(self, field, getattr(self, field, 0) + 1)
137         field = self._getSignalCollectName(name)
138         setattr(self, field, getattr(self, field, []) + [args[:-1]])