-
Notifications
You must be signed in to change notification settings - Fork 634
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #56 from mmmwhy/wip-fy
feat(bert): add tokenizer part
- Loading branch information
Showing
25 changed files
with
894 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
.idea | ||
.idea | ||
|
||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
## 0.0.4 (2022-01-25) | ||
|
||
|
||
### Features | ||
|
||
* **bert:** add tokenizer part ([3683fa9](https://github.com/mmmwhy/pure_attention/commit/3683fa937a4355d616893cc6f99e2a7b69b2a2af)) | ||
* **layers:** fix import for layerNorm ([eb61b31](https://github.com/mmmwhy/pure_attention/commit/eb61b313458ac18bf4b15271fee2cf7e39f8afde)) | ||
* **nlp:** init basic bert code ([f9cb13a](https://github.com/mmmwhy/pure_attention/commit/f9cb13a3e811eb8c44ba8ff1373d688311426927)) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
include *.txt | ||
include *.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"name": "pure_attention", | ||
"version": "0.0.5", | ||
"description": "Generate a changelog from git metadata", | ||
"repository": { | ||
"type": "git", | ||
"url": "https://github.com/mmmwhy/pure_attention" | ||
}, | ||
"scripts": { | ||
"changelog": "conventional-changelog -p angular -i CHANGELOG.md -s -r 0" | ||
} | ||
} |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# !/usr/bin/python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# @author: fly.sun <[email protected]> | ||
# @date: 2022/01/25 | ||
# | ||
"""""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# 介绍 | ||
[transformers](https://github.com/huggingface/transformers) 为了适应非常多种模型结构,结构变得非常复杂。 | ||
|
||
我在参考 | ||
[transformers](https://github.com/huggingface/transformers) 、 | ||
[bert4pytorch](https://github.com/MuQiuJun-AI/bert4pytorch) 、 | ||
[Read_Bert_Code](https://github.com/DA-southampton/Read_Bert_Code) | ||
的代码基础上,对结构进行了一些调整,提高了代码的易读性,并和 [transformers](https://github.com/huggingface/transformers) 的结果完全一致。 | ||
|
||
# 使用 | ||
|
||
``` python | ||
from pure_attention.common.nlp.tokenization import Tokenizer | ||
from pure_attention.backbone_bert.bert_model import BertModel | ||
|
||
bert_model_path = "/data/pretrain_modal/bert-base-chinese" | ||
test_query = "结果一致性验证" | ||
|
||
tokenizer = Tokenizer(bert_model_path + "/vocab.txt") | ||
bert = BertModel(bert_model_path) | ||
tokens_ids, segments_ids = tokenizer.encode(test_query, max_len=64) | ||
|
||
bert_pooler_output = bert(tokens_ids, token_type_ids=segments_ids).pooler_output | ||
|
||
``` | ||
|
||
|
||
# 结果一致性 | ||
分别在下边三个常用中文 bert 上进行测试,结果与 transformers 完全一致。 | ||
- [bert-base-chinese](https://huggingface.co/bert-base-chinese) | ||
|
||
 | ||
|
||
|
||
- [chinese-roberta-wwm-ext](https://huggingface.co/hfl/chinese-roberta-wwm-ext) | ||
|
||
 | ||
|
||
|
||
- [chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large) | ||
|
||
 | ||
|
||
|
||
``` python | ||
import torch | ||
from transformers import BertModel | ||
from transformers import BertTokenizer | ||
|
||
from pure_attention.common.nlp.tokenization import Tokenizer as LocalTokenizer | ||
from pure_attention.backbone_bert.bert_model import BertModel as OurBertModel | ||
|
||
bert_model_path = "/data/pretrain_modal/chinese-roberta-wwm-ext-large" | ||
test_query = "结果一致性验证" | ||
|
||
text_tokenizer = BertTokenizer.from_pretrained(bert_model_path, do_lower_case=True) | ||
bert_model = BertModel.from_pretrained(bert_model_path) | ||
|
||
tensor_caption = text_tokenizer.encode(test_query, return_tensors="pt", padding='max_length', truncation=True, | ||
max_length=64) | ||
|
||
origin_bert_pooler_output = bert_model(tensor_caption).pooler_output | ||
|
||
tokenizer = LocalTokenizer(bert_model_path + "/vocab.txt") | ||
bert = OurBertModel(bert_model_path) | ||
tokens_ids, segments_ids = tokenizer.encode(test_query, max_len=64) | ||
|
||
our_bert_pooler_output = bert(tokens_ids, token_type_ids=segments_ids).pooler_output | ||
|
||
print("check result:", torch.cosine_similarity(origin_bert_pooler_output, our_bert_pooler_output)) | ||
``` |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# !/usr/bin/python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# @author: fly.sun <[email protected]> | ||
# @date: 2022/01/25 | ||
# | ||
"""""" | ||
import json | ||
import torch | ||
from typing import Optional, Tuple | ||
|
||
|
||
class BertConfig: | ||
def __init__(self, vocab_size_or_config_json_file): | ||
""" | ||
定制化的 config,__getattr__ 处进行判断 | ||
""" | ||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: | ||
json_config = json.loads(reader.read()) | ||
for key, value in json_config.items(): | ||
self.__dict__[key] = value | ||
|
||
def __getattr__(self, key): | ||
if key in self.__dict__: | ||
return self.__dict__[key] | ||
return None | ||
|
||
|
||
class BertOutput: | ||
last_hidden_state: torch.FloatTensor = None | ||
pooler_output: torch.FloatTensor = None | ||
attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
|
||
def __init__(self, last_hidden_state, pooler_output, attentions): | ||
self.last_hidden_state = last_hidden_state | ||
self.pooler_output = pooler_output | ||
self.attentions = attentions |
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# !/usr/bin/python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# @author: fly.sun <[email protected]> | ||
# @date: 2022/01/24 | ||
# | ||
"""""" | ||
|
||
import logging | ||
import os | ||
import sys | ||
|
||
_log_format = "%(asctime)s |- %(levelname).1s %(name)s - %(message)s" | ||
|
||
|
||
def _get_log_level(): | ||
""" get log level from env variable 'LOG_LEVEL' | ||
Returns: | ||
str|int: e.g. "INFO", 20, "DEBUG", 10, "ERROR", 40. | ||
""" | ||
level = os.environ.get("LOG_LEVEL", "INFO") | ||
try: | ||
level = int(level) | ||
except ValueError: | ||
assert isinstance(level, str) | ||
level = level.upper() | ||
return level | ||
|
||
|
||
# Attention do not change | ||
# unless you test it online and there is no DEBUG log | ||
# add stream handler to avoid default stream heandler NOSET level | ||
logging.basicConfig( | ||
format=_log_format, | ||
level=_get_log_level(), | ||
stream=sys.stdout) | ||
|
||
|
||
class LogLevel(object): | ||
CRITICAL = 50 | ||
ERROR = 40 | ||
WARNING = 30 | ||
INFO = 20 | ||
DEBUG = 10 | ||
DETAIL = 5 | ||
NOTSET = 0 | ||
|
||
|
||
def init_logger(logger_name=None, | ||
log_file=os.environ.get("LOG_FILE", ""), | ||
log_format=_log_format, | ||
level=_get_log_level()): | ||
""" init logger | ||
Args: | ||
logger_name(str): optional, default: None. | ||
log_file(str): optional, default: "". | ||
output log messages to file if specified, by default is set by env | ||
`LOG_FILE`. | ||
log_format(str): optional, default: | ||
"%(asctime)s |- %(levelname).1s %(name)s - %(message)s" | ||
level(int|logging.Level): set log level, by default it is set by env | ||
`LOG_LEVEL`, `INFO` level is used if not set. | ||
:: level | ||
- CRITICAL 50 | ||
- ERROR 40 | ||
- WARNING 30 | ||
- INFO 20 | ||
- DEBUG 10 | ||
- DETAIL 5 | ||
- NOTSET 0 | ||
Returns: | ||
logging.Logger: a logger instance | ||
""" | ||
logger = logging.getLogger(logger_name) | ||
logger.setLevel(level) | ||
|
||
if log_file: | ||
handler = logging.FileHandler(log_file) | ||
if log_format: | ||
formatter = logging.Formatter(log_format) | ||
handler.setFormatter(formatter) | ||
|
||
logger.addHandler(handler) | ||
|
||
return logger | ||
|
||
|
||
def _test(): | ||
logger = init_logger("test_logger", "test_file.log", | ||
level=_get_log_level()) | ||
logger.info("level: {}".format(os.environ.get("LOG_LEVEL", "INFO"))) | ||
import sys | ||
logger.info(sys.modules[__name__]) | ||
logger.info(logging.getLoggerClass()) | ||
logger.debug("test DEBUG 10") | ||
logger.info("test INFO 20") | ||
logger.warning("test WARNING 30") | ||
logger.error("test ERROR 40") | ||
logger.critical("test CRITICAL 50") | ||
|
||
if logger.isEnabledFor(logging.DEBUG): | ||
logger.warning("debug enabled!") | ||
if logger.isEnabledFor(LogLevel.DEBUG): | ||
logger.info("detail enabled") | ||
|
||
|
||
if __name__ == "__main__": | ||
_test() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# !/usr/bin/python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# @author: fly.sun <[email protected]> | ||
# @date: 2022/01/24 | ||
# | ||
"""""" |
Oops, something went wrong.