@@ -1608,15 +1608,51 @@ def _phi3_self_attn_sdpa_forward(
16081608 return attn_output , None , past_key_value
16091609
16101610
1611+ @torch .jit .script
1612+ def select_ext_factor (seq_len : torch .Tensor , max_pos_embeddings : torch .Tensor , short_factor : torch .Tensor , long_factor : torch .Tensor ):
1613+ if seq_len > max_pos_embeddings :
1614+ return long_factor
1615+ return short_factor
1616+
1617+ def long_rope (self , x , position_ids , seq_len = None ):
1618+ seq_len = torch .max (position_ids ) + 1
1619+ original_max_position_embeddings = (
1620+ self .original_max_position_embeddings
1621+ if hasattr (self , "original_max_positional_embeddings" ) else self .config .original_max_position_embeddings
1622+ )
1623+ max_position_embeddings = self .max_position_embeddings if hasattr (self , "max_position_embeddings" ) else self .config .max_position_embeddings
1624+ inv_freq = select_ext_factor (
1625+ seq_len ,
1626+ torch .tensor (original_max_position_embeddings ),
1627+ self .inv_freq ,
1628+ self .long_inv_freq
1629+ )
1630+
1631+ inv_freq_expanded = inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 )
1632+ position_ids_expanded = position_ids [:, None , :].float ()
1633+
1634+ # Force float32 since bfloat16 loses precision on long contexts
1635+ # See https://github.com/huggingface/transformers/pull/29285
1636+ device_type = x .device .type
1637+ device_type = device_type if isinstance (device_type , str ) and device_type != "mps" else "cpu"
1638+ freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
1639+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
1640+
1641+ scale = max_position_embeddings / original_max_position_embeddings
1642+ if scale <= 1.0 :
1643+ scaling_factor = 1.0
1644+ else :
1645+ scaling_factor = math .sqrt (1 + math .log (scale ) / math .log (original_max_position_embeddings ))
1646+ cos = emb .cos () * scaling_factor
1647+ sin = emb .sin () * scaling_factor
1648+ return cos , sin
1649+
1650+
16111651class Phi3ModelPatcher (DecoderModelPatcher ):
16121652 def __enter__ (self ):
16131653 super ().__enter__ ()
16141654
16151655 # currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues
1616- if self ._model .config .max_position_embeddings != getattr (
1617- self ._model .config , "original_max_position_embeddings" , self ._model .config .max_position_embeddings
1618- ):
1619- self ._model .config .max_position_embeddings = self ._model .config .original_max_position_embeddings
16201656
16211657 if is_transformers_version (">=" , "4.42.0" ) and is_transformers_version ("<" , "4.48.0" ):
16221658 self ._model .model ._orig_forward = self ._model .model .forward
@@ -1643,6 +1679,17 @@ def __enter__(self):
16431679 layer .self_attn .rotary_emb .inv_freq = 1.0 / (
16441680 rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
16451681 )
1682+
1683+ if hasattr (self ._model .model , "rotary_emb" ) and getattr (self ._model .model .rotary_emb , "rope_type" , "default" ) == "longrope" :
1684+ long_inv_freq , _ = self ._model .model .rotary_emb .rope_init_fn (self ._model .config , torch .device ("cpu" ), seq_len = self ._model .config .original_max_position_embeddings + 1 )
1685+ self ._model .model .rotary_emb .long_inv_freq = long_inv_freq
1686+ self ._model .model .rotary_emb ._orig_forward = self ._model .model .rotary_emb .forward
1687+ self ._model .model .rotary_emb .forward = types .MethodType (long_rope , self ._model .model .rotary_emb )
1688+ elif self ._model .config .max_position_embeddings != getattr (
1689+ self ._model .config , "original_max_position_embeddings" , self ._model .config .max_position_embeddings
1690+ ):
1691+ self ._model .config .max_position_embeddings = self ._model .config .original_max_position_embeddings
1692+
16461693
16471694 def __exit__ (self , exc_type , exc_value , traceback ):
16481695 super ().__exit__ (exc_type , exc_value , traceback )
@@ -1653,6 +1700,8 @@ def __exit__(self, exc_type, exc_value, traceback):
16531700 for layer in self ._model .model .layers :
16541701 if hasattr (layer .self_attn , "_orig_forward" ):
16551702 layer .self_attn .forward = layer .self_attn ._orig_forward
1703+ if hasattr (self ._model .model , "rotary_emb" ) and hasattr (self ._model .model .rotary_emb , "_orig_forward" ):
1704+ self ._model .model .rotary_emb .forward = self ._model .model .rotary_emb ._orig_forward
16561705
16571706
16581707# Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756
0 commit comments