spark|SCALA下的GBDT与LR融合实现
admin
2023-08-16 15:45:36
0

我们直接使用的ML的包对GBDT/LR进行融合
首先我们需要导入的包如下所示:

import org.apache.spark.sql. Row import scala.collection.mutable import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.DenseVector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, FeatureType, Strategy} import org.apache.spark.mllib.tree.model.Node import com.suning.aps.util.handle_data.deleteHDFS import com.suning.aps.utils.StringUtil.getSign import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.sql.types._

其次,由于ml的数据需要LabeledPoint格式,而我们从hive里面提取的数据格式是DataFrame ,所以我们需要做数据转化
val spark = SparkSession.builder().appName("wap_cpc_searchApp_v") .enableHiveSupport().getOrCreate() import spark.implicits._ val instance = "yuntu/cpc_ctr/hotmai_02&20" val statis_day = "20191207" val statis_day_before = "20191205" val parallels =50deleteHDFS(instance + "/wap_cpc_searchApp_v/" + statis_day, spark.sparkContext)val pv_click_searchApp1 = spark.sql( s"""select |case when length(ta.query)=0 then '-' else ta.query end as query, |case when length(ta.ideaid)=0 then '00000000000000000-' |when ta.ideaid is NULL then '00000000000000000-' |else lpad(ta.ideaid,18,'0') end as ideaid, |case when length(ta.terminal)=0 then '-' else ta.terminal end as terminal,""".stripMargin) .select("is_click","query","terminal","query_brand_name","query_third_categ","userid" ,"ideaid","idea_first_categ","idea_second_categ", "idea_third_categ","idea_brand_name").rdd.map(l=>{Row( l(0).toString.toDouble, getSign(l(1).toString).toString.toDouble, getSign(l(3).toString).toString.toDouble, getSign(l(4).toString).toString.toDouble, getSign(l(5).toString).toString.toDouble, getSign(l(6).toString).toString.toDouble, getSign(l(7).toString).toString.toDouble, getSign(l(8).toString).toString.toDouble, getSign(l(9).toString).toString.toDouble, getSign(l(10).toString).toString.toDouble)}) //由于使用MAP转化了数据(原先数据很多是离散变量,我们需要进行哈希编码)所以我们先转化为DATAFRAME val ScoreSchema=StructType(mutable.ArraySeq( StructField("y",DoubleType,nullable=false), StructField("x1",DoubleType,nullable=false),StructField("x2",DoubleType,nullable=false), StructField("x3",DoubleType,nullable=false),StructField("x4",DoubleType,nullable=false), StructField("x5",DoubleType,nullable=false),StructField("x6",DoubleType,nullable=false), StructField("x7",DoubleType,nullable=false),StructField("x8",DoubleType,nullable=false), StructField("x9",DoubleType,nullable=false) )) val df=spark.createDataFrame(pv_click_searchApp1,ScoreSchema) val ignored = List("y") val featInd = df.columns.diff(ignored).map(df.columns.indexOf(_))// Get index of target val targetInd = df.columns.indexOf("y")val ds=df.rdd.map(r => LabeledPoint( r.getDouble(targetInd), // Get target value // Map feature indices to values Vectors.dense(featInd.map(r.getDouble(_)).toArray) ))

上文DS即为转化后的数据,下文将介绍GBDT
//训练GBDT val numTrees = 2 // Train a GradientBoostedTrees model. val boostingStrategy = BoostingStrategy.defaultParams("Classification") boostingStrategy.numIterations = 10 boostingStrategy.treeStrategy.numClasses = 2 boostingStrategy.treeStrategy.maxDepth = 3 boostingStrategy.learningRate = 0.3 boostingStrategy.setNumIterations(numTrees) // Empty categoricalFeaturesInfo indicates all features are continuous. boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()val model = GradientBoostedTrees.train(ds, boostingStrategy)

训练完GBDT后,我们需要从GBDT中探查树模型的结构,如下
val treeLeafArray = new Array[Array[Int]](numTrees)for (i <- 0.until(numTrees)) { treeLeafArray(i) = getLeafNodes(model.trees(i).topNode) } for (i <- 0.until(numTrees)) { println("正在打印第%d 棵树的 topnode 叶子节点", i) for (j <- 0.until(treeLeafArray(i).length)) { println(j) }}

