diff --git a/.gitignore b/.gitignore index 723ef36..786a438 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ -.idea \ No newline at end of file +.idea + +__pycache__/ \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..5245235 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,11 @@ +## 0.0.4 (2022-01-25) + + +### Features + +* **bert:** add tokenizer part ([3683fa9](https://github.com/mmmwhy/pure_attention/commit/3683fa937a4355d616893cc6f99e2a7b69b2a2af)) +* **layers:** fix import for layerNorm ([eb61b31](https://github.com/mmmwhy/pure_attention/commit/eb61b313458ac18bf4b15271fee2cf7e39f8afde)) +* **nlp:** init basic bert code ([f9cb13a](https://github.com/mmmwhy/pure_attention/commit/f9cb13a3e811eb8c44ba8ff1373d688311426927)) + + + diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..4e4aaa4 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include *.txt +include *.md \ No newline at end of file diff --git a/README.md b/README.md index 43c02ba..5fd1f32 100644 --- a/README.md +++ b/README.md @@ -17,16 +17,21 @@ cv 和 nlp 中的很多方法和技巧也在相互影响,比如大规模的预 2、java 环境下使用 onnx 的在线推理部署,使用 onnx 的原因是我在公司用的是 TensorFlow 做推理,我不想和公司的代码一致。 # todo -- 第一阶段的目标:实现 NLP 和 CV 的典型任务,并评估下游效果。 -- [ ] Transformer 的 pytorch 实现; -- [ ] 多版本的 Bert 的实现; -- [ ] NLP 下游任务 序列标注、分类 的实现,并在公开数据集上进行评估。这里主要是想证明实现的 Bert 效果是符合预期的; +第一阶段:实现 NLP 和 CV 的典型任务,并评估下游效果。 +- [x] Pytorch 实现 Transformer 的 encode 阶段,并实现 bert。 + > 参考 [transformers](https://github.com/huggingface/transformers) 的设计,但只保留与关键 encode 相关的代码,简化代码量。 + 并保持与原始 huggingface encode 的结果一致, 使用方法和一致性校验可以参考 [backbone_bert](pure_attention/backbone_bert/README.md) +- [ ] Pytorch 实现 Transformer 的 decode 阶段,并实现 seq2seq 任务。 + > todo +- [ ] NLP 下游任务 序列标注、分类 的实现,并在公开数据集上进行评估,这里主要是想证明实现的 backbone 效果是符合预期的; + > todo - [ ] 实现 Vit,并在下游任务上验证实现 Vit 的效果是否符合预期; + > todo -- 第二阶段的目标:增加 NLP 和 CV 的其余常见任务,扩增项目的能力范围。 + 第二阶段:增加 NLP 和 CV 的其余常见任务,扩增项目的能力范围。 - [ ] UNILM; - [ ] MAE; - [ ] GPT系列; - [ ] seq2seq,搞一个翻译任务; - [ ] 实现模型的 onnx export; -- [ ] 实现 java 下的 onnx 推理过程; \ No newline at end of file +- [ ] 实现 java 下的 onnx 推理过程; diff --git a/images/bert-base-chinese.png b/images/bert-base-chinese.png new file mode 100644 index 0000000..923b262 Binary files /dev/null and b/images/bert-base-chinese.png differ diff --git a/images/chinese-roberta-wwm-ext-large.png b/images/chinese-roberta-wwm-ext-large.png new file mode 100644 index 0000000..bd4b06e Binary files /dev/null and b/images/chinese-roberta-wwm-ext-large.png differ diff --git a/images/chinese-roberta-wwm-ext.png b/images/chinese-roberta-wwm-ext.png new file mode 100644 index 0000000..68b73e1 Binary files /dev/null and b/images/chinese-roberta-wwm-ext.png differ diff --git a/package.json b/package.json new file mode 100644 index 0000000..1b9eac9 --- /dev/null +++ b/package.json @@ -0,0 +1,12 @@ +{ + "name": "pure_attention", + "version": "0.0.5", + "description": "Generate a changelog from git metadata", + "repository": { + "type": "git", + "url": "https://github.com/mmmwhy/pure_attention" + }, + "scripts": { + "changelog": "conventional-changelog -p angular -i CHANGELOG.md -s -r 0" + } +} diff --git a/pure_attention/.DS_Store b/pure_attention/.DS_Store new file mode 100644 index 0000000..461e1f5 Binary files /dev/null and b/pure_attention/.DS_Store differ diff --git a/pure_attention/__init__.py b/pure_attention/__init__.py new file mode 100644 index 0000000..db832f3 --- /dev/null +++ b/pure_attention/__init__.py @@ -0,0 +1,7 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# +# @author: fly.sun +# @date: 2022/01/25 +# +"""""" diff --git a/pure_attention/backbone_bert/README.md b/pure_attention/backbone_bert/README.md new file mode 100644 index 0000000..6106c08 --- /dev/null +++ b/pure_attention/backbone_bert/README.md @@ -0,0 +1,71 @@ +# 介绍 +[transformers](https://github.com/huggingface/transformers) 为了适应非常多种模型结构,结构变得非常复杂。 + +我在参考 +[transformers](https://github.com/huggingface/transformers) 、 +[bert4pytorch](https://github.com/MuQiuJun-AI/bert4pytorch) 、 +[Read_Bert_Code](https://github.com/DA-southampton/Read_Bert_Code) +的代码基础上,对结构进行了一些调整,提高了代码的易读性,并和 [transformers](https://github.com/huggingface/transformers) 的结果完全一致。 + +# 使用 + +``` python +from pure_attention.common.nlp.tokenization import Tokenizer +from pure_attention.backbone_bert.bert_model import BertModel + +bert_model_path = "/data/pretrain_modal/bert-base-chinese" +test_query = "结果一致性验证" + +tokenizer = Tokenizer(bert_model_path + "/vocab.txt") +bert = BertModel(bert_model_path) +tokens_ids, segments_ids = tokenizer.encode(test_query, max_len=64) + +bert_pooler_output = bert(tokens_ids, token_type_ids=segments_ids).pooler_output + +``` + + +# 结果一致性 +分别在下边三个常用中文 bert 上进行测试,结果与 transformers 完全一致。 +- [bert-base-chinese](https://huggingface.co/bert-base-chinese) + + ![](../../images/bert-base-chinese.png) + + +- [chinese-roberta-wwm-ext](https://huggingface.co/hfl/chinese-roberta-wwm-ext) + + ![](../../images/chinese-roberta-wwm-ext.png) + + +- [chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large) + + ![](../../images/chinese-roberta-wwm-ext-large.png) + + +``` python +import torch +from transformers import BertModel +from transformers import BertTokenizer + +from pure_attention.common.nlp.tokenization import Tokenizer as LocalTokenizer +from pure_attention.backbone_bert.bert_model import BertModel as OurBertModel + +bert_model_path = "/data/pretrain_modal/chinese-roberta-wwm-ext-large" +test_query = "结果一致性验证" + +text_tokenizer = BertTokenizer.from_pretrained(bert_model_path, do_lower_case=True) +bert_model = BertModel.from_pretrained(bert_model_path) + +tensor_caption = text_tokenizer.encode(test_query, return_tensors="pt", padding='max_length', truncation=True, + max_length=64) + +origin_bert_pooler_output = bert_model(tensor_caption).pooler_output + +tokenizer = LocalTokenizer(bert_model_path + "/vocab.txt") +bert = OurBertModel(bert_model_path) +tokens_ids, segments_ids = tokenizer.encode(test_query, max_len=64) + +our_bert_pooler_output = bert(tokens_ids, token_type_ids=segments_ids).pooler_output + +print("check result:", torch.cosine_similarity(origin_bert_pooler_output, our_bert_pooler_output)) +``` \ No newline at end of file diff --git a/common/__init__.py b/pure_attention/backbone_bert/__init__.py similarity index 100% rename from common/__init__.py rename to pure_attention/backbone_bert/__init__.py diff --git a/common/bert/bert_layer.py b/pure_attention/backbone_bert/bert_layer.py similarity index 99% rename from common/bert/bert_layer.py rename to pure_attention/backbone_bert/bert_layer.py index dc64961..cd254a2 100644 --- a/common/bert/bert_layer.py +++ b/pure_attention/backbone_bert/bert_layer.py @@ -10,8 +10,8 @@ import torch from torch import nn -from common.activate import activations -from common.layers import LayerNorm as BertLayerNorm +from pure_attention.common.activate import activations +from pure_attention.common.layers import LayerNorm as BertLayerNorm class BertEmbeddings(nn.Module): diff --git a/common/bert/bert_model.py b/pure_attention/backbone_bert/bert_model.py similarity index 86% rename from common/bert/bert_model.py rename to pure_attention/backbone_bert/bert_model.py index 4e4de81..b201227 100644 --- a/common/bert/bert_model.py +++ b/pure_attention/backbone_bert/bert_model.py @@ -6,29 +6,11 @@ # """""" import os -import json +from pure_attention.backbone_bert.package import BertConfig, BertOutput import torch from torch import nn -from common.bert.bert_layer import BertLayer -from common.bert.bert_layer import BertEmbeddings -from common.layers import LayerNorm as BertLayerNorm - - -class BertConfig: - def __init__(self, vocab_size_or_config_json_file): - """ - 定制化的 config,__getattr__ 处进行判断 - """ - with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - print(f"{key}:{value}({type(value)})") - - def __getattr__(self, key): - if key in self.__dict__: - return self.__dict__[key] - return None +from pure_attention.backbone_bert.bert_layer import BertLayer, BertEmbeddings +from pure_attention.common.layers import LayerNorm as BertLayerNorm class BertEncoder(nn.Module): @@ -75,8 +57,7 @@ def __init__(self, config): self.activation = nn.Tanh() def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. + # 只取出第一个 token 也就是 cls 位置上的 embedding 进行 dense 变形 first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) @@ -95,6 +76,7 @@ def __init__(self, config_path): self.init_weights() self.from_pretrained(os.path.join(os.path.join(config_path, "pytorch_model.bin"))) + self.eval() def init_weights(self): self.apply(self._init_weights) @@ -156,6 +138,7 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for @@ -190,13 +173,7 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) - # add hidden_states and attentions if they are here - outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] + outputs = BertOutput(last_hidden_state=sequence_output, pooler_output=pooled_output, + attentions=encoder_outputs[1:]) - # sequence_output, pooled_output, (hidden_states), (attentions) return outputs - - -if __name__ == "__main__": - bert = BertModel("/data1/lifeiyang/data/pretrain_modal/bert-base-chinese") - print("pass") diff --git a/pure_attention/backbone_bert/package.py b/pure_attention/backbone_bert/package.py new file mode 100644 index 0000000..1985fcd --- /dev/null +++ b/pure_attention/backbone_bert/package.py @@ -0,0 +1,37 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# +# @author: fly.sun +# @date: 2022/01/25 +# +"""""" +import json +import torch +from typing import Optional, Tuple + + +class BertConfig: + def __init__(self, vocab_size_or_config_json_file): + """ + 定制化的 config,__getattr__ 处进行判断 + """ + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + + def __getattr__(self, key): + if key in self.__dict__: + return self.__dict__[key] + return None + + +class BertOutput: + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + def __init__(self, last_hidden_state, pooler_output, attentions): + self.last_hidden_state = last_hidden_state + self.pooler_output = pooler_output + self.attentions = attentions diff --git a/common/bert/__init__.py b/pure_attention/common/__init__.py similarity index 100% rename from common/bert/__init__.py rename to pure_attention/common/__init__.py diff --git a/common/activate.py b/pure_attention/common/activate.py similarity index 100% rename from common/activate.py rename to pure_attention/common/activate.py diff --git a/common/layers.py b/pure_attention/common/layers.py similarity index 100% rename from common/layers.py rename to pure_attention/common/layers.py diff --git a/pure_attention/common/logger.py b/pure_attention/common/logger.py new file mode 100644 index 0000000..4d39d6b --- /dev/null +++ b/pure_attention/common/logger.py @@ -0,0 +1,113 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# +# @author: fly.sun +# @date: 2022/01/24 +# +"""""" + +import logging +import os +import sys + +_log_format = "%(asctime)s |- %(levelname).1s %(name)s - %(message)s" + + +def _get_log_level(): + """ get log level from env variable 'LOG_LEVEL' + + Returns: + str|int: e.g. "INFO", 20, "DEBUG", 10, "ERROR", 40. + + """ + level = os.environ.get("LOG_LEVEL", "INFO") + try: + level = int(level) + except ValueError: + assert isinstance(level, str) + level = level.upper() + return level + + +# Attention do not change +# unless you test it online and there is no DEBUG log +# add stream handler to avoid default stream heandler NOSET level +logging.basicConfig( + format=_log_format, + level=_get_log_level(), + stream=sys.stdout) + + +class LogLevel(object): + CRITICAL = 50 + ERROR = 40 + WARNING = 30 + INFO = 20 + DEBUG = 10 + DETAIL = 5 + NOTSET = 0 + + +def init_logger(logger_name=None, + log_file=os.environ.get("LOG_FILE", ""), + log_format=_log_format, + level=_get_log_level()): + """ init logger + + Args: + logger_name(str): optional, default: None. + log_file(str): optional, default: "". + output log messages to file if specified, by default is set by env + `LOG_FILE`. + log_format(str): optional, default: + "%(asctime)s |- %(levelname).1s %(name)s - %(message)s" + level(int|logging.Level): set log level, by default it is set by env + `LOG_LEVEL`, `INFO` level is used if not set. + :: level + - CRITICAL 50 + - ERROR 40 + - WARNING 30 + - INFO 20 + - DEBUG 10 + - DETAIL 5 + - NOTSET 0 + + Returns: + logging.Logger: a logger instance + + """ + logger = logging.getLogger(logger_name) + logger.setLevel(level) + + if log_file: + handler = logging.FileHandler(log_file) + if log_format: + formatter = logging.Formatter(log_format) + handler.setFormatter(formatter) + + logger.addHandler(handler) + + return logger + + +def _test(): + logger = init_logger("test_logger", "test_file.log", + level=_get_log_level()) + logger.info("level: {}".format(os.environ.get("LOG_LEVEL", "INFO"))) + import sys + logger.info(sys.modules[__name__]) + logger.info(logging.getLoggerClass()) + logger.debug("test DEBUG 10") + logger.info("test INFO 20") + logger.warning("test WARNING 30") + logger.error("test ERROR 40") + logger.critical("test CRITICAL 50") + + if logger.isEnabledFor(logging.DEBUG): + logger.warning("debug enabled!") + if logger.isEnabledFor(LogLevel.DEBUG): + logger.info("detail enabled") + + +if __name__ == "__main__": + _test() diff --git a/pure_attention/common/nlp/__init__.py b/pure_attention/common/nlp/__init__.py new file mode 100644 index 0000000..31c0fbc --- /dev/null +++ b/pure_attention/common/nlp/__init__.py @@ -0,0 +1,7 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# +# @author: fly.sun +# @date: 2022/01/24 +# +"""""" diff --git a/pure_attention/common/nlp/tokenization.py b/pure_attention/common/nlp/tokenization.py new file mode 100644 index 0000000..45ef72f --- /dev/null +++ b/pure_attention/common/nlp/tokenization.py @@ -0,0 +1,418 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# +# @author: fly.sun +# @date: 2022/01/24 +# +"""""" +import collections +import os +import unicodedata +from io import open +import torch +import numpy as np +from pure_attention.common.logger import init_logger + +logger = init_logger(__name__) + + +def truncate_sequences(maxlen, indices, *sequences): + """截断总长度至不超过 maxlen + """ + sequences = [s for s in sequences if s] + if not isinstance(indices, (list, tuple)): + indices = [indices] * len(sequences) + + while True: + lengths = [len(s) for s in sequences] + if sum(lengths) > maxlen: + # 从较长的一侧进行 pop + i = np.argmax(lengths) + sequences[i].pop(indices[i]) + else: + return sequences + + +def load_vocab(vocab_file): + """ + 加载词典文件到 dict + """ + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding="utf-8") as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def whitespace_tokenize(text): + """ + 去除文本中的空白符 + """ + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class Tokenizer(object): + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]"): + """ + + 参数: + vocab_file: + 词典文件 + do_lower_case: + 是否转换成小写 + do_basic_tokenize: + 分词前,是否进行基础的分词 + unk_token: + 未知词标记 + sep_token: + 句子切分标记,当只有一句话作为输入时,此标记知识作为结束符;当有两句话作为输入时,此标记作为分隔符、最后一句话的结束符 + pad_token: + padding 填充标记 + cls_token: + 分类标记,位于整个序列的第一个 + mask_token: + mask标记 + + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'.".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=(unk_token, sep_token, pad_token, cls_token, mask_token)) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.unk_token = unk_token + self.sep_token = sep_token + self.pad_token = pad_token + self.cls_token = cls_token + self.mask_token = mask_token + + def tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + if self.cls_token is not None: + split_tokens.insert(0, self.cls_token) + if self.sep_token is not None: + split_tokens.append(self.sep_token) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """ + tokens 转为 vocab 中的 id + """ + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + return ids + + def convert_ids_to_tokens(self, ids): + """ + ids 转为词表中的 token + """ + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def encode( + self, + first_text, + second_text=None, + is_padding=True, + max_len=512, + truncate_from='right' + ): + """ + 输出文本对应 token id 和 segment id + """ + if isinstance(first_text, str): + first_tokens = self.tokenize(first_text) + else: + first_tokens = first_text + + if second_text is None: + second_tokens = None + elif isinstance(second_text, str): + second_tokens = self.tokenize(second_text) + else: + second_tokens = second_text + + if max_len is not None: + if truncate_from == 'right': + index = -2 + elif truncate_from == 'left': + index = 1 + else: + index = truncate_from + if second_text is not None: + max_len += 1 + truncate_sequences(max_len, index, first_tokens, second_tokens) + + # token_ids 等价于 input_ids,segment_ids 等价于 token_type_ids + first_token_ids = self.convert_tokens_to_ids(first_tokens) + first_segment_ids = [0] * len(first_token_ids) + + if second_text is not None: + second_tokens = second_tokens[1:] + second_token_ids = self.convert_tokens_to_ids(second_tokens) + second_segment_ids = [1] * len(second_token_ids) + first_token_ids.extend(second_token_ids) + first_segment_ids.extend(second_segment_ids) + + # 做一个 padding 操作 + if is_padding: + while len(first_token_ids) < max_len: + first_token_ids.append(self.vocab[self.pad_token]) + first_segment_ids.append(self.vocab[self.pad_token]) + + if max_len and len(first_token_ids) > max_len: + first_token_ids = first_token_ids[:max_len] + first_segment_ids = first_segment_ids[:max_len] + + first_token_ids = torch.tensor([first_token_ids]) + first_segment_ids = torch.tensor([first_segment_ids]) + + return first_token_ids, first_segment_ids + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """ + 文本切分成 token,这个操作可能会导致序列标注类任务位置无法对齐,ner 类任务中需注意这一点。 + + """ + text = self._clean_text(text) + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/script/build.sh b/script/build.sh new file mode 100755 index 0000000..8aa0b9f --- /dev/null +++ b/script/build.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash +# +# @author: fy.li +# @date: 2021/10/12 +# + +SELF_DIR=$(cd "$(dirname "$0")" || exit 1; pwd) +source "$SELF_DIR/common.sh" + +# 使用 pep8 为校验标准 +function format() { + if ! which "autopep8" ; then + cecho r "autopep8 not install, run:pip install autopep8==1.5.7" + exit 1 + fi + + # 只对最近一次提交 commit 内的文件进行 autopep8 + git log --name-only -1|grep .py|while read py_file; + do + if [ -f "$py_file" ] ; then + autopep8 --in-place --recursive --max-line-length=120 "$py_file" + fi + done + + cecho g "auto format done" +} + +# 更新版本号 +function update_version() { + python -m script.fetch_newest_version +} + + +# 生成 changelog +function changelog() { + if ! which "npm" ; then + cecho r "see: {} for detail" + exit 1 + fi + npm run changelog +} + +# 完成发包 +function release_package_to_pypi() { + python setup.py sdist + + twine upload dist/* + + python setup.py sdist upload -r pypi + rm -rf dist pure_attention.egg-info +} + +function usage() { + cat << EOF + Usage: + $0 sub_command + sub_command: + - format: format code (running by autopep8) + - release: release code to pypi + - all: run all +EOF + exit 1 +} + + +function args() { + if [[ $# -lt 1 ]]; then + usage + fi + + case $1 in + format|f) + cecho y ">>>>>>> formatting ..." + changelog + format + update_version + ;; + release|r|all) + cecho y ">>>>>>> release code to pypi ..." + changelog + format + update_version + release_package_to_pypi + ;; + *) + ;; + esac + +} + +args "$@" \ No newline at end of file diff --git a/script/common.sh b/script/common.sh index d36d8b3..b7f4e37 100755 --- a/script/common.sh +++ b/script/common.sh @@ -29,4 +29,3 @@ function cecho { echo -e "${text}" } -autopep8 --in-place --recursive --max-line-length=120 . \ No newline at end of file diff --git a/script/fetch_newest_version.py b/script/fetch_newest_version.py new file mode 100644 index 0000000..0b78e5a --- /dev/null +++ b/script/fetch_newest_version.py @@ -0,0 +1,71 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# +# @author: fly.sun +# @date: 2022/01/25 +# +""" +得到最近的版本号,用于自动更新版本。 +""" + +import json +import os +import urllib.request + +module_root_directory = os.path.join(os.path.abspath(__file__).split("pure_attention")[0], "pure_attention") + + +# 读取文件内容 +def read_file(filename): + with open(os.path.join(module_root_directory, filename), encoding='utf-8') as f: + long_description = f.read() + return long_description + + +# 从 package.json 读取版本号 +def fetch_package_version(): + version = json.loads(read_file("package.json"))["version"] + return version + + +# 解析远程版本号 +def fetch_remote_versions() -> str: + url = "https://pypi.org/pypi/pure_attention/json" + data = json.load(urllib.request.urlopen(url)) + versions = data["info"]["version"] + return versions + + +# 替换 package.json 内的版本号 +def alter(old_str, new_str): + file_data = "" + with open(os.path.join(module_root_directory, "package.json"), encoding='utf-8') as f: + for line in f: + if old_str in line: + line = line.replace(old_str, new_str) + file_data += line + + with open(os.path.join(module_root_directory, "package.json"), "w", encoding="utf-8") as f: + f.write(file_data) + + +# 组织在一起 +def update_version(): + local_version_part = fetch_package_version().split(".") + remote_version_part = fetch_remote_versions().split(".") + + # 如果本地的版本号大于远程的,则直接上 + for i, j in zip(local_version_part, remote_version_part): + if i > j: + return ".".join(local_version_part) + + # 说明本地版本号要么比远程的旧,要么没有一样 + remote_version_part[-1] = str(int(remote_version_part[-1]) + 1) + + # 进行替换 + alter(".".join(local_version_part), ".".join(remote_version_part)) + return ".".join(remote_version_part) + + +if __name__ == "__main__": + print(update_version()) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..438d0d3 --- /dev/null +++ b/setup.py @@ -0,0 +1,30 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# +# @author: fly.sun +# @date: 2022/01/22 +# + +from setuptools import setup, find_packages +from script.fetch_newest_version import read_file, fetch_package_version + + +# 获取依赖 +def read_requirements(filename): + return [line.strip() for line in read_file(filename).splitlines() + if not line.startswith('#')] + + +setup( + name='pure_attention', + version=fetch_package_version(), + description='use pure attention implement cv/nlp backbone', + long_description=read_file('README.md'), + long_description_content_type="text/markdown", + license='Apache License 2.0', + url="https://github.com/mmmwhy/pure_attention", + author='mmmwhy', + author_email="mmmwhy@mail.ustc.edu.cn", + install_requires=read_requirements("requirements.txt"), + packages=find_packages(exclude=['tests']) +)