数据挖掘任务主要分为以下六个步骤:
这里准备了20条关于不同地区、不同性别、不同身高、体重…的人的兴趣数据集(命名为hobby.csv):
id,hobby,sex,address,age,height,weight
1,football,male,dalian,12,168,55
2,pingpang,female,yangzhou,21,163,60
3,football,male,dalian,,172,70
4,football,female,,13,167,58
5,pingpang,female,shanghai,63,170,64
6,football,male,dalian,30,177,76
7,basketball,male,shanghai,25,181,90
8,football,male,dalian,15,172,71
9,basketball,male,shanghai,25,179,80
10,pingpang,male,shanghai,55,175,72
11,football,male,dalian,13,169,55
12,pingpang,female,yangzhou,22,164,61
13,football,male,dalian,23,170,71
14,football,female,,12,164,55
15,pingpang,female,shanghai,64,169,63
16,football,male,dalian,30,177,76
17,basketball,male,shanghai,22,180,80
18,football,male,dalian,16,173,72
19,basketball,male,shanghai,23,176,73
20,pingpang,male,shanghai,56,171,71
想要连接数据,必须先创建一个spark对象
使用SparkSession中的builder()构建 后续设定appName 和master ,最后使用getOrCreate()完成构建
// 定义spark对象
val spark = SparkSession.builder().appName("兴趣预测").master("local[*]").getOrCreate()
使用spark.read连接数据,需要指定数据的格式为“CSV”,将首行设置为header,最后指定文件路径:
val df=spark.read.format("CSV").option("header",true).load("C:/Users/35369/Desktop/hobby.csv")
使用df.show() df.printSchema()查看数据:
df.show()
df.printSchema()
spark.stop() // 关闭spark
输出信息:
+---+----------+------+--------+----+------+------+
| id| hobby| sex| address| age|height|weight|
+---+----------+------+--------+----+------+------+
| 1| football| male| dalian| 12| 168| 55|
| 2| pingpang|female|yangzhou| 21| 163| 60|
| 3| football| male| dalian|null| 172| 70|
| 4| football|female| null| 13| 167| 58|
| 5| pingpang|female|shanghai| 63| 170| 64|
| 6| football| male| dalian| 30| 177| 76|
| 7|basketball| male|shanghai| 25| 181| 90|
| 8| football| male| dalian| 15| 172| 71|
| 9|basketball| male|shanghai| 25| 179| 80|
| 10| pingpang| male|shanghai| 55| 175| 72|
| 11| football| male| dalian| 13| 169| 55|
| 12| pingpang|female|yangzhou| 22| 164| 61|
| 13| football| male| dalian| 23| 170| 71|
| 14| football|female| null| 12| 164| 55|
| 15| pingpang|female|shanghai| 64| 169| 63|
| 16| football| male| dalian| 30| 177| 76|
| 17|basketball| male|shanghai| 22| 180| 80|
| 18| football| male| dalian| 16| 173| 72|
| 19|basketball| male|shanghai| 23| 176| 73|
| 20| pingpang| male|shanghai| 56| 171| 71|
+---+----------+------+--------+----+------+------+
root
|-- id: string (nullable = true)
|-- hobby: string (nullable = true)
|-- sex: string (nullable = true)
|-- address: string (nullable = true)
|-- age: string (nullable = true)
|-- height: string (nullable = true)
|-- weight: string (nullable = true)
补全数值型数据可以分三步:
(1)取出去除空行数据之后的这一列数据
(2)计算(1)中那一列数据的平均值
(3)将平均值填充至原先的表中
val ageNaDF = df.select("age").na.drop()
ageNaDF.show()
+---+
|age|
+---+
| 12|
| 21|
| 13|
| 63|
| 30|
| 25|
| 15|
| 25|
| 55|
| 13|
| 22|
| 23|
| 12|
| 64|
| 30|
| 22|
| 16|
| 23|
| 56|
+---+
查看ageNaDF的基本特征
ageNaDF.describe("age").show()
输出:
+-------+-----------------+
|summary| age|
+-------+-----------------+
| count| 19|
| mean|28.42105263157895|
| stddev|17.48432882286206|
| min| 12|
| max| 64|
+-------+-----------------+
可以看到其中的均值mean为28.42105263157895,我们需要取出这个mean
val mean = ageNaDF.describe("age").select("age").collect()(1)(0).toString
print(mean) //28.42105263157895
val ageFilledDF = df.na.fill(mean,List("age"))
ageFilledDF.show()
输出:
+---+----------+------+--------+-----------------+------+------+
| id| hobby| sex| address| age|height|weight|
+---+----------+------+--------+-----------------+------+------+
| 1| football| male| dalian| 12| 168| 55|
| 2| pingpang|female|yangzhou| 21| 163| 60|
| 3| football| male| dalian|28.42105263157895| 172| 70|
| 4| football|female| null| 13| 167| 58|
| 5| pingpang|female|shanghai| 63| 170| 64|
| 6| football| male| dalian| 30| 177| 76|
| 7|basketball| male|shanghai| 25| 181| 90|
| 8| football| male| dalian| 15| 172| 71|
| 9|basketball| male|shanghai| 25| 179| 80|
| 10| pingpang| male|shanghai| 55| 175| 72|
| 11| football| male| dalian| 13| 169| 55|
| 12| pingpang|female|yangzhou| 22| 164| 61|
| 13| football| male| dalian| 23| 170| 71|
| 14| football|female| null| 12| 164| 55|
| 15| pingpang|female|shanghai| 64| 169| 63|
| 16| football| male| dalian| 30| 177| 76|
| 17|basketball| male|shanghai| 22| 180| 80|
| 18| football| male| dalian| 16| 173| 72|
| 19|basketball| male|shanghai| 23| 176| 73|
| 20| pingpang| male|shanghai| 56| 171| 71|
+---+----------+------+--------+-----------------+------+------+
可以发现年龄中的空值被填充了平均值
由于城市的列没有合理的数据可以填充,所以如果城市出现空数据则选择把改行删除
使用.na.drop()方法
val addressDf = ageFilledDF.na.drop()
addressDf.show()
输出:
+---+----------+------+--------+-----------------+------+------+
| id| hobby| sex| address| age|height|weight|
+---+----------+------+--------+-----------------+------+------+
| 1| football| male| dalian| 12| 168| 55|
| 2| pingpang|female|yangzhou| 21| 163| 60|
| 3| football| male| dalian|28.42105263157895| 172| 70|
| 5| pingpang|female|shanghai| 63| 170| 64|
| 6| football| male| dalian| 30| 177| 76|
| 7|basketball| male|shanghai| 25| 181| 90|
| 8| football| male| dalian| 15| 172| 71|
| 9|basketball| male|shanghai| 25| 179| 80|
| 10| pingpang| male|shanghai| 55| 175| 72|
| 11| football| male| dalian| 13| 169| 55|
| 12| pingpang|female|yangzhou| 22| 164| 61|
| 13| football| male| dalian| 23| 170| 71|
| 15| pingpang|female|shanghai| 64| 169| 63|
| 16| football| male| dalian| 30| 177| 76|
| 17|basketball| male|shanghai| 22| 180| 80|
| 18| football| male| dalian| 16| 173| 72|
| 19|basketball| male|shanghai| 23| 176| 73|
| 20| pingpang| male|shanghai| 56| 171| 71|
+---+----------+------+--------+-----------------+------+------+
4和14行被删除
//对df的schema进行调整
val formatDF = addressDf.select(
col("id").cast("int"),
col("hobby").cast("String"),
col("sex").cast("String"),
col("address").cast("String"),
col("age").cast("Double"),
col("height").cast("Double"),
col("weight").cast("Double")
)
formatDF.printSchema()
输出:
root
|-- id: integer (nullable = true)
|-- hobby: string (nullable = true)
|-- sex: string (nullable = true)
|-- address: string (nullable = true)
|-- age: double (nullable = true)
|-- height: double (nullable = true)
|-- weight: double (nullable = true)
到此,数据预处理部分完成。
为了便于模型训练,在数据的特征转换中,我们需要对age、weight、height、address、sex这些特征做分桶处理。
使用Bucketizer类用来分桶处理,需要设置输入的列名和输出的列名,把定义的分桶区间作为这个类分桶的依据,最后给定需要做分桶处理的DataFrame
//2.1 对年龄进行分桶处理
//定义一个数组作为分桶的区间
val ageSplits = Array(Double.NegativeInfinity,18,35,60,Double.PositiveInfinity)
val bucketizerDF = new Bucketizer()
.setInputCol("age")
.setOutputCol("ageFeature")
.setSplits(ageSplits)
.transform(formatDF)
bucketizerDF.show()
查看分桶结果:
+---+----------+------+--------+-----------------+------+------+----------+
| id| hobby| sex| address| age|height|weight|ageFeature|
+---+----------+------+--------+-----------------+------+------+----------+
| 1| football| male| dalian| 12.0| 168.0| 55.0| 0.0|
| 2| pingpang|female|yangzhou| 21.0| 163.0| 60.0| 1.0|
| 3| football| male| dalian|28.42105263157895| 172.0| 70.0| 1.0|
| 5| pingpang|female|shanghai| 63.0| 170.0| 64.0| 3.0|
| 6| football| male| dalian| 30.0| 177.0| 76.0| 1.0|
| 7|basketball| male|shanghai| 25.0| 181.0| 90.0| 1.0|
| 8| football| male| dalian| 15.0| 172.0| 71.0| 0.0|
| 9|basketball| male|shanghai| 25.0| 179.0| 80.0| 1.0|
| 10| pingpang| male|shanghai| 55.0| 175.0| 72.0| 2.0|
| 11| football| male| dalian| 13.0| 169.0| 55.0| 0.0|
| 12| pingpang|female|yangzhou| 22.0| 164.0| 61.0| 1.0|
| 13| football| male| dalian| 23.0| 170.0| 71.0| 1.0|
| 15| pingpang|female|shanghai| 64.0| 169.0| 63.0| 3.0|
| 16| football| male| dalian| 30.0| 177.0| 76.0| 1.0|
| 17|basketball| male|shanghai| 22.0| 180.0| 80.0| 1.0|
| 18| football| male| dalian| 16.0| 173.0| 72.0| 0.0|
| 19|basketball| male|shanghai| 23.0| 176.0| 73.0| 1.0|
| 20| pingpang| male|shanghai| 56.0| 171.0| 71.0| 2.0|
+---+----------+------+--------+-----------------+------+------+----------+
基准为170 使用Binarizer类
//2.2 对身高做二值化处理
val heightDF = new Binarizer()
.setInputCol("height")
.setOutputCol("heightFeature")
.setThreshold(170) // 阈值
.transform(bucketizerDF)
heightDF.show()
查看处理后结果:
+---+----------+------+--------+-----------------+------+------+----------+-------------+
| id| hobby| sex| address| age|height|weight|ageFeature|heightFeature|
+---+----------+------+--------+-----------------+------+------+----------+-------------+
| 1| football| male| dalian| 12.0| 168.0| 55.0| 0.0| 0.0|
| 2| pingpang|female|yangzhou| 21.0| 163.0| 60.0| 1.0| 0.0|
| 3| football| male| dalian|28.42105263157895| 172.0| 70.0| 1.0| 1.0|
| 5| pingpang|female|shanghai| 63.0| 170.0| 64.0| 3.0| 0.0|
| 6| football| male| dalian| 30.0| 177.0| 76.0| 1.0| 1.0|
| 7|basketball| male|shanghai| 25.0| 181.0| 90.0| 1.0| 1.0|
| 8| football| male| dalian| 15.0| 172.0| 71.0| 0.0| 1.0|
| 9|basketball| male|shanghai| 25.0| 179.0| 80.0| 1.0| 1.0|
| 10| pingpang| male|shanghai| 55.0| 175.0| 72.0| 2.0| 1.0|
| 11| football| male| dalian| 13.0| 169.0| 55.0| 0.0| 0.0|
| 12| pingpang|female|yangzhou| 22.0| 164.0| 61.0| 1.0| 0.0|
| 13| football| male| dalian| 23.0| 170.0| 71.0| 1.0| 0.0|
| 15| pingpang|female|shanghai| 64.0| 169.0| 63.0| 3.0| 0.0|
| 16| football| male| dalian| 30.0| 177.0| 76.0| 1.0| 1.0|
| 17|basketball| male|shanghai| 22.0| 180.0| 80.0| 1.0| 1.0|
| 18| football| male| dalian| 16.0| 173.0| 72.0| 0.0| 1.0|
| 19|basketball| male|shanghai| 23.0| 176.0| 73.0| 1.0| 1.0|
| 20| pingpang| male|shanghai| 56.0| 171.0| 71.0| 2.0| 1.0|
+---+----------+------+--------+-----------------+------+------+----------+-------------+
阈值设为 65
//2.3 对体重做二值化处理
val weightDF = new Binarizer()
.setInputCol("weight")
.setOutputCol("weightFeature")
.setThreshold(65)
.transform(heightDF)
weightDF.show()
这三个字段都是字符串,而字符串的形式在机器学习中是不适合做分析处理的,所以也需要对他们做特征转换(编码处理)。
//2.4 对性别进行labelEncode转换
val sexIndex = new StringIndexer()
.setInputCol("sex")
.setOutputCol("sexIndex")
.fit(weightDF)
.transform(weightDF)
//2.5对家庭地址进行labelEncode转换
val addIndex = new StringIndexer()
.setInputCol("address")
.setOutputCol("addIndex")
.fit(sexIndex)
.transform(sexIndex)
//2.6对地址进行one-hot编码
val addOneHot = new OneHotEncoder()
.setInputCol("addIndex")
.setOutputCol("addOneHot")
.fit(addIndex)
.transform(addIndex)
//2.7对兴趣字段进行LabelEncode处理
val hobbyIndexDF = new StringIndexer()
.setInputCol("hobby")
.setOutputCol("hobbyIndex")
.fit(addOneHot)
.transform(addOneHot)
hobbyIndexDF.show()
这里额外对地址做了一个one-hot处理。
将hobbyIndex列名称改成label,因为hobby在模型训练阶段用作标签。
//2.8修改列名
val resultDF = hobbyIndexDF.withColumnRenamed("hobbyIndex","label")
resultDF.show()
最终特征转换后的结果:
+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+
| id| hobby| sex| address| age|height|weight|ageFeature|heightFeature|weightFeature|sexIndex|addIndex| addOneHot|label|
+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+
| 1| football| male| dalian| 12.0| 168.0| 55.0| 0.0| 0.0| 0.0| 0.0| 0.0|(2,[0],[1.0])| 0.0|
| 2| pingpang|female|yangzhou| 21.0| 163.0| 60.0| 1.0| 0.0| 0.0| 1.0| 2.0| (2,[],[])| 1.0|
| 3| football| male| dalian|28.42105263157895| 172.0| 70.0| 1.0| 1.0| 1.0| 0.0| 0.0|(2,[0],[1.0])| 0.0|
| 5| pingpang|female|shanghai| 63.0| 170.0| 64.0| 3.0| 0.0| 0.0| 1.0| 1.0|(2,[1],[1.0])| 1.0|
| 6| football| male| dalian| 30.0| 177.0| 76.0| 1.0| 1.0| 1.0| 0.0| 0.0|(2,[0],[1.0])| 0.0|
| 7|basketball| male|shanghai| 25.0| 181.0| 90.0| 1.0| 1.0| 1.0| 0.0| 1.0|(2,[1],[1.0])| 2.0|
| 8| football| male| dalian| 15.0| 172.0| 71.0| 0.0| 1.0| 1.0| 0.0| 0.0|(2,[0],[1.0])| 0.0|
| 9|basketball| male|shanghai| 25.0| 179.0| 80.0| 1.0| 1.0| 1.0| 0.0| 1.0|(2,[1],[1.0])| 2.0|
| 10| pingpang| male|shanghai| 55.0| 175.0| 72.0| 2.0| 1.0| 1.0| 0.0| 1.0|(2,[1],[1.0])| 1.0|
| 11| football| male| dalian| 13.0| 169.0| 55.0| 0.0| 0.0| 0.0| 0.0| 0.0|(2,[0],[1.0])| 0.0|
| 12| pingpang|female|yangzhou| 22.0| 164.0| 61.0| 1.0| 0.0| 0.0| 1.0| 2.0| (2,[],[])| 1.0|
| 13| football| male| dalian| 23.0| 170.0| 71.0| 1.0| 0.0| 1.0| 0.0| 0.0|(2,[0],[1.0])| 0.0|
| 15| pingpang|female|shanghai| 64.0| 169.0| 63.0| 3.0| 0.0| 0.0| 1.0| 1.0|(2,[1],[1.0])| 1.0|
| 16| football| male| dalian| 30.0| 177.0| 76.0| 1.0| 1.0| 1.0| 0.0| 0.0|(2,[0],[1.0])| 0.0|
| 17|basketball| male|shanghai| 22.0| 180.0| 80.0| 1.0| 1.0| 1.0| 0.0| 1.0|(2,[1],[1.0])| 2.0|
| 18| football| male| dalian| 16.0| 173.0| 72.0| 0.0| 1.0| 1.0| 0.0| 0.0|(2,[0],[1.0])| 0.0|
| 19|basketball| male|shanghai| 23.0| 176.0| 73.0| 1.0| 1.0| 1.0| 0.0| 1.0|(2,[1],[1.0])| 2.0|
| 20| pingpang| male|shanghai| 56.0| 171.0| 71.0| 2.0| 1.0| 1.0| 0.0| 1.0|(2,[1],[1.0])| 1.0|
+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+
特征转换后的结果是一个多列数据,但不是所有的列都可以拿来用作机器学习的模型训练,特征选择就是要选择可以用来机器学习的数据。
使用VectorAssembler()可以将需要的列取出
//3.1选择特征
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("ageFeature","heightFeature","weightFeature","sexIndex","addIndex","label"))
.setOutputCol("features")
val scaler = new StandardScaler()
.setInputCol("features")
.setOutputCol("featureScaler")
.setWithStd(true) // 是否使用标准差
.setWithMean(false) // 是否使用中位数
// 特征筛选,使用卡方检验方法来做筛选
val selector = new ChiSqSelector()
.setLabelCol("label")
.setOutputCol("featuresSelector")
// 逻辑回归模型
val lr = new LogisticRegression().setLabelCol("label").setFeaturesCol("featuresSelector")
// 构造pipeline
val pipeline = new Pipeline().setStages(Array(vectorAssembler,scaler,selector,lr))
// 设置网络搜索最佳参数
val params = new ParamGridBuilder()
.addGrid(lr.regParam,Array(0.1,0.01)) //正则化参数
.addGrid(selector.numTopFeatures,Array(5,10,5)) //设置卡方检验最佳特征数
.build()
// 设置交叉检验
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator())
.setEstimatorParamMaps(params)
.setNumFolds(5)
模型训练前需要拆分一下训练集和测试集
val Array(trainDF,testDF) = resultDF.randomSplit(Array(0.8,0.2))
使用randomSplit方法可以完成拆分
val model = cv.fit(trainDF)
// 模型预测
val preddiction = model.bestModel.transform(testDF)
preddiction.show()
运行cv.fit(trainDF)的地方报错了 这个信息网上也没找到
Exception in thread "main" java.lang.NoClassDefFoundError: org/apache/spark/sql/catalyst/trees/BinaryLike
at java.lang.ClassLoader.defineClass1(Native Method)
at java.lang.ClassLoader.defineClass(ClassLoader.java:756)
at java.security.SecureClassLoader.defineClass(SecureClassLoader.java:142)
at java.net.URLClassLoader.defineClass(URLClassLoader.java:473)
at java.net.URLClassLoader.access$100(URLClassLoader.java:74)
at java.net.URLClassLoader$1.run(URLClassLoader.java:369)
at java.net.URLClassLoader$1.run(URLClassLoader.java:363)
at java.security.AccessController.doPrivileged(Native Method)
at java.net.URLClassLoader.findClass(URLClassLoader.java:362)
at java.lang.ClassLoader.loadClass(ClassLoader.java:418)
at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:355)
at java.lang.ClassLoader.loadClass(ClassLoader.java:351)
at org.apache.spark.ml.stat.SummaryBuilderImpl.summary(Summarizer.scala:251)
at org.apache.spark.ml.stat.SummaryBuilder.summary(Summarizer.scala:54)
at org.apache.spark.ml.feature.StandardScaler.fit(StandardScaler.scala:112)
at org.apache.spark.ml.feature.StandardScaler.fit(StandardScaler.scala:84)
at org.apache.spark.ml.Pipeline.$anonfun$fit$5(Pipeline.scala:151)
at org.apache.spark.ml.MLEvents.withFitEvent(events.scala:130)
at org.apache.spark.ml.MLEvents.withFitEvent$(events.scala:123)
at org.apache.spark.ml.util.Instrumentation.withFitEvent(Instrumentation.scala:42)
at org.apache.spark.ml.Pipeline.$anonfun$fit$4(Pipeline.scala:151)
at scala.collection.Iterator.foreach(Iterator.scala:943)
at scala.collection.Iterator.foreach$(Iterator.scala:943)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
at org.apache.spark.ml.Pipeline.$anonfun$fit$2(Pipeline.scala:147)
at org.apache.spark.ml.MLEvents.withFitEvent(events.scala:130)
at org.apache.spark.ml.MLEvents.withFitEvent$(events.scala:123)
at org.apache.spark.ml.util.Instrumentation.withFitEvent(Instrumentation.scala:42)
at org.apache.spark.ml.Pipeline.$anonfun$fit$1(Pipeline.scala:133)
at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
at scala.util.Try$.apply(Try.scala:213)
at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
at org.apache.spark.ml.Pipeline.fit(Pipeline.scala:133)
at org.apache.spark.ml.Pipeline.fit(Pipeline.scala:93)
at org.apache.spark.ml.Estimator.fit(Estimator.scala:59)
at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$7(CrossValidator.scala:174)
at scala.runtime.java8.JFunction0$mcD$sp.apply(JFunction0$mcD$sp.java:23)
at scala.concurrent.Future$.$anonfun$apply$1(Future.scala:659)
at scala.util.Success.$anonfun$map$1(Try.scala:255)
at scala.util.Success.map(Try.scala:213)
at scala.concurrent.Future.$anonfun$map$1(Future.scala:292)
at scala.concurrent.impl.Promise.liftedTree1$1(Promise.scala:33)
at scala.concurrent.impl.Promise.$anonfun$transform$1(Promise.scala:33)
at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:64)
at org.sparkproject.guava.util.concurrent.MoreExecutors$SameThreadExecutorService.execute(MoreExecutors.java:293)
at scala.concurrent.impl.ExecutionContextImpl$$anon$4.execute(ExecutionContextImpl.scala:138)
at scala.concurrent.impl.CallbackRunnable.executeWithValue(Promise.scala:72)
at scala.concurrent.impl.Promise$KeptPromise$Kept.onComplete(Promise.scala:372)
at scala.concurrent.impl.Promise$KeptPromise$Kept.onComplete$(Promise.scala:371)
at scala.concurrent.impl.Promise$KeptPromise$Successful.onComplete(Promise.scala:379)
at scala.concurrent.impl.Promise.transform(Promise.scala:33)
at scala.concurrent.impl.Promise.transform$(Promise.scala:31)
at scala.concurrent.impl.Promise$KeptPromise$Successful.transform(Promise.scala:379)
at scala.concurrent.Future.map(Future.scala:292)
at scala.concurrent.Future.map$(Future.scala:292)
at scala.concurrent.impl.Promise$KeptPromise$Successful.map(Promise.scala:379)
at scala.concurrent.Future$.apply(Future.scala:659)
at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$6(CrossValidator.scala:182)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:198)
at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$4(CrossValidator.scala:172)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:198)
at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$1(CrossValidator.scala:166)
at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
at scala.util.Try$.apply(Try.scala:213)
at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
at org.apache.spark.ml.tuning.CrossValidator.fit(CrossValidator.scala:137)
at org.example.SparkML.SparkMl01$.main(SparkMl01.scala:147)
at org.example.SparkML.SparkMl01.main(SparkMl01.scala)
Caused by: java.lang.ClassNotFoundException: org.apache.spark.sql.catalyst.trees.BinaryLike
at java.net.URLClassLoader.findClass(URLClassLoader.java:387)
at java.lang.ClassLoader.loadClass(ClassLoader.java:418)
at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:355)
at java.lang.ClassLoader.loadClass(ClassLoader.java:351)
package org.example.SparkML
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{Binarizer, Bucketizer, ChiSqSelector, OneHotEncoder, StandardScaler, StringIndexer, VectorAssembler}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
/**
* 数据挖掘的过程
* 1.数据预处理
* 2.特征转换(编码。。。)
* 3.特征选择
* 4.训练模型
* 5.模型预测
* 6.评估预测结果
*/
object SparkMl01 {
def main(args: Array[String]): Unit = {
// 定义spark对象
val spark = SparkSession.builder().appName("兴趣预测").master("local").getOrCreate()
import spark.implicits._
val df=spark.read.format("CSV").option("header",true).load("C:/Users/35369/Desktop/hobby.csv")
//1.数据预处理,补全空缺的年龄
val ageNaDF = df.select("age").na.drop()
val mean = ageNaDF.describe("age").select("age").collect()(1)(0).toString
val ageFilledDF = df.na.fill(mean,List("age"))
//address为空的行直接删除
val addressDf = ageFilledDF.na.drop()
//对df的schema进行调整
val formatDF = addressDf.select(
col("id").cast("int"),
col("hobby").cast("String"),
col("sex").cast("String"),
col("address").cast("String"),
col("age").cast("Double"),
col("height").cast("Double"),
col("weight").cast("Double")
)
//2.特征转换
//2.1 对年龄进行分桶处理
//定义一个数组作为分桶的区间
val ageSplits = Array(Double.NegativeInfinity,18,35,60,Double.PositiveInfinity)
val bucketizerDF = new Bucketizer()
.setInputCol("age")
.setOutputCol("ageFeature")
.setSplits(ageSplits)
.transform(formatDF)
//2.2 对身高做二值化处理
val heightDF = new Binarizer()
.setInputCol("height")
.setOutputCol("heightFeature")
.setThreshold(170) // 阈值
.transform(bucketizerDF)
//2.3 对体重做二值化处理
val weightDF = new Binarizer()
.setInputCol("weight")
.setOutputCol("weightFeature")
.setThreshold(65)
.transform(heightDF)
//2.4 对性别进行labelEncode转换
val sexIndex = new StringIndexer()
.setInputCol("sex")
.setOutputCol("sexIndex")
.fit(weightDF)
.transform(weightDF)
//2.5对家庭地址进行labelEncode转换
val addIndex = new StringIndexer()
.setInputCol("address")
.setOutputCol("addIndex")
.fit(sexIndex)
.transform(sexIndex)
//2.6对地址进行one-hot编码
val addOneHot = new OneHotEncoder()
.setInputCol("addIndex")
.setOutputCol("addOneHot")
.fit(addIndex)
.transform(addIndex)
//2.7对兴趣字段进行LabelEncode处理
val hobbyIndexDF = new StringIndexer()
.setInputCol("hobby")
.setOutputCol("hobbyIndex")
.fit(addOneHot)
.transform(addOneHot)
//2.8修改列名
val resultDF = hobbyIndexDF.withColumnRenamed("hobbyIndex","label")
//3 特征选择
//3.1选择特征
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("ageFeature","heightFeature","weightFeature","sexIndex","addOneHot"))
.setOutputCol("features")
//3.2特征进行规范化处理
val scaler = new StandardScaler()
.setInputCol("features")
.setOutputCol("featureScaler")
.setWithStd(true) // 是否使用标准差
.setWithMean(false) // 是否使用中位数
// 特征筛选,使用卡方检验方法来做筛选
val selector = new ChiSqSelector()
.setFeaturesCol("featureScaler")
.setLabelCol("label")
.setOutputCol("featuresSelector")
// 逻辑回归模型
val lr = new LogisticRegression().setLabelCol("label").setFeaturesCol("featuresSelector")
// 构造pipeline
val pipeline = new Pipeline()
.setStages(Array(vectorAssembler,scaler,selector,lr))
// 设置网络搜索最佳参数
val params = new ParamGridBuilder()
.addGrid(lr.regParam,Array(0.1,0.01)) //正则化参数
.addGrid(selector.numTopFeatures,Array(5,10,5)) //设置卡方检验最佳特征数
.build()
// 设置交叉检验
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator())
.setEstimatorParamMaps(params)
.setNumFolds(5)
// 模型训练
val Array(trainDF,testDF) = resultDF.randomSplit(Array(0.8,0.2))
trainDF.show()
testDF.show()
val model = cv.fit(trainDF)
//生成模型
// val model = pipeline.fit(trainDF)
// val prediction = model.transform(testDF)
// prediction.show()
// 模型预测
// val preddiction = model.bestModel.transform(testDF)
// preddiction.show()
spark.stop()
}
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>untitled</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>2.12.18</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>3.0.0-preview2</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.12</artifactId>
<version>3.1.2</version>
<!-- <scope>provided</scope>-->
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>3.0.0-preview2</version>
<!-- <scope>compile</scope>-->
</dependency>
<!-- <dependency>-->
<!-- <groupId>mysql</groupId>-->
<!-- <artifactId>mysql-connector-java</artifactId>-->
<!-- <version>8.0.16</version>-->
<!-- </dependency>-->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.12</artifactId>
<version>3.5.0</version>
<!-- <scope>compile</scope>-->
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>2.4.1</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>com.xxg.Main</mainClass>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>