【HuggingFace Transformer库学习笔记】基础组件学习:Model

发布时间:2024年01月10日

基础组件——Model

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

1、模型加载与保存

from transformers import AutoConfig, AutoModel, AutoTokenizer

model = AutoModel.from_pretrained("hfl/rbt3", force_download=False)	# 为True时是强制下载

其他加载方式
模型下载

# 全部文件下载
!git clone "https://huggingface.co/hfl/rbt3"

# 指定文件下载,下载.bin结尾文件
!git lfs clone "https://huggingface.co/hfl/rbt3" --include="*.bin"

离线加载

model = AutoModel.from_pretrained("rbt3")

查看配置

model.config

BertConfig {
  "_name_or_path": "rbt3",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 3,
  "output_past": true,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.35.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128
}

使用config加载

config = AutoConfig.from_pretrained("rbt3")
config

BertConfig {
  "_name_or_path": "rbt3",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 3,
  "output_past": true,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.35.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128
}
# 查看模型是否要输出attention的结果
config.output_attentions

# 可用于更改model.config里的参数
from transformers import BertConfig

2、模型调用

sen = "弱小的我也有大梦想!"
tokenizer = AutoTokenizer.from_pretrained("rbt3")
inputs = tokenizer(sen, return_tensors="pt")    # 返回一个pytorch的tensor
inputs

