forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
145 lines (119 loc) · 5.76 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..._utils import pad_vocab_size
from ...functional import PositionEmbeddingType, Tensor
from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear,
Embedding, LayerNorm)
from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
PretrainedConfig)
class GPTNeoXDecoderLayer(Module):
def __init__(self, config: PretrainedConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.config = config
hidden_size = config.hidden_size
dtype = config.dtype
tp_group = config.mapping.tp_group
tp_size = config.mapping.tp_size
self.input_layernorm = LayerNorm(normalized_shape=hidden_size,
dtype=dtype)
self.post_attention_layernorm = LayerNorm(normalized_shape=hidden_size,
dtype=dtype)
layers_range = config.mapping.pp_layers(config.num_hidden_layers)
local_layer_idx = layer_idx - layers_range[0]
self.attention = Attention(
local_layer_idx=local_layer_idx,
hidden_size=hidden_size,
num_attention_heads=config.num_attention_heads,
rotary_embedding_percentage=config.rotary_pct,
rotary_embedding_base=config.rotary_emb_base,
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
max_position_embeddings=config.max_position_embeddings,
dtype=dtype,
attention_mask_type=AttentionMaskType.causal,
bias=True,
tp_group=tp_group,
tp_size=tp_size)
self.mlp = MLP(hidden_size=hidden_size,
ffn_hidden_size=hidden_size * 4,
hidden_act=config.hidden_act,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size)
def forward(self,
hidden_states: Tensor,
attention_mask=None,
use_cache=False,
kv_cache_params=None,
attention_params=None):
residual = hidden_states
input_layernorm_output = self.input_layernorm(hidden_states)
post_attention_layernorm_output = self.post_attention_layernorm(
hidden_states)
attention_output = self.attention(input_layernorm_output,
attention_mask=attention_mask,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
norm_before_bmm1=True)
if use_cache:
attention_output, presents = attention_output
feed_forward_hidden_states = self.mlp(post_attention_layernorm_output)
hidden_states = attention_output + feed_forward_hidden_states + residual
if use_cache:
return (hidden_states, presents)
return hidden_states
class GPTNeoXModel(Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.vocab_embedding = Embedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
dtype=config.dtype)
self.layers = DecoderLayerList(GPTNeoXDecoderLayer, config)
self.ln_f = LayerNorm(normalized_shape=config.hidden_size,
dtype=config.dtype)
def forward(self,
input_ids: Tensor,
position_ids=None,
use_cache=False,
attention_mask=None,
kv_cache_params=None,
attention_params=None):
hidden_states = self.vocab_embedding(input_ids)
hidden_states = self.layers(hidden_states,
use_cache=use_cache,
attention_mask=attention_mask,
kv_cache_params=kv_cache_params,
attention_params=attention_params)
if use_cache:
hidden_states, presents = hidden_states
hidden_states = self.ln_f(hidden_states)
if use_cache:
return (hidden_states, tuple(presents))
return hidden_states
class GPTNeoXForCausalLM(DecoderModelForCausalLM):
def __init__(self, config: PretrainedConfig):
transformer = GPTNeoXModel(config)
vocab_size_padded = pad_vocab_size(config.vocab_size,
config.mapping.tp_size)
lm_head = ColumnLinear(config.hidden_size,
vocab_size_padded,
bias=False,
dtype=config.dtype,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
gather_output=True)
super().__init__(config, transformer, lm_head)