Skip to content

Commit

Permalink
Cache the AmazonS3 instances instead of just the Credentials and adde…
Browse files Browse the repository at this point in the history
…d some additional logging
  • Loading branch information
tpunder committed Apr 7, 2017
1 parent b920738 commit f629294
Showing 1 changed file with 46 additions and 43 deletions.
89 changes: 46 additions & 43 deletions src/main/scala/fm/sbt/S3URLHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import scala.util.matching.Regex

object S3URLHandler {
private val DOT_SBT_DIR: File = new File(System.getProperty("user.home"), ".sbt")

// This is for matching region names in URLs or host names
private val RegionMatcher: Regex = Regions.values().map{ _.getName }.sortBy{ -1 * _.length }.mkString("|").r

Expand Down Expand Up @@ -81,7 +81,7 @@ object S3URLHandler {

def getCredentials(): AWSCredentials = {
val roleArn: String = getRoleArn(RoleArnKeyNames: _*)

if (roleArn == null || roleArn == "") return null

val securityTokenService: AWSSecurityTokenService = AWSSecurityTokenServiceClient.builder().withCredentials(providerChain).build()
Expand Down Expand Up @@ -168,12 +168,12 @@ final class S3URLHandler extends URLHandler {
def getURLInfo(url: URL): URLInfo = getURLInfo(url, 0)

private def debug(msg: String): Unit = Message.debug("S3URLHandler."+msg)

private def makePropertiesFileCredentialsProvider(fileName: String): PropertiesFileCredentialsProvider = {
val file: File = new File(DOT_SBT_DIR, fileName)
new PropertiesFileCredentialsProvider(file.toString)
}

private def makeCredentialsProviderChain(bucket: String): AWSCredentialsProviderChain = {
val basicProviders: Vector[AWSCredentialsProvider] = Vector(
new BucketSpecificEnvironmentVariableCredentialsProvider(bucket),
Expand All @@ -196,30 +196,24 @@ final class S3URLHandler extends URLHandler {
new RoleBasedSystemPropertiesCredentialsProvider(basicProviderChain),
new RoleBasedPropertiesFileCredentialsProvider(basicProviderChain, s".s3credentials")
)
new AWSCredentialsProviderChain((roleBasedProviders union basicProviders): _*)

new AWSCredentialsProviderChain((roleBasedProviders ++ basicProviders): _*)
}

private val credentialsCache: ConcurrentHashMap[String,AWSCredentials] = new ConcurrentHashMap()

def getCredentials(bucket: String): AWSCredentials = {
var credentials: AWSCredentials = credentialsCache.get(bucket)

if (null == credentials) {
credentials = try {
makeCredentialsProviderChain(bucket).getCredentials()
} catch {
case ex: com.amazonaws.AmazonClientException =>
Message.error("Unable to find AWS Credentials.")
throw ex
}

Message.info("S3URLHandler - Using AWS Access Key Id: "+credentials.getAWSAccessKeyId+" for bucket: "+bucket)

credentialsCache.put(bucket, credentials)

def getCredentialsProvider(bucket: String): AWSCredentialsProvider = {
Message.info("S3URLHandler - Looking up AWS Credentials for bucket: "+bucket+" ...")

val credentialsProvider: AWSCredentialsProvider = try {
makeCredentialsProviderChain(bucket)
} catch {
case ex: com.amazonaws.AmazonClientException =>
Message.error("Unable to find AWS Credentials.")
throw ex
}

credentials

Message.info("S3URLHandler - Using AWS Access Key Id: "+credentialsProvider.getCredentials().getAWSAccessKeyId+" for bucket: "+bucket)

credentialsProvider
}

def getProxyConfiguration: ClientConfiguration = {
Expand All @@ -234,20 +228,29 @@ final class S3URLHandler extends URLHandler {
configuration
}

// Bucket Name => AmazonS3
private val amazonS3ClientCache: ConcurrentHashMap[String,AmazonS3] = new ConcurrentHashMap()

def getClientBucketAndKey(url: URL): (AmazonS3, String, String) = {
val (bucket, key) = getBucketAndKey(url)

val client: AmazonS3 = AmazonS3Client.builder()
.withCredentials(new AWSStaticCredentialsProvider(getCredentials(bucket)))
.withClientConfiguration(getProxyConfiguration)
.build()
var client: AmazonS3 = amazonS3ClientCache.get(bucket)

if (null == client) {
client = AmazonS3Client.builder()
.withCredentials(getCredentialsProvider(bucket))
.withClientConfiguration(getProxyConfiguration)
.withRegion(getRegion(url, bucket))
.build()

amazonS3ClientCache.put(bucket, client)

Message.info("S3URLHandler - Created S3 Client for bucket: "+bucket+" and region: "+client.getRegionName)
}

val region: Option[Region] = getRegion(url, bucket, client)
region.foreach{ client.setRegion }

(client, bucket, key)
}

def getURLInfo(url: URL, timeout: Int): URLInfo = try {
debug(s"getURLInfo($url, $timeout)")

Expand Down Expand Up @@ -330,10 +333,10 @@ final class S3URLHandler extends URLHandler {
def setRequestMethod(requestMethod: Int): Unit = debug(s"setRequestMethod($requestMethod)")

// Try to get the region of the S3 URL so we can set it on the S3Client
def getRegion(url: URL, bucket: String, client: AmazonS3): Option[Region] = {
val region: Option[String] = getRegionNameFromURL(url) orElse getRegionNameFromDNS(bucket) orElse getRegionNameFromService(bucket, client)
def getRegion(url: URL, bucket: String/*, client: AmazonS3*/): Regions = {
val region: Option[String] = getRegionNameFromURL(url) orElse getRegionNameFromDNS(bucket) orElse Option(Regions.getCurrentRegion()).map{ _.getName }

region.map{ RegionUtils.getRegion }.flatMap{ Option(_) }
region.map{ Regions.fromName }.flatMap{ Option(_) } getOrElse Regions.DEFAULT_REGION
}

def getRegionNameFromURL(url: URL): Option[String] = {
Expand All @@ -349,12 +352,12 @@ final class S3URLHandler extends URLHandler {
// So we use our regex based RegionMatcher to try and extract the region since AmazonS3URI doesn't work
RegionMatcher.findFirstIn(canonicalHostName)
}
// TODO: cache the result of this so we aren't always making the call
def getRegionNameFromService(bucket: String, client: AmazonS3): Option[String] = {
// This might fail if the current credentials don't have access to the getBucketLocation call
Try { client.getBucketLocation(bucket) }.toOption
}

// Not used anymore since the AmazonS3ClientBuilder requires the region during construction
// def getRegionNameFromService(bucket: String, client: AmazonS3): Option[String] = {
// // This might fail if the current credentials don't have access to the getBucketLocation call
// Try { client.getBucketLocation(bucket) }.toOption
// }

def getBucketAndKey(url: URL): (String, String) = {
// The AmazonS3URI constructor should work for standard S3 urls. But if a custom domain is being used
Expand Down

0 comments on commit f629294

Please sign in to comment.