-
Notifications
You must be signed in to change notification settings - Fork 334
/
Copy path5xa-spark-1hot.txt
44 lines (30 loc) · 1.55 KB
/
5xa-spark-1hot.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
spark-1.6.0-bin-hadoop2.4/bin/spark-shell --packages com.databricks:spark-csv_2.11:1.2.0
// from Joseph Bradley https://gist.github.com/jkbradley/1e3cc0b3116f2f615b3f
// modifications by Xusen Yin https://github.com/szilard/benchm-ml/commit/db65cf000c9b1565b6e93d2d10c92dd646644d85
// TODO (or rather use the Spark 2.0 version): Distance and DepTime numerical, not string (was OK in J.Bradley version)
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions.{col, lit}
// Paths
val origTrainPath = "train-10m.csv"
val origTestPath = "test.csv"
val newTrainPath = "spark1hot-train-10m.parquet"
val newTestPath = "spark1hot-test-10m.parquet"
// Read CSV as Spark DataFrames
val loader = sqlContext.read.format("com.databricks.spark.csv").option("header", "true")
val trainDF = loader.load(origTrainPath)
val testDF = loader.load(origTestPath)
// Combine train, test temporarily
val fullDF = trainDF.withColumn("isTrain", lit(true)).unionAll(testDF.withColumn("isTrain", lit(false)))
// fullDF.show(10)
// Use RFormula to generate training data.
val res = new RFormula().setFormula("dep_delayed_15min ~ .").fit(fullDF).transform(fullDF)
// Split back into train, test
val finalTrainDF = res.where(col("isTrain"))
val finalTestDF = res.where(!col("isTrain"))
// Save Spark DataFrames as Parquet
finalTrainDF.write.mode("overwrite").parquet(newTrainPath)
finalTestDF.write.mode("overwrite").parquet(newTestPath)
// finalTrainDF.show(10)
// finalTrainDF.show(10)