diff --git a/.gitignore b/.gitignore index 9e53722e0..be75df848 100644 --- a/.gitignore +++ b/.gitignore @@ -92,3 +92,6 @@ node_modules/ # Ignore mock database **/*.sqlite + +# Ignore virtual envs +*.venv diff --git a/covalent/_file_transfer/file_transfer.py b/covalent/_file_transfer/file_transfer.py index ee982a92f..96555c3b1 100644 --- a/covalent/_file_transfer/file_transfer.py +++ b/covalent/_file_transfer/file_transfer.py @@ -18,10 +18,37 @@ from .enums import FileTransferStrategyTypes, FtCallDepReturnValue, Order from .file import File +from .strategies.gcloud_strategy import GCloud from .strategies.http_strategy import HTTP +from .strategies.s3_strategy import S3 from .strategies.shutil_strategy import Shutil from .strategies.transfer_strategy_base import FileTransferStrategy +# TODO: make this pluggable similar to executor plugins +_strategy_type_map = { + FileTransferStrategyTypes.Shutil: Shutil, + FileTransferStrategyTypes.S3: S3, + FileTransferStrategyTypes.HTTP: HTTP, + FileTransferStrategyTypes.GCloud: GCloud, +} + + +def _guess_transfer_strategy(from_file: File, to_file: File) -> FileTransferStrategy: + # Handle the following cases automatically + # Local-Remote (except HTTP destination) + # Remote-local + # Local-local + + if ( + from_file.mapped_strategy_type == FileTransferStrategyTypes.Shutil + and to_file.mapped_strategy_type != FileTransferStrategyTypes.HTTP + ): + return _strategy_type_map[to_file.mapped_strategy_type] + elif to_file.mapped_strategy_type == FileTransferStrategyTypes.Shutil: + return _strategy_type_map[from_file.mapped_strategy_type] + else: + raise AttributeError("FileTransfer requires a file transfer strategy to be specified") + class FileTransfer: """ @@ -58,15 +85,8 @@ def __init__( # assign explicit strategy or default to strategy based on from_file & to_file schemes if strategy: self.strategy = strategy - elif ( - from_file.mapped_strategy_type == FileTransferStrategyTypes.Shutil - and to_file.mapped_strategy_type == FileTransferStrategyTypes.Shutil - ): - self.strategy = Shutil() - elif from_file.mapped_strategy_type == FileTransferStrategyTypes.HTTP: - self.strategy = HTTP() else: - raise AttributeError("FileTransfer requires a file transfer strategy to be specified") + self.strategy = _guess_transfer_strategy(from_file, to_file)() self.to_file = to_file self.from_file = from_file diff --git a/tests/covalent_tests/file_transfer/file_transfer_test.py b/tests/covalent_tests/file_transfer/file_transfer_test.py index 746e490c9..afd5ff5f5 100644 --- a/tests/covalent_tests/file_transfer/file_transfer_test.py +++ b/tests/covalent_tests/file_transfer/file_transfer_test.py @@ -26,6 +26,8 @@ TransferToRemote, ) from covalent._file_transfer.strategies.rsync_strategy import Rsync +from covalent._file_transfer.strategies.s3_strategy import S3 +from covalent._file_transfer.strategies.shutil_strategy import Shutil class TestFileTransfer: @@ -109,3 +111,18 @@ def test_transfer_to_remote(self): with pytest.raises(ValueError): result = TransferToRemote("file:///home/one", "file:///home/one/", strategy=strategy) + + def test_auto_transfer_strategy(self): + from_file = File("s3://bucket/object.pkl") + to_file = File("file:///tmp/object.pkl") + ft = FileTransfer(from_file, to_file) + assert type(ft.strategy) is S3 + + ft = FileTransfer(to_file, from_file) + assert type(ft.strategy) is S3 + + ft = FileTransfer(to_file, to_file) + assert type(ft.strategy) is Shutil + + with pytest.raises(AttributeError): + _ = FileTransfer(from_file, from_file)