diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RateLimiter.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RateLimiter.scala new file mode 100644 index 0000000..c0e9131 --- /dev/null +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RateLimiter.scala @@ -0,0 +1,29 @@ +package dev.caraml.spark.stores.redis + +import dev.caraml.spark.RedisWriteProperties +import io.github.bucket4j.{Bandwidth, Bucket} + +import java.time.Duration.ofSeconds +import java.util.concurrent.ConcurrentHashMap + +object RateLimiter { + + private lazy val buckets: ConcurrentHashMap[RedisWriteProperties, Bucket] = new ConcurrentHashMap + + def get(properties: RedisWriteProperties): Bucket = { + buckets.computeIfAbsent(properties, create) + } + + def create(properties: RedisWriteProperties): Bucket = { + Bucket + .builder() + .addLimit( + Bandwidth + .builder() + .capacity(properties.ratePerSecondLimit) + .refillIntervally(properties.ratePerSecondLimit, ofSeconds(1)) + .build() + ) + .build() + } +} diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala index a28c69e..302db14 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala @@ -48,17 +48,6 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC ratePerSecondLimit = sparkConf.get("spark.redis.properties.ratePerSecondLimit").toInt ) - lazy private val rateLimitBucket: Bucket = Bucket - .builder() - .addLimit( - Bandwidth - .builder() - .capacity(properties.ratePerSecondLimit) - .refillIntervally(properties.ratePerSecondLimit, ofSeconds(1)) - .build() - ) - .build() - override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.foreachPartition { partition: Iterator[Row] => java.security.Security.setProperty("networkaddress.cache.ttl", "3"); @@ -69,6 +58,7 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC // grouped iterator to only allocate memory for a portion of rows partition.grouped(properties.pipelineSize).foreach { batch => if (properties.enableRateLimit) { + val rateLimitBucket = RateLimiter.get(properties) rateLimitBucket.asBlocking().consume(batch.length) } val rowsWithKey: Seq[(String, Row)] = batch.map(row => dataKeyId(row) -> row)