Disable a debug option
[platform/upstream/curl.git] / tests / smbserver.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 #
4 #  Project                     ___| | | |  _ \| |
5 #                             / __| | | | |_) | |
6 #                            | (__| |_| |  _ <| |___
7 #                             \___|\___/|_| \_\_____|
8 #
9 # Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
10 #
11 # This software is licensed as described in the file COPYING, which
12 # you should have received as part of this distribution. The terms
13 # are also available at https://curl.se/docs/copyright.html.
14 #
15 # You may opt to use, copy, modify, merge, publish, distribute and/or sell
16 # copies of the Software, and permit persons to whom the Software is
17 # furnished to do so, under the terms of the COPYING file.
18 #
19 # This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
20 # KIND, either express or implied.
21 #
22 # SPDX-License-Identifier: curl
23 #
24 """Server for testing SMB"""
25
26 from __future__ import (absolute_import, division, print_function,
27                         unicode_literals)
28
29 import argparse
30 import logging
31 import os
32 import signal
33 import sys
34 import tempfile
35 import threading
36
37 # Import our curl test data helper
38 from util import ClosingFileHandler, TestData
39
40 if sys.version_info.major >= 3:
41     import configparser
42 else:
43     import ConfigParser as configparser
44
45 # impacket needs to be installed in the Python environment
46 try:
47     import impacket
48 except ImportError:
49     sys.stderr.write('Python package impacket needs to be installed!\n')
50     sys.stderr.write('Use pip or your package manager to install it.\n')
51     sys.exit(1)
52 from impacket import smb as imp_smb
53 from impacket import smbserver as imp_smbserver
54 from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE,
55                                 STATUS_SUCCESS)
56
57 log = logging.getLogger(__name__)
58 SERVER_MAGIC = "SERVER_MAGIC"
59 TESTS_MAGIC = "TESTS_MAGIC"
60 VERIFIED_REQ = "verifiedserver"
61 VERIFIED_RSP = "WE ROOLZ: {pid}\n"
62
63
64 class ShutdownHandler(threading.Thread):
65     """Cleanly shut down the SMB server
66
67     This can only be done from another thread while the server is in
68     serve_forever(), so a thread is spawned here that waits for a shutdown
69     signal before doing its thing. Use in a with statement around the
70     serve_forever() call.
71     """
72
73     def __init__(self, server):
74         super(ShutdownHandler, self).__init__()
75         self.server = server
76         self.shutdown_event = threading.Event()
77
78     def __enter__(self):
79         self.start()
80         signal.signal(signal.SIGINT, self._sighandler)
81         signal.signal(signal.SIGTERM, self._sighandler)
82
83     def __exit__(self, *_):
84         # Call for shutdown just in case it wasn't done already
85         self.shutdown_event.set()
86         # Wait for thread, and therefore also the server, to finish
87         self.join()
88         # Uninstall our signal handlers
89         signal.signal(signal.SIGINT, signal.SIG_DFL)
90         signal.signal(signal.SIGTERM, signal.SIG_DFL)
91         # Delete any temporary files created by the server during its run
92         log.info("Deleting %d temporary files", len(self.server.tmpfiles))
93         for f in self.server.tmpfiles:
94             os.unlink(f)
95
96     def _sighandler(self, _signum, _frame):
97         # Wake up the cleanup task
98         self.shutdown_event.set()
99
100     def run(self):
101         # Wait for shutdown signal
102         self.shutdown_event.wait()
103         # Notify the server to shut down
104         self.server.shutdown()
105
106
107 def smbserver(options):
108     """Start up a TCP SMB server that serves forever
109
110     """
111     if options.pidfile:
112         pid = os.getpid()
113         # see tests/server/util.c function write_pidfile
114         if os.name == "nt":
115             pid += 65536
116         with open(options.pidfile, "w") as f:
117             f.write(str(pid))
118
119     # Here we write a mini config for the server
120     smb_config = configparser.ConfigParser()
121     smb_config.add_section("global")
122     smb_config.set("global", "server_name", "SERVICE")
123     smb_config.set("global", "server_os", "UNIX")
124     smb_config.set("global", "server_domain", "WORKGROUP")
125     smb_config.set("global", "log_file", "")
126     smb_config.set("global", "credentials_file", "")
127
128     # We need a share which allows us to test that the server is running
129     smb_config.add_section("SERVER")
130     smb_config.set("SERVER", "comment", "server function")
131     smb_config.set("SERVER", "read only", "yes")
132     smb_config.set("SERVER", "share type", "0")
133     smb_config.set("SERVER", "path", SERVER_MAGIC)
134
135     # Have a share for tests.  These files will be autogenerated from the
136     # test input.
137     smb_config.add_section("TESTS")
138     smb_config.set("TESTS", "comment", "tests")
139     smb_config.set("TESTS", "read only", "yes")
140     smb_config.set("TESTS", "share type", "0")
141     smb_config.set("TESTS", "path", TESTS_MAGIC)
142
143     if not options.srcdir or not os.path.isdir(options.srcdir):
144         raise ScriptException("--srcdir is mandatory")
145
146     test_data_dir = os.path.join(options.srcdir, "data")
147
148     smb_server = TestSmbServer((options.host, options.port),
149                                config_parser=smb_config,
150                                test_data_directory=test_data_dir)
151     log.info("[SMB] setting up SMB server on port %s", options.port)
152     smb_server.processConfigFile()
153
154     # Start a thread that cleanly shuts down the server on a signal
155     with ShutdownHandler(smb_server):
156         # This will block until smb_server.shutdown() is called
157         smb_server.serve_forever()
158
159     return 0
160
161
162 class TestSmbServer(imp_smbserver.SMBSERVER):
163     """
164     Test server for SMB which subclasses the impacket SMBSERVER and provides
165     test functionality.
166     """
167
168     def __init__(self,
169                  address,
170                  config_parser=None,
171                  test_data_directory=None):
172         imp_smbserver.SMBSERVER.__init__(self,
173                                          address,
174                                          config_parser=config_parser)
175         self.tmpfiles = []
176
177         # Set up a test data object so we can get test data later.
178         self.ctd = TestData(test_data_directory)
179
180         # Override smbComNtCreateAndX so we can pretend to have files which
181         # don't exist.
182         self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
183                             self.create_and_x)
184
185     def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
186         """
187         Our version of smbComNtCreateAndX looks for special test files and
188         fools the rest of the framework into opening them as if they were
189         normal files.
190         """
191         conn_data = smb_server.getConnectionData(conn_id)
192
193         # Wrap processing in a try block which allows us to throw SmbException
194         # to control the flow.
195         try:
196             ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
197                 smb_command["Parameters"])
198
199             path = self.get_share_path(conn_data,
200                                        ncax_parms["RootFid"],
201                                        recv_packet["Tid"])
202             log.info("[SMB] Requested share path: %s", path)
203
204             disposition = ncax_parms["Disposition"]
205             log.debug("[SMB] Requested disposition: %s", disposition)
206
207             # Currently we only support reading files.
208             if disposition != imp_smb.FILE_OPEN:
209                 raise SmbException(STATUS_ACCESS_DENIED,
210                                    "Only support reading files")
211
212             # Check to see if the path we were given is actually a
213             # magic path which needs generating on the fly.
214             if path not in [SERVER_MAGIC, TESTS_MAGIC]:
215                 # Pass the command onto the original handler.
216                 return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
217                                                                     smb_server,
218                                                                     smb_command,
219                                                                     recv_packet)
220
221             flags2 = recv_packet["Flags2"]
222             ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
223                                                      data=smb_command[
224                                                          "Data"])
225             requested_file = imp_smbserver.decodeSMBString(
226                 flags2,
227                 ncax_data["FileName"])
228             log.debug("[SMB] User requested file '%s'", requested_file)
229
230             if path == SERVER_MAGIC:
231                 fid, full_path = self.get_server_path(requested_file)
232             else:
233                 assert (path == TESTS_MAGIC)
234                 fid, full_path = self.get_test_path(requested_file)
235
236             self.tmpfiles.append(full_path)
237
238             resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
239             resp_data = ""
240
241             # Simple way to generate a fid
242             if len(conn_data["OpenedFiles"]) == 0:
243                 fakefid = 1
244             else:
245                 fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
246             resp_parms["Fid"] = fakefid
247             resp_parms["CreateAction"] = disposition
248
249             if os.path.isdir(path):
250                 resp_parms[
251                     "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
252                 resp_parms["IsDirectory"] = 1
253             else:
254                 resp_parms["IsDirectory"] = 0
255                 resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
256
257             # Get this file's information
258             resp_info, error_code = imp_smbserver.queryPathInformation(
259                 os.path.dirname(full_path), os.path.basename(full_path),
260                 level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
261
262             if error_code != STATUS_SUCCESS:
263                 raise SmbException(error_code, "Failed to query path info")
264
265             resp_parms["CreateTime"] = resp_info["CreationTime"]
266             resp_parms["LastAccessTime"] = resp_info[
267                 "LastAccessTime"]
268             resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
269             resp_parms["LastChangeTime"] = resp_info[
270                 "LastChangeTime"]
271             resp_parms["FileAttributes"] = resp_info[
272                 "ExtFileAttributes"]
273             resp_parms["AllocationSize"] = resp_info[
274                 "AllocationSize"]
275             resp_parms["EndOfFile"] = resp_info["EndOfFile"]
276
277             # Let's store the fid for the connection
278             # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
279             conn_data["OpenedFiles"][fakefid] = {}
280             conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
281             conn_data["OpenedFiles"][fakefid]["FileName"] = path
282             conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
283
284         except SmbException as s:
285             log.debug("[SMB] SmbException hit: %s", s)
286             error_code = s.error_code
287             resp_parms = ""
288             resp_data = ""
289
290         resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
291         resp_cmd["Parameters"] = resp_parms
292         resp_cmd["Data"] = resp_data
293         smb_server.setConnectionData(conn_id, conn_data)
294
295         return [resp_cmd], None, error_code
296
297     def get_share_path(self, conn_data, root_fid, tid):
298         conn_shares = conn_data["ConnectedShares"]
299
300         if tid in conn_shares:
301             if root_fid > 0:
302                 # If we have a rootFid, the path is relative to that fid
303                 path = conn_data["OpenedFiles"][root_fid]["FileName"]
304                 log.debug("RootFid present %s!" % path)
305             else:
306                 if "path" in conn_shares[tid]:
307                     path = conn_shares[tid]["path"]
308                 else:
309                     raise SmbException(STATUS_ACCESS_DENIED,
310                                        "Connection share had no path")
311         else:
312             raise SmbException(imp_smbserver.STATUS_SMB_BAD_TID,
313                                "TID was invalid")
314
315         return path
316
317     def get_server_path(self, requested_filename):
318         log.debug("[SMB] Get server path '%s'", requested_filename)
319
320         if requested_filename not in [VERIFIED_REQ]:
321             raise SmbException(STATUS_NO_SUCH_FILE, "Couldn't find the file")
322
323         fid, filename = tempfile.mkstemp()
324         log.debug("[SMB] Created %s (%d) for storing '%s'",
325                   filename, fid, requested_filename)
326
327         contents = ""
328
329         if requested_filename == VERIFIED_REQ:
330             log.debug("[SMB] Verifying server is alive")
331             pid = os.getpid()
332             # see tests/server/util.c function write_pidfile
333             if os.name == "nt":
334                 pid += 65536
335             contents = VERIFIED_RSP.format(pid=pid).encode('utf-8')
336
337         self.write_to_fid(fid, contents)
338         return fid, filename
339
340     def write_to_fid(self, fid, contents):
341         # Write the contents to file descriptor
342         os.write(fid, contents)
343         os.fsync(fid)
344
345         # Rewind the file to the beginning so a read gets us the contents
346         os.lseek(fid, 0, os.SEEK_SET)
347
348     def get_test_path(self, requested_filename):
349         log.info("[SMB] Get reply data from 'test%s'", requested_filename)
350
351         fid, filename = tempfile.mkstemp()
352         log.debug("[SMB] Created %s (%d) for storing test '%s'",
353                   filename, fid, requested_filename)
354
355         try:
356             contents = self.ctd.get_test_data(requested_filename).encode('utf-8')
357             self.write_to_fid(fid, contents)
358             return fid, filename
359
360         except Exception:
361             log.exception("Failed to make test file")
362             raise SmbException(STATUS_NO_SUCH_FILE, "Failed to make test file")
363
364
365 class SmbException(Exception):
366     def __init__(self, error_code, error_message):
367         super(SmbException, self).__init__(error_message)
368         self.error_code = error_code
369
370
371 class ScriptRC(object):
372     """Enum for script return codes"""
373     SUCCESS = 0
374     FAILURE = 1
375     EXCEPTION = 2
376
377
378 class ScriptException(Exception):
379     pass
380
381
382 def get_options():
383     parser = argparse.ArgumentParser()
384
385     parser.add_argument("--port", action="store", default=9017,
386                       type=int, help="port to listen on")
387     parser.add_argument("--host", action="store", default="127.0.0.1",
388                       help="host to listen on")
389     parser.add_argument("--verbose", action="store", type=int, default=0,
390                         help="verbose output")
391     parser.add_argument("--pidfile", action="store",
392                         help="file name for the PID")
393     parser.add_argument("--logfile", action="store",
394                         help="file name for the log")
395     parser.add_argument("--srcdir", action="store", help="test directory")
396     parser.add_argument("--id", action="store", help="server ID")
397     parser.add_argument("--ipv4", action="store_true", default=0,
398                         help="IPv4 flag")
399
400     return parser.parse_args()
401
402
403 def setup_logging(options):
404     """
405     Set up logging from the command line options
406     """
407     root_logger = logging.getLogger()
408     add_stdout = False
409
410     formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
411
412     # Write out to a logfile
413     if options.logfile:
414         handler = ClosingFileHandler(options.logfile)
415         handler.setFormatter(formatter)
416         handler.setLevel(logging.DEBUG)
417         root_logger.addHandler(handler)
418     else:
419         # The logfile wasn't specified. Add a stdout logger.
420         add_stdout = True
421
422     if options.verbose:
423         # Add a stdout logger as well in verbose mode
424         root_logger.setLevel(logging.DEBUG)
425         add_stdout = True
426     else:
427         root_logger.setLevel(logging.INFO)
428
429     if add_stdout:
430         stdout_handler = logging.StreamHandler(sys.stdout)
431         stdout_handler.setFormatter(formatter)
432         stdout_handler.setLevel(logging.DEBUG)
433         root_logger.addHandler(stdout_handler)
434
435
436 if __name__ == '__main__':
437     # Get the options from the user.
438     options = get_options()
439
440     # Setup logging using the user options
441     setup_logging(options)
442
443     # Run main script.
444     try:
445         rc = smbserver(options)
446     except Exception as e:
447         log.exception(e)
448         rc = ScriptRC.EXCEPTION
449
450     if options.pidfile and os.path.isfile(options.pidfile):
451         os.unlink(options.pidfile)
452
453     log.info("[SMB] Returning %d", rc)
454     sys.exit(rc)