diff --git a/src/main/scala/io/archivesunleashed/spark/rdd/RecordRDD.scala b/src/main/scala/io/archivesunleashed/spark/rdd/RecordRDD.scala index c179d7d6..20e27857 100644 --- a/src/main/scala/io/archivesunleashed/spark/rdd/RecordRDD.scala +++ b/src/main/scala/io/archivesunleashed/spark/rdd/RecordRDD.scala @@ -73,8 +73,8 @@ object RecordRDD extends java.io.Serializable { rdd.filter(r => mimeTypes.contains(r.getMimeType)) } - def keepDate(date: String, component: DateComponent = DateComponent.YYYYMMDD) = { - rdd.filter(r => ExtractDate(r.getCrawlDate, component) == date) + def keepDate(dates: List[String], component: DateComponent = DateComponent.YYYYMMDD) = { + rdd.filter(r => dates.contains(ExtractDate(r.getCrawlDate, component))) } def keepUrls(urls: Set[String]) = { diff --git a/src/test/scala/io/archivesunleashed/spark/ArcTest.scala b/src/test/scala/io/archivesunleashed/spark/ArcTest.scala index 2ff87922..1046d32c 100644 --- a/src/test/scala/io/archivesunleashed/spark/ArcTest.scala +++ b/src/test/scala/io/archivesunleashed/spark/ArcTest.scala @@ -45,12 +45,12 @@ class ArcTest extends FunSuite with BeforeAndAfter { test("filter date") { val four = RecordLoader.loadArchives(arcPath, sc, keepValidPages = false) - .keepDate("200804", DateComponent.YYYYMM) + .keepDate(List("200804","200805"), DateComponent.YYYYMM) .map(r => r.getCrawlDate) .collect() val five = RecordLoader.loadArchives(arcPath, sc, keepValidPages = false) - .keepDate("200805", DateComponent.YYYYMM) + .keepDate(List("200805","200807"), DateComponent.YYYYMM) .map(r => r.getCrawlDate) .collect() diff --git a/src/test/scala/io/archivesunleashed/spark/rdd/RecordRDDTest.scala b/src/test/scala/io/archivesunleashed/spark/rdd/RecordRDDTest.scala index f74f8b58..462e5bcd 100644 --- a/src/test/scala/io/archivesunleashed/spark/rdd/RecordRDDTest.scala +++ b/src/test/scala/io/archivesunleashed/spark/rdd/RecordRDDTest.scala @@ -62,7 +62,7 @@ class RecordRDDTest extends FunSuite with BeforeAndAfter { val r = base .filter (x => ExtractDate(x.getCrawlDate, component) == "2008") .map ( mp => mp.getUrl).take(3) - val r2 = base.keepDate("2008", component) + val r2 = base.keepDate(List("2008"), component) .map ( mp => mp.getUrl).take(3) assert (r2.sameElements(r)) }