forked from johnsmith0031/alpaca_lora_4bit
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathamp_wrapper.py
26 lines (19 loc) · 917 Bytes
/
amp_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
class AMPWrapper:
def __init__(self, model, options=None):
self.model = model
self.options = options
if self.options is None:
self.options = {'enabled': True, 'device_type': 'cuda'}
def autocast_forward(self, *args, **kwargs):
with torch.amp.autocast(**self.options):
return self.model.non_autocast_forward(*args, **kwargs)
def autocast_generate(self, *args, **kwargs):
with torch.amp.autocast(**self.options):
return self.model.non_autocast_generate(*args, **kwargs)
def apply_forward(self):
self.model.non_autocast_forward = self.model.forward
self.model.forward = self.autocast_forward
def apply_generate(self):
self.model.non_autocast_generate = self.model.generate
self.model.generate = self.autocast_generate