-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodelevaldoc.scala
23 lines (14 loc) · 918 Bytes
/
modelevaldoc.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.{ParamGridBuilder,TrainValidationSplit}
val data = spark.read.format("libsvm").load("sample_linear_regression_data.txt")
// TRAIN TEST SPLIT
val Array(train,test) = data.randomSplit(Array(0.8,0.2),seed=56)
// data.printSchema()
val lr = new LinearRegression()
val paramGrid = new ParamGridBuilder().addGrid(lr.regParam,Array(0.01,0.1)).addGrid(lr.fitIntercept).addGrid(lr.elasticNetParam,Array(0.0,0.5,1.0)).build()
val trainValidationSplit = new TrainValidationSplit().setEstimator(lr).setEvaluator(new RegressionEvaluator()).setEstimatorParamMaps(paramGrid).setTrainRatio(0.8)
val model = trainValidationSplit.fit(train)
val predictions = model.transform(test)
predictions.select("features","label","prediction").show()