本文内容主要基于以下开源项目探索实践,
- Awesome-Text2SQL:GitHub - eosphoros-ai/Awesome-Text2SQL: Curated tutorials and resources for Large Language Models, Text2SQL, Text2DSL、Text2API、Text2Vis and more.
- DB-GPT-Hub:GitHub - eosphoros-ai/DB-GPT-Hub: A repository that contains models, datasets, and fine-tuning techniques for DB-GPT, with the purpose of enhancing model performance in Text-to-SQL
- DB-GPT:GitHub - eosphoros-ai/DB-GPT: Revolutionizing Database Interactions with Private LLM Technology
- DeepSpeedExamples:GitHub - microsoft/DeepSpeedExamples: Example models using DeepSpeed
开源不易,希望大家给个star支持一下,感谢!
本章主要对Text2SQL的基本定义、使用的开源数据集和评测指标做了介绍,同时也介绍了一些实践项目,供大家参考。
Text-to-SQL(简写为Text2SQL),顾名思义就是把文本转化为SQL语言,更学术一点的定义是:把数据库领域下的自然语言(Natural Language,简写为NL)问题,转化为在关系型数据库中可以执行的结构化查询语言(Structured Query Language,简写为SQL),因此Text2SQL也可以被简写为NL2SQL。
举个例子比较直观:
查询表t_user的所有信息,结果按id降序排序,只保留前10个数据
SELECT * FROM t_user ORDER BY id DESC LIMIT 10
图1 DB-GPT项目原生对话示意图
公开的Text2SQL数据集比较多,这里仅介绍目前使用较多的几个数据集:
如何还想了解更多数据集以及Text2SQL的基本知识,可以查看我之前知乎的Text2QL综述文章:Text-to-SQL小白入门(一)综述文章学习
以Spider数据集为例:主要有两个指标,分别是执行准确率(Execution Accuracy,简称EX)和逻辑形式准确率(Exact Match,简称EM)
在Awesome-Text2SQL项目中,列举了常见的数据以及对应的指标榜单,如图2所示,比如Spider数据集上,目前EX得分第一是MiniSeek组织提交的91.2,EM得分第一也是MiniSeek提交的81.5,因为运用了GPT-4以及一些其他的trick,所以得分最高。
图2 Awesome-Text2SQL项目数据集得分榜单
Text2SQL研究主要有基于模版和匹配的方法、基于Seq2Seq框架的方法和基于模型预训练的方法,随着LLM的崛起,如今利用LLM微调完成Text2SQL任务也越来越常见,比如在DB-GPT-Hub项目中,就实现了利用各种开源模型在Spider数据集上进行lora和qlora方法微调,亲测好用!(方法详情可以参考代码仓库)
本章主要介绍了RLHF的基本定义,以及介绍了强化学习的基础概念和RLHF框架。
RLHF:Reinforcement Learning from Human Feedback,通过强化学习方式方式根据人类反馈优化语言模型,使得在一般文本数据语料库的语言模型能够和复杂人类价值观对齐。
RL:指的是Reinforcement learning。
为了更好理解强化学习,我们可以先了解一下比较常见的有监督学习(Supervised Learning, SL)。对于有监督学习而言,模型完整的训练pipline通常可以分成如图3所示:
图3 有监督学习示意图
对于强化学习而言,模型训练的pipline也是类似的,如图4所示。
图4 有监督学习和强化学习对比示意图
由上面讲述可知,强化学习的基本组成主要由以下部分:
RLHF方法最早是在是2017年论文(Deep reinforcement learning from human preferences)提出。
如果想了解InstructGPT论文的详细内容,可以参考我之前的知乎文章:Text-to-SQL小白入门(九)InstructGPT论文:教你如何训练ChatGPT
图5 InstructGPT论文中的RLHF实现范式
RLHF主要流程有3步:
第一阶段:SFT
之前听一个大学教授的讲座,有个观点很有意思:Open AI做大模型为什么比谷歌强,因为包括transformer在内的一些创新模型大多是谷歌研究的,那为什么Open AI在大模型领域为什么比谷歌强?答:因为Open AI在数据清洗,数据质量把控这方面做的很好。——所以数据是相当重要的!
第二阶段:RM
第三阶段:RL
本章节主要结合DB-GPT-Hub项目代码以及一些RLHF代码对Text2SQL进行了实践探索。
SFT模块的实现主要参考DB-GPT-Hub,比如在Spider数据集上进行实现。
sh dbgpt_hub/scripts/gen_train_eval_data.sh
经过数据预处理后,可以得到example_text2sql_train.json和example_text2sql_dev.json
数据格式如下所示:
{
"db_id": "department_management",
"instruction": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n\n",
"input": "###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:",
"output": "SELECT count(*) FROM head WHERE age > 56",
"history": []
}
{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age > 56"}
sh dbgpt_hub/scripts/train_sft.sh
训练的基础大模型为CodeLlama-13b-instruct,如果想了解该开源模型,可以参考论文讲解:Text-to-SQL小白入门(五)开源代码大模型Code Llama
训练的参数如下所示:
CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \
--model_name_or_path /home/model/CodeLlama-13B-Instruct \
--do_train \
--dataset example_text2sql_train \
--max_source_length 2048 \
--max_target_length 512 \
--template llama2 \
--finetuning_type lora \
--lora_rank 64 \
--lora_alpha 32 \
--lora_target q_proj,v_proj \
--output_dir dbgpt_hub/output/adapter/CodeLlama-13B-Instruct-lora \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--lr_scheduler_type cosine_with_restarts \
--logging_steps 500 \
--save_steps 2000 \
--learning_rate 2e-4 \
--num_train_epochs 8 \
--plot_loss \
--bf16
sh dbgpt_hub/scripts/predict_sft.sh
预测完成后,会生成一个predict.sql文件,文件中存放了dev集合中1034个sql.
测试的库为ts库
python dbgpt_hub/eval/evaluation.py --plug_value --input Your_model_pred_file
评估过程如下所示:会对每一个sql进行对比,对错误的sql进行打印输出展示。
最终对1034条sql验证完成后,可以得到EX、EM精度得分。
其他模型的一些baseline分数也可以通过DB-GPT-Hub获取。
RM模型训练的模型以SFT阶段的模型为基础,参考微软代码进行训练(Hub项目近期也会增加RLHF功能,敬请期待),自行构建了少量Text2SQL的RM训练数据集用于测试训练。
数据格式如下所示:
{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","chosen": "SELECT count(*) FROM head WHERE age > 56","rejected":"SELECT COUNT(head_name) FROM head WHERE age > 56;"}
deepspeed --num_gpus=$n_gpu \
main.py \
--data_path $data_path \
--data_split 2,4,4 \
--model_name_or_path $model_name_or_path \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--max_seq_len 1024 \
--learning_rate 9.65e-6 \
--weight_decay 0.1 \
--num_padding_at_beginning 0 \
--num_train_epochs 10 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--num_warmup_steps 0 \
--seed 1234 \
--gradient_checkpointing \
--zero_stage $ZERO_STAGE \
--deepspeed \
--offload \
--lora_dim 128 \
--lora_module_name "layers." \
--output_dir $OUTPUT \
2>&1 | tee $OUTPUT/log.txt
训练完成后,会在制定目前生成训练好的模型,比如有以下文件:
RL阶段和SFT阶段的数据格式保持一致,以Text2SQL任务举例子,RL数据可以构造为(prompt,output}的二元组,如下所示:
{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age > 56"}
deepspeed --master_port 12346 main.py \
--data_path $data_path \
--data_split 2,4,4 \
--actor_model_name_or_path $ACTOR_MODEL_PATH \
--critic_model_name_or_path $CRITIC_MODEL_PATH \
--num_padding_at_beginning 1 \
--per_device_generation_batch_size 8 \
--per_device_training_batch_size 8 \
--generation_batches 1 \
--ppo_epochs 1 \
--max_answer_seq_len 256 \
--max_prompt_seq_len 1024 \
--actor_learning_rate ${Actor_Lr} \
--critic_learning_rate ${Critic_Lr} \
--actor_weight_decay 0.1 \
--critic_weight_decay 0.1 \
--num_train_epochs 10 \
--lr_scheduler_type cosine \
--gradient_accumulation_steps 1 \
--actor_gradient_checkpointing \
--critic_gradient_checkpointing \
--offload_reference_model \
--disable_actor_dropout \
--num_warmup_steps 100 \
--deepspeed --seed 1234 \
--actor_zero_stage $ACTOR_ZERO_STAGE \
--critic_zero_stage $CRITIC_ZERO_STAGE \
--enable_hybrid_engine \
--actor_lora_dim 64 \
--critic_lora_dim 64 \
--critic_lora_module_name "layers." \
--actor_lora_module_name "layers." \
--output_dir $OUTPUT \
2>&1 | tee $OUTPUT/log.txt
训练结束会得到两个模型,actor模型即为需要的最终评测模型。
可以发现的是,RLHF相比SFT方法,精度有轻微提升,主要是数据质量的问题,后续还可以进一步探索。
Text-to-SQL小白入门(二)Transformer学习
Text-to-SQL小白入门(三)IRNet:引入中间表示SemQL
Text-to-SQL小白入门(四)指令进化大模型WizardLM
Text-to-SQL小白入门(五)开源代码大模型Code Llama
Text-to-SQL小白入门(六)Awesome-Text2SQL项目介绍
Text-to-SQL小白入门(七)PanGu-Coder2论文——RRTF