TensorFlow2.4完成Word2vec詞嵌入訓(xùn)練方法詳解
前言
本文使用 cpu 版本的 tensorflow 2.4 ,在 shakespeare 數(shù)據(jù)的基礎(chǔ)上使用 Skip-Gram 算法訓(xùn)練詞嵌入。
相關(guān)概念
Word2Vec 不是一個(gè)單純的算法,而是對最初的神經(jīng)概率語言模型 NNLM 做出的改進(jìn)和優(yōu)化,可用于從大型數(shù)據(jù)集中學(xué)習(xí)單詞的詞嵌入。通過word2vec 學(xué)習(xí)到的單詞詞嵌入已經(jīng)被成功地運(yùn)用于下游的各種 NLP 任務(wù)中。
Word2Vec 是輕量級的神經(jīng)網(wǎng)絡(luò),其模型僅僅包括輸入層、隱藏層和輸出層,模型框架根據(jù)輸入輸出的不同,主要包括 CBOW 和 Skip-Gram 模型:
- CBOW :根據(jù)周圍的上下文詞預(yù)測中間的目標(biāo)詞。周圍的上下文詞由當(dāng)中間的目標(biāo)詞的前面和后面的若干個(gè)單詞組成,這種體系結(jié)構(gòu)中單詞的順序并不重要。
- Skip-Gram :在同一個(gè)句子中預(yù)測當(dāng)前單詞前后一定范圍內(nèi)的若干個(gè)目標(biāo)單詞。
實(shí)現(xiàn)過程
1. 使用例子介紹 Skip-Gram 操作
(1)使用例子說明負(fù)采樣過程 我們首先用一個(gè)句子 "我是中國人"來說明相關(guān)的操作流程。
(2)先要對句子中的 token 進(jìn)行拆分,保存每個(gè)字到整數(shù)的映射關(guān)系,以及每個(gè)整數(shù)到字的映射關(guān)系。
(3)然后用整數(shù)對整個(gè)句子進(jìn)行向量化,也就是用整數(shù)表示對應(yīng)的字,從而形成一個(gè)包含了整數(shù)的的向量,需要注意的是這里要特意保留 0 作為填充占位符。
(4)sequence 模塊提供可以簡化 word2vec 數(shù)據(jù)準(zhǔn)備的功能,我們使用 skipgram 函數(shù),在 example_sequence 中以每個(gè)單詞為中心,與前后窗口大小為 window_size 范圍內(nèi)的詞生成 Skip-Gram 整數(shù)對集合。具體結(jié)果可以結(jié)合例子的輸出理解。
import io
import re
import string
import tqdm
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
SEED = 2
AUTOTUNE = tf.data.AUTOTUNE
window_size = 2
sentence = "我是一個(gè)偉大的中國人"
tokens = list(sentence)
vocab, index = {}, 1
vocab['<pad>'] = 0
for token in tokens:
if token not in vocab:
vocab[token] = index
index += 1
vocab_size = len(vocab)
inverse_vocab = {index: token for token, index in vocab.items()}
example_sequence = [vocab[word] for word in tokens]
positive_skip, _ = tf.keras.preprocessing.sequence.skipgrams( example_sequence, vocabulary_size = vocab_size, window_size = window_size, negative_samples = 0)
positive_skip.sort()
for t, c in positive_skip:
print(f"({t}, {c}): ({inverse_vocab[t]}, {inverse_vocab[c]})")
所有正樣本輸出:
(1, 2): (我, 是)
(1, 3): (我, 一)
(2, 1): (是, 我)
(2, 3): (是, 一)
(2, 4): (是, 個(gè))
(3, 1): (一, 我)
(3, 2): (一, 是)
(3, 4): (一, 個(gè))
(3, 5): (一, 偉)
(4, 2): (個(gè), 是)
(4, 3): (個(gè), 一)
(4, 5): (個(gè), 偉)
(4, 6): (個(gè), 大)
(5, 3): (偉, 一)
(5, 4): (偉, 個(gè))
(5, 6): (偉, 大)
(5, 7): (偉, 的)
(6, 4): (大, 個(gè))
(6, 5): (大, 偉)
(6, 7): (大, 的)
(6, 8): (大, 中)
(7, 5): (的, 偉)
(7, 6): (的, 大)
(7, 8): (的, 中)
(7, 9): (的, 國)
(8, 6): (中, 大)
(8, 7): (中, 的)
(8, 9): (中, 國)
(8, 10): (中, 人)
(9, 7): (國, 的)
(9, 8): (國, 中)
(9, 10): (國, 人)
(10, 8): (人, 中)
(10, 9): (人, 國)
(5)skipgram 函數(shù)通過在給定的 window_size 上窗口上進(jìn)行滑動(dòng)來返回所有正樣本對,但是我們在進(jìn)行模型訓(xùn)練的時(shí)候還需要負(fù)樣本,要生成 skip-gram 負(fù)樣本,就需要從詞匯表中對單詞進(jìn)行隨機(jī)采樣。使用 log_uniform_candidate_sampler 函數(shù)對窗口中給定 target 進(jìn)行 num_ns 個(gè)負(fù)采樣。我們可以在一個(gè)目標(biāo)詞 target 上調(diào)用負(fù)采樣函數(shù),并將上下文 context 出現(xiàn)的詞作為 true_classes ,以避免在負(fù)采樣時(shí)被采樣到。但是這里需要注意的是,雖然理論上負(fù)采樣中 true_classes 是不被采樣的,但是由于 log_uniform_candidate_sampler 中實(shí)現(xiàn)的負(fù)采樣算法不同,所以還是可能會(huì)被采樣到,想要了解具體的情況,我們可以查看 github.com/tensorflow/… 。
(6)在較小的數(shù)據(jù)集中一般將 num_ns 設(shè)置為 [5, 20] 范圍內(nèi)的整數(shù),而在較大的數(shù)據(jù)集一般設(shè)置為 [2, 5] 范圍內(nèi)整數(shù)。
target_word, context_word = positive_skip[0] num_ns = 3 context_class = tf.expand_dims( tf.constant([context_word], dtype="int64"), 1) negative_sampling, _, _ = tf.random.log_uniform_candidate_sampler( true_classes=context_class, num_true=1, num_sampled=num_ns, unique=True, range_max=vocab_size, seed=SEED, name="negative_sampling" )
(7)我們選用了一個(gè)正樣本 (我, 是) 來為其生成對應(yīng)的負(fù)采樣樣本,目標(biāo)詞為“我”,該樣本的標(biāo)簽類別定義為“是”,使用函數(shù) log_uniform_candidate_sampler 就會(huì)以“我”為目標(biāo),再去在詞典中隨機(jī)采樣一個(gè)不存在于 true_classes 的字作為負(fù)采樣的標(biāo)簽類別, 如下我們生成了三個(gè)樣本類別,可以分別組成(我,一)、(我,個(gè))、(我,我)三個(gè)負(fù)樣本。
(8)對于一個(gè)給定的正樣本 skip-gram 對,每個(gè)樣本對都是 (target_word, context_word) 的形式,我們現(xiàn)在又生成了 3 個(gè)負(fù)采樣,將 1 個(gè)正樣本 和 3 負(fù)樣本組合到一個(gè)張量中。對于正樣本標(biāo)簽標(biāo)記為 1 和負(fù)樣本標(biāo)簽標(biāo)記為 0 。
squeezed_context_class = tf.squeeze(context_class, 1)
context = tf.concat([squeezed_context_class, negative_sampling], 0)
label = tf.constant([1] + [0]*num_ns, dtype="int64")
target = target_word
print(f"target_index : {target}")
print(f"target_word : {inverse_vocab[target_word]}")
print(f"context_indices : {context}")
print(f"context_words : {[inverse_vocab[c.numpy()] for c in context]}")
print(f"label : {label}")
結(jié)果為:
target_index : 1
target_word : 我
context_indices : [2 3 4 1]
context_words : ['是', '一', '個(gè)', '我']
label : [1 0 0 0]
2. 獲取、處理數(shù)據(jù)
(1)這里我們使用 tensorflow 的內(nèi)置函數(shù)從網(wǎng)絡(luò)上下載 shakespeare 文本數(shù)據(jù),里面保存的都是莎士比亞的作品。
(2)我們使用內(nèi)置的 TextVectorization 函數(shù)對數(shù)據(jù)進(jìn)行預(yù)處理,并且將出現(xiàn)的所有的詞都映射層對應(yīng)的整數(shù),并且保證每個(gè)樣本的長度不超過 10
(3)將所有的數(shù)據(jù)都轉(zhuǎn)化成對應(yīng)的整數(shù)表示,并且設(shè)置每個(gè) batcc_size 為 1024 。
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
text_ds = tf.data.TextLineDataset(path_to_file).filter(lambda x: tf.cast(tf.strings.length(x), bool))
def custom_standardization(input_data):
lowercase = tf.strings.lower(input_data)
return tf.strings.regex_replace(lowercase, '[%s]' % re.escape(string.punctuation), '')
vocab_size = 4096
sequence_length = 10
vectorize_layer = layers.TextVectorization(
standardize=custom_standardization,
max_tokens=vocab_size,
output_mode='int',
output_sequence_length=sequence_length)
vectorize_layer.adapt(text_ds.batch(1024))
inverse_vocab = vectorize_layer.get_vocabulary()
text_vector_ds = text_ds.batch(1024).prefetch(AUTOTUNE).map(vectorize_layer).unbatch()
sequences = list(text_vector_ds.as_numpy_iterator())
截取部分進(jìn)行打印:
for seq in sequences[:5]:
print(f"{seq} => {[inverse_vocab[i] for i in seq]}")
結(jié)果為:
[ 89 270 0 0 0 0 0 0 0 0] => ['first', 'citizen', '', '', '', '', '', '', '', '']
[138 36 982 144 673 125 16 106 0 0] => ['before', 'we', 'proceed', 'any', 'further', 'hear', 'me', 'speak', '', '']
[34 0 0 0 0 0 0 0 0 0] => ['all', '', '', '', '', '', '', '', '', '']
[106 106 0 0 0 0 0 0 0 0] => ['speak', 'speak', '', '', '', '', '', '', '', '']
[ 89 270 0 0 0 0 0 0 0 0] => ['first', 'citizen', '', '', '', '', '', '', '', '']
(4)我們將上面所使用到的步驟都串聯(lián)起來,可以組織形成生成訓(xùn)練數(shù)據(jù)的函數(shù),里面包括了正采樣和負(fù)采樣操作。另外可以使用 make_sampling_table 函數(shù)生成基于詞頻的采樣概率表,對應(yīng)到詞典中第 i 個(gè)最常見的詞的概率,為平衡期起見,對于越經(jīng)常出現(xiàn)的詞,采樣到的概率越低。
(5)這里調(diào)用 generate_training_data 函數(shù)可以生成訓(xùn)練數(shù)據(jù),target 的維度為 (64901,) ,contexts 和 labels 的維度為 (64901, 5) 。
(6)要對大量的訓(xùn)練樣本執(zhí)行高效的批處理,可以使用 Dataset 相關(guān)的 API ,使用 shuffle 可以從緩存的 BUFFER_SIZE 大小的樣本集中隨機(jī)選擇一個(gè),使用 batch 表示我們每個(gè) batch 的大小設(shè)置為 BATCH_SIZE ,使用 cache 為了保證在加載數(shù)據(jù)的時(shí)候不會(huì)出現(xiàn) I/O 不會(huì)阻塞,我們在從磁盤加載完數(shù)據(jù)之后,使用 cache 會(huì)將數(shù)據(jù)保存在內(nèi)存中,確保在訓(xùn)練模型過程中數(shù)據(jù)的獲取不會(huì)成為訓(xùn)練速度的瓶頸。如果說要保存的數(shù)據(jù)量太大,可以使用 cache 創(chuàng)建磁盤緩存提高數(shù)據(jù)的讀取效率。我們使用 prefetch 在訓(xùn)練過程中可以并行執(zhí)行數(shù)據(jù)的預(yù)獲取。
(7)最終每個(gè)樣本最終的形態(tài)為 ((target_word, context_word),label) 。
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
targets, contexts, labels = [], [], []
sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)
for sequence in tqdm.tqdm(sequences):
positive_skip, _ = tf.keras.preprocessing.sequence.skipgrams(
sequence,
vocabulary_size=vocab_size,
sampling_table=sampling_table,
window_size=window_size,
negative_samples=0)
for target_word, context_word in positive_skip:
context_class = tf.expand_dims(
tf.constant([context_word], dtype="int64"), 1)
negative_sampling, _, _ = tf.random.log_uniform_candidate_sampler(
true_classes=context_class,
num_true=1,
num_sampled=num_ns,
unique=True,
range_max=vocab_size,
seed=seed,
name="negative_sampling")
context = tf.concat([tf.squeeze(context_class,1), negative_sampling], 0)
label = tf.constant([1] + [0]*num_ns, dtype="int64")
targets.append(target_word)
contexts.append(context)
labels.append(label)
return targets, contexts, labels
targets, contexts, labels = generate_training_data( sequences=sequences, window_size=2, num_ns=4, vocab_size=vocab_size, seed=SEED)
targets = np.array(targets)
contexts = np.array(contexts)
labels = np.array(labels)
BATCH_SIZE = 1024
BUFFER_SIZE = 10000
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE)
3. 搭建、訓(xùn)練模型
(1)第一層是 target Embedding 層,我們?yōu)槟繕?biāo)單詞初始化詞嵌入,設(shè)置詞嵌入向量的維度為 128 ,也就是說這一層的參數(shù)總共有 (vocab_size * embedding_dim) 個(gè),輸入長度為 1 。
(2)第二層是 context Embedding 層,我們?yōu)樯舷挛膯卧~初始化詞嵌入,我們?nèi)匀辉O(shè)置詞嵌入向量的維度為 128 ,這一層的參數(shù)也有 (vocab_size * embedding_dim) 個(gè),輸入長度為 num_ns+1 。
(3)第三層是點(diǎn)積計(jì)算層,用于計(jì)算訓(xùn)練對中 target 和 context 嵌入的點(diǎn)積。
(4)我們選擇 Adam 優(yōu)化器來進(jìn)行優(yōu)化,選用 CategoricalCrossentropy 作為損失函數(shù),選用 accuracy 作為評估指標(biāo),使用訓(xùn)練數(shù)據(jù)來完成 20 個(gè) eopch 的訓(xùn)練。
class Word2Vec(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim):
super(Word2Vec, self).__init__()
self.target_embedding = layers.Embedding(vocab_size,
embedding_dim,
input_length=1,
name="w2v_embedding")
self.context_embedding = layers.Embedding(vocab_size,
embedding_dim,
input_length=num_ns+1)
def call(self, pair):
target, context = pair
if len(target.shape) == 2:
target = tf.squeeze(target, axis=1)
word_emb = self.target_embedding(target)
context_emb = self.context_embedding(context)
dots = tf.einsum('be,bce->bc', word_emb, context_emb)
return dots
embedding_dim = 128
word2vec = Word2Vec(vocab_size, embedding_dim)
word2vec.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
word2vec.fit(dataset, epochs=20)
過程如下:
Epoch 1/20 63/63 [==============================] - 1s 14ms/step - loss: 1.6082 - accuracy: 0.2321 Epoch 2/20 63/63 [==============================] - 1s 14ms/step - loss: 1.5888 - accuracy: 0.5527 ... Epoch 19/20 63/63 [==============================] - 1s 13ms/step - loss: 0.5041 - accuracy: 0.8852 Epoch 20/20 63/63 [==============================] - 1s 13ms/step - loss: 0.4737 - accuracy: 0.8945
4. 查看 Word2Vec 向量
我們已經(jīng)訓(xùn)練好所有的詞向量,可以查看前三個(gè)單詞對應(yīng)的詞嵌入,不過因?yàn)榈谝粋€(gè)是一個(gè)填充字符,我們直接跳過了,所以只顯示了兩個(gè)單詞的結(jié)果。
weights = word2vec.get_layer('w2v_embedding').get_weights()[0]
vocab = vectorize_layer.get_vocabulary()
for index, word in enumerate(vocab[:3]):
if index == 0:
continue
vec = weights[index]
print(word, "||",' '.join([str(x) for x in vec]) + "")
輸出:
[UNK] || -0.033048704 -0.13244359 0.011660721 0.04466736 0.016815167 -0.0021747486 -0.22271504 -0.19703679 -0.23452276 0.11212586 -0.016061027 0.17981936 0.07774545 0.024562761 -0.17993309 -0.18202212 -0.13211365 -0.0836222 0.14589612 0.10907205 0.14628777 -0.10057361 -0.20254703 -0.012516517 -0.026788604 0.10540704 0.10908849 0.2110478 0.09297589 -0.20392798 0.3033481 -0.06899316 -0.11218286 0.08671802 -0.032792106 0.015512758 -0.11241121 0.03193802 -0.07420188 0.058226038 0.09341678 0.0020246594 0.11772731 0.22016191 -0.019723132 -0.124759704 0.15371098 -0.032143503 -0.16924457 0.07010268 -0.27322608 -0.04762394 0.1720905 -0.27821517 -0.021202642 0.022981782 0.017429957 -0.018919267 0.0821674 0.14892177 0.032966584 0.016503694 -0.024588188 -0.15450846 0.25163063 -0.09960359 -0.08205034 -0.059559997 -0.2328465 -0.017229442 -0.11387295 0.027335169 -0.21991524 -0.25220546 -0.057238836 0.062819794 -0.07596143 0.1036019 -0.11330178 0.041029476 -0.0036062107 -0.09850497 0.026396573 0.040283844 0.09707356 -0.108100675 0.14983237 0.094585866 -0.11460251 0.159306 -0.18871744 -0.0021350821 0.21181738 -0.11000824 0.026631303 0.0043079373 -0.10093511 -0.057986196 -0.13534115 -0.05459506 0.067853846 -0.09538108 -0.1882101 0.15350497 -0.1521072 -0.01917603 -0.2464314 0.07098584 -0.085702434 -0.083315894 0.01850418 -0.019426668 0.215964 -0.04208141 0.18032664 -0.067350626 0.29129744 0.07231988 0.2200896 0.04984232 -0.2129336 -0.005486685 0.0047443025 -0.06323578 0.10223014 -0.14854044 -0.09165846 0.14745502
the || 0.012899147 -0.11042492 -0.2028403 0.20705906 -0.14402795 -0.012134922 -0.008227022 -0.19896115 -0.18482314 -0.31812677 -0.050654292 0.063769065 0.013379926 -0.04029531 -0.19954327 0.020137483 -0.035291195 -0.03429038 0.07547649 0.04313068 -0.05675204 0.34193155 -0.13978302 0.033053987 -0.0038114514 8.5749794e-05 0.15582523 0.11737131 0.1599838 -0.14866571 -0.19313708 -0.0936122 0.12842192 -0.037015382 -0.05241146 -0.00085232017 -0.04838702 -0.17497984 0.13466156 0.17985387 0.032516308 0.028982501 -0.08578549 0.014078035 0.11176433 -0.08876962 -0.12424359 -0.00049041177 -0.07127252 0.13457641 -0.17463619 0.038635027 -0.23191011 -0.13592774 -0.01954393 -0.28888118 0.0130044455 0.10935221 -0.10274326 0.16326426 0.24069212 -0.068884164 -0.042140033 -0.08411051 0.14803806 -0.08204498 0.13407354 -0.08042538 0.032217037 -0.2666482 -0.17485079 0.37256253 -0.02551431 -0.25904474 -0.002844959 0.1672513 0.035283662 -0.11897226 0.14446032 0.08866355 -0.024791516 -0.22040974 0.0137709975 -0.16484109 0.18097405 0.07075867 0.13830985 0.025787655 0.017255543 -0.0387513 0.07857641 0.20455246 -0.02442122 -0.18393797 -0.0361829 -0.12946953 -0.15860991 -0.10650375 -0.251683 -0.1709236 0.12092594 0.20731401 0.035180748 -0.09422942 0.1373039 0.121121824 -0.09530268 -0.15685256 -0.14398256 -0.068801016 0.0666081 0.13958378 0.0868633 -0.036316663 0.10832365 -0.21385072 0.15025891 0.2161903 0.2097545 -0.0487211 -0.18837014 -0.16750671 0.032201447 0.03347862 0.09050423 -0.20007794 0.11616628 0.005944925
以上就是TensorFlow2.4完成Word2vec詞嵌入訓(xùn)練方法詳解的詳細(xì)內(nèi)容,更多關(guān)于TensorFlow Word2vec詞嵌入訓(xùn)練的資料請關(guān)注腳本之家其它相關(guān)文章!
- python深度學(xué)習(xí)tensorflow訓(xùn)練好的模型進(jìn)行圖像分類
- python神經(jīng)網(wǎng)絡(luò)tensorflow利用訓(xùn)練好的模型進(jìn)行預(yù)測
- 詳解TensorFlow訓(xùn)練網(wǎng)絡(luò)兩種方式
- Tensorflow2.1 MNIST圖像分類實(shí)現(xiàn)思路分析
- Tensorflow2.1實(shí)現(xiàn)文本中情感分類實(shí)現(xiàn)解析
- Tensorflow?2.1完成對MPG回歸預(yù)測詳解
- Tensorflow?2.4加載處理圖片的三種方式詳解
- Tensorflow2.1實(shí)現(xiàn)Fashion圖像分類示例詳解
相關(guān)文章
python網(wǎng)絡(luò)爬蟲精解之正則表達(dá)式的使用說明
正則表達(dá)式是對字符串操作的一種邏輯公式,就是用事先定義好的一些特定字符、及這些特定字符的組合,組成一個(gè)“規(guī)則字符串”,這個(gè)“規(guī)則字符串”用來表達(dá)對字符串的一種過濾邏輯2021-09-09
PyTorch之torch.randn()如何創(chuàng)建正態(tài)分布隨機(jī)數(shù)
這篇文章主要介紹了PyTorch之torch.randn()如何創(chuàng)建正態(tài)分布隨機(jī)數(shù)問題,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-02-02
深入解析python項(xiàng)目引用運(yùn)行路徑
這篇文章主要介紹了python項(xiàng)目引用運(yùn)行路徑的問題,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2023-05-05
python3+mysql查詢數(shù)據(jù)并通過郵件群發(fā)excel附件
這篇文章主要為大家詳細(xì)介紹了python3+mysql查詢數(shù)據(jù),并通過郵件群發(fā)excel附件,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-02-02

