# Copyright (c) 2016 Cisco and/or its affiliates.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at:## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Library for SSH connection management."""importStringIOfromtimeimporttime,sleepimportsocketimportparamikofromparamikoimportRSAKeyfromparamiko.ssh_exceptionimportSSHExceptionfromscpimportSCPClientfromrobot.apiimportloggerfromrobot.utils.assertsimportassert_equal__all__=["exec_cmd","exec_cmd_no_error"]# TODO: load priv keyclassSSHTimeout(Exception):"""This exception is raised when a timeout occurs."""passclassSSH(object):"""Contains methods for managing and using SSH connections."""__MAX_RECV_BUF=10*1024*1024__existing_connections={}def__init__(self):self._ssh=Noneself._node=None@staticmethoddef_node_hash(node):"""Get IP address and port hash from node dictionary. :param node: Node in topology. :type node: dict :return: IP address and port for the specified node. :rtype: int """returnhash(frozenset([node['host'],node['port']]))defconnect(self,node,attempts=5):"""Connect to node prior to running exec_command or scp. If there already is a connection to the node, this method reuses it. """try:self._node=nodenode_hash=self._node_hash(node)ifnode_hashinSSH.__existing_connections:self._ssh=SSH.__existing_connections[node_hash]logger.debug('reusing ssh: {0}'.format(self._ssh))else:start=time()pkey=Noneif'priv_key'innode:pkey=RSAKey.from_private_key(StringIO.StringIO(node['priv_key']))self._ssh=paramiko.SSHClient()self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())self._ssh.connect(node['host'],username=node['username'],password=node.get('password'),pkey=pkey,port=node['port'])self._ssh.get_transport().set_keepalive(10)SSH.__existing_connections[node_hash]=self._sshlogger.trace('connect took {} seconds'.format(time()-start))logger.debug('new ssh: {0}'.format(self._ssh))logger.debug('Connect peer: {0}'.format(self._ssh.get_transport().getpeername()))logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))except:ifattempts>0:self._reconnect(attempts-1)else:raisedefdisconnect(self,node):"""Close SSH connection to the node. :param node: The node to disconnect from. :type node: dict """node_hash=self._node_hash(node)ifnode_hashinSSH.__existing_connections:logger.debug('Disconnecting peer: {}, {}'.format(node['host'],node['port']))ssh=SSH.__existing_connections.pop(node_hash)ssh.close()def_reconnect(self,attempts=0):"""Close the SSH connection and open it again."""node=self._nodeself.disconnect(node)self.connect(node,attempts)logger.debug('Reconnecting peer done: {}'.format(self._ssh.get_transport().getpeername()))defexec_command(self,cmd,timeout=10):"""Execute SSH command on a new channel on the connected Node. :param cmd: Command to run on the Node. :param timeout: Maximal time in seconds to wait until the command is done. If set to None then wait forever. :type cmd: str :type timeout: int :return return_code, stdout, stderr :rtype: tuple(int, str, str) :raise SSHTimeout: If command is not finished in timeout time. """start=time()stdout=StringIO.StringIO()stderr=StringIO.StringIO()try:chan=self._ssh.get_transport().open_session(timeout=5)exceptAttributeError:self._reconnect()chan=self._ssh.get_transport().open_session(timeout=5)exceptSSHException:self._reconnect()chan=self._ssh.get_transport().open_session(timeout=5)chan.settimeout(timeout)logger.trace('exec_command on {0}: {1}'.format(self._ssh.get_transport().getpeername(),cmd))chan.exec_command(cmd)whilenotchan.exit_status_ready()andtimeoutisnotNone:ifchan.recv_ready():stdout.write(chan.recv(self.__MAX_RECV_BUF))ifchan.recv_stderr_ready():stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))iftime()-start>timeout:raiseSSHTimeout('Timeout exception.\n''Current contents of stdout buffer: {0}\n''Current contents of stderr buffer: {1}\n'.format(stdout.getvalue(),stderr.getvalue()))sleep(0.1)return_code=chan.recv_exit_status()whilechan.recv_ready():stdout.write(chan.recv(self.__MAX_RECV_BUF))whilechan.recv_stderr_ready():stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))end=time()logger.trace('exec_command on {0} took {1} seconds'.format(self._ssh.get_transport().getpeername(),end-start))logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))logger.trace('return RC {}'.format(return_code))logger.trace('return STDOUT {}'.format(stdout.getvalue()))logger.trace('return STDERR {}'.format(stderr.getvalue()))returnreturn_code,stdout.getvalue(),stderr.getvalue()defexec_command_sudo(self,cmd,cmd_input=None,timeout=30):"""Execute SSH command with sudo on a new channel on the connected Node. :param cmd: Command to be executed. :param cmd_input: Input redirected to the command. :param timeout: Timeout. :return: return_code, stdout, stderr :Example: >>> from ssh import SSH >>> ssh = SSH() >>> ssh.connect(node) >>> # Execute command without input (sudo -S cmd) >>> ssh.exec_command_sudo("ifconfig eth0 down") >>> # Execute command with input (sudo -S cmd <<< "input") >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table") """ifcmd_inputisNone:command='sudo -S {c}'.format(c=cmd)else:command='sudo -S {c} <<< "{i}"'.format(c=cmd,i=cmd_input)returnself.exec_command(command,timeout)defexec_command_lxc(self,lxc_cmd,lxc_name,lxc_params='',sudo=True,timeout=30):"""Execute command in LXC on a new SSH channel on the connected Node. :param lxc_cmd: Command to be executed. :param lxc_name: LXC name. :param lxc_params: Additional parameters for LXC attach. :param sudo: Run in privileged LXC mode. Default: privileged :param timeout: Timeout. :type lxc_cmd: str :type lxc_name: str :type lxc_params: str :type sudo: bool :type timeout: int :return: return_code, stdout, stderr """command="lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
.format(p=lxc_params,n=lxc_name,c=lxc_cmd)ifsudo:command='sudo -S {c}'.format(c=command)returnself.exec_command(command,timeout)definteractive_terminal_open(self,time_out=30):"""Open interactive terminal on a new channel on the connected Node. :param time_out: Timeout in seconds. :return: SSH channel with opened terminal. .. warning:: Interruptingcow is used here, and it uses signal(SIGALRM) to let the operating system interrupt program execution. This has the following limitations: Python signal handlers only apply to the main thread, so you cannot use this from other threads. You must not use this in a program that uses SIGALRM itself (this includes certain profilers) """chan=self._ssh.get_transport().open_session()chan.get_pty()chan.invoke_shell()chan.settimeout(int(time_out))chan.set_combine_stderr(True)buf=''whilenotbuf.endswith((":~$ ","~]$ ")):try:chunk=chan.recv(self.__MAX_RECV_BUF)ifnotchunk:breakbuf+=chunkifchan.exit_status_ready():logger.error('Channel exit status ready')breakexceptsocket.timeout:raiseException('Socket timeout: {0}'.format(buf))returnchandefinteractive_terminal_exec_command(self,chan,cmd,prompt):"""Execute command on interactive terminal. interactive_terminal_open() method has to be called first! :param chan: SSH channel with opened terminal. :param cmd: Command to be executed. :param prompt: Command prompt, sequence of characters used to indicate readiness to accept commands. :return: Command output. .. warning:: Interruptingcow is used here, and it uses signal(SIGALRM) to let the operating system interrupt program execution. This has the following limitations: Python signal handlers only apply to the main thread, so you cannot use this from other threads. You must not use this in a program that uses SIGALRM itself (this includes certain profilers) """chan.sendall('{c}\n'.format(c=cmd))buf=''whilenotbuf.endswith(prompt):try:chunk=chan.recv(self.__MAX_RECV_BUF)ifnotchunk:breakbuf+=chunkifchan.exit_status_ready():logger.error('Channel exit status ready')breakexceptsocket.timeout:raiseException('Socket timeout: {0}'.format(buf))tmp=buf.replace(cmd.replace('\n',''),'')foriteminprompt:tmp.replace(item,'')returntmp@staticmethoddefinteractive_terminal_close(chan):"""Close interactive terminal SSH channel. :param: chan: SSH channel to be closed. """chan.close()defscp(self,local_path,remote_path,get=False):"""Copy files from local_path to remote_path or vice versa. connect() method has to be called first! :param local_path: Path to local file that should be uploaded; or path where to save remote file. :param remote_path: Remote path where to place uploaded file; or path to remote file which should be downloaded. :param get: scp operation to perform. Default is put. :type local_path: str :type remote_path: str :type get: bool """ifnotget:logger.trace('SCP {0} to {1}:{2}'.format(local_path,self._ssh.get_transport().getpeername(),remote_path))else:logger.trace('SCP {0}:{1} to {2}'.format(self._ssh.get_transport().getpeername(),remote_path,local_path))# SCPCLient takes a paramiko transport as its only argumentscp=SCPClient(self._ssh.get_transport(),socket_timeout=10)start=time()ifnotget:scp.put(local_path,remote_path)else:scp.get(remote_path,local_path)scp.close()end=time()logger.trace('SCP took {0} seconds'.format(end-start))defexec_cmd(node,cmd,timeout=600,sudo=False):"""Convenience function to ssh/exec/return rc, out & err. Returns (rc, stdout, stderr). """ifnodeisNone:raiseTypeError('Node parameter is None')ifcmdisNone:raiseTypeError('Command parameter is None')iflen(cmd)==0:raiseValueError('Empty command parameter')ssh=SSH()try:ssh.connect(node)exceptSSHExceptionaserr:logger.error("Failed to connect to node"+str(err))returnNone,None,Nonetry:ifnotsudo:(ret_code,stdout,stderr)=ssh.exec_command(cmd,timeout=timeout)else:(ret_code,stdout,stderr)=ssh.exec_command_sudo(cmd,timeout=timeout)exceptSSHExceptionaserr:logger.error(err)returnNone,None,Nonereturnret_code,stdout,stderrdefexec_cmd_no_error(node,cmd,timeout=600,sudo=False):"""Convenience function to ssh/exec/return out & err. Verifies that return code is zero. Returns (stdout, stderr). """(ret_code,stdout,stderr)=exec_cmd(node,cmd,timeout=timeout,sudo=sudo)assert_equal(ret_code,0,'Command execution failed: "{}"\n{}'.format(cmd,stderr))returnstdout,stderr