Skip to content

Commit

Permalink
Merge pull request #56 from mmmwhy/wip-fy
Browse files Browse the repository at this point in the history
feat(bert): add tokenizer part
  • Loading branch information
mmmwhy authored Jan 25, 2022
2 parents 6a446fb + 054df14 commit 97b624a
Show file tree
Hide file tree
Showing 25 changed files with 894 additions and 41 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
.idea
.idea

__pycache__/
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## 0.0.4 (2022-01-25)


### Features

* **bert:** add tokenizer part ([3683fa9](https://github.com/mmmwhy/pure_attention/commit/3683fa937a4355d616893cc6f99e2a7b69b2a2af))
* **layers:** fix import for layerNorm ([eb61b31](https://github.com/mmmwhy/pure_attention/commit/eb61b313458ac18bf4b15271fee2cf7e39f8afde))
* **nlp:** init basic bert code ([f9cb13a](https://github.com/mmmwhy/pure_attention/commit/f9cb13a3e811eb8c44ba8ff1373d688311426927))



2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include *.txt
include *.md
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@ cv 和 nlp 中的很多方法和技巧也在相互影响,比如大规模的预
2、java 环境下使用 onnx 的在线推理部署,使用 onnx 的原因是我在公司用的是 TensorFlow 做推理,我不想和公司的代码一致。

# todo
- 第一阶段的目标:实现 NLP 和 CV 的典型任务,并评估下游效果。
- [ ] Transformer 的 pytorch 实现;
- [ ] 多版本的 Bert 的实现;
- [ ] NLP 下游任务 序列标注、分类 的实现,并在公开数据集上进行评估。这里主要是想证明实现的 Bert 效果是符合预期的;
第一阶段:实现 NLP 和 CV 的典型任务,并评估下游效果。
- [x] Pytorch 实现 Transformer 的 encode 阶段,并实现 bert。
> 参考 [transformers](https://github.com/huggingface/transformers) 的设计,但只保留与关键 encode 相关的代码,简化代码量。
并保持与原始 huggingface encode 的结果一致, 使用方法和一致性校验可以参考 [backbone_bert](pure_attention/backbone_bert/README.md)
- [ ] Pytorch 实现 Transformer 的 decode 阶段,并实现 seq2seq 任务。
> todo
- [ ] NLP 下游任务 序列标注、分类 的实现,并在公开数据集上进行评估,这里主要是想证明实现的 backbone 效果是符合预期的;
> todo
- [ ] 实现 Vit,并在下游任务上验证实现 Vit 的效果是否符合预期;
> todo
- 第二阶段的目标:增加 NLP 和 CV 的其余常见任务,扩增项目的能力范围。
第二阶段:增加 NLP 和 CV 的其余常见任务,扩增项目的能力范围。
- [ ] UNILM;
- [ ] MAE;
- [ ] GPT系列;
- [ ] seq2seq,搞一个翻译任务;
- [ ] 实现模型的 onnx export;
- [ ] 实现 java 下的 onnx 推理过程;
- [ ] 实现 java 下的 onnx 推理过程;
Binary file added images/bert-base-chinese.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/chinese-roberta-wwm-ext-large.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/chinese-roberta-wwm-ext.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 12 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"name": "pure_attention",
"version": "0.0.5",
"description": "Generate a changelog from git metadata",
"repository": {
"type": "git",
"url": "https://github.com/mmmwhy/pure_attention"
},
"scripts": {
"changelog": "conventional-changelog -p angular -i CHANGELOG.md -s -r 0"
}
}
Binary file added pure_attention/.DS_Store
Binary file not shown.
7 changes: 7 additions & 0 deletions pure_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# !/usr/bin/python
# -*- coding: utf-8 -*-
#
# @author: fly.sun <[email protected]>
# @date: 2022/01/25
#
""""""
71 changes: 71 additions & 0 deletions pure_attention/backbone_bert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# 介绍
[transformers](https://github.com/huggingface/transformers) 为了适应非常多种模型结构,结构变得非常复杂。

我在参考
[transformers](https://github.com/huggingface/transformers)
[bert4pytorch](https://github.com/MuQiuJun-AI/bert4pytorch)
[Read_Bert_Code](https://github.com/DA-southampton/Read_Bert_Code)
的代码基础上,对结构进行了一些调整,提高了代码的易读性,并和 [transformers](https://github.com/huggingface/transformers) 的结果完全一致。

# 使用

``` python
from pure_attention.common.nlp.tokenization import Tokenizer
from pure_attention.backbone_bert.bert_model import BertModel

bert_model_path = "/data/pretrain_modal/bert-base-chinese"
test_query = "结果一致性验证"

tokenizer = Tokenizer(bert_model_path + "/vocab.txt")
bert = BertModel(bert_model_path)
tokens_ids, segments_ids = tokenizer.encode(test_query, max_len=64)

bert_pooler_output = bert(tokens_ids, token_type_ids=segments_ids).pooler_output

```


# 结果一致性
分别在下边三个常用中文 bert 上进行测试,结果与 transformers 完全一致。
- [bert-base-chinese](https://huggingface.co/bert-base-chinese)

![](../../images/bert-base-chinese.png)


- [chinese-roberta-wwm-ext](https://huggingface.co/hfl/chinese-roberta-wwm-ext)

![](../../images/chinese-roberta-wwm-ext.png)


- [chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)

![](../../images/chinese-roberta-wwm-ext-large.png)


``` python
import torch
from transformers import BertModel
from transformers import BertTokenizer

from pure_attention.common.nlp.tokenization import Tokenizer as LocalTokenizer
from pure_attention.backbone_bert.bert_model import BertModel as OurBertModel

bert_model_path = "/data/pretrain_modal/chinese-roberta-wwm-ext-large"
test_query = "结果一致性验证"

text_tokenizer = BertTokenizer.from_pretrained(bert_model_path, do_lower_case=True)
bert_model = BertModel.from_pretrained(bert_model_path)

tensor_caption = text_tokenizer.encode(test_query, return_tensors="pt", padding='max_length', truncation=True,
max_length=64)

origin_bert_pooler_output = bert_model(tensor_caption).pooler_output

tokenizer = LocalTokenizer(bert_model_path + "/vocab.txt")
bert = OurBertModel(bert_model_path)
tokens_ids, segments_ids = tokenizer.encode(test_query, max_len=64)

our_bert_pooler_output = bert(tokens_ids, token_type_ids=segments_ids).pooler_output

print("check result:", torch.cosine_similarity(origin_bert_pooler_output, our_bert_pooler_output))
```
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch
from torch import nn

from common.activate import activations
from common.layers import LayerNorm as BertLayerNorm
from pure_attention.common.activate import activations
from pure_attention.common.layers import LayerNorm as BertLayerNorm


class BertEmbeddings(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,11 @@
#
""""""
import os
import json
from pure_attention.backbone_bert.package import BertConfig, BertOutput
import torch
from torch import nn
from common.bert.bert_layer import BertLayer
from common.bert.bert_layer import BertEmbeddings
from common.layers import LayerNorm as BertLayerNorm


class BertConfig:
def __init__(self, vocab_size_or_config_json_file):
"""
定制化的 config,__getattr__ 处进行判断
"""
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
print(f"{key}:{value}({type(value)})")

def __getattr__(self, key):
if key in self.__dict__:
return self.__dict__[key]
return None
from pure_attention.backbone_bert.bert_layer import BertLayer, BertEmbeddings
from pure_attention.common.layers import LayerNorm as BertLayerNorm


class BertEncoder(nn.Module):
Expand Down Expand Up @@ -75,8 +57,7 @@ def __init__(self, config):
self.activation = nn.Tanh()

def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
# 只取出第一个 token 也就是 cls 位置上的 embedding 进行 dense 变形
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
Expand All @@ -95,6 +76,7 @@ def __init__(self, config_path):

self.init_weights()
self.from_pretrained(os.path.join(os.path.join(config_path, "pytorch_model.bin")))
self.eval()

def init_weights(self):
self.apply(self._init_weights)
Expand Down Expand Up @@ -156,6 +138,7 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.

extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
Expand Down Expand Up @@ -190,13 +173,7 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)

# add hidden_states and attentions if they are here
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
outputs = BertOutput(last_hidden_state=sequence_output, pooler_output=pooled_output,
attentions=encoder_outputs[1:])

# sequence_output, pooled_output, (hidden_states), (attentions)
return outputs


if __name__ == "__main__":
bert = BertModel("/data1/lifeiyang/data/pretrain_modal/bert-base-chinese")
print("pass")
37 changes: 37 additions & 0 deletions pure_attention/backbone_bert/package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# !/usr/bin/python
# -*- coding: utf-8 -*-
#
# @author: fly.sun <[email protected]>
# @date: 2022/01/25
#
""""""
import json
import torch
from typing import Optional, Tuple


class BertConfig:
def __init__(self, vocab_size_or_config_json_file):
"""
定制化的 config,__getattr__ 处进行判断
"""
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value

def __getattr__(self, key):
if key in self.__dict__:
return self.__dict__[key]
return None


class BertOutput:
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
attentions: Optional[Tuple[torch.FloatTensor]] = None

def __init__(self, last_hidden_state, pooler_output, attentions):
self.last_hidden_state = last_hidden_state
self.pooler_output = pooler_output
self.attentions = attentions
File renamed without changes.
File renamed without changes.
File renamed without changes.
113 changes: 113 additions & 0 deletions pure_attention/common/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# !/usr/bin/python
# -*- coding: utf-8 -*-
#
# @author: fly.sun <[email protected]>
# @date: 2022/01/24
#
""""""

import logging
import os
import sys

_log_format = "%(asctime)s |- %(levelname).1s %(name)s - %(message)s"


def _get_log_level():
""" get log level from env variable 'LOG_LEVEL'
Returns:
str|int: e.g. "INFO", 20, "DEBUG", 10, "ERROR", 40.
"""
level = os.environ.get("LOG_LEVEL", "INFO")
try:
level = int(level)
except ValueError:
assert isinstance(level, str)
level = level.upper()
return level


# Attention do not change
# unless you test it online and there is no DEBUG log
# add stream handler to avoid default stream heandler NOSET level
logging.basicConfig(
format=_log_format,
level=_get_log_level(),
stream=sys.stdout)


class LogLevel(object):
CRITICAL = 50
ERROR = 40
WARNING = 30
INFO = 20
DEBUG = 10
DETAIL = 5
NOTSET = 0


def init_logger(logger_name=None,
log_file=os.environ.get("LOG_FILE", ""),
log_format=_log_format,
level=_get_log_level()):
""" init logger
Args:
logger_name(str): optional, default: None.
log_file(str): optional, default: "".
output log messages to file if specified, by default is set by env
`LOG_FILE`.
log_format(str): optional, default:
"%(asctime)s |- %(levelname).1s %(name)s - %(message)s"
level(int|logging.Level): set log level, by default it is set by env
`LOG_LEVEL`, `INFO` level is used if not set.
:: level
- CRITICAL 50
- ERROR 40
- WARNING 30
- INFO 20
- DEBUG 10
- DETAIL 5
- NOTSET 0
Returns:
logging.Logger: a logger instance
"""
logger = logging.getLogger(logger_name)
logger.setLevel(level)

if log_file:
handler = logging.FileHandler(log_file)
if log_format:
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)

logger.addHandler(handler)

return logger


def _test():
logger = init_logger("test_logger", "test_file.log",
level=_get_log_level())
logger.info("level: {}".format(os.environ.get("LOG_LEVEL", "INFO")))
import sys
logger.info(sys.modules[__name__])
logger.info(logging.getLoggerClass())
logger.debug("test DEBUG 10")
logger.info("test INFO 20")
logger.warning("test WARNING 30")
logger.error("test ERROR 40")
logger.critical("test CRITICAL 50")

if logger.isEnabledFor(logging.DEBUG):
logger.warning("debug enabled!")
if logger.isEnabledFor(LogLevel.DEBUG):
logger.info("detail enabled")


if __name__ == "__main__":
_test()
7 changes: 7 additions & 0 deletions pure_attention/common/nlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# !/usr/bin/python
# -*- coding: utf-8 -*-
#
# @author: fly.sun <[email protected]>
# @date: 2022/01/24
#
""""""
Loading

0 comments on commit 97b624a

Please sign in to comment.