【spark|SCALA下的GBDT与LR融合实现】其中,我们定义了两个函数,如下所示:
def getLeafNodes(node: Node): Array[Int] = { var treeLeafNodes = new Array[Int](0) if (node.isLeaf) { treeLeafNodes = treeLeafNodes.:+(node.id) } else { treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.leftNode.get) treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.rightNode.get) } treeLeafNodes }// predict decision tree leaf's node value def predictModify(node: Node, features: DenseVector): Int = { val split = node.split if (node.isLeaf) { node.id } else { if (split.get.featureType == FeatureType.Continuous) { if (features(split.get.feature) <= split.get.threshold) { //println("Continuous left node") predictModify(node.leftNode.get, features) } else { //println("Continuous right node") predictModify(node.rightNode.get, features) } } else { if (split.get.categories.contains(features(split.get.feature))) { //println("Categorical left node") predictModify(node.leftNode.get, features) } else { //println("Categorical right node") predictModify(node.rightNode.get, features) } } } }

下文将利用上文训练的GBDT构造新特征并用LR训练
//gbdt 构造新特征 val newFeatureDataSet = df.rdd.map { x => (x(0).toString().toDouble, new DenseVector(Array(x(1).toString().toDouble, x(2).toString().toDouble, x(3).toString().toDouble, x(4).toString().toDouble, x(5).toString().toDouble, x(6).toString().toDouble , x(7).toString().toDouble, x(8).toString().toDouble, x(9).toString().toDouble))) }.map { x => var newFeature = new Array[Double](0) for (i <- 0.until(numTrees)) { val treePredict = predictModify(model.trees(i).topNode, x._2) //gbdt tree is binary tree val treeArray = new Array[Double]((model.trees(i).numNodes + 1) / 2) treeArray(treeLeafArray(i).indexOf(treePredict)) = 1 newFeature = newFeature ++ treeArray } (x._1, newFeature) }newFeatureDataSet.take(2).foreach(println)val newData = https://www.it610.com/article/newFeatureDataSet.map(x => LabeledPoint(x._1, new DenseVector(x._2)))newData.take(2).foreach(println) val splits2 = newData.randomSplit(Array(0.8, 0.2)) val train2 = splits2(0) val test2 = splits2(1)var predictions = ds.map(lp => model.predict(lp.features)) predictions.take(10).foreach(println) var predictionAndLabel = predictions.zip( ds.map(_.label)) var accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2 ).count/ds.count println("GBTR accuracy " + accuracy)val model1 = new LogisticRegressionWithLBFGS().setNumClasses(2).run(train2).setThreshold(0.01) model1.weights val predictionAndLabels = test2.map { case LabeledPoint(label, features) => val prediction = model1.predict(features) (prediction, label) } val metrics = new MulticlassMetrics(predictionAndLabels) val precision = metrics.accuracy println("Precision = " + precision)

相关内容

热门资讯

超... 本文目录导航: 超级云计算是什么 怎么做难看的PPT 1、...
谢... 本文目录导航: 请问云主机是什么 云主机有什么好处 具体的教程,谢谢! 云...
w... 本文目录导航: wps是什么意思 ppt的新配置designer和morp...
大... 本文目录导航: 大专学什么专业务工率高? 未来十年务工率最高的几大专业都是...
软... 本文目录导航: 软件技术专升本可以报什么专业 云计算专升本可以报医学吗 ...
云... 本文目录导航: 云计算务工前景 云计算务工方向及前景怎样样 ...
学... 本文目录导航: 学云计算进去无能嘛 云计算技术与运行是干什么的 ...
中... 本文目录导航: 如何了解云计算,中国的云计算产业开展现状如何 云计算未来几...
云... 本文目录导航: 云计算1+x证书含金量 云计算须要考什么证书 ...
云... 本文目录导航: 云计算股票龙头股票有哪些? 普通云计算概念龙头股有哪些?...
大... 本文目录导航: 大专云计算技术运行务工方向 大专毕业证上是物联网,实践学习...
大... 本文目录导航: 大数据云计算有必要升本吗 内蒙古大专云计算技术与运行专业升...
9... 本文目录导航: 99%学霸假期逆袭必看网站 99%学霸假期逆袭必看网站 ...
云... 本文目录导航: 云计算属于哪个专业 云计算属于什么专业 计...
计... 本文目录导航: 计算机二级MSOffice上机操作题及答案 想做一篇关于解...
A... 本文目录导航: AI能否会彻底扭转上流职业市场,如律师、会计师和医师? A...
人... 本文目录导航: 人工智能芯片产业链有哪些? 更多本行业钻研剖析详见前瞻产业...
人... 本文目录导航: 人工智能会带来哪些风险? 或许有一天,人工智能机器人将取代...
a... 本文目录导航: ai智能写作软件哪个好 ai智能写作软件有哪些?ai智能对...
自... 本文目录导航: 自考本科计算机专业难吗 自考计算机专业须要考哪些科目 ...