Skip to content

Commit e856a65

Browse files
Overrides _post_quantize to reset generate_function graph after quantization (#2436)
* Overrides _post_quantize to reset generate_function graph * added _post_quantize override to image_to_image
1 parent bc823a3 commit e856a65

File tree

4 files changed

+20
-0
lines changed

4 files changed

+20
-0
lines changed

keras_hub/src/models/causal_lm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,3 +424,8 @@ def export_to_transformers(self, path):
424424
)
425425

426426
export_to_safetensors(self, path)
427+
428+
def _post_quantize(self, mode, **kwargs):
429+
super()._post_quantize(mode, **kwargs)
430+
# Reset the compiled generate function.
431+
self.generate_function = None

keras_hub/src/models/image_to_image.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,8 @@ def generate(images, x):
415415
# Image-to-image.
416416
outputs = [generate(*x) for x in inputs]
417417
return self._normalize_generate_outputs(outputs, input_is_scalar)
418+
419+
def _post_quantize(self, mode, **kwargs):
420+
super()._post_quantize(mode, **kwargs)
421+
# Reset the compiled generate function.
422+
self.generate_function = None

keras_hub/src/models/inpaint.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,3 +518,8 @@ def generate(images, masks, x):
518518
# Inpaint.
519519
outputs = [generate(*x) for x in inputs]
520520
return self._normalize_generate_outputs(outputs, input_is_scalar)
521+
522+
def _post_quantize(self, mode, **kwargs):
523+
super()._post_quantize(mode, **kwargs)
524+
# Reset the compiled generate function.
525+
self.generate_function = None

keras_hub/src/models/text_to_image.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,8 @@ def generate(x):
345345
# Text-to-image.
346346
outputs = [generate(x) for x in inputs]
347347
return self._normalize_generate_outputs(outputs, input_is_scalar)
348+
349+
def _post_quantize(self, mode, **kwargs):
350+
super()._post_quantize(mode, **kwargs)
351+
# Reset the compiled generate function.
352+
self.generate_function = None

0 commit comments

Comments
 (0)