Skip to content

Commit

Permalink
Added recursive copy functionality to parallel client. Added recursiv…
Browse files Browse the repository at this point in the history
…e copy test for parallel client. Fixed docstring indendations
  • Loading branch information
pkittenis committed Dec 9, 2015
1 parent e685a3e commit 6bde7d0
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 16 deletions.
33 changes: 18 additions & 15 deletions pssh/pssh_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,34 +472,37 @@ def get_stdout(self, greenlet, return_buffers=False):
'stdout' : stdout,
'stderr' : stderr, }}

def copy_file(self, local_file, remote_file):
def copy_file(self, local_file, remote_file, recurse=False):
"""Copy local file to remote file in parallel
:param local_file: Local filepath to copy to remote host
:type local_file: str
:param remote_file: Remote filepath on remote host to copy file to
:type remote_file: str
:param recurse: Whether or not to descend into directories recursively.
:type recurse: bool
:raises: :mod:'ValueError' when a directory is supplied to local_file \
and recurse is not set
.. note ::
Remote directories in `remote_file` that do not exist will be
created as long as permissions allow.
.. note ::
Path separation is handled client side so it is possible to copy
to/from hosts with differing path separators, like from/to Linux
and Windows.
:rtype: List(:mod:`gevent.Greenlet`) of greenlets for remote copy \
commands
"""
return [self.pool.spawn(self._copy_file, host, local_file, remote_file)
return [self.pool.spawn(self._copy_file, host, local_file, remote_file,
{'recurse' : recurse})
for host in self.hosts]

def _copy_file(self, host, local_file, remote_file):
def _copy_file(self, host, local_file, remote_file, recurse=False):
"""Make sftp client, copy file"""
if not self.host_clients[host]:
self.host_clients[host] = SSHClient(host, user=self.user,
password=self.password,
port=self.port, pkey=self.pkey,
forward_ssh_agent=self.forward_ssh_agent)
return self.host_clients[host].copy_file(local_file, remote_file)
self.host_clients[host] = SSHClient(
host, user=self.user, password=self.password,
port=self.port, pkey=self.pkey,
forward_ssh_agent=self.forward_ssh_agent)
return self.host_clients[host].copy_file(local_file, remote_file,
recurse=recurse)
2 changes: 1 addition & 1 deletion pssh/ssh_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def copy_file(self, local_file, remote_file, recurse=False):
:type remote_file: str
:param recurse: Whether or not to descend into directories recursively.
:type recurse: bool
:raises: :mod:'ValueError' when a directory is supplied to local_file \
and recurse is not set
"""
Expand Down
31 changes: 31 additions & 0 deletions tests/test_pssh_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import paramiko
import os
import warnings
import shutil

USER_KEY = paramiko.RSAKey.from_private_key_file(
os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key']))
Expand Down Expand Up @@ -342,6 +343,36 @@ def test_pssh_copy_file(self):
del client
server.join()

def test_pssh_client_directory(self):
"""Tests copying directories with SSH client. Copy all the files from
local directory to server, then make sure they are all present."""
test_file_data = 'test'
local_test_path = 'directory_test'
remote_test_path = 'directory_test_copied'
for path in [local_test_path, remote_test_path]:
try:
shutil.rmtree(path)
except OSError:
pass
os.mkdir(local_test_path)
remote_file_paths = []
for i in range(0, 10):
local_file_path = os.path.join(local_test_path, 'foo' + str(i))
remote_file_path = os.path.join(remote_test_path, 'foo' + str(i))
remote_file_paths.append(remote_file_path)
test_file = open(local_file_path, 'w')
test_file.write(test_file_data)
test_file.close()
client = ParallelSSHClient([self.host], port=self.listen_port,
pkey=self.user_key)
cmds = client.copy_file(local_test_path, remote_test_path, recurse=True)
for cmd in cmds:
cmd.get()
for path in remote_file_paths:
self.assertTrue(os.path.isfile(path))
shutil.rmtree(local_test_path)
shutil.rmtree(remote_test_path)

def test_pssh_pool_size(self):
"""Test pool size logic"""
hosts = ['host-%01d' % d for d in xrange(5)]
Expand Down
10 changes: 10 additions & 0 deletions tests/test_ssh_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def test_ssh_client_directory(self):
test_file_data = 'test'
local_test_path = 'directory_test'
remote_test_path = 'directory_test_copied'
for path in [local_test_path, remote_test_path]:
try:
shutil.rmtree(path)
except OSError:
pass
os.mkdir(local_test_path)
remote_file_paths = []
for i in range(0, 10):
Expand All @@ -170,6 +175,11 @@ def test_ssh_client_directory_no_recurse(self):
test_file_data = 'test'
local_test_path = 'directory_test'
remote_test_path = 'directory_test_copied'
for path in [local_test_path, remote_test_path]:
try:
shutil.rmtree(path)
except OSError:
pass
os.mkdir(local_test_path)
remote_file_paths = []
for i in range(0, 10):
Expand Down

0 comments on commit 6bde7d0

Please sign in to comment.