使用 EmbeddingBag 和 Embedding 完成词嵌入

发布时间:2024年01月12日

🍨 本文为[🔗365天深度学习训练营学习记录博客\n🍦 参考文章:365天深度学习训练营\n🍖 原作者:[K同学啊 | 接辅导、项目定制]\n🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)

使用 EmbeddingBagEmbedding 完成词嵌入,首先需要处理文档中的文本,将其转换为适合进行词嵌入的格式,涉及到以下步骤:

  1. 文本清洗:移除文档中的特殊字符和标点符号,将文本统一为小写(如果适用)。
  2. 分词:将文本分割成单词或标记(tokens)。
  3. 建立词汇表:从分词后的文本中创建一个词汇表,每个唯一的单词对应一个索引。
  4. 文本向量化:将文本转换为数字形式,以便进行嵌入处理。

第二步,使用 EmbeddingBagEmbedding 层进行词嵌入。EmbeddingBag 层适用于处理变长的文本,它会计算所有嵌入向量的平均值或和。而 Embedding 层适用于单个单词或固定长度的序列。

目标文件:

实现代码:?

from collections import Counter
import torch
import torch.nn as nn
import re

# 清洗文本并进行分词
def tokenize(text):
    # 移除特殊字符和标点,并转换为小写
    text = re.sub(r'[^\w\s]', '', text).lower()
    # 分词
    return text.split()

# 创建词汇表
def create_vocab(text_tokens):
    vocab = Counter(text_tokens)
    vocab = sorted(vocab, key=vocab.get, reverse=True)
    vocab_to_int = {word: ii for ii, word in enumerate(vocab, 1)} # 索引从1开始
    return vocab_to_int

# 将文本转换为数字形式
def text_to_int(tokens, vocab_to_int):
    return [vocab_to_int[word] for word in tokens if word in vocab_to_int]

# 定义Embedding和EmbeddingBag层
def define_embedding_layers(vocab_size, embedding_dim=100):
    embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
    embedding_bag = nn.EmbeddingBag(num_embeddings=vocab_size, embedding_dim=embedding_dim, mode='mean')
    return embedding, embedding_bag

# 读取文件内容
file_path = 'D:/任务文件 (1).txt'
with open(file_path, 'r', encoding='utf-8') as file:
    file_content = file.read()

# 文本清洗和分词
tokens = tokenize(file_content)

# 创建词汇表
vocab_to_int = create_vocab(tokens)

# 将文本转换为数字形式
int_text = text_to_int(tokens, vocab_to_int)

# 定义嵌入层参数
embedding_dim = 100
vocab_size = len(vocab_to_int) + 1

# 定义Embedding和EmbeddingBag层
embedding, embedding_bag = define_embedding_layers(vocab_size, embedding_dim)

# 转换为tensor以供嵌入层使用
input_tensor = torch.tensor([int_text], dtype=torch.long)

# 使用Embedding和EmbeddingBag进行词嵌入
embedded = embedding(input_tensor)
embedded_bag = embedding_bag(input_tensor)

# 打印结果
print("Embedding shape:", embedded.shape)
print("EmbeddingBag shape:", embedded_bag.shape)

文章来源:https://blog.csdn.net/qq_60245590/article/details/135559809
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。