validate: Use a single TCPServer for subprocess communication
authorEdward Hervey <edward@centricular.com>
Sun, 3 Dec 2017 09:42:49 +0000 (10:42 +0100)
committerEdward Hervey <bilboed@bilboed.com>
Sun, 3 Dec 2017 11:23:29 +0000 (12:23 +0100)
Instead of creating a separate TCPServer for each test, just create
one which handles all connections in a threaded fashion.

Shaves off ~500ms per test

https://bugzilla.gnome.org/show_bug.cgi?id=791159

validate/gst/validate/gst-validate-report.c
validate/launcher/baseclasses.py

index 3d17b69..e773f75 100644 (file)
@@ -452,7 +452,7 @@ done:
 void
 gst_validate_report_init (void)
 {
-  const gchar *var, *file_env, *server_env;
+  const gchar *var, *file_env, *server_env, *uuid;
   const GDebugKey keys[] = {
     {"fatal_criticals", GST_VALIDATE_FATAL_CRITICALS},
     {"fatal_warnings", GST_VALIDATE_FATAL_WARNINGS},
@@ -481,7 +481,11 @@ gst_validate_report_init (void)
   }
 
   server_env = g_getenv ("GST_VALIDATE_SERVER");
-  if (server_env) {
+  uuid = g_getenv ("GST_VALIDATE_UUID");
+
+  if (server_env && !uuid) {
+    GST_ERROR ("No GST_VALIDATE_UUID specified !");
+  } else if (server_env) {
     GstUri *server_uri = gst_uri_from_string (server_env);
 
     if (server_uri && !g_strcmp0 (gst_uri_get_scheme (server_uri), "tcp")) {
@@ -502,6 +506,8 @@ gst_validate_report_init (void)
             g_io_stream_get_output_stream (G_IO_STREAM (server_connection));
         jbuilder = json_builder_new ();
         json_builder_begin_object (jbuilder);
+        json_builder_set_member_name (jbuilder, "uuid");
+        json_builder_add_string_value (jbuilder, uuid);
         json_builder_set_member_name (jbuilder, "started");
         json_builder_add_boolean_value (jbuilder, TRUE);
         json_builder_end_object (jbuilder);
index a41c343..1fe340d 100644 (file)
@@ -36,6 +36,7 @@ import queue
 import configparser
 import xml
 import random
+import uuid
 
 from . import reporters
 from . import loggable
@@ -95,6 +96,7 @@ class Test(Loggable):
         self.queue = None
         self.duration = duration
         self.stack_trace = None
+        self._uuid = None
         if expected_failures is None:
             self.expected_failures = []
         elif not isinstance(expected_failures, list):
@@ -208,6 +210,11 @@ class Test(Loggable):
     def get_name(self):
         return self.classname.split('.')[-1]
 
+    def get_uuid(self):
+        if self._uuid is None:
+            self._uuid = self.classname + str(uuid.uuid4())
+        return self._uuid
+
     def add_arguments(self, *args):
         self.command += args
 
@@ -514,10 +521,13 @@ class Test(Loggable):
 
         return self.result
 
+class GstValidateTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
+    pass
 
 class GstValidateListener(socketserver.BaseRequestHandler):
     def handle(self):
         """Implements BaseRequestHandler handle method"""
+        test = None
         while True:
             raw_len = self.request.recv(4)
             if raw_len == b'':
@@ -528,7 +538,20 @@ class GstValidateListener(socketserver.BaseRequestHandler):
                 return
 
             obj = json.loads(msg)
-            test = getattr(self.server, "test")
+
+            if test is None:
+                # First message must contain the uuid
+                uuid = obj.get("uuid", None)
+                if uuid is None:
+                    return
+                # Find test from launcher
+                for t in self.server.launcher.tests:
+                    if uuid == t.get_uuid():
+                        test = t
+                        break
+                if test is None:
+                    self.server.launcher.error("Could not find test for UUID %s" % uuid)
+                    return
 
             obj_type = obj.get("type", '')
             if obj_type == 'position':
@@ -617,16 +640,8 @@ class GstValidateTest(Test):
         else:
             self.scenario = scenario
 
-    def stop_server(self):
-        if self.server:
-            self.server.shutdown()
-            self.server_thread.join()
-            self.server.server_close()
-            self.server = None
-
     def kill_subprocess(self):
         Test.kill_subprocess(self)
-        self.stop_server()
 
     def add_report(self, report):
         self.reports.append(report)
@@ -642,31 +657,6 @@ class GstValidateTest(Test):
             self._sent_eos_time = time.time()
         self.actions_infos.append(action_infos)
 
-    def server_wrapper(self, ready):
-        self.server = socketserver.TCPServer(('localhost', 0), GstValidateListener)
-        self.server.socket.settimeout(None)
-        self.server.test = self
-        self.serverport = self.server.socket.getsockname()[1]
-        self.info("%s server port: %s" % (self, self.serverport))
-        ready.set()
-
-        self.server.serve_forever()
-
-    def test_start(self, queue):
-        ready = threading.Event()
-        self.server_thread = threading.Thread(target=self.server_wrapper,
-                                              kwargs={'ready': ready})
-        self.server_thread.start()
-        ready.wait()
-
-        Test.test_start(self, queue)
-
-    def test_end(self):
-        res = Test.test_end(self)
-        self.stop_server()
-
-        return res
-
     def get_override_file(self, media_descriptor):
         if media_descriptor:
             if media_descriptor.get_path():
@@ -701,7 +691,7 @@ class GstValidateTest(Test):
     def get_subproc_env(self):
         subproc_env = os.environ.copy()
 
-        subproc_env["GST_VALIDATE_SERVER"] = "tcp://localhost:%s" % self.serverport
+        subproc_env["GST_VALIDATE_UUID"] = self.get_uuid()
 
         if 'GST_DEBUG' in os.environ and not self.options.redirect_logs:
             gstlogsfile = self.logfile + '.gstdebug'
@@ -1294,6 +1284,7 @@ class _TestsLauncher(Loggable):
         self.queue = queue.Queue()
         self.jobs = []
         self.total_num_tests = 0
+        self.server = None
 
     def _list_app_dirs(self):
         app_dirs = []
@@ -1551,6 +1542,32 @@ class _TestsLauncher(Loggable):
         cur_test_num = self.tests.index(test) + 1
         sys.stdout.write("[%d / %d] " % (cur_test_num, self.total_num_tests))
 
+    def server_wrapper(self, ready):
+        self.server = GstValidateTCPServer(('localhost', 0), GstValidateListener)
+        self.server.socket.settimeout(None)
+        self.server.launcher = self
+        self.serverport = self.server.socket.getsockname()[1]
+        self.info("%s server port: %s" % (self, self.serverport))
+        ready.set()
+
+        self.server.serve_forever(poll_interval=0.05)
+
+    def _start_server(self):
+        self.info("Starting TCP Server")
+        ready = threading.Event()
+        self.server_thread = threading.Thread(target=self.server_wrapper,
+                                              kwargs={'ready': ready})
+        self.server_thread.start()
+        ready.wait()
+        os.environ["GST_VALIDATE_SERVER"] = "tcp://localhost:%s" % self.serverport
+
+    def _stop_server(self):
+        if self.server:
+            self.server.shutdown()
+            self.server_thread.join()
+            self.server.server_close()
+            self.server = None
+
     def test_wait(self):
         while True:
             # Check process every second for timeout
@@ -1640,8 +1657,10 @@ class _TestsLauncher(Loggable):
     def clean_tests(self):
         for test in self.tests:
             test.clean()
+        self._stop_server()
 
     def run_tests(self):
+        self._start_server()
         if self.options.forever:
             r = 1
             while True: