【ChatGLM3】微调指南

发布时间:2024年01月09日

下载数据集ToolAlpaca

  • 从GitHub下载
cd ChatGLM3/finetune_chatmodel_demo
git clone https://github.com/tangqiaoyu/ToolAlpaca.git
  • 除基础的 torch 依赖外,示例代码运行还需要依赖: pip install transformers==4.30.2 accelerate sentencepiece astunparse deepspeed
  • 处理数据集格式
./scripts/format_tool_alpaca.py --path "ToolAlpaca/data/train_data.json"
  • 处理后的数据: formatted_data/tool_alpaca.jsonl
  • 开始微调: ./scripts/finetune_pt_multiturn.sh
  • 如果出现显存不足的报错提示,需要修改finetune_pt_multiturn.sh
  • 把MAX_SEQ_LEN=2048改成MAX_SEQ_LEN=1024,MAX_SEQ_LEN会影响输入文本的长度限制
  • 使用2张16G的T4显卡,每张显卡都需要加载完整的模型,只是把任务分成2部分
  • 使用MAX_SEQ_LEN=2048需要单张显卡21G以上
  • 参数调整参考
数据量MAX_STEP x BATCHSIZE x gradient_accumulation_steps
100500
10003000
100000100000
  • 训练完成后,checkpoint的路径在: output/tool_alpaca_pt-20240104-184837-128-2e-2

下载数据集AdvertiseGen

  • 从清华大学网站下载
cd ChatGLM3/finetune_chatmodel_demo
curl -O https://cloud.tsinghua.edu.cn/seafhttp/files/93349217-b0ae-4b3e-875e-303fa05d7f08/AdvertiseGen.tar.gz

# 解压下载的文件
tar zxvf AdvertiseGen.tar.gz
  • 处理数据集格式
./scripts/format_advertise_gen.py --path "AdvertiseGen/train.json"
  • 开始训练: ./scripts/finetune_pt.sh

加载PT训练的checkpoint

加载pt训练微调后的checkpoint的关键代码

MODEL_PATH = os.environ.get('MODEL_PATH', '/ChatGLM3/THUDM/chatglm3-6b-32k')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
PT_PATH = os.environ.get('PT_PATH', '/ChatGLM3/finetune_chatmodel_demo/output/tool_alpaca_pt-20240105-185333-128-2e-2')
PT_PRE_SEQ_LEN = 128

@st.cache_resource
def get_model():
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
    if PT_PATH is not None and os.path.exists(PT_PATH):
        config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True, pre_seq_len=PT_PRE_SEQ_LEN)
        model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, config=config, device_map="auto").eval()

        prefix_state_dict = torch.load(os.path.join(PT_PATH, "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        print("Loaded from pt checkpoints", new_prefix_state_dict.keys())
        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    else:
        model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()

    return tokenizer, model
文章来源:https://blog.csdn.net/friendlytkyj/article/details/135477013
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。