【Spark-ML源码解析】Word2Vec

发布时间:2023年12月20日

前言

在阅读源码之前,需要了解Spark机器学习Pipline的概念。

相关阅读:SparkMLlib之Pipeline介绍及其应用

这里比较核心的两个概念是:Transformer和Estimator
Transformer包括特征转换和学习后的模型两种情况,用来将一个DataFrame转换成另一个DataFrame;
Estimator接收一个DataFrame并输出一个模型(Transformer)。

Word2Vec类是一个Estimator,实现了fit方法,返回模型Word2VecModel,即一个Transformer,该类实现了transform方法。

代码逐行注释见文章:链接
哈夫曼编码复习:哈夫曼编码
阅读中有任何疑问,请参考PDF《word2vec 中的数学原理详解》

Word2VecBase——Word2Vec和Word2VecModel的参数类

参数名介绍默认大小
vectorSize单词的Embedding维度100
windowSize窗口大小,上下文单词在[-window,window]5
numPartitions训练数据的分区数1
minCount单词出现的最低次数5
maxSentenceLength单个序列的最大长度,所有序列合并后根据该值截断1000
stepSize优化器的步长0.025
maxIter最大迭代次数1
inputCol只支持string类型的数据1
outputCol训练数据的分区数1
seed随机种子this.getClass.getName.hashCode.toLong

Word2Vec

Word2Vec训练一个模型Map(String,Vector)

ml包——流程构建

ml包中的Word2Vec仅仅是“格式化”成了新版Pipline流程,具体的模型训练代码还是调用的mllib包的Word2Vec。

基础功能:

  • 各种set方法。设置Word2VecBase里的参数,如setInputCol、setOutputCol等
  • 重写PipelineStage类的transformSchema方法。用来check输入列的类型以及生成输出列
  • 重写Estimator的copy方法。复制一个Word2Vec对象,UID相同,Embedding矩阵和参数不同

核心功能:

  • 重写Estimator的fit方法,调用mllib包的Word2Vec

其余功能:

  • 实现DefaultParamsWritable,具有模型保存的功能。

mllib包——模型训练逻辑

从注释来看,该实现完全重写的C版本的。实现了skip-gram模型,并使用分层softmax方法来训练模型。

词库实体类

首先映入眼帘的是一个VocabWord的case class,包含属性如下:

  • word——词
  • cn——单词出现次数
  • point——存的是从根节点到这个词对应的叶子节点的路径经过的节点,最大40
  • code——Huffman编码,最大40
  • codeLen——路径长度

fit方法

整体流程:

  1. 调用learnVocab方法。初始化词表vocab(VocabWord数组)和vocabHash(词->索引映射Map,根据词频排序),输出一些统计信息。

这块会拉到driver节点计算,因此driver内存设置和词表大小相关。

  1. 调用createBinaryTree方法。

    这一步按照词频构建出一个Huffman树。这里为每个词保存point和code属性,point是每个单词的路径经过的非叶子节点,每个节点为一个二分类器,对应一个参数向量θ;code则为值为{0,1}的Huffman编码,也即为路径上每个二分类器的真实Label,是logloss计算公式中的一个参数。

skip-gram下,通过中心词预测上下文词时,预测概率就是上下文词对应叶子节点的路径上二分类器预测的概率乘积;损失函数就是这条路径上一系列二分类的logloss之和。也就是说层次哈夫曼树方法将Softmax计算转成了logN复杂度的二分类计算。

  1. 广播exp计算表、词对象表和词索引表

    exp计算表是为了加速计算。x大于6和小于-6的时候,simoid函数值都无限接近1和0。将范围限制在[-6,6]之间,将该区间划分为1000份,近似计算好sigmoid(x)对应的取值,能够显著减少计算量。

  2. 执行doFit方法(分布式梯度上升法参数训练)

    这里首先需要明确是梯度上升法学习的参数有哪些?词向量矩阵V和二分类器参数θ。

    doFit流程:

    1. 构造sentences

    2. 梯度上升法训练
      梯度下降法步骤
      注:每计算一个context word就更新一次v(·),源码中neu1e为e,syn1为θ,syn0为v。

Word2VecModel

这块的代码没啥可看的。在doFit训练时通过wordIndex和wordVector两个参数创建。重写的transform函数,主要功能是将每个单词转换成一个向量,即查map。

文章来源:https://blog.csdn.net/qq_30057549/article/details/134852042
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。