From f9cb13a3e811eb8c44ba8ff1373d688311426927 Mon Sep 17 00:00:00 2001 From: mmmwhy Date: Wed, 19 Jan 2022 20:52:11 +0800 Subject: [PATCH] feat(nlp): init basic bert code --- README.md | 4 +- common/activate.py | 10 +- common/bert/bert_layer.py | 289 ++++++++++++++++++++++++++++++++++++++ common/bert/bert_model.py | 202 ++++++++++++++++++++++++++ common/bert/layers.py | 24 ---- common/layers.py | 35 +++++ script/common.sh | 4 +- 7 files changed, 540 insertions(+), 28 deletions(-) create mode 100644 common/bert/bert_layer.py create mode 100644 common/bert/bert_model.py delete mode 100644 common/bert/layers.py create mode 100644 common/layers.py mode change 100644 => 100755 script/common.sh diff --git a/README.md b/README.md index 6d65acc..43c02ba 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,9 @@ cv 和 nlp 中的很多方法和技巧也在相互影响,比如大规模的预 # 目标 提供一套完整的的基础算法服务 -1、python 训练任务,包含 NLP 和 CV 任务。 + +1、python 训练任务,包含 NLP 和 CV 任务 。 + 2、java 环境下使用 onnx 的在线推理部署,使用 onnx 的原因是我在公司用的是 TensorFlow 做推理,我不想和公司的代码一致。 # todo diff --git a/common/activate.py b/common/activate.py index 3f8abe3..f140933 100644 --- a/common/activate.py +++ b/common/activate.py @@ -5,9 +5,12 @@ # @date: 2022/01/19 # """""" -import torch import math +import torch +import torch.nn.functional as F + + def gelu(x): """ Original Implementation of the gelu activation function in Google Bert repo when initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): @@ -16,13 +19,16 @@ def gelu(x): """ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + def gelu_new(x): """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). Also see https://arxiv.org/abs/1606.08415 """ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + def swish(x): return x * torch.sigmoid(x) -activations = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new} + +activations = {"gelu": gelu, "relu": F.relu, "swish": swish, "gelu_new": gelu_new} diff --git a/common/bert/bert_layer.py b/common/bert/bert_layer.py new file mode 100644 index 0000000..dc64961 --- /dev/null +++ b/common/bert/bert_layer.py @@ -0,0 +1,289 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# +# @author: fly.sun +# @date: 2022/01/22 +# +"""""" +import math + +import torch +from torch import nn + +from common.activate import activations +from common.layers import LayerNorm as BertLayerNorm + + +class BertEmbeddings(nn.Module): + def __init__(self, config): + """ + 「input_embedding」 部分的实现 + vocab_size: 字典长度; + hidden_size: 内部神经网络的隐层大小; + type_vocab_size: 一般是 2 ,一般只有 0 和 1,告诉模型这是第一句话还是第二句话; + (a) For sequence pairs: + + ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1`` + + (b) For single sequences: + + ``tokens: [CLS] the dog is hairy . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0`` + max_position_embeddings: 最长多少个字,生成的 position_embeddings 记录每个位置的 embedding,可以通过 cos 和 sin 交替产生。 + 直接初始化学习 和 sin_cos 差不多,但这样写读取预训练模型时,更兼容一些。 + hidden_dropout_prob: 随机丢弃的比例 + layer_norm_eps: norm 分母的 eps + """ + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, position_ids=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + # 注意按位相加 + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class MultiHeadAttentionLayer(nn.Module): + def __init__(self, config): + """ + 「Multi-Head Attention」 的实现,attention 核心代码 + hidden_size: 隐层纬度 + num_attention_heads: 注意力头的数量 + attention_probs_dropout_prob: attention prob 的 dropout 比例 + attention_scale: 对 query 和 value 的乘积结果进行缩放,目的是为了 softmax 结果稳定 + return_attention_scores: 是否返回 attention 矩阵 + """ + super(MultiHeadAttentionLayer, self).__init__() + + assert config.hidden_size % config.num_attention_heads == 0, "隐藏层纬度 需为 注意力头的数量 整数倍,否则注意力 embedding 无法计算" + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.return_attention_scores = config.return_attention_scores + + self.query = nn.Linear(config.hidden_size, config.hidden_size) + self.key = nn.Linear(config.hidden_size, config.hidden_size) + self.value = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + """ + 这个函数的名字起的比较让人费解 + + 举个例子,以标准的 bert-base 的 query 来说, 输入的 x 纬度为 [batch_size, query_len, hidden_size] + hidden_size 为 768 + num_attention_heads 为 12 + attention_head_size 为 768 / 12 = 64 + + new_x_shape = [batch_size, query_len] + [12, 64] 即 [batch_size, query_len, num_attention_heads, attention_head_size] + + 换句话来说,这个函数其实是把每个 token 的向量都分成了 12 份,给每个注意力头准备了 64d 的数。 + + """ + + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, query, key, value, attention_mask=None, head_mask=None): + """ + query shape: [batch_size, query_len, hidden_size] + key shape: [batch_size, key_len, hidden_size] + value shape: [batch_size, value_len, hidden_size] + 一般情况下,query_len、key_len、value_len 三者相等 + """ + + mixed_query_layer = self.query(query) + mixed_key_layer = self.key(key) + mixed_value_layer = self.value(value) + """ + mixed_query_layer shape: [batch_size, query_len, hidden_size] + mixed_query_layer shape: [batch_size, key_len, hidden_size] + mixed_query_layer shape: [batch_size, value_len, hidden_size] + """ + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + """ + query_layer shape: [batch_size, num_attention_heads, query_len, attention_head_size] + key_layer shape: [batch_size, num_attention_heads, key_len, attention_head_size] + value_layer shape: [batch_size, num_attention_heads, value_len, attention_head_size] + """ + + # 交换 k 的最后两个维度,然后 q 和 k 执行点积, 获得 attention score + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + # attention_scores shape: [batch_size, num_attention_heads, query_len, key_len] + + """ + 进行 attention scale, 除以 math.sqrt(self.attention_head_size) 是为了避免直接 softmax 后结果变得非常悬殊。 + 避免只注意到极其个别的 key 上,大家可以感受一下 softmax([1,2]) 与 softmax([1 * np.sqrt(768), 2 * np.sqrt(768)]) 的结果。 + """ + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # attention_mask 的值是 -inf, softmax 后的权重就是 0 了 + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # 对注意力结果进行 softmax, 得到 query 对于每个 value 的 score + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + """ + 注意这里的实现是比较特别的,他是把某个 value 的 score 整个 mask 掉,但原始论文的确是这个意思 + 这里引出一个很有趣的预训练方式,我们使用两个权重完全相同的 bert 进行对比学习 (比如搞 moco ),而可行的原因就是 drop 不一致 + """ + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # 某些 bert 会有 head_mask,我们这个版本不实现 todo @sun + + """ + 再回忆一下 + value_layer shape: [batch_size, num_attention_heads, value_len, attention_head_size] + attention_scores shape: [batch_size, num_attention_heads, query_len, key_len] + + value_len == key_len + """ + context_layer = torch.matmul(attention_probs, value_layer) + + # context_layer shape: [batch_size, num_attention_heads, query_len, attention_head_size] + + # transpose、permute 等维度变换操作后,tensor 在内存中不再是连续存储的,而 view 操作要求 tensor 的内存连续存储, + # 所以在调用 view 之前,需要 contiguous 来返回一个 contiguous copy; + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + # context_layer shape: [batch_size, query_len, num_attention_heads, attention_head_size] + + # 注意这里又把最后两个纬度合回去了,做的是 view 操作 + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + # 是否返回attention scores, 注意这里是最原始的 attention_scores 没有归一化且没有 dropout + # 第一个位置是产出的 embedding,第二个位置是 attention_probs,后边会有不同的判断 + outputs = (context_layer, attention_scores) if self.return_attention_scores else (context_layer,) + return outputs + + +class BertAddNorm(nn.Module): + def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob, layer_norm_eps): + """ + 「Add & Norm」 部分的代码实现,本模块会循环多次使用 + 这里我将原始的 BertSelfOutput 和 BertOutput 和成一个了 + + 这里的 Add & Norm 实现了三个功能: + 1、在 Multi-Head attention 后,所有的头注意力结果是直接 concat 在一起的(view 调整 size 也可以认为 concat 在一起) + 直接 concat 在一起的结果用起来也有点奇怪,所以需要有个 fc ,来帮助把这些分散注意力结果合并在一起; + 2、在 Feed Forward 操作后,纬度被提升到 intermediate_size,BertAddNorm 还实现了把纬度从 intermediate_size 降回 hidden_size 的功能; + 3、真正的 Add & Norm 部分,也就是 LayerNorm(hidden_states + input_tensor) 这一行; + """ + super(BertAddNorm, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # 残差,非常重要 + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertIntermediate(nn.Module): + """ + 「Position-wise Feed-Forward Networks 」 的部分代码实现 + FFN(x) = max(0, xW1 + b1)W2 + b2 + + 原始 Attention is all you need 中,hidden_size: 512, intermediate_size: 2048,纬度放大的操作。 + 有点像 cnn 中有两个 kernel size 为 1 的卷积,对纬度进行放大然后再缩小。 + + 但我们发现这里的代码,似乎只有 activate(xw1+b1) 的部分,没有外边的那个 fc 在 BertAddNorm 里边放着 + """ + + def __init__(self, hidden_size, intermediate_size, hidden_act): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = activations[hidden_act] + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config): + """ + 「Multi-Head Attention 和 Add & Norm」 的实现 + hidden_size: 隐层纬度 + num_attention_heads: 注意力头的数量 + attention_probs_dropout_prob: attention prob 的 dropout 比例 + attention_scale: 对 query 和 value 的乘积结果进行缩放,目的是为了 softmax 结果稳定 + return_attention_scores: 是否返回 attention 矩阵 + hidden_dropout_prob: 隐层 dropout 比例 + layer_norm_eps: norm 下边的 eps + """ + super(BertAttention, self).__init__() + self.self = MultiHeadAttentionLayer(config) + # 这里是左下的那个 Add & Norm + self.output = BertAddNorm(config.hidden_size, config.hidden_size, + config.hidden_dropout_prob, config.layer_norm_eps) + self.pruned_heads = set() + + def forward(self, input_tensor, attention_mask=None, head_mask=None): + self_outputs = self.self(input_tensor, input_tensor, input_tensor, attention_mask, head_mask) + attention_output = self.output(self_outputs[0], input_tensor) + outputs = (attention_output,) + self_outputs[1:] + return outputs + + +class BertLayer(nn.Module): + def __init__(self, config): + """ + 完整的 bert 单层结构 + 这里我刻意把 config 内的参数都拿出来,方便进行注释 + """ + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + + self.intermediate = BertIntermediate(config.hidden_size, config.intermediate_size, config.hidden_act) + self.output = BertAddNorm(config.intermediate_size, config.hidden_size, + config.hidden_dropout_prob, config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + attention_outputs = self.attention(hidden_states, attention_mask, head_mask) + attention_output = attention_outputs[0] + + # 这里是左上的 Add & Norm,从而得到完整的 FFN + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + + # attention_outputs[0] 是 embedding, [1] 是 attention_probs + outputs = (layer_output,) + attention_outputs[1:] + return outputs diff --git a/common/bert/bert_model.py b/common/bert/bert_model.py new file mode 100644 index 0000000..4e4de81 --- /dev/null +++ b/common/bert/bert_model.py @@ -0,0 +1,202 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# +# @author: fly.sun +# @date: 2022/01/22 +# +"""""" +import os +import json +import torch +from torch import nn +from common.bert.bert_layer import BertLayer +from common.bert.bert_layer import BertEmbeddings +from common.layers import LayerNorm as BertLayerNorm + + +class BertConfig: + def __init__(self, vocab_size_or_config_json_file): + """ + 定制化的 config,__getattr__ 处进行判断 + """ + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + print(f"{key}:{value}({type(value)})") + + def __getattr__(self, key): + if key in self.__dict__: + return self.__dict__[key] + return None + + +class BertEncoder(nn.Module): + def __init__(self, config): + super(BertEncoder, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) + # [0] 是 embedding, [1] 是 attention_score + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if self.output_hidden_states: + # 把中间层的结果取出来,一些研究认为中间层的 embedding 也有价值 + outputs = outputs + (all_hidden_states,) + if self.output_attentions: + outputs = outputs + (all_attentions,) + + # last-layer hidden state, (all hidden states), (all attentions) + return outputs + + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + def __init__(self, config_path): + super(BertModel, self).__init__() + + self.config = BertConfig(os.path.join(config_path, "config.json")) + + self.embeddings = BertEmbeddings(self.config) + self.encoder = BertEncoder(self.config) + self.pooler = BertPooler(self.config) + + self.init_weights() + self.from_pretrained(os.path.join(os.path.join(config_path, "pytorch_model.bin"))) + + def init_weights(self): + self.apply(self._init_weights) + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def from_pretrained(self, pretrained_model_path): + if not os.path.exists(pretrained_model_path): + print(f"missing pretrained_model_path: {pretrained_model_path}") + pass + + state_dict = torch.load(pretrained_model_path, map_location='cpu') + + # 名称可能存在不一致,进行替换 + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = key + if 'gamma' in key: + new_key = new_key.replace('gamma', 'weight') + if 'beta' in key: + new_key = new_key.replace('beta', 'bias') + if 'bert.' in key: + new_key = new_key.replace('bert.', '') + + if new_key: + old_keys.append(key) + new_keys.append(new_key) + + for old_key, new_key in zip(old_keys, new_keys): + + if new_key in self.state_dict().keys(): + state_dict[new_key] = state_dict.pop(old_key) + else: + # 避免预训练模型里有多余的结构,影响 strict load_state_dict + state_dict.pop(old_key) + + # 确保完全一致 + self.load_state_dict(state_dict, strict=True) + + def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) + encoder_outputs = self.encoder(embedding_output, + extended_attention_mask, + head_mask=head_mask) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + # add hidden_states and attentions if they are here + outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] + + # sequence_output, pooled_output, (hidden_states), (attentions) + return outputs + + +if __name__ == "__main__": + bert = BertModel("/data1/lifeiyang/data/pretrain_modal/bert-base-chinese") + print("pass") diff --git a/common/bert/layers.py b/common/bert/layers.py deleted file mode 100644 index cb1869e..0000000 --- a/common/bert/layers.py +++ /dev/null @@ -1,24 +0,0 @@ -# !/usr/bin/python -# -*- coding: utf-8 -*- -# -# @author: fly.sun -# @date: 2022/01/19 -# -"""""" -import torch -from torch import nn - -class LayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-12): - """ layernorm 层,也可使用 pytorch 自带的 layernorm。 - 可参考 https://iii.run/archives/fae41911210f.html 实现 - """ - super(LayerNorm, self).__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = nn.Parameter(torch.zeros(hidden_size)) - self.eps = eps - - def forward(self, x): - mean = x.mean(-1, keepdim=True) - std = x.std(-1, keepdim=True) - return self.weight * (x - mean) / torch.sqrt(std + self.eps) + self.bias diff --git a/common/layers.py b/common/layers.py new file mode 100644 index 0000000..4f3418b --- /dev/null +++ b/common/layers.py @@ -0,0 +1,35 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# +# @author: fly.sun +# @date: 2022/01/19 +# +""" +我们自己实现一个 layerNorm, 注意 layerNorm 是对每一条数据进行 Norm,而不是每一批数据,这两个很像,但是作用纬度不一样。 +在 NLP 任务中,我们使用 layerNorm 比较多,因为是: + +1、文本自身是变长的,max_length 为 512 的话,可能大部分的数据都只有几十个字。那么让这几十个字以及大批的 padding 进行 norm 是不合理的。 +2、batchNorm 中的 平均值 和 方差,是在训练任务中学到的。 然后推理的时候,根据训练任务中学到的平均值和方法来使用,比如 cv 中常见的 transforms.Normalize。 +如果使用 layerNorm 的话,就不需要提前计算好平均值和方法,每句话输入进来的时候,单独计算就可以了。 +对于变长文本预测来说,这样其实更合理一些。 +3、自己实现 layerNorm 还可以方便后续进行一些细小的优化。 +4、参考 https://iii.run/archives/fae41911210f.html 实现 + +""" + +import torch +import torch.nn as nn + + +class LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + super(LayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.eps = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + return self.weight * x + self.bias diff --git a/script/common.sh b/script/common.sh old mode 100644 new mode 100755 index ae6190a..d36d8b3 --- a/script/common.sh +++ b/script/common.sh @@ -27,4 +27,6 @@ function cecho { [[ -z "${text}" ]] && local text="${color}$2${code}0m" echo -e "${text}" -} \ No newline at end of file +} + +autopep8 --in-place --recursive --max-line-length=120 . \ No newline at end of file