Deeplearning4j 实战 (22):基于DSSM的语义匹配建模

发布时间:2024年01月01日

Deeplearning4j 实战 (22):基于DSSM的语义匹配建模

Eclipse Deeplearning4j GitChat课程:Deeplearning4j 快速入门_专栏
Eclipse Deeplearning4j 系列博客:万宫玺的专栏_wangongxi_CSDN博客
Eclipse Deeplearning4j Github:https://github.com/eclipse/deeplearning4j
Eclipse Deeplearning4j 社区:https://community.konduit.ai/

DSSM是微软在2013年提出的,最早用于搜索引擎语义召回的双塔模型。目前在工业界也广泛用于推荐召回、搜索相关性排序、语义召回等环节。DSSM是一个轻量级模型,在线上serving的时候,可以通过对query向量和doc向量计算内积,得到的相似值用来衡量query和doc的相似度,从而进行进一步的排序。下面就分别从DSSM模型结构、基于DL4J的DSSM建模、对开源数据集LCQMC的建模等几个环节来介绍如何使用DSSM模型。当然,由于DSSM模型的论文发表时间较早,发表时给出的模型结构比较简单,在我们具体实现的时候,会做一些调整,具体在介绍模型搭建的部分会提到。

1. DSSM模型简述

论文中,query和doc分别通过各自独立的神经网络映射成一个语义向量。需要注意的是,原论文中doc是一个包含正样本和负样本的集合。正样本取1个,负样本取4个。论文中有提到,正样本是搜素后被点击的样本,负样本则是随机选取的搜索未被点击的样本集合。通过分别计算query的语义向量和正负doc样本的语义向量的余弦相似度,再通过softmax函数得到正负样本的概率分布后,和label计算交叉熵损失。这就是DSSM模型的大致的idea。下面先看下论文中对于DSSM描绘的架构图:
在这里插入图片描述
通过模型架构图可以看到,论文中是使用最简单的MLP对输入进行映射。这里需要提一下word hashing的操作。由于2013年时候word embedding技术还不是较广泛的使用,因此论文中的word hashing是在n-gram语言模型的基础上,通过hash操作将接近50W的词表计算每个词的索引值。这在当时是一种比较高效的做法,目前由于硬件的进步以及embedding技术的进一步成熟,可以直接使用预训练的embedding向量或者做端到端的建模。因此,在第三部分中构建DSSM模型的过程中,我们也是使用的端到端的方案。
在这里插入图片描述

上面这张截图中模型训练的有关描述。就像在本节开始时候提到的,通过softmax计算query和每个doc的余弦相似度的值归一化概率分布。由于softmax函数与cosine相似度的一致性,因此相似度越高的query-doc pair,其softmax值也会越接近于1。在损失函数部分,使用的是经典的log loss。这部分没啥说的。
另外需要说明的是,从ranking loss的角度,论文中的loss应当属于list-wise loss。当然,如果将负样本减少到一个或者doc集合中只有一个正样本或负样本(softmax更改为sigmoid函数),那就退化成pair-wise loss或者point-wise loss。为了方便起见,在第三部分的建模过程中,我们会使用point-wise loss。
对于搜索场景来说,双塔的输入分别是query和doc。对于推荐场景来说,双塔的输入可以是user和item或者item和item,用于U2I的召回或者I2I的召回。

2. LCQMC数据集

LCQMC是哈工大和阿里共同开源的用于QA的数据集,详情可参见论文。下载链接为:地址。压缩包中共有三个文件,三个文件都是以制表符作为分隔符。我们先来看下用于训练的部分数据的截图:
在这里插入图片描述
文件中有三列。最后一列用1或者0来代表 text_a 和 text_b两列文本的是否相关。如果把text_a列文本看作是query,那text_b列可以看作是doc。用于验证的文件中的内容也和训练文件中的数据格式相似,这里就不做另外截图了。

最后提一下,训练样本数量是:238767,验证的样本数量是:12501。

