score:8
here is a possible solution for the specific case of linearregression
and any other algorithm that support objective history (in this case, and linearregressiontrainingsummary
does the job).
let's first create a minimal verifiable and complete example :
import org.apache.spark.ml.param.parammap
import org.apache.spark.ml.regression.{linearregression, linearregressionmodel}
import org.apache.spark.ml.tuning.{paramgridbuilder, trainvalidationsplit}
import org.apache.spark.mllib.util.{lineardatagenerator, mlutils}
import org.apache.spark.sql.sparksession
val spark: sparksession = sparksession.builder().getorcreate()
import org.apache.spark.ml.evaluation.regressionevaluator
import spark.implicits._
val data = {
val tmp = lineardatagenerator.generatelinearrdd(
spark.sparkcontext,
nexamples = 10000,
nfeatures = 4,
eps = 0.05
).todf
mlutils.convertvectorcolumnstoml(tmp, "features")
}
as you've noticed, when you want to generate data for testing purposes for spark-mllib
or spark-ml
, it's advised to use data generators.
now, let's train a linear regressor :
// create model of linear regression.
val lr = new linearregression().setmaxiter(1000)
// the following line will create two sets of parameters
val paramgrid = new paramgridbuilder().addgrid(lr.regparam, array(0.001)).addgrid(lr.fitintercept).addgrid(lr.elasticnetparam, array(0.5)).build()
// create trainer using validation split to evaluate which set of parameters performs the best.
// i'm using the regular regressionevaluator here
val trainvalidationsplit = new trainvalidationsplit()
.setestimator(lr)
.setevaluator(new regressionevaluator)
.setestimatorparammaps(paramgrid)
.settrainratio(0.8) // 80% of the data will be used for training and the remaining 20% for validation.
// to retrieve submodels, make sure to set collectsubmodels to true before fitting.
trainvalidationsplit.setcollectsubmodels(true)
// run train validation split, and choose the best set of parameters.
var model = trainvalidationsplit.fit(data)
now since our model is trained, all we need is to get the objective history.
the following part needs a bit of gymnastics between the model and sub-models object parameters.
in case you have a pipeline
or so, this code needs to be modified, so use it carefully. it's just an example :
val objectivehist = spark.sparkcontext.parallelize(
model.submodels.zip(model.getestimatorparammaps).map {
case (m: linearregressionmodel, pm: parammap) =>
val history: array[double] = m.summary.objectivehistory
val idx: seq[int] = 1 until history.length
// regparam, elasticnetparam, fitintercept
val parameters = pm.toseq.map(pair => (pair.param.name, pair.value.tostring)) match {
case seq(x, y, z) => (x._2, y._2, z._2)
}
(parameters._1, parameters._2, parameters._3, idx.zip(history).tomap)
}).todf("regparam", "elasticnetparam", "fitintercept", "objectivehistory")
we can now examine those metrics :
objectivehist.show(false)
// +--------+---------------+------------+-------------------------------------------------------------------------------------------------------+
// |regparam|elasticnetparam|fitintercept|objectivehistory |
// +--------+---------------+------------+-------------------------------------------------------------------------------------------------------+
// |0.001 |0.5 |true |[1 -> 0.4999999999999999, 2 -> 0.4038796441909531, 3 -> 0.02659222058006269, 4 -> 0.026592220340980147]|
// |0.001 |0.5 |false |[1 -> 0.5000637621421942, 2 -> 0.4039303922115196, 3 -> 0.026592220673025396, 4 -> 0.02659222039347222]|
// +--------+---------------+------------+-------------------------------------------------------------------------------------------------------+
you can notice that the training process actually stops after 4 iterations.
if you want just the number of iterations, you can do the following instead :
val objectivehist2 = spark.sparkcontext.parallelize(
model.submodels.zip(model.getestimatorparammaps).map {
case (m: linearregressionmodel, pm: parammap) =>
val history: array[double] = m.summary.objectivehistory
// regparam, elasticnetparam, fitintercept
val parameters = pm.toseq.map(pair => (pair.param.name, pair.value.tostring)) match {
case seq(x, y, z) => (x._2, y._2, z._2)
}
(parameters._1, parameters._2, parameters._3, history.size)
}).todf("regparam", "elasticnetparam", "fitintercept", "iterations")
i've changed the number of features in the generator (nfeatures = 100
) for the sake of demonstrations :
objectivehist2.show
// +--------+---------------+------------+----------+
// |regparam|elasticnetparam|fitintercept|iterations|
// +--------+---------------+------------+----------+
// | 0.001| 0.5| true| 11|
// | 0.001| 0.5| false| 11|
// +--------+---------------+------------+----------+
Source: stackoverflow.com
Related Query
- Spark: Draw learning curve of a model with spark
- how we load machine learning Model .sav file with spark
- How can you efficiently build one ML model per partition in Spark with foreachPartition?
- Unable to save XGBoost model with spark
- How to Predict with a Spark MLlib model trained in LibSVM format
- LDA model with spark
- How to evaluate the performance of the model (accuracy) in Spark Pipeline with Linear Regression
- Spark returns (LogisticRegression) model with scaled coefficients
- What's the difference of the Akka's Actor with Scala's Actor model
- Querying Spark SQL DataFrame with complex types
- aggregate function Count usage with groupBy in Spark
- Why does Spark fail with java.lang.OutOfMemoryError: GC overhead limit exceeded?
- Optimal way to create a ml pipeline in Apache Spark for dataset with high number of columns
- Learning Haskell with a view to learning Scala
- Filter Spark DataFrame by checking if value is in a list, with other criteria
- Aggregating multiple columns with custom function in Spark
- Create new column with function in Spark Dataframe
- Why does Spark fail with "Detected cartesian product for INNER join between logical plans"?
- Column name with dot spark
- Perform a typed join in Scala with Spark Datasets
- Spark / Scala: forward fill with last observation
- How to use orderby() with descending order in Spark window functions?
- Reading TSV into Spark Dataframe with Scala API
- Merge Spark output CSV files with a single header
- Using Scala 2.12 with Spark 2.x
- Reading JSON with Apache Spark - `corrupt_record`
- Why does Spark application fail with “ClassNotFoundException: Failed to find data source: kafka” as uber-jar with sbt assembly?
- How to escape column names with hyphen in Spark SQL
- Howto model named parameters in method invocations with Scala macros?
- How to use s3 with Apache spark 2.2 in the Spark shell
More Query from same tag
- Why Scala compiler throws IndexOutOfBoundException while applying foreach on a mutable list
- Spark and sharded JDBC datasources
- scala list of objects, using groupBy with average
- How to convert context.universe.Annotation to MyAnnotation in a Scala Macro
- How to read only latest 7 days csv files from S3 bucket
- Private canonical constructor for record
- How to convert a Scala Iterable into Java util.List?
- Why does a small change to this Scala code make such a huge difference to performance?
- Akka Scala actor scheduled message does not appear to fire
- Dynamically loading a Scala object
- Scala Type MisMatch Error in spark
- Authentication for controllers.Assets.versioned
- In Scala, is it possible to have a case class containing methods with identical names to fields?
- Scala play coffee not compile javascript in public/javascript/
- How to terminate a hung akka actor?
- Concept of Linearization in Scala and behaviour of super
- How to sort an array with a custom order in Scala
- Sending data from my spark code to redshift
- Spark Dataframe stat throwing Task not serializable
- Quicksort using Future ends up in a deadlock
- Comparing string numbers in Scala
- scalafmt: Getting meaningful error messages
- How to deserialize a scala tree with JSON4S
- Scala Option type comparison
- Simple Type Inference in Scala
- Supplying resources outside JAR
- Task not serializable - Regex
- How to handle the null/empty values on a dataframe Spark/Scala
- What does it mean to 'hash cons'?
- How can I create a Scala XMLEventReader that reads from stdin?