Imported Upstream version 7.59.0
[platform/upstream/curl.git] / tests / smbserver.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 #
4 #  Project                     ___| | | |  _ \| |
5 #                             / __| | | | |_) | |
6 #                            | (__| |_| |  _ <| |___
7 #                             \___|\___/|_| \_\_____|
8 #
9 # Copyright (C) 2017, Daniel Stenberg, <daniel@haxx.se>, et al.
10 #
11 # This software is licensed as described in the file COPYING, which
12 # you should have received as part of this distribution. The terms
13 # are also available at https://curl.haxx.se/docs/copyright.html.
14 #
15 # You may opt to use, copy, modify, merge, publish, distribute and/or sell
16 # copies of the Software, and permit persons to whom the Software is
17 # furnished to do so, under the terms of the COPYING file.
18 #
19 # This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
20 # KIND, either express or implied.
21 #
22 """Server for testing SMB"""
23
24 from __future__ import (absolute_import, division, print_function)
25 # unicode_literals)
26 import argparse
27 import ConfigParser
28 import os
29 import sys
30 import logging
31 import tempfile
32
33 # Import our curl test data helper
34 import curl_test_data
35
36 # This saves us having to set up the PYTHONPATH explicitly
37 deps_dir = os.path.join(os.path.dirname(__file__), "python_dependencies")
38 sys.path.append(deps_dir)
39 from impacket import smbserver as imp_smbserver
40 from impacket import smb as imp_smb
41 from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_SUCCESS,
42                                 STATUS_NO_SUCH_FILE)
43
44 log = logging.getLogger(__name__)
45 SERVER_MAGIC = "SERVER_MAGIC"
46 TESTS_MAGIC = "TESTS_MAGIC"
47 VERIFIED_REQ = "verifiedserver"
48 VERIFIED_RSP = b"WE ROOLZ: {pid}\n"
49
50
51 def smbserver(options):
52     """Start up a TCP SMB server that serves forever
53
54     """
55     if options.pidfile:
56         pid = os.getpid()
57         with open(options.pidfile, "w") as f:
58             f.write("{0}".format(pid))
59
60     # Here we write a mini config for the server
61     smb_config = ConfigParser.ConfigParser()
62     smb_config.add_section("global")
63     smb_config.set("global", "server_name", "SERVICE")
64     smb_config.set("global", "server_os", "UNIX")
65     smb_config.set("global", "server_domain", "WORKGROUP")
66     smb_config.set("global", "log_file", "")
67     smb_config.set("global", "credentials_file", "")
68
69     # We need a share which allows us to test that the server is running
70     smb_config.add_section("SERVER")
71     smb_config.set("SERVER", "comment", "server function")
72     smb_config.set("SERVER", "read only", "yes")
73     smb_config.set("SERVER", "share type", "0")
74     smb_config.set("SERVER", "path", SERVER_MAGIC)
75
76     # Have a share for tests.  These files will be autogenerated from the
77     # test input.
78     smb_config.add_section("TESTS")
79     smb_config.set("TESTS", "comment", "tests")
80     smb_config.set("TESTS", "read only", "yes")
81     smb_config.set("TESTS", "share type", "0")
82     smb_config.set("TESTS", "path", TESTS_MAGIC)
83
84     if not options.srcdir or not os.path.isdir(options.srcdir):
85         raise ScriptException("--srcdir is mandatory")
86
87     test_data_dir = os.path.join(options.srcdir, "data")
88
89     smb_server = TestSmbServer(("127.0.0.1", options.port),
90                                config_parser=smb_config,
91                                test_data_directory=test_data_dir)
92     log.info("[SMB] setting up SMB server on port %s", options.port)
93     smb_server.processConfigFile()
94     smb_server.serve_forever()
95     return 0
96
97
98 class TestSmbServer(imp_smbserver.SMBSERVER):
99     """
100     Test server for SMB which subclasses the impacket SMBSERVER and provides
101     test functionality.
102     """
103
104     def __init__(self,
105                  address,
106                  config_parser=None,
107                  test_data_directory=None):
108         imp_smbserver.SMBSERVER.__init__(self,
109                                          address,
110                                          config_parser=config_parser)
111
112         # Set up a test data object so we can get test data later.
113         self.ctd = curl_test_data.TestData(test_data_directory)
114
115         # Override smbComNtCreateAndX so we can pretend to have files which
116         # don't exist.
117         self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
118                             self.create_and_x)
119
120     def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
121         """
122         Our version of smbComNtCreateAndX looks for special test files and
123         fools the rest of the framework into opening them as if they were
124         normal files.
125         """
126         conn_data = smb_server.getConnectionData(conn_id)
127
128         # Wrap processing in a try block which allows us to throw SmbException
129         # to control the flow.
130         try:
131             ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
132                 smb_command["Parameters"])
133
134             path = self.get_share_path(conn_data,
135                                        ncax_parms["RootFid"],
136                                        recv_packet["Tid"])
137             log.info("[SMB] Requested share path: %s", path)
138
139             disposition = ncax_parms["Disposition"]
140             log.debug("[SMB] Requested disposition: %s", disposition)
141
142             # Currently we only support reading files.
143             if disposition != imp_smb.FILE_OPEN:
144                 raise SmbException(STATUS_ACCESS_DENIED,
145                                    "Only support reading files")
146
147             # Check to see if the path we were given is actually a
148             # magic path which needs generating on the fly.
149             if path not in [SERVER_MAGIC, TESTS_MAGIC]:
150                 # Pass the command onto the original handler.
151                 return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
152                                                                     smb_server,
153                                                                     smb_command,
154                                                                     recv_packet)
155
156             flags2 = recv_packet["Flags2"]
157             ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
158                                                      data=smb_command[
159                                                          "Data"])
160             requested_file = imp_smbserver.decodeSMBString(
161                 flags2,
162                 ncax_data["FileName"])
163             log.debug("[SMB] User requested file '%s'", requested_file)
164
165             if path == SERVER_MAGIC:
166                 fid, full_path = self.get_server_path(requested_file)
167             else:
168                 assert (path == TESTS_MAGIC)
169                 fid, full_path = self.get_test_path(requested_file)
170
171             resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
172             resp_data = ""
173
174             # Simple way to generate a fid
175             if len(conn_data["OpenedFiles"]) == 0:
176                 fakefid = 1
177             else:
178                 fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
179             resp_parms["Fid"] = fakefid
180             resp_parms["CreateAction"] = disposition
181
182             if os.path.isdir(path):
183                 resp_parms[
184                     "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
185                 resp_parms["IsDirectory"] = 1
186             else:
187                 resp_parms["IsDirectory"] = 0
188                 resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
189
190             # Get this file's information
191             resp_info, error_code = imp_smbserver.queryPathInformation(
192                 "", full_path, level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
193
194             if error_code != STATUS_SUCCESS:
195                 raise SmbException(error_code, "Failed to query path info")
196
197             resp_parms["CreateTime"] = resp_info["CreationTime"]
198             resp_parms["LastAccessTime"] = resp_info[
199                 "LastAccessTime"]
200             resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
201             resp_parms["LastChangeTime"] = resp_info[
202                 "LastChangeTime"]
203             resp_parms["FileAttributes"] = resp_info[
204                 "ExtFileAttributes"]
205             resp_parms["AllocationSize"] = resp_info[
206                 "AllocationSize"]
207             resp_parms["EndOfFile"] = resp_info["EndOfFile"]
208
209             # Let's store the fid for the connection
210             # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
211             conn_data["OpenedFiles"][fakefid] = {}
212             conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
213             conn_data["OpenedFiles"][fakefid]["FileName"] = path
214             conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
215
216         except SmbException as s:
217             log.debug("[SMB] SmbException hit: %s", s)
218             error_code = s.error_code
219             resp_parms = ""
220             resp_data = ""
221
222         resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
223         resp_cmd["Parameters"] = resp_parms
224         resp_cmd["Data"] = resp_data
225         smb_server.setConnectionData(conn_id, conn_data)
226
227         return [resp_cmd], None, error_code
228
229     def get_share_path(self, conn_data, root_fid, tid):
230         conn_shares = conn_data["ConnectedShares"]
231
232         if tid in conn_shares:
233             if root_fid > 0:
234                 # If we have a rootFid, the path is relative to that fid
235                 path = conn_data["OpenedFiles"][root_fid]["FileName"]
236                 log.debug("RootFid present %s!" % path)
237             else:
238                 if "path" in conn_shares[tid]:
239                     path = conn_shares[tid]["path"]
240                 else:
241                     raise SmbException(STATUS_ACCESS_DENIED,
242                                        "Connection share had no path")
243         else:
244             raise SmbException(imp_smbserver.STATUS_SMB_BAD_TID,
245                                "TID was invalid")
246
247         return path
248
249     def get_server_path(self, requested_filename):
250         log.debug("[SMB] Get server path '%s'", requested_filename)
251
252         if requested_filename not in [VERIFIED_REQ]:
253             raise SmbException(STATUS_NO_SUCH_FILE, "Couldn't find the file")
254
255         fid, filename = tempfile.mkstemp()
256         log.debug("[SMB] Created %s (%d) for storing '%s'",
257                   filename, fid, requested_filename)
258
259         contents = ""
260
261         if requested_filename == VERIFIED_REQ:
262             log.debug("[SMB] Verifying server is alive")
263             contents = VERIFIED_RSP.format(pid=os.getpid())
264
265         self.write_to_fid(fid, contents)
266         return fid, filename
267
268     def write_to_fid(self, fid, contents):
269         # Write the contents to file descriptor
270         os.write(fid, contents)
271         os.fsync(fid)
272
273         # Rewind the file to the beginning so a read gets us the contents
274         os.lseek(fid, 0, os.SEEK_SET)
275
276     def get_test_path(self, requested_filename):
277         log.info("[SMB] Get reply data from 'test%s'", requested_filename)
278
279         fid, filename = tempfile.mkstemp()
280         log.debug("[SMB] Created %s (%d) for storing test '%s'",
281                   filename, fid, requested_filename)
282
283         try:
284             contents = self.ctd.get_test_data(requested_filename)
285             self.write_to_fid(fid, contents)
286             return fid, filename
287
288         except Exception:
289             log.exception("Failed to make test file")
290             raise SmbException(STATUS_NO_SUCH_FILE, "Failed to make test file")
291
292
293 class SmbException(Exception):
294     def __init__(self, error_code, error_message):
295         super(SmbException, self).__init__(error_message)
296         self.error_code = error_code
297
298
299 class ScriptRC(object):
300     """Enum for script return codes"""
301     SUCCESS = 0
302     FAILURE = 1
303     EXCEPTION = 2
304
305
306 class ScriptException(Exception):
307     pass
308
309
310 def get_options():
311     parser = argparse.ArgumentParser()
312
313     parser.add_argument("--port", action="store", default=9017,
314                       type=int, help="port to listen on")
315     parser.add_argument("--verbose", action="store", type=int, default=0,
316                         help="verbose output")
317     parser.add_argument("--pidfile", action="store",
318                         help="file name for the PID")
319     parser.add_argument("--logfile", action="store",
320                         help="file name for the log")
321     parser.add_argument("--srcdir", action="store", help="test directory")
322     parser.add_argument("--id", action="store", help="server ID")
323     parser.add_argument("--ipv4", action="store_true", default=0,
324                         help="IPv4 flag")
325
326     return parser.parse_args()
327
328
329 def setup_logging(options):
330     """
331     Set up logging from the command line options
332     """
333     root_logger = logging.getLogger()
334     add_stdout = False
335
336     formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
337
338     # Write out to a logfile
339     if options.logfile:
340         handler = logging.FileHandler(options.logfile, mode="w")
341         handler.setFormatter(formatter)
342         handler.setLevel(logging.DEBUG)
343         root_logger.addHandler(handler)
344     else:
345         # The logfile wasn't specified. Add a stdout logger.
346         add_stdout = True
347
348     if options.verbose:
349         # Add a stdout logger as well in verbose mode
350         root_logger.setLevel(logging.DEBUG)
351         add_stdout = True
352     else:
353         root_logger.setLevel(logging.INFO)
354
355     if add_stdout:
356         stdout_handler = logging.StreamHandler(sys.stdout)
357         stdout_handler.setFormatter(formatter)
358         stdout_handler.setLevel(logging.DEBUG)
359         root_logger.addHandler(stdout_handler)
360
361
362 if __name__ == '__main__':
363     # Get the options from the user.
364     options = get_options()
365
366     # Setup logging using the user options
367     setup_logging(options)
368
369     # Run main script.
370     try:
371         rc = smbserver(options)
372     except Exception as e:
373         log.exception(e)
374         rc = ScriptRC.EXCEPTION
375
376     log.info("[SMB] Returning %d", rc)
377     sys.exit(rc)