机器学习---决策树和随机森林代码

发布时间:2023年12月20日

1、决策树代码

1.object ClassificationDecisionTree {
2.
3.  def main(args: Array[String]): Unit = {
4.    val conf = new SparkConf()
5.    conf.setAppName("analysItem")
6.    conf.setMaster("local[3]")
7.    val sc = new SparkContext(conf)
8.    val data = MLUtils.loadLibSVMFile(sc, "汽车数据样本.txt")
9.    // Split the data into training and test sets (30% held out for testing)
10.    val splits = data.randomSplit(Array(0.7, 0.3))
11.    val (trainingData, testData) = (splits(0), splits(1))
12.    //指明类别
13.    val numClasses=2
14.    //指定离散变量,未指明的都当作连续变量处理
15.    //1,2,3,4维度进来就变成了0,1,2,3
16.    //这里天气维度有3类,但是要指明4,这里是个坑,后面以此类推
17.    val categoricalFeaturesInfo=Map[Int,Int](0->4,1->4,2->3,3->3)
18.    //设定评判标准 "gini"/"entropy"
19.    val impurity="entropy"
20.    //树的最大深度,太深运算量大也没有必要 剪枝 防止模型的过拟合!!!
21.    val maxDepth=3
22.    //设置离散化程度,连续数据需要离散化,分成32个区间,默认其实就是32,分割的区间保证数量差不多 这个参数也可以进行剪枝
23.     val maxBins=32
24.    //生成模型
25.    val model =DecisionTree.trainClassifier(trainingData,numClasses,categoricalFeaturesInfo,impurity,maxDepth,maxBins)
26.    //测试
27.    val labelAndPreds = testData.map { point =>
28.       val prediction = model.predict(point.features)
29.       (point.label, prediction)
30.    }
31.    val testErr = labelAndPreds.filter(r => r._1 !=    r._2).count().toDouble / testData.count()
32.    println("Test Error = " + testErr)
33.    println("Learned classification tree model:\n" +      model.toDebugString)
34.
35.  }
36.}

2、随机森林代码

1.object ClassificationRandomForest {
2.  def main(args: Array[String]): Unit = {
3.    val conf = new SparkConf()
4.    conf.setAppName("analysItem")
5.    conf.setMaster("local[3]")
6.    val sc = new SparkContext(conf)
7.    //读取数据
8.    val data = MLUtils.loadLibSVMFile(sc,"汽车数据样本.txt")
9.    //将样本按7:3的比例分成
10.    val splits = data.randomSplit(Array(0.7, 0.3))
11.    val (trainingData, testData) = (splits(0), splits(1))
12.    //分类数
13.    val numClasses = 2
14.    // categoricalFeaturesInfo 为空,意味着所有的特征为连续型变量
15.    val categoricalFeaturesInfo =Map[Int, Int](0->4,1->4,2->3,3->3)
16.    //树的个数
17.    val numTrees = 3 
18.    //特征子集采样策略,auto 表示算法自主选取
19.    //"auto"根据特征数量在4个中进行选择
20.    // 1,all 全部特征 2,sqrt 把特征数量开根号后随机选择的 3,log2 取对数个 4,onethird 三分之一
21.    val featureSubsetStrategy = "auto"
22.   //纯度计算 "gini"/"entropy"
23.   val impurity = "entropy"
24.    //树的最大层次
25.    val maxDepth = 3
26.    //特征最大装箱数,即连续数据离散化的区间
27.    val maxBins = 32
28.    //训练随机森林分类器,trainClassifier 返回的是 RandomForestModel 对象
29.    val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
30.       numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
31.    //打印模型
32.    println(model.toDebugString)
33.    //保存模型
34.    //model.save(sc,"汽车保险")
35.    //在测试集上进行测试
36.    val count = testData.map { point =>
37.      val prediction = model.predict(point.features)
38.      // Math.abs(prediction-point.label)
39.      (prediction,point.label)
40.    }.filter(r => r._1 != r._2).count()
41.    println("Test Error = " + count.toDouble/testData.count().toDouble)
42.    println()
43.  }
44.}

文章来源:https://blog.csdn.net/yaya_jn/article/details/135099044
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。