1 # Copyright 2013 The Chromium Authors. All rights reserved.
2 # Use of this source code is governed by a BSD-style license that can be
3 # found in the LICENSE file.
4 """A wrapper around ssh for common operations on a CrOS-based device"""
13 # TODO(nduca): This whole file is built up around making individual ssh calls
14 # for each operation. It really could get away with a single ssh session built
15 # around pexpect, I suspect, if we wanted it to be faster. But, this was
18 def IsRunningOnCrosDevice():
19 """Returns True if we're on a ChromeOS device."""
20 lsb_release = '/etc/lsb-release'
21 if sys.platform.startswith('linux') and os.path.exists(lsb_release):
22 with open(lsb_release, 'r') as f:
24 if res.count('CHROMEOS_RELEASE_NAME'):
28 def RunCmd(args, cwd=None, quiet=False):
29 """Opens a subprocess to execute a program and returns its return value.
32 args: A string or a sequence of program arguments. The program to execute is
33 the string or the first item in the args sequence.
34 cwd: If not None, the subprocess's current directory will be changed to
35 |cwd| before it's executed.
38 Return code from the command execution.
41 logging.debug(' '.join(args) + ' ' + (cwd or ''))
42 with open(os.devnull, 'w') as devnull:
43 p = subprocess.Popen(args=args, cwd=cwd, stdout=devnull,
44 stderr=devnull, stdin=devnull, shell=False)
47 def GetAllCmdOutput(args, cwd=None, quiet=False):
48 """Open a subprocess to execute a program and returns its output.
51 args: A string or a sequence of program arguments. The program to execute is
52 the string or the first item in the args sequence.
53 cwd: If not None, the subprocess's current directory will be changed to
54 |cwd| before it's executed.
57 Captures and returns the command's stdout.
58 Prints the command's stderr to logger (which defaults to stdout).
61 logging.debug(' '.join(args) + ' ' + (cwd or ''))
62 with open(os.devnull, 'w') as devnull:
63 p = subprocess.Popen(args=args, cwd=cwd, stdout=subprocess.PIPE,
64 stderr=subprocess.PIPE, stdin=devnull)
65 stdout, stderr = p.communicate()
67 logging.debug(' > stdout=[%s], stderr=[%s]', stdout, stderr)
72 RunCmd(['ssh'], quiet=True)
73 RunCmd(['scp'], quiet=True)
74 logging.debug("HasSSH()->True")
77 logging.debug("HasSSH()->False")
80 class LoginException(Exception):
83 class KeylessLoginRequiredException(LoginException):
86 class CrOSInterface(object):
87 # pylint: disable=R0923
88 def __init__(self, hostname = None, ssh_identity = None):
89 self._hostname = hostname
90 # List of ports generated from GetRemotePort() that may not be in use yet.
91 self._reserved_ports = []
96 self._ssh_identity = None
97 self._ssh_args = ['-o ConnectTimeout=5',
98 '-o StrictHostKeyChecking=no',
99 '-o KbdInteractiveAuthentication=no',
100 '-o PreferredAuthentications=publickey',
101 '-o UserKnownHostsFile=/dev/null']
104 self._ssh_identity = os.path.abspath(os.path.expanduser(ssh_identity))
108 return not self._hostname
112 return self._hostname
114 def FormSSHCommandLine(self, args, extra_ssh_args=None):
116 # We run the command through the shell locally for consistency with
117 # how commands are run through SSH (crbug.com/239161). This work
118 # around will be unnecessary once we implement a persistent SSH
119 # connection to run remote commands (crbug.com/239607).
120 return ['sh', '-c', " ".join(args)]
124 '-o ForwardX11Trusted=no',
125 '-n'] + self._ssh_args
126 if self._ssh_identity is not None:
127 full_args.extend(['-i', self._ssh_identity])
129 full_args.extend(extra_ssh_args)
130 full_args.append('root@%s' % self._hostname)
131 full_args.extend(args)
134 def _RemoveSSHWarnings(self, toClean):
135 """Removes specific ssh warning lines from a string.
138 toClean: A string that may be containing multiple lines.
141 A copy of toClean with all the Warning lines removed.
143 # Remove the Warning about connecting to a new host for the first time.
144 return re.sub('Warning: Permanently added [^\n]* to the list of known '
145 'hosts.\s\n', '', toClean)
147 def RunCmdOnDevice(self, args, cwd=None, quiet=False):
148 stdout, stderr = GetAllCmdOutput(
149 self.FormSSHCommandLine(args), cwd, quiet=quiet)
150 # The initial login will add the host to the hosts file but will also print
151 # a warning to stderr that we need to remove.
152 stderr = self._RemoveSSHWarnings(stderr)
153 return stdout, stderr
156 logging.debug('TryLogin()')
157 assert not self.local
158 stdout, stderr = self.RunCmdOnDevice(['echo', '$USER'], quiet=True)
160 if 'Host key verification failed' in stderr:
161 raise LoginException(('%s host key verification failed. ' +
162 'SSH to it manually to fix connectivity.') %
164 if 'Operation timed out' in stderr:
165 raise LoginException('Timed out while logging into %s' % self._hostname)
166 if 'UNPROTECTED PRIVATE KEY FILE!' in stderr:
167 raise LoginException('Permissions for %s are too open. To fix this,\n'
168 'chmod 600 %s' % (self._ssh_identity,
170 if 'Permission denied (publickey,keyboard-interactive)' in stderr:
171 raise KeylessLoginRequiredException(
172 'Need to set up ssh auth for %s' % self._hostname)
173 raise LoginException('While logging into %s, got %s' % (
174 self._hostname, stderr))
175 if stdout != 'root\n':
176 raise LoginException(
177 'Logged into %s, expected $USER=root, but got %s.' % (
178 self._hostname, stdout))
180 def FileExistsOnDevice(self, file_name):
182 return os.path.exists(file_name)
184 stdout, stderr = self.RunCmdOnDevice([
185 'if', 'test', '-e', file_name, ';',
186 'then', 'echo', '1', ';',
190 if "Connection timed out" in stderr:
191 raise OSError('Machine wasn\'t responding to ssh: %s' %
193 raise OSError('Unepected error: %s' % stderr)
194 exists = stdout == '1\n'
195 logging.debug("FileExistsOnDevice(<text>, %s)->%s" % (file_name, exists))
198 def PushFile(self, filename, remote_filename):
200 args = ['cp', '-r', filename, remote_filename]
201 stdout, stderr = GetAllCmdOutput(args, quiet=True)
203 raise OSError('No such file or directory %s' % stderr)
206 args = ['scp', '-r' ] + self._ssh_args
207 if self._ssh_identity:
208 args.extend(['-i', self._ssh_identity])
210 args.extend([os.path.abspath(filename),
211 'root@%s:%s' % (self._hostname, remote_filename)])
213 stdout, stderr = GetAllCmdOutput(args, quiet=True)
214 stderr = self._RemoveSSHWarnings(stderr)
216 raise OSError('No such file or directory %s' % stderr)
218 def PushContents(self, text, remote_filename):
219 logging.debug("PushContents(<text>, %s)" % remote_filename)
220 with tempfile.NamedTemporaryFile() as f:
223 self.PushFile(f.name, remote_filename)
225 def GetFile(self, filename, destfile=None):
226 """Copies a local file |filename| to |destfile| on the device.
229 filename: The name of the local source file.
230 destfile: The name of the file to copy to, and if it is not specified
231 then it is the basename of the source file.
234 logging.debug("GetFile(%s, %s)" % (filename, destfile))
236 if destfile is not None and destfile != filename:
237 shutil.copyfile(filename, destfile)
241 destfile = os.path.basename(filename)
242 args = ['scp'] + self._ssh_args
243 if self._ssh_identity:
244 args.extend(['-i', self._ssh_identity])
246 args.extend(['root@%s:%s' % (self._hostname, filename),
247 os.path.abspath(destfile)])
248 stdout, stderr = GetAllCmdOutput(args, quiet=True)
249 stderr = self._RemoveSSHWarnings(stderr)
251 raise OSError('No such file or directory %s' % stderr)
253 def GetFileContents(self, filename):
254 """Get the contents of a file on the device.
257 filename: The name of the file on the device.
260 A string containing the contents of the file.
262 # TODO: handle the self.local case
263 assert not self.local
264 t = tempfile.NamedTemporaryFile()
265 self.GetFile(filename, t.name)
266 with open(t.name, 'r') as f2:
268 logging.debug("GetFileContents(%s)->%s" % (filename, res))
272 def ListProcesses(self):
273 """Returns (pid, cmd, ppid, state) of all processes on the device."""
274 stdout, stderr = self.RunCmdOnDevice([
275 '/bin/ps', '--no-headers',
277 '-o', 'pid,ppid,args:4096,state'], quiet=True)
278 assert stderr == '', stderr
280 for l in stdout.split('\n'): # pylint: disable=E1103
283 m = re.match('^\s*(\d+)\s+(\d+)\s+(.+)\s+(.+)', l, re.DOTALL)
285 procs.append((int(m.group(1)), m.group(3).rstrip(),
286 int(m.group(2)), m.group(4)))
287 logging.debug("ListProcesses(<predicate>)->[%i processes]" % len(procs))
290 def RmRF(self, filename):
291 logging.debug("rm -rf %s" % filename)
292 self.RunCmdOnDevice(['rm', '-rf', filename], quiet=True)
294 def Chown(self, filename):
295 self.RunCmdOnDevice(['chown', '-R', 'chronos:chronos', filename])
297 def KillAllMatching(self, predicate):
298 kills = ['kill', '-KILL']
299 for pid, cmd, _, _ in self.ListProcesses():
301 logging.info('Killing %s, pid %d' % cmd, pid)
303 logging.debug("KillAllMatching(<predicate>)->%i" % (len(kills) - 2))
305 self.RunCmdOnDevice(kills, quiet=True)
306 return len(kills) - 2
308 def IsServiceRunning(self, service_name):
309 stdout, stderr = self.RunCmdOnDevice([
310 'status', service_name], quiet=True)
311 assert stderr == '', stderr
312 running = 'running, process' in stdout
313 logging.debug("IsServiceRunning(%s)->%s" % (service_name, running))
316 def GetRemotePort(self):
317 netstat = self.RunCmdOnDevice(['netstat', '-ant'])
318 netstat = netstat[0].split('\n')
321 for line in netstat[2:]:
324 address_in_use = line.split()[3]
325 port_in_use = address_in_use.split(':')[-1]
326 ports_in_use.append(int(port_in_use))
328 ports_in_use.extend(self._reserved_ports)
330 new_port = sorted(ports_in_use)[-1] + 1
331 self._reserved_ports.append(new_port)
335 def IsHTTPServerRunningOnPort(self, port):
336 wget_output = self.RunCmdOnDevice(
337 ['wget', 'localhost:%i' % (port), '-T1', '-t1'])
339 if 'Connection refused' in wget_output[1]:
344 def FilesystemMountedAt(self, path):
345 """Returns the filesystem mounted at |path|"""
346 df_out, _ = self.RunCmdOnDevice(['/bin/df', path])
347 df_ary = df_out.split('\n')
348 # 3 lines for title, mount info, and empty line.
350 line_ary = df_ary[1].split()
355 def CryptohomePath(self, user):
356 """Returns the cryptohome mount point for |user|."""
357 return self.RunCmdOnDevice(
358 ['cryptohome-path', 'user', "'%s'" % user])[0].strip()
360 def IsCryptohomeMounted(self, username):
361 """Returns True iff |user|'s cryptohome is mounted."""
362 profile_path = self.CryptohomePath(username)
363 mount = self.FilesystemMountedAt(profile_path)
364 mount_prefix = 'guestfs' if username == '$guest' else '/home/.shadow/'
365 return mount and mount.startswith(mount_prefix)
367 def TakeScreenShot(self, screenshot_prefix):
368 """Takes a screenshot, useful for debugging failures."""
369 # TODO(achuith): Find a better location for screenshots. Cros autotests
370 # upload everything in /var/log so use /var/log/screenshots for now.
371 SCREENSHOT_DIR = '/var/log/screenshots/'
372 SCREENSHOT_EXT = '.png'
374 self.RunCmdOnDevice(['mkdir', '-p', SCREENSHOT_DIR])
376 screenshot_file = ('%s%s-%d%s' %
377 (SCREENSHOT_DIR, screenshot_prefix, i, SCREENSHOT_EXT))
378 if not self.FileExistsOnDevice(screenshot_file):
379 self.RunCmdOnDevice([
380 'DISPLAY=:0.0 XAUTHORITY=/home/chronos/.Xauthority '
381 '/usr/local/bin/import',
386 logging.warning('screenshot directory full.')