{'input_ids': tensor([[ 101, 2483, 2207, 4638, 2769,  738, 3300, 1920, 3457, 2682, 8013,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

3、不带Model Head的模型调用

model = AutoModel.from_pretrained("rbt3", output_attentions=True)
output = model(**inputs)
output

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 0.6804,  0.6664,  0.7170,  ..., -0.4102,  0.7839, -0.0262],
         [-0.7378, -0.2748,  0.5034,  ..., -0.1359, -0.4331, -0.5874],
         [-0.0212,  0.5642,  0.1032,  ..., -0.3617,  0.4646, -0.4747],
         ...,
         [ 0.0853,  0.6679, -0.1757,  ..., -0.0942,  0.4664,  0.2925],
         [ 0.3336,  0.3224, -0.3355,  ..., -0.3262,  0.2532, -0.2507],
         [ 0.6761,  0.6688,  0.7154,  ..., -0.4083,  0.7824, -0.0224]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-1.2646e-01, -9.8619e-01, -1.0000e+00, -9.8325e-01,  8.0238e-01,
         -6.6268e-02,  6.6919e-02,  1.4784e-01,  9.9451e-01,  9.9995e-01,
         -8.3051e-02, -1.0000e+00, -9.8866e-02,  9.9980e-01, -1.0000e+00,
          9.9993e-01,  9.8291e-01,  9.5363e-01, -9.9948e-01, -1.3219e-01,
         -9.9733e-01, -7.7934e-01,  1.0720e-01,  9.8040e-01,  9.9953e-01,
         -9.9939e-01, -9.9997e-01,  1.4967e-01, -8.7627e-01, -9.9996e-01,
         -9.9821e-01, -9.9999e-01,  1.9396e-01, -1.1276e-01,  9.9359e-01,
         -9.9153e-01,  4.4752e-02, -9.8731e-01, -9.9942e-01, -9.9982e-01,
          2.9361e-02,  9.9847e-01, -9.2016e-03,  9.9999e-01,  1.7111e-01,
          4.5076e-03,  9.9998e-01,  9.9467e-01,  4.9721e-03, -9.0707e-01,
          6.9056e-02, -1.8141e-01, -9.8831e-01,  9.9668e-01,  4.9800e-01,
          1.2997e-01,  9.9895e-01, -1.0000e+00, -9.9990e-01,  9.9478e-01,
         -9.9989e-01,  9.9906e-01,  9.9820e-01,  9.9990e-01, -6.8953e-01,
          9.9990e-01,  9.9987e-01,  9.4563e-01, -3.7660e-01, -1.0000e+00,
          1.3151e-01, -9.7371e-01, -9.9997e-01, -1.3228e-02, -2.9800e-01,
         -9.9985e-01,  9.9662e-01, -2.0004e-01,  9.9997e-01,  3.6876e-01,
         -9.9997e-01,  1.5462e-01,  1.9264e-01,  8.9872e-02,  9.9996e-01,
          9.9998e-01,  1.5184e-01, -8.9713e-01, -2.1646e-01, -9.9922e-01,
         -4.9491e-01,  9.9957e-01,  9.9998e-01, -9.9998e-01,  9.9995e-01,
         -5.1678e-01,  5.2056e-02,  5.4613e-02, -9.9816e-01,  9.9328e-01,
          1.2717e-04, -1.3744e-01,  1.0000e+00,  9.9984e-01, -3.4417e-01,
         -9.9995e-01, -9.9573e-01,  9.9988e-01, -9.9981e-01,  6.3344e-02,
          1.0000e+00,  9.4779e-01,  1.0000e+00,  9.9946e-01,  9.9999e-01,
         -9.9999e-01, -4.3540e-01,  2.3526e-01, -9.9997e-01,  9.9905e-01,
         -9.9272e-01,  1.4150e-01, -9.3078e-01, -8.8246e-02, -1.2646e-02,
         -9.9999e-01,  1.8302e-02,  3.9718e-02, -9.8869e-01, -9.9944e-01,
         -9.9975e-01, -9.9994e-01,  9.9785e-01,  7.9386e-01,  2.7185e-01,
         -1.5316e-01,  9.0940e-02, -9.5427e-02, -1.0000e+00, -9.9974e-01,
         -9.9999e-01,  9.5742e-01, -3.5169e-01,  9.9779e-01, -9.9894e-01,
          9.9997e-01, -9.9997e-01,  9.9997e-01,  9.9414e-01, -2.7013e-01,
         -9.7769e-01, -1.1832e-01, -9.9976e-01, -4.3269e-02,  2.7017e-02,
          9.9011e-01,  9.9801e-01,  7.6135e-01, -9.8868e-01,  1.0000e+00,
         -9.9946e-01,  9.7542e-01,  1.4210e-01,  9.9955e-01,  1.0000e+00,
         -1.0000e+00,  2.5602e-01, -1.0000e+00,  6.9886e-01,  1.1957e-01,
          9.9996e-01,  9.9962e-01,  9.7632e-01,  9.9998e-01, -8.6662e-01,
         -9.9994e-01,  9.5777e-01, -1.0000e+00,  9.8048e-01,  1.0000e+00,
          9.6255e-02,  5.4608e-01,  9.9999e-01, -6.1723e-01,  9.9141e-01,
         -1.0398e-01, -1.9344e-01, -9.9981e-01,  2.0875e-01,  9.4846e-01,
          9.9600e-01, -9.9833e-01, -3.6391e-02,  9.8665e-01, -3.1239e-02,
          6.7723e-02, -9.9968e-01, -9.9970e-01,  9.9994e-01,  9.9983e-01,
          6.2746e-01, -2.7500e-01,  1.0000e+00, -1.1557e-01,  9.9997e-01,
         -7.4189e-02,  8.3064e-01, -8.6326e-02,  9.9989e-01,  1.6120e-01,
          8.7417e-01,  4.2869e-03,  9.9993e-01, -8.4737e-01, -9.9999e-01,
          8.9604e-02,  8.9435e-01,  1.0934e-01, -9.9877e-01,  2.1512e-01,
         -4.4630e-01,  9.9997e-01,  1.9113e-01, -9.8081e-01,  9.9929e-01,
         -9.9977e-01,  6.1149e-01, -1.0000e+00, -9.9892e-01,  9.9998e-01,
         -2.9081e-01, -1.0000e+00,  8.6111e-01,  1.0000e+00, -8.8875e-01,
          9.9958e-01, -2.4632e-01, -9.9994e-01, -1.4219e-02,  3.7028e-02,
         -1.0000e+00, -9.9450e-01, -1.0000e+00, -8.2727e-01, -1.4345e-01,
          9.9392e-01, -1.0000e+00,  1.1743e-01, -9.9999e-01,  9.9873e-01,
          9.9997e-01, -1.5349e-01,  1.7382e-01,  1.0000e+00, -3.5095e-01,
          1.3408e-01, -8.4305e-01,  3.7473e-01,  2.2783e-02,  9.9625e-01,
          3.2440e-01,  9.9899e-01, -9.9979e-01,  2.4282e-01,  8.5081e-01,
         -1.0000e+00, -1.0721e-01,  9.9331e-01,  2.8107e-02,  1.0824e-01,
         -1.8632e-01,  1.7009e-01,  9.5663e-01,  9.9947e-01,  1.0000e+00,
          9.9177e-01,  9.9999e-01,  9.9999e-01, -3.1200e-01, -9.9837e-01,
         -5.6503e-01,  2.3465e-01, -1.0000e+00, -9.8613e-01, -9.9979e-01,
          9.9075e-01,  1.1560e-01,  1.0000e+00, -1.0000e+00,  1.0000e+00,
         -9.6587e-01,  8.5970e-02, -5.3795e-02,  1.2931e-01, -5.4356e-01,
         -1.2560e-01,  9.9880e-01, -7.6849e-02,  9.9302e-01,  9.9631e-01,
         -4.9744e-03, -2.4950e-01,  2.0312e-01, -2.2919e-01,  9.9857e-01,
         -9.9750e-01,  9.9836e-01,  1.0468e-01,  9.9982e-01, -4.5313e-01,
         -1.0000e+00,  9.9977e-01, -9.9988e-01, -5.4165e-01, -9.9991e-01,
         -9.8466e-01,  9.0576e-02, -9.8760e-01,  7.2146e-01,  9.9684e-01,
          2.2268e-01,  1.4701e-01, -9.9999e-01, -9.6879e-01,  9.9483e-01,
          9.9992e-01, -9.9977e-01,  9.9892e-01,  9.9656e-01, -9.3349e-01,
          2.5862e-01,  9.7359e-01, -9.9937e-01,  9.8777e-01, -9.9999e-01,
          1.1818e-01,  9.9960e-01, -1.7951e-01, -9.9984e-01, -9.2495e-01,
         -2.2660e-02,  7.8255e-01, -2.6023e-02,  9.9999e-01, -1.2445e-02,
          1.5701e-01, -9.9998e-01, -9.9624e-01, -8.6672e-01,  3.4873e-01,
          9.9931e-01, -9.9999e-01, -6.6310e-02,  9.9949e-01, -9.9926e-01,
         -4.1633e-01,  4.3387e-02,  8.4618e-02, -8.7278e-02, -9.9765e-01,
         -9.9999e-01, -9.9998e-01,  9.9993e-01,  1.0225e-01, -5.4221e-04,
          9.9924e-01,  9.9998e-01,  9.9997e-01, -9.8936e-01,  9.3540e-01,
          9.9986e-01, -3.1887e-01,  1.1548e-01, -9.8294e-01,  1.4084e-01,
         -8.1032e-01, -9.9606e-01,  1.2704e-01,  2.7952e-01, -6.5889e-01,
         -9.9392e-01,  9.9999e-01,  9.9994e-01,  1.0000e+00, -1.0210e-01,
         -9.4733e-01,  8.3178e-01, -9.4359e-01, -9.9962e-01, -4.4847e-02,
          9.9938e-01, -9.9812e-01,  1.7198e-01,  7.5851e-02, -9.4664e-01,
          9.9917e-01, -9.9949e-01,  1.5547e-01, -1.0000e+00, -9.9998e-01,
         -9.9998e-01,  1.0000e+00,  9.2368e-02, -1.2598e-01, -9.9929e-01,
          1.0000e+00,  9.8569e-01, -9.6164e-01, -2.5984e-01,  9.9998e-01,
         -4.7267e-01, -8.6810e-01, -1.0000e+00, -9.9985e-01,  9.9819e-01,
          1.2791e-01,  9.9999e-01,  8.4013e-01, -9.9762e-01,  9.8651e-01,
          9.7417e-01,  3.1610e-01, -9.9945e-01, -9.9936e-01, -3.3195e-03,
          7.0084e-02,  1.5902e-01,  9.8477e-03, -5.9952e-02,  9.9992e-01,
         -3.2020e-02, -9.5302e-02, -3.2294e-01,  1.0000e+00,  8.7427e-01,
         -9.9866e-01, -6.7442e-01, -8.8977e-02, -9.9465e-01, -9.9605e-01,
...
           1.7911e-02, 4.8671e-01],
          [4.0732e-01, 3.8137e-02, 9.6832e-03,  ..., 4.4490e-02,
           2.2998e-02, 4.0793e-01],
          [1.7047e-01, 3.6989e-02, 2.3646e-02,  ..., 4.6833e-02,
           2.5233e-01, 1.6721e-01]]]], grad_fn=<SoftmaxBackward0>)), cross_attentions=None)
output.last_hidden_state.size()

torch.Size([1, 12, 768])
len(inputs["input_ids"][0])

12

4、带Model Head的模型调用

from transformers import AutoModelForSequenceClassification, BertForSequenceClassification

# 设置num_labels控制输出的分类个标签个数
clz_model = AutoModelForSequenceClassification.from_pretrained("rbt3", num_labels=10)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /u01/zhanggaoke/project/transformers-code-master/model/rbt3 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
clz_model(**inputs)     # 将token输入给模型

# 可以看到tensor中有十个分类标签
SequenceClassifierOutput(loss=None, logits=tensor([[-0.0586,  0.7688, -0.0553,  0.2559, -0.2247,  0.4708, -0.6269, -0.3731,
          0.0453, -0.1891]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)          
clz_model.config.num_labels     # 可以通过调这个参数,来控制输出几个标签分类

10
文章来源:https://blog.csdn.net/qq_41094332/article/details/134740025
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。