Tensorflow2.10使用BERT從文本中抽取答案實(shí)現(xiàn)詳解
前言
本文詳細(xì)介紹了用 tensorflow-gpu 2.10 版本實(shí)現(xiàn)一個(gè)簡(jiǎn)單的從文本中抽取答案的過(guò)程。
數(shù)據(jù)準(zhǔn)備
這里主要用于準(zhǔn)備訓(xùn)練和評(píng)估 SQuAD(Standford Question Answering Dataset)數(shù)據(jù)集的 Bert 模型所需的數(shù)據(jù)和工具。
首先,通過(guò)導(dǎo)入相關(guān)庫(kù),包括 os、re、json、string、numpy、tensorflow、tokenizers 和 transformers,為后續(xù)處理數(shù)據(jù)和構(gòu)建模型做好準(zhǔn)備。 然后,設(shè)置了最大長(zhǎng)度為384 ,并創(chuàng)建了一個(gè) BertConfig 對(duì)象。接著從 Hugging Face 模型庫(kù)中下載預(yù)訓(xùn)練模型 bert-base-uncased 模型的 tokenizer ,并將其保存到同一目錄下的名叫 bert_base_uncased 文件夾中。 當(dāng)下載結(jié)束之后,使用 BertWordPieceTokenizer 從已下載的文件夾中夾在 tokenizer 的詞匯表從而創(chuàng)建分詞器 tokenizer 。
剩下的部分就是從指定的 URL 下載訓(xùn)練和驗(yàn)證集,并使用 keras.utils.get_file() 將它們保存到本地,一般存放在 “用戶目錄.keras\datasets”下 ,以便后續(xù)的數(shù)據(jù)預(yù)處理和模型訓(xùn)練。
import os
import re
import json
import string
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer, TFBertModel, BertConfig
max_len = 384
configuration = BertConfig()
slow_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
save_path = "bert_base_uncased/"
if not os.path.exists(save_path):
os.makedirs(save_path)
slow_tokenizer.save_pretrained(save_path)
tokenizer = BertWordPieceTokenizer("bert_base_uncased/vocab.txt", lowercase=True)
train_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"
train_path = keras.utils.get_file("train.json", train_data_url)
eval_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
eval_path = keras.utils.get_file("eval.json", eval_data_url)
打?。?/p>
Downloading data from https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
30288272/30288272 [==============================] - 131s 4us/step
Downloading data from https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
4854279/4854279 [==============================] - 20s 4us/step
模型輸入、輸出處理
這里定義了一個(gè)名為 SquadExample 的類(lèi),用于表示一個(gè) SQuAD 數(shù)據(jù)集中的問(wèn)題和對(duì)應(yīng)的上下文片段、答案位置等信息。
該類(lèi)的構(gòu)造函數(shù) __init__() 接受五個(gè)參數(shù):?jiǎn)栴}(question)、上下文(context)、答案起始字符索引(start_char_idx)、答案文本(answer_text) 和所有答案列表 (all_answers) 。
類(lèi)還包括一個(gè)名為 preprocess() 的方法,用于對(duì)每個(gè) SQuAD 樣本進(jìn)行預(yù)處理,首先對(duì)context 、question 和 answer 進(jìn)行預(yù)處理,并計(jì)算出答案的結(jié)束位置 end_char_idx 。接下來(lái),根據(jù) start_char_idx 和 end_char_idx 在 context 的位置,構(gòu)建了一個(gè)表示 context 中哪些字符屬于 answer 的列表 is_char_in_ans 。然后,使用 tokenizer 對(duì) context 進(jìn)行編碼,得到 tokenized_context。
接著,通過(guò)比較 answer 的字符位置和 context 中每個(gè)標(biāo)記的字符位置,得到了包含答案的標(biāo)記的索引列表 ans_token_idx 。如果 answer 未在 context 中找到,則將 skip 屬性設(shè)置為 True ,并直接返回空結(jié)果。
最后,將 context 和 question 的序列拼接成輸入序列 input_ids ,并根據(jù)兩個(gè)句子的不同生成了同樣長(zhǎng)度的序列 token_type_ids 以及與 input_ids 同樣長(zhǎng)度的 attention_mask 。然后對(duì)這三個(gè)序列進(jìn)行了 padding 操作。
class SquadExample:
def __init__(self, question, context, start_char_idx, answer_text, all_answers):
self.question = question
self.context = context
self.start_char_idx = start_char_idx
self.answer_text = answer_text
self.all_answers = all_answers
self.skip = False
def preprocess(self):
context = self.context
question = self.question
answer_text = self.answer_text
start_char_idx = self.start_char_idx
context = " ".join(str(context).split())
question = " ".join(str(question).split())
answer = " ".join(str(answer_text).split())
end_char_idx = start_char_idx + len(answer)
if end_char_idx >= len(context):
self.skip = True
return
is_char_in_ans = [0] * len(context)
for idx in range(start_char_idx, end_char_idx):
is_char_in_ans[idx] = 1
tokenized_context = tokenizer.encode(context)
ans_token_idx = []
for idx, (start, end) in enumerate(tokenized_context.offsets):
if sum(is_char_in_ans[start:end]) > 0:
ans_token_idx.append(idx)
if len(ans_token_idx) == 0:
self.skip = True
return
start_token_idx = ans_token_idx[0]
end_token_idx = ans_token_idx[-1]
tokenized_question = tokenizer.encode(question)
input_ids = tokenized_context.ids + tokenized_question.ids[1:]
token_type_ids = [0] * len(tokenized_context.ids) + [1] * len(tokenized_question.ids[1:])
attention_mask = [1] * len(input_ids)
padding_length = max_len - len(input_ids)
if padding_length > 0:
input_ids = input_ids + ([0] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
token_type_ids = token_type_ids + ([0] * padding_length)
elif padding_length < 0:
self.skip = True
return
self.input_ids = input_ids
self.token_type_ids = token_type_ids
self.attention_mask = attention_mask
self.start_token_idx = start_token_idx
self.end_token_idx = end_token_idx
self.context_token_to_char = tokenized_context.offsets
這里的兩個(gè)函數(shù)用于準(zhǔn)備數(shù)據(jù)以訓(xùn)練一個(gè)使用 BERT 結(jié)構(gòu)的問(wèn)答模型。
第一個(gè)函數(shù) create_squad_examples 接受一個(gè) JSON 文件的原始數(shù)據(jù),將里面的每條數(shù)據(jù)都變成 SquadExample 類(lèi)所定義的輸入格式。
第二個(gè)函數(shù) create_inputs_targets 將 SquadExample 對(duì)象列表轉(zhuǎn)換為模型的輸入和目標(biāo)。這個(gè)函數(shù)返回兩個(gè)列表,一個(gè)是模型的輸入,包含了 input_ids 、token_type_ids 、 attention_mask ,另一個(gè)是模型的目標(biāo),包含了 start_token_idx 、end_token_idx。
def create_squad_examples(raw_data):
squad_examples = []
for item in raw_data["data"]:
for para in item["paragraphs"]:
context = para["context"]
for qa in para["qas"]:
question = qa["question"]
answer_text = qa["answers"][0]["text"]
all_answers = [_["text"] for _ in qa["answers"]]
start_char_idx = qa["answers"][0]["answer_start"]
squad_eg = SquadExample(question, context, start_char_idx, answer_text, all_answers)
squad_eg.preprocess()
squad_examples.append(squad_eg)
return squad_examples
def create_inputs_targets(squad_examples):
dataset_dict = {
"input_ids": [],
"token_type_ids": [],
"attention_mask": [],
"start_token_idx": [],
"end_token_idx": [],
}
for item in squad_examples:
if item.skip == False:
for key in dataset_dict:
dataset_dict[key].append(getattr(item, key))
for key in dataset_dict:
dataset_dict[key] = np.array(dataset_dict[key])
x = [ dataset_dict["input_ids"], dataset_dict["token_type_ids"], dataset_dict["attention_mask"], ]
y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]]
return x, y
這里主要讀取了 SQuAD 訓(xùn)練集和驗(yàn)證集的 JSON 文件,并使用create_squad_examples 函數(shù)將原始數(shù)據(jù)轉(zhuǎn)換為 SquadExample 對(duì)象列表。然后使用 create_inputs_targets 函數(shù)將這些 SquadExample 對(duì)象列表轉(zhuǎn)換為模型輸入和目標(biāo)輸出。最后輸出打印了已創(chuàng)建的訓(xùn)練數(shù)據(jù)樣本數(shù)和評(píng)估數(shù)據(jù)樣本數(shù)。
with open(train_path) as f:
raw_train_data = json.load(f)
with open(eval_path) as f:
raw_eval_data = json.load(f)
train_squad_examplesa = create_squad_examples(raw_train_data)
x_train, y_train = create_inputs_targets(train_squad_examples)
print(f"{len(train_squad_examples)} training points created.")
eval_squad_examples = create_squad_examples(raw_eval_data)
x_eval, y_eval = create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")
打?。?/p>
87599 training points created.
10570 evaluation points created.
模型搭建
這里定義了一個(gè)基于 BERT 的問(wèn)答模型。在 create_model() 函數(shù)中,首先使用 TFBertModel.from_pretrained() 方法加載預(yù)訓(xùn)練的 BERT 模型。然后創(chuàng)建了三個(gè)輸入層(input_ids、token_type_ids 和 attention_mask),每個(gè)輸入層的形狀都是(max_len,) 。這些輸入層用于接收模型的輸入數(shù)據(jù)。
接下來(lái)使用 encoder() 方法對(duì)輸入進(jìn)行編碼得到 embedding ,然后分別對(duì)這些向量表示進(jìn)行全連接層的操作,得到一個(gè) start_logits 和一個(gè) end_logits 。接著分別對(duì)這兩個(gè)向量進(jìn)行扁平化操作,并將其傳遞到激活函數(shù) softmax 中,得到一個(gè) start_probs 向量和一個(gè) end_probs 向量。
最后,將這三個(gè)輸入層和這兩個(gè)輸出層傳遞給 keras.Model() 函數(shù),構(gòu)建出一個(gè)模型。此模型使用 SparseCategoricalCrossentropy 損失函數(shù)進(jìn)行編譯,并使用 Adam 優(yōu)化器進(jìn)行訓(xùn)練。
def create_model():
encoder = TFBertModel.from_pretrained("bert-base-uncased")
input_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
token_type_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
attention_mask = layers.Input(shape=(max_len,), dtype=tf.int32)
embedding = encoder(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
start_logits = layers.Dense(1, name="start_logit", use_bias=False)(embedding)
start_logits = layers.Flatten()(start_logits)
end_logits = layers.Dense(1, name="end_logit", use_bias=False)(embedding)
end_logits = layers.Flatten()(end_logits)
start_probs = layers.Activation(keras.activations.softmax)(start_logits)
end_probs = layers.Activation(keras.activations.softmax)(end_logits)
model = keras.Model( inputs=[input_ids, token_type_ids, attention_mask], outputs=[start_probs, end_probs],)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = keras.optimizers.Adam(lr=5e-5)
model.compile(optimizer=optimizer, loss=[loss, loss])
return model
這里主要是展示了一下模型的架構(gòu),可以看到所有的參數(shù)都可以訓(xùn)練,并且主要調(diào)整的部分都幾乎是 bert 中的參數(shù)。
model = create_model() model.summary()
打?。?/p>
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_4 (InputLayer) [(None, 384)] 0 []
input_6 (InputLayer) [(None, 384)] 0 []
input_5 (InputLayer) [(None, 384)] 0 []
tf_bert_model_1 (TFBertModel) TFBaseModelOutputWi 109482240 ['input_4[0][0]',
thPoolingAndCrossAt 'input_6[0][0]',
tentions(last_hidde 'input_5[0][0]']
n_state=(None, 384,
768),
pooler_output=(Non
e, 768),
past_key_values=No
ne, hidden_states=N
one, attentions=Non
e, cross_attentions
=None)
start_logit (Dense) (None, 384, 1) 768 ['tf_bert_model_1[0][0]']
end_logit (Dense) (None, 384, 1) 768 ['tf_bert_model_1[0][0]']
flatten_2 (Flatten) (None, 384) 0 ['start_logit[0][0]']
flatten_3 (Flatten) (None, 384) 0 ['end_logit[0][0]']
activation_2 (Activation) (None, 384) 0 ['flatten_2[0][0]']
activation_3 (Activation) (None, 384) 0 ['flatten_3[0][0]']
==================================================================================================
Total params: 109,483,776
Trainable params: 109,483,776
Non-trainable params: 0
自定義驗(yàn)證回調(diào)函數(shù)
這里定義了一個(gè)回調(diào)函數(shù) ExactMatch , 有一個(gè)初始化方法 __init__ ,接收驗(yàn)證集的輸入和目標(biāo) x_eval 和 y_eval 。該類(lèi)還實(shí)現(xiàn)了 on_epoch_end 方法,在每個(gè) epoch 結(jié)束時(shí)調(diào)用,計(jì)算模型的預(yù)測(cè)值,并計(jì)算精確匹配分?jǐn)?shù)。
具體地,on_epoch_end 方法首先使用模型對(duì) x_eval 進(jìn)行預(yù)測(cè),得到預(yù)測(cè)的起始位置 pred_start 和結(jié)束位置 pred_end ,并進(jìn)一步找到對(duì)應(yīng)的預(yù)測(cè)答案和正確答案標(biāo)準(zhǔn)化為 normalized_pred_ans 和 normalized_true_ans ,如果前者存在于后者,則說(shuō)明該樣本被正確地回答,最終將精確匹配分?jǐn)?shù)打印出來(lái)。
def normalize_text(text):
text = text.lower()
exclude = set(string.punctuation)
text = "".join(ch for ch in text if ch not in exclude)
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
text = re.sub(regex, " ", text)
text = " ".join(text.split())
return text
class ExactMatch(keras.callbacks.Callback):
def __init__(self, x_eval, y_eval):
self.x_eval = x_eval
self.y_eval = y_eval
def on_epoch_end(self, epoch, logs=None):
pred_start, pred_end = self.model.predict(self.x_eval)
count = 0
eval_examples_no_skip = [_ for _ in eval_squad_examples if _.skip == False]
for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
squad_eg = eval_examples_no_skip[idx]
offsets = squad_eg.context_token_to_char
start = np.argmax(start)
end = np.argmax(end)
if start >= len(offsets):
continue
pred_char_start = offsets[start][0]
if end < len(offsets):
pred_char_end = offsets[end][1]
pred_ans = squad_eg.context[pred_char_start:pred_char_end]
else:
pred_ans = squad_eg.context[pred_char_start:]
normalized_pred_ans = normalize_text(pred_ans)
normalized_true_ans = [normalize_text(_) for _ in squad_eg.all_answers]
if normalized_pred_ans in normalized_true_ans:
count += 1
acc = count / len(self.y_eval[0])
print(f"\nepoch={epoch+1}, exact match score={acc:.2f}")
模型訓(xùn)練和驗(yàn)證
訓(xùn)練模型,并使用驗(yàn)證集對(duì)模型的性能進(jìn)行測(cè)試。這里的 epoch 只設(shè)置了 1 ,如果數(shù)值增大效果會(huì)更好。
exact_match_callback = ExactMatch(x_eval, y_eval) model.fit( x_train, y_train, epochs=1, verbose=2, batch_size=16, callbacks=[exact_match_callback],)
打印:
23/323 [==============================] - 47s 139ms/step
epoch=1, exact match score=0.77
5384/5384 - 1268s - loss: 2.4677 - activation_2_loss: 1.2876 - activation_3_loss: 1.1800 - 1268s/epoch - 236ms/step
以上就是Tensorflow2.10使用BERT從文本中抽取答案實(shí)現(xiàn)詳解的詳細(xì)內(nèi)容,更多關(guān)于Tensorflow BERT文本抽取的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python中max函數(shù)用于二維列表的實(shí)例
下面小編就為大家分享一篇Python中max函數(shù)用于二維列表的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-04-04
django中send_mail功能實(shí)現(xiàn)詳解
這篇文章主要給大家介紹了關(guān)于django中send_mail功能實(shí)現(xiàn)的相關(guān)資料,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧。2018-02-02
Python實(shí)現(xiàn)將不規(guī)范的英文名字首字母大寫(xiě)
這篇文章給大家主要介紹的是利用map()函數(shù),把用戶輸入的不規(guī)范的英文名字,變?yōu)槭鬃帜复髮?xiě),其他小寫(xiě)的規(guī)范名字。文中給出了三種解決方法,大家可以根據(jù)需要選擇使用,感興趣的朋友們下面來(lái)一起學(xué)習(xí)學(xué)習(xí)吧。2016-11-11
Python解決Flutter項(xiàng)目簡(jiǎn)體字問(wèn)題的方法
作為面向大陸外市場(chǎng)的應(yīng)用,我們經(jīng)常編寫(xiě)代碼的時(shí)候往往忘記切換繁體字導(dǎo)致上線后出現(xiàn)簡(jiǎn)體字,因?yàn)檠芯肯聵I(yè)內(nèi)相關(guān)插件,看看怎么好解決這個(gè)問(wèn)題,OpenCC 支持語(yǔ)言比較多,所以基于此嘗試了用 Python 去實(shí)現(xiàn),需要的朋友可以參考下2024-07-07
python讀取nc數(shù)據(jù)并繪圖的方法實(shí)例
最近項(xiàng)目中需要處理和分析NC數(shù)據(jù),所以下面這篇文章主要給大家介紹了關(guān)于python讀取nc數(shù)據(jù)并繪圖的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-05-05
Python實(shí)現(xiàn)的爬取百度文庫(kù)功能示例
這篇文章主要介紹了Python實(shí)現(xiàn)的爬取百度文庫(kù)功能,結(jié)合實(shí)例形式分析了Python針對(duì)百度文庫(kù)的爬取、編碼轉(zhuǎn)換、文件保存等相關(guān)操作技巧,需要的朋友可以參考下2019-02-02
python?PyQt5(自定義)信號(hào)與槽使用及說(shuō)明
這篇文章主要介紹了python?PyQt5(自定義)信號(hào)與槽使用及說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-12-12

