diff --git a/src/main/scala/fm/sbt/S3URLHandler.scala b/src/main/scala/fm/sbt/S3URLHandler.scala index 29b69aa..ec3562d 100644 --- a/src/main/scala/fm/sbt/S3URLHandler.scala +++ b/src/main/scala/fm/sbt/S3URLHandler.scala @@ -35,7 +35,7 @@ import scala.util.matching.Regex object S3URLHandler { private val DOT_SBT_DIR: File = new File(System.getProperty("user.home"), ".sbt") - + // This is for matching region names in URLs or host names private val RegionMatcher: Regex = Regions.values().map{ _.getName }.sortBy{ -1 * _.length }.mkString("|").r @@ -81,7 +81,7 @@ object S3URLHandler { def getCredentials(): AWSCredentials = { val roleArn: String = getRoleArn(RoleArnKeyNames: _*) - + if (roleArn == null || roleArn == "") return null val securityTokenService: AWSSecurityTokenService = AWSSecurityTokenServiceClient.builder().withCredentials(providerChain).build() @@ -168,12 +168,12 @@ final class S3URLHandler extends URLHandler { def getURLInfo(url: URL): URLInfo = getURLInfo(url, 0) private def debug(msg: String): Unit = Message.debug("S3URLHandler."+msg) - + private def makePropertiesFileCredentialsProvider(fileName: String): PropertiesFileCredentialsProvider = { val file: File = new File(DOT_SBT_DIR, fileName) new PropertiesFileCredentialsProvider(file.toString) } - + private def makeCredentialsProviderChain(bucket: String): AWSCredentialsProviderChain = { val basicProviders: Vector[AWSCredentialsProvider] = Vector( new BucketSpecificEnvironmentVariableCredentialsProvider(bucket), @@ -196,30 +196,24 @@ final class S3URLHandler extends URLHandler { new RoleBasedSystemPropertiesCredentialsProvider(basicProviderChain), new RoleBasedPropertiesFileCredentialsProvider(basicProviderChain, s".s3credentials") ) - - new AWSCredentialsProviderChain((roleBasedProviders union basicProviders): _*) + + new AWSCredentialsProviderChain((roleBasedProviders ++ basicProviders): _*) } - - private val credentialsCache: ConcurrentHashMap[String,AWSCredentials] = new ConcurrentHashMap() - - def getCredentials(bucket: String): AWSCredentials = { - var credentials: AWSCredentials = credentialsCache.get(bucket) - - if (null == credentials) { - credentials = try { - makeCredentialsProviderChain(bucket).getCredentials() - } catch { - case ex: com.amazonaws.AmazonClientException => - Message.error("Unable to find AWS Credentials.") - throw ex - } - - Message.info("S3URLHandler - Using AWS Access Key Id: "+credentials.getAWSAccessKeyId+" for bucket: "+bucket) - - credentialsCache.put(bucket, credentials) + + def getCredentialsProvider(bucket: String): AWSCredentialsProvider = { + Message.info("S3URLHandler - Looking up AWS Credentials for bucket: "+bucket+" ...") + + val credentialsProvider: AWSCredentialsProvider = try { + makeCredentialsProviderChain(bucket) + } catch { + case ex: com.amazonaws.AmazonClientException => + Message.error("Unable to find AWS Credentials.") + throw ex } - - credentials + + Message.info("S3URLHandler - Using AWS Access Key Id: "+credentialsProvider.getCredentials().getAWSAccessKeyId+" for bucket: "+bucket) + + credentialsProvider } def getProxyConfiguration: ClientConfiguration = { @@ -234,20 +228,29 @@ final class S3URLHandler extends URLHandler { configuration } + // Bucket Name => AmazonS3 + private val amazonS3ClientCache: ConcurrentHashMap[String,AmazonS3] = new ConcurrentHashMap() + def getClientBucketAndKey(url: URL): (AmazonS3, String, String) = { val (bucket, key) = getBucketAndKey(url) - val client: AmazonS3 = AmazonS3Client.builder() - .withCredentials(new AWSStaticCredentialsProvider(getCredentials(bucket))) - .withClientConfiguration(getProxyConfiguration) - .build() + var client: AmazonS3 = amazonS3ClientCache.get(bucket) + + if (null == client) { + client = AmazonS3Client.builder() + .withCredentials(getCredentialsProvider(bucket)) + .withClientConfiguration(getProxyConfiguration) + .withRegion(getRegion(url, bucket)) + .build() + + amazonS3ClientCache.put(bucket, client) + + Message.info("S3URLHandler - Created S3 Client for bucket: "+bucket+" and region: "+client.getRegionName) + } - val region: Option[Region] = getRegion(url, bucket, client) - region.foreach{ client.setRegion } - (client, bucket, key) } - + def getURLInfo(url: URL, timeout: Int): URLInfo = try { debug(s"getURLInfo($url, $timeout)") @@ -330,10 +333,10 @@ final class S3URLHandler extends URLHandler { def setRequestMethod(requestMethod: Int): Unit = debug(s"setRequestMethod($requestMethod)") // Try to get the region of the S3 URL so we can set it on the S3Client - def getRegion(url: URL, bucket: String, client: AmazonS3): Option[Region] = { - val region: Option[String] = getRegionNameFromURL(url) orElse getRegionNameFromDNS(bucket) orElse getRegionNameFromService(bucket, client) + def getRegion(url: URL, bucket: String/*, client: AmazonS3*/): Regions = { + val region: Option[String] = getRegionNameFromURL(url) orElse getRegionNameFromDNS(bucket) orElse Option(Regions.getCurrentRegion()).map{ _.getName } - region.map{ RegionUtils.getRegion }.flatMap{ Option(_) } + region.map{ Regions.fromName }.flatMap{ Option(_) } getOrElse Regions.DEFAULT_REGION } def getRegionNameFromURL(url: URL): Option[String] = { @@ -349,12 +352,12 @@ final class S3URLHandler extends URLHandler { // So we use our regex based RegionMatcher to try and extract the region since AmazonS3URI doesn't work RegionMatcher.findFirstIn(canonicalHostName) } - - // TODO: cache the result of this so we aren't always making the call - def getRegionNameFromService(bucket: String, client: AmazonS3): Option[String] = { - // This might fail if the current credentials don't have access to the getBucketLocation call - Try { client.getBucketLocation(bucket) }.toOption - } + + // Not used anymore since the AmazonS3ClientBuilder requires the region during construction +// def getRegionNameFromService(bucket: String, client: AmazonS3): Option[String] = { +// // This might fail if the current credentials don't have access to the getBucketLocation call +// Try { client.getBucketLocation(bucket) }.toOption +// } def getBucketAndKey(url: URL): (String, String) = { // The AmazonS3URI constructor should work for standard S3 urls. But if a custom domain is being used