spark读取MySQL数据机器学习预测存入MySQL

发布时间:2024年01月17日


import org.apache.spark.SparkConf
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.sql.{SaveMode, SparkSession}
import java.sql.{Connection, DriverManager, Statement}
import java.util.Properties

import org.apache.spark.sql.DataFrame

object Spqrk_lr_predict_mysql_3 {
  def main(args: Array[String]): Unit = {
    //设置spark环境
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()
    // mysql导入数据
    val df = spark.read
      .format("jdbc")
      .option("url", "jdbc:mysql://hadoop102:3306/localstreamdata?characterEncoding=utf8&useSSL=false")
      .option("driver", "com.mysql.jdbc.Driver")
      .option("user", "root")
      .option("password", "000000")
      .option("dbtable", "normal_data")
      .load()
    //构建特征和标签,划分训练集和测试集
    val convertedDF = df.select("stream_id","stream_money", "stream_consume_type", "stream_time_date", "stream_seconds", "stream_is_new", "stream_is_normal")
    // 使用StringIndexer将字符串类型的列转换为数值类型
    val stringIndexer = new StringIndexer().setInputCol("stream_id").setOutputCol("indexed_stream_id").fit(convertedDF)
    val convertedDF1 = stringIndexer.transform(convertedDF.as("df1"))
    val selectedDF = convertedDF1.withColumn("stream_is_normal", col("stream_is_normal").cast("integer"))
    //VectorAssembler将多列特征合并为一个“features” 特征向量列。features特征向量将成为机器学习算法的输入。
    val assembler = new VectorAssembler()
      .setInputCols(Array("indexed_stream_id","stream_money", "stream_consume_type", "stream_time_date", "stream_seconds", "stream_is_new"))
      .setOutputCol("features")
    val featureDF = assembler.transform(selectedDF) //assembler.transform() 方法会生成一个新的数据框,将旧数据框的列保留,并在新数据框中添加新的列。
    val labelIndexer = new StringIndexer().setInputCol("stream_is_normal").setOutputCol("label")
    val featureLabelDF = labelIndexer.fit(featureDF).transform(featureDF)
    // 划分训练集和测试集
    val Array(trainData, testData) = featureLabelDF.randomSplit(Array(0.1, 0.9), seed = 1234)

    // 创建随机森林分类器
    val lr = new LogisticRegression().setFeaturesCol("features").setLabelCol("label")
    // 创建参数网格,这里只列举了几个示例参数
    val paramGrid = new ParamGridBuilder().addGrid(lr.maxIter, Array(10, 20, 30)).addGrid(lr.regParam, Array(0.01, 0.1, 1.0)).build()
    // 创建管道
    val pipeline = new Pipeline().setStages(Array(lr))
    // 使用训练验证拆分进行网格搜索
    val evaluator = new BinaryClassificationEvaluator().setLabelCol("label").setRawPredictionCol("rawPrediction").setMetricName("areaUnderROC")
    val trainValidationSplit = new TrainValidationSplit().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setTrainRatio(0.8)

    // 训练验证拆分进行网格搜索
    val model = trainValidationSplit.fit(trainData)
    // 在测试集上进行评估
    val result = model.transform(testData)
    val areaUnderROC = evaluator.evaluate(result)
    println(s"Area Under ROC: $areaUnderROC")
    // 保存最佳模型
    val bestModel = model.bestModel.asInstanceOf[PipelineModel]
    val lrModel = bestModel.stages(0).asInstanceOf[LogisticRegressionModel]
    result.show()
    result.printSchema()
    val rowCount = result.count()
    println(s"Result has $rowCount rows.")
    df.show()
    df.printSchema()
    val dfCount = df.count()
    println(s"df has $dfCount rows.")

    // 从 result 中选择 stream_id 和 prediction 列,并形成新的 DataFrame
    val predictionDF = result.select("stream_id", "prediction")
    // 将 df 与 predictionDF 按 stream_id 列进行左连接
    val mergedDF = df.join(predictionDF, Seq("stream_id"), "left")
    // 假设 mergedDF 是你要操作的 DataFrame
    val updatedDFWithoutPrediction = mergedDF.drop("prediction","stream_consume_location", "stream_sign_location")
    updatedDFWithoutPrediction.show()
    val mergedDFCount = updatedDFWithoutPrediction.count()
    println(s"mergedDFCount has $mergedDFCount rows.")




    updatedDFWithoutPrediction.write
      .format("jdbc")
      .option("url", "jdbc:mysql://hadoop102:3306/localstreamdata?characterEncoding=utf8&useSSL=false")
      .option("driver", "com.mysql.jdbc.Driver")
      .option("user", "root")
      .option("password", "000000")
      .option("dbtable", "normal_data2")
      .mode(SaveMode.Overwrite)  // 保存模式,此处选择覆盖模式,如果表已存在,则会覆盖
      .save()

  }


}


文章来源:https://blog.csdn.net/qq_45972323/article/details/135636710
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。