-
Notifications
You must be signed in to change notification settings - Fork 0
/
Build_Hash.py
32 lines (25 loc) · 1.19 KB
/
Build_Hash.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
from pyspark.sql.types import *
from pyspark.sql.functions import col
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline, PipelineModel
import sys
def main(spark, sc):
sc.setLogLevel("OFF")
spark.conf.set("spark.blacklist.enabled", "False")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
schemaRatings = spark.read.parquet('hdfs:/user/bm106/pub/MSD/cf_train_new.parquet').select('user_id','track_id','count')
schemaRatings = schemaRatings.repartition(1000)
# print(schemaRatings.rdd.getNumPartitions())
indexers = [StringIndexer(inputCol=column, outputCol=column+"_index").setHandleInvalid("skip").fit(schemaRatings) \
for column in list(set(schemaRatings.columns)-set(['count'])) ]
pipeline = Pipeline(stages=indexers)
indexed = pipeline.fit(schemaRatings)
path = 'hdfs:/user/zm2114/hash'
indexed.write().overwrite().save(path)
if __name__ == "__main__":
# Create the spark session object
spark = SparkSession.builder.appName('Build_Hash').getOrCreate()
sc = spark.sparkContext
main(spark, sc)