diff --git a/cdap-runtime-ext-dataproc/src/main/java/io/cdap/cdap/runtime/spi/runtimejob/DataprocRuntimeJobManager.java b/cdap-runtime-ext-dataproc/src/main/java/io/cdap/cdap/runtime/spi/runtimejob/DataprocRuntimeJobManager.java index 0d252dc0462e..94cdf06a5717 100644 --- a/cdap-runtime-ext-dataproc/src/main/java/io/cdap/cdap/runtime/spi/runtimejob/DataprocRuntimeJobManager.java +++ b/cdap-runtime-ext-dataproc/src/main/java/io/cdap/cdap/runtime/spi/runtimejob/DataprocRuntimeJobManager.java @@ -162,7 +162,8 @@ public DataprocRuntimeJobManager(DataprocClusterInfo clusterInfo, /** * Returns a {@link Storage} object for interacting with GCS. */ - private Storage getStorageClient() { + @VisibleForTesting + public Storage getStorageClient() { Storage client = storageClient; if (client != null) { return client; @@ -573,7 +574,8 @@ private LocalFile uploadCacheableFile(String bucket, String targetFilePath, /** * Uploads files to gcs. */ - private LocalFile uploadFile(String bucket, String targetFilePath, + @VisibleForTesting + public LocalFile uploadFile(String bucket, String targetFilePath, LocalFile localFile, boolean cacheable) throws IOException, StorageException { BlobId blobId = BlobId.of(bucket, targetFilePath); @@ -612,9 +614,9 @@ private LocalFile uploadFile(String bucket, String targetFilePath, // https://cloud.google.com/storage/docs/request-preconditions#special-case // Overwrite the file Blob existingBlob = storage.get(blobId); - BlobInfo newBlobInfo = existingBlob.toBuilder().setContentType(contentType).build(); - uploadToGcsUtil(localFile, storage, targetFilePath, newBlobInfo, - Storage.BlobWriteOption.generationNotMatch()); + BlobInfo newBlobInfo = + BlobInfo.newBuilder(existingBlob.getBlobId()).setContentType(contentType).build(); + uploadToGcsUtil(localFile, storage, targetFilePath, newBlobInfo); } else { LOG.debug("Skip uploading file {} to gs://{}/{} because it exists.", localFile.getURI(), bucket, targetFilePath); @@ -637,7 +639,8 @@ private long getCustomTime() { /** * Uploads the file to GCS Bucket. */ - private void uploadToGcsUtil(LocalFile localFile, Storage storage, String targetFilePath, + @VisibleForTesting + public void uploadToGcsUtil(LocalFile localFile, Storage storage, String targetFilePath, BlobInfo blobInfo, Storage.BlobWriteOption... blobWriteOptions) throws IOException, StorageException { long start = System.nanoTime(); diff --git a/cdap-runtime-ext-dataproc/src/test/java/io/cdap/cdap/runtime/spi/provisioner/dataproc/DataprocRuntimeJobManagerTest.java b/cdap-runtime-ext-dataproc/src/test/java/io/cdap/cdap/runtime/spi/provisioner/dataproc/DataprocRuntimeJobManagerTest.java index d11a2897e0d9..47bda8cc1020 100644 --- a/cdap-runtime-ext-dataproc/src/test/java/io/cdap/cdap/runtime/spi/provisioner/dataproc/DataprocRuntimeJobManagerTest.java +++ b/cdap-runtime-ext-dataproc/src/test/java/io/cdap/cdap/runtime/spi/provisioner/dataproc/DataprocRuntimeJobManagerTest.java @@ -16,12 +16,21 @@ package io.cdap.cdap.runtime.spi.provisioner.dataproc; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.storage.Blob; +import com.google.cloud.storage.BlobId; +import com.google.cloud.storage.BlobInfo; +import com.google.cloud.storage.Bucket; +import com.google.cloud.storage.Storage; +import com.google.cloud.storage.StorageException; import com.google.common.collect.ImmutableMap; import io.cdap.cdap.runtime.spi.ProgramRunInfo; import io.cdap.cdap.runtime.spi.SparkCompat; +import io.cdap.cdap.runtime.spi.runtimejob.DataprocClusterInfo; import io.cdap.cdap.runtime.spi.runtimejob.DataprocRuntimeJobManager; import io.cdap.cdap.runtime.spi.runtimejob.LaunchMode; import io.cdap.cdap.runtime.spi.runtimejob.RuntimeJobInfo; +import java.net.HttpURLConnection; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -32,6 +41,8 @@ import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import org.mockito.Matchers; +import org.mockito.Mockito; /** Tests for DataprocRuntimeJobManager. */ public class DataprocRuntimeJobManagerTest { @@ -163,4 +174,50 @@ public void getPropertiesTest() { Assert.assertEquals( runInfo.getRun(), properties.get(DataprocRuntimeJobManager.CDAP_RUNTIME_RUNID)); } + + @Test + public void uploadFileTest() throws Exception { + final String bucketName = "bucket"; + GoogleCredentials credentials = Mockito.mock(GoogleCredentials.class); + Mockito.doReturn(true).when(credentials).createScopedRequired(); + DataprocRuntimeJobManager dataprocRuntimeJobManager = new DataprocRuntimeJobManager( + new DataprocClusterInfo(new MockProvisionerContext(), "test-cluster", credentials, + null, "test-project", "test-region", bucketName, Collections.emptyMap()), + Collections.emptyMap(), null); + + DataprocRuntimeJobManager mockedDataprocRuntimeJobManager = + Mockito.spy(dataprocRuntimeJobManager); + + Storage storage = Mockito.mock(Storage.class); + Mockito.doReturn(storage).when(mockedDataprocRuntimeJobManager).getStorageClient(); + + Bucket bucket = Mockito.mock(Bucket.class); + Mockito.doReturn(bucket).when(storage).get(Matchers.eq(bucketName)); + Mockito.doReturn("regional").when(bucket).getLocationType(); + Mockito.doReturn("test-region").when(bucket).getLocation(); + + String targetFilePath = "cdap-job/target"; + BlobId blobId = BlobId.of(bucketName, targetFilePath); + BlobId newBlobId = BlobId.of(bucketName, targetFilePath, 1L); + Blob blob = Mockito.mock(Blob.class); + Mockito.doReturn(blob).when(storage).get(blobId); + Mockito.doReturn(newBlobId).when(blob).getBlobId(); + + BlobInfo expectedBlobInfo = + BlobInfo.newBuilder(blobId).setContentType("application/octet-stream").build(); + Mockito.doThrow( + new StorageException(HttpURLConnection.HTTP_PRECON_FAILED, "blob already exists")) + .when(mockedDataprocRuntimeJobManager).uploadToGcsUtil(Mockito.any(), + Mockito.any(), Mockito.any(), Matchers.eq(expectedBlobInfo), + Matchers.eq(Storage.BlobWriteOption.doesNotExist())); + + expectedBlobInfo = + BlobInfo.newBuilder(newBlobId).setContentType("application/octet-stream").build(); + Mockito.doNothing().when(mockedDataprocRuntimeJobManager).uploadToGcsUtil(Mockito.any(), + Mockito.any(), Mockito.any(), Matchers.eq(expectedBlobInfo)); + + // call the method + LocalFile localFile = Mockito.mock(LocalFile.class); + mockedDataprocRuntimeJobManager.uploadFile(bucketName, targetFilePath, localFile, false); + } }