3. 基于DL4J的DSSM模型构建

在第一部分中,我们提到DSSM的论文中双塔内部是使用MLP结构。考虑到MLP结构的单一性,我们使用Embedding+LSTM+MLP的结构作为双塔的内部结构。虽然query和doc对应的塔结构相同,但是不做参数的共享。另外,由于LCQMC数据集中label是1或者0,因此我们将DSSM的输出层改为sigmoid + binary cross entropy loss。具体我们先给出代码片段:

private static ComputationGraph getDSSM(final int QUERY_VOCAB_SIZE, final int DOC_VOCAB_SIZE, 
										final int VECTOR_SIZE) {
	ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
		    .updater(new Adam(5 * 1e-3))
		    .weightInit(WeightInit.XAVIER)
		    .seed(12345L)
		    .graphBuilder()
		    .addInputs("query", "doc")
		    .setInputTypes(InputType.recurrent(QUERY_VOCAB_SIZE), InputType.recurrent(DOC_VOCAB_SIZE))
		    .addLayer("query-embedding", new EmbeddingSequenceLayer.Builder()
		    					.nIn(QUERY_VOCAB_SIZE + 1)
		    					.nOut(VECTOR_SIZE).build(), "query")
		    .addLayer("query-embedding-lstm", new LSTM.Builder()
					.nIn(VECTOR_SIZE)
					.nOut(VECTOR_SIZE).activation(Activation.TANH).build(), "query-embedding")
		    .addLayer("doc-embedding", new EmbeddingSequenceLayer.Builder()
								.nIn(DOC_VOCAB_SIZE + 1)
								.nOut(VECTOR_SIZE).build(), "doc")
		    .addLayer("doc-embedding-lstm", new LSTM.Builder()
								.nIn(VECTOR_SIZE)
								.nOut(VECTOR_SIZE).activation(Activation.TANH).build(), "doc-embedding")
		    .addVertex("query-embedding-lstm-last-output", new LastTimeStepVertex("query"), "query-embedding-lstm")
		    .addVertex("doc-embedding-lstm-last-output", new LastTimeStepVertex("doc"), "doc-embedding-lstm")
		    .addLayer("query-output", new DenseLayer.Builder()
										.nIn(VECTOR_SIZE)
										.nOut(VECTOR_SIZE / 2)
										.activation(Activation.LEAKYRELU)
										.build(), "query-embedding-lstm-last-output")
		    .addLayer("doc-output", new DenseLayer.Builder()
										.nIn(VECTOR_SIZE)
										.nOut(VECTOR_SIZE / 2)
										.activation(Activation.LEAKYRELU)
										.build(), "doc-embedding-lstm-last-output")
		    .addVertex("query-output-l2-norm", new L2NormalizeVertex(), "query-output")
		    .addVertex("doc-output-l2-norm", new L2NormalizeVertex(), "doc-output")
		    .addVertex("cosing-similar", new ElementWiseVertex(ElementWiseVertex.Op.Product), "query-output-l2-norm", "doc-output-l2-norm")
		    .addLayer("out", new OutputLayer.Builder()
                    				.lossFunction(LossFunctions.LossFunction.XENT)	//bce
                    				.nIn(VECTOR_SIZE / 2).nOut(1).activation(Activation.SIGMOID).build(), "cosing-similar")
		    .setOutputs("out")
		    .build();

	ComputationGraph net = new ComputationGraph(conf);
	net.setListeners(new ScoreIterationListener(1));
	net.init();
	return net;
}

由于存在两个输入,因此使用DL4J中的ComputationGraph。这里需要说明的有几点:

  • LastTimeStepVertex的作用:获取LSTM最后一个time step输出的张量
  • L2NormalizeVertex的作用:L2归一化,将query和doc的向量转化为单位向量
  • ElementWiseVertex的作用:通过设置Op为点积,实际为计算query和doc单位向量的内积,因此L2NormalizeVertex + ElementWiseVertex联合起来的作用是计算向量间的余弦相似度值
  • 输出端使用sigmoid + bce 作point-wise的损失函数
    在这里插入图片描述

上面的截图中通过summary接口打印的模型结构和待训练参数。可见待训练参数68W。

另外,对于该静态方法,输入的几个参数QUERY_VOCAB_SIZE,DOC_VOCAB_SIZE,VECTOR_SIZE分别代表LCQMC数据集中text_a的词表大小和text_b的词表大小,以及词向量的大小。

需要指出的是,在第四部分进行建模的操作中,我们使用中文单字作为query和doc的最小粒度特征,而不做分词的处理。

4. DSSM模型训练和评估

首先介绍下数据处理的部分:

  • 读取训练文件,构建中文单字和单字的索引,存储在map结构中。同时记录最长的文本长度,用于后续的padding操作。
  • 再次读取文件,对每条记录构建MultiDataSet对象,并存储在LinkedList对象中。MultiDataSet对象中会存储query和doc作为输入,label作为输出,此外还有query和doc的mask张量,用于统一变长文本的处理。

我们看下具体的实现逻辑:

class DataSetInfo{
	public Map<String,Integer> queryDict = new TreeMap<>();
	public Map<String,Integer> docDict = new TreeMap<>();
	public int queryMaxLen = 0;
	public int docMaxLen = 0;
}

private static DataSetInfo preprocess(String filePath) {
	DataSetInfo info = new DataSetInfo();
	try(BufferedReader br = Files.newReader(new File(filePath), Charset.forName("UTF-8"))){
		String line = null;
		int lineIndex = 0;
		while( (line = br.readLine()) != null ) {
			if( lineIndex == 0 ) {
				lineIndex++;
				continue;
			}
			String[] splits = line.split("\t");
			if( null == splits || splits.length != 3 )continue;
			String query = splits[0];
			String doc = splits[1];
			if( query != null && query.length() > 0 ) {
				info.queryMaxLen = Math.max(query.length(), info.queryMaxLen);
				for( char c : query.toCharArray() ) {
					String charStr = String.valueOf(c);
					if( !info.queryDict.containsKey(charStr) ) {
						int curIndex = info.queryDict.size();
						info.queryDict.put(charStr, curIndex);
					}
				}
			}
			if( doc != null && doc.length() > 0 ) {
				info.docMaxLen = Math.max(doc.length(), info.docMaxLen);
				for( char c : doc.toCharArray() ) {
					String charStr = String.valueOf(c);
					if( !info.docDict.containsKey(charStr) ) {
						int curIndex = info.docDict.size();
						info.docDict.put(charStr, curIndex);
					}
				}
			}
		}
	}catch(Exception ex) {
		ex.printStackTrace();
	}finally {
		int curIndex = info.queryDict.size();
		info.queryDict.put("UNK", curIndex);
		//
		curIndex = info.docDict.size();
		info.docDict.put("UNK", curIndex);
	}
	return info;
}

这部分处理逻辑比较清晰,主要是先定义个DataSetInfo的类,里面包含了单字和单字索引的映射关系,还有最大文本长度。在finally部分,我们使用UNK代表所有未登录词。接着看下MultiDataSet的构造:

private static List<org.nd4j.linalg.dataset.api.MultiDataSet> getMultiDataIter(String filePath, DataSetInfo dataInfo) {
	List<org.nd4j.linalg.dataset.api.MultiDataSet> list = new LinkedList<>();
	try(BufferedReader br = Files.newReader(new File(filePath), Charset.forName("UTF-8"))){
		String line = null;
		int lineIndex = 0;
		while( (line = br.readLine()) != null ) {
			if( lineIndex == 0 ) {
				lineIndex++;
				continue;
			}
			String[] splits = line.split("\t");
			String query = splits[0];
			String doc = splits[1];
			String label = splits[2];
			if( query == null || query.isEmpty() ||
				doc == null || doc.isEmpty() || 
				label == null)continue;
			//
			double[][] queryIndexArray = new double[1][dataInfo.queryMaxLen];
			double[][] docIndexArray = new double[1][dataInfo.docMaxLen];
			double[][] queryIndexMaskArray = new double[1][dataInfo.queryMaxLen];
			double[][] docIndexMaskArray = new double[1][dataInfo.docMaxLen];
			double[][] labelIndexArray = new double[1][1];
			//
			for( int i = 0; i < query.length(); ++i ) {
				queryIndexArray[0][i] = dataInfo.queryDict.getOrDefault(String.valueOf(query.charAt(i)),
													dataInfo.queryDict.get("UNK"));
				queryIndexMaskArray[0][i] = 1.0;
			}
			for( int i = 0; i < doc.length(); ++i ) {
				docIndexArray[0][i] = dataInfo.docDict.getOrDefault(String.valueOf(doc.charAt(i)),
													dataInfo.docDict.get("UNK"));
				docIndexMaskArray[0][i] = 1.0;
			}
			labelIndexArray[0][0] = Double.parseDouble(label);
			//
			org.nd4j.linalg.dataset.api.MultiDataSet mds = 
					new MultiDataSet(new INDArray[] {Nd4j.create(queryIndexArray), Nd4j.create(docIndexArray)},
									new INDArray[] {Nd4j.create(labelIndexArray)},
									new INDArray[] {Nd4j.create(queryIndexMaskArray), Nd4j.create(docIndexMaskArray)},
									null);
			list.add(mds);
		}
	}catch(Exception ex) {
		ex.printStackTrace();
	}
	return list;
}

该部分逻辑主要是通过一个静态方法来读取训练文本中的每一行数据,并且针对text_a和text_b以及label转换成一个MultiDataSet对象,并存储在一个LinkedList对象中。需要注意的是Mask部分的处理。Mask张量中用1.0代表有效,0.0代表无效的部分。下面我们看下训练建模和评估的部分。

final int batchSize = 256;
final int embedding_size = 64;
DataSetInfo dataInfo = preprocess("data/lcqmc/train.tsv");
ComputationGraph dssm = getDSSM(dataInfo.queryDict.size(), dataInfo.docDict.size(), embedding_size);
System.out.println(dssm.summary());
List<org.nd4j.linalg.dataset.api.MultiDataSet> trainDataList = getMultiDataIter("data/lcqmc/train.tsv", dataInfo);
List<org.nd4j.linalg.dataset.api.MultiDataSet> testDataList = getMultiDataIter("data/lcqmc/test.tsv", dataInfo);
System.out.println("Finish Loading Train Data");
for(int epoch = 0; epoch < 5; ++epoch) {
	Collections.shuffle(trainDataList);
	MultiDataSetIterator trainIter = new IteratorMultiDataSetIterator(trainDataList.iterator(), batchSize);
	dssm.fit(trainIter);
	Evaluation eval = dssm.evaluate(new IteratorMultiDataSetIterator(testDataList.iterator(), batchSize));
	System.out.println(eval);
}

通过10个epoch的训练,我们最终在验证集上得到70%左右的准确率, loss值在0.4左右。
在这里插入图片描述

5. 总结

DSSM是一个经典的双塔模型,但其也有明显的缺点,就是两个塔之间是独立的,没有信息的交叉。这种信息的交叉对应推荐场景来说是很重要的。DSSM论文中的结构比较简单,是MLP为主,且输入层使用词袋模型进行处理,这其实忽略的上下文的语义信息,因此我们在实现的时候,使用LSTM模型来捕获序列的完整语义信息。当然,由于时间原因,我们这边并没有做分词处理,相信经过分词处理,在LCQMC数据集上的准确率可以进一步得到提升。另外,双塔的结构可以很灵活,内部可以直接上BERT来做,这里变体就太多,不做过多陈述了。

文章来源:https://blog.csdn.net/wangongxi/article/details/135320951
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。