目录
针对bad case中,错误的主要原因是边界定位不准确问题,sub,obj抽取过短。
因此想要通过jieba分词,然后调用GPT4的api判断当前的新span是否符合条件。
import json
from pdb import set_trace as stop
import jieba
import openai
from tqdm import tqdm
openai.api_key = "api_key" # GPT4.0
openai.api_base = 'https://api.ngapi.top/v1'
def get_response(prompt, temperature=0.5, max_tokens=2048):
print(prompt)
completion = openai.ChatCompletion.create(
# model="gpt-3.5-turbo",
model="gpt-4",
temperature=0,
top_p=0,
# max_tokens=max_tokens,
messages=[
{"role": "user", "content": f"{prompt}"}
]
)
return completion
llm_generated_path= "/public/home/hongy/qtxu/Qwen-main/results/Ele_lora/pred_20240101_instruction_0104.jsonl"
change_path = "/public/home/hongy/qtxu/Qwen-main/results/Ele_lora/pred_20240101_instruction_0104_post.txt"
po_dict = {"相等":'equal',
"更好": 'better',
'更差': 'worse',
'不同': 'different'}
pad_word = '无'
chinese_punctuation = [',', '。', '?', '!', ':', ';', '‘', '’', '“', '”', '(', ')', '【', '】', '{', '}', '《', '》', '、', '——', '-', '……', '~', '·']
def get_previsous_word(cur_span, cur_sent):
front_prompt = f"在输入语句({cur_sent})中,({cur_span})的前一个单词是什么?。直接给出答案即可。"
front_result = get_response(front_prompt)['choices'][0]['message']['content']
if front_result=='的':
cur_span = front_result+cur_span
front_prompt = f"在输入语句({cur_sent})中,({cur_span})的前一个单词是什么?直接给出答案即可。"
front_result = get_response(front_prompt)['choices'][0]['message']['content']
return front_result
def identify_nonu_phrase(front_result, cur_span, cur_sent):
identify_prompt = f"在输入语句({cur_sent})中,({front_result}{cur_span})是一个可以表示物品名称、物品品牌的名词或名词短语吗?直接回答'yes'或'no'"
# if '#' in identify_prompt:
# identify_prompt = identify_prompt.replace('#','')
identify_result = get_response(identify_prompt)['choices'][0]['message']['content']
return identify_result
def get_chinese_index(cur_span, cur_sent):
index = cur_sent.find(cur_span) # 没发现的话, index = -1
return index
def get_front_end_word(text, span):
text_seg_list = jieba.cut(text, cut_all=False)
span_seg_list = jieba.cut(span,cut_all=False )
text_result = " ".join(text_seg_list)
span_result = " ".join(span_seg_list)
index = text_result.find(span_result) # 获取最后一个位置
front_word =text_result[:index].split()[-1] # 获取前一个元素index
if front_word == '的':
front_front_word = text_result[:index-2].split()[-1] # 因为有一个空格,所以是-2
front_word = front_front_word+front_word
end_word = text_result[index + len(span_result):].split()[0] # 至于后面的0要不要添加,需要依据统计结果而定
return front_word, end_word
def post_processing(cur_span, cur_sent, pad_word):
if cur_span == pad_word: # 如果是空,则返回本身
final_span = pad_word
else:
cur_span_index = get_chinese_index(cur_span, cur_sent)
if cur_span_index == 0: # 如果当前给定的span已经位于句首,则保持不变
final_span = cur_span
else:
front_result, end_result = get_front_end_word(cur_sent, cur_span)
identify_result = identify_nonu_phrase(front_result, cur_span, cur_sent)
print("identify_result结果是:", identify_result)
if identify_result=='yes':
final_span = front_result+cur_span
else:
final_span = cur_span
return final_span
with open(llm_generated_path, 'r') as fr, open(change_path, 'w') as fw:
for line in fr:
cur_line = json.loads(line)
cur_sent = cur_line['query'].split('\n\n')[1][7:-52].strip() # instruction2
# cur_sent = cur_line['query'].split('\n\n')[-1][7:-57].strip() # instruction kaisong
compar = cur_line['type'] # 是否是比较句
if compar == 1:
# cur_sent = cur_line['query'].split('\n\n')[1][7:-32].strip()
fw.write(cur_sent + "\n")
result = cur_line['output'].strip().split('\n')
gold = cur_line['truth'].strip().split('\n') #
# for j in range(0, len(gold), 2): # 如果是位置信息,则是 for j in range(0, len(gold), 2)
# gold_quintuple = gold[j][7:].strip()
# fw.write("gold:"+ gold_quintuple + "\n")
for i in range(0, len(result), 2): # 同上 如果是位置信息,则是 for j in range(0, len(gold), 2)
cur_quintuple = result[i][7:].strip() # 有几个特殊的,不能以逗号分隔
# stop()
# cur_quintuple_index = result[i+1][5:].strip() # '元组位置:(,17:18:19:20:21:22:23,12:13,24:25)'
cur_quintuple_list = cur_quintuple[1:-1].split(',')
# cur_quintuple_index_list = cur_quintuple_index[1:-1].split(',')
sub, obj, asp, op, polarity = cur_quintuple_list[0].strip(), cur_quintuple_list[1].strip(), cur_quintuple_list[2].strip(), cur_quintuple_list[3].strip(), cur_quintuple_list[-1].strip()
# sub_index, obj_index, asp_index, op_index = cur_quintuple_index_list[0].strip(),cur_quintuple_index_list[1].strip(),cur_quintuple_index_list[2].strip(),cur_quintuple_index_list[3].strip()
sub = sub if sub else pad_word
obj = obj if obj else pad_word
asp = asp if asp else pad_word
op = op if op else pad_word
polarity = po_dict[polarity] if polarity else pad_word
# 对产生的结果进行后处理
# stop()
post_sub = post_processing(sub, cur_sent, pad_word) # sub_index.split(";")[0]
post_obj = post_processing(obj, cur_sent, pad_word)
# post_asp = post_processing(asp, cur_sent, pad_word)
# stop()
final_quintuple = '('+sub +','+obj+','+ asp + ','+ op+','+polarity+')'
post_final_quintuple = '('+post_sub +','+post_obj+','+ asp + ','+ op+','+polarity+')'
# fw.write("final_quintuple"+final_quintuple +"\n")
# fw.write("post_final_quintuple"+post_final_quintuple+"\n")
fw.write(post_final_quintuple+"\n")