@@ -192,7 +192,7 @@ def run(self, overwrite: bool = False,
192192 indexer_b_s = 0
193193
194194 pixel_values_merged = self .model_zoo .get ('processor' )(imgs , return_tensors = "pt" ).pixel_values
195- generated_ids_merged = self .model_ocr .generate (
195+ generated_ids_merged = self .model_zoo . get ( 'ocr' ) .generate (
196196 pixel_values_merged .to (self .device ))
197197 generated_text_merged = self .model_zoo .get ('processor' ).batch_decode (
198198 generated_ids_merged , skip_special_tokens = True )
@@ -215,7 +215,7 @@ def run(self, overwrite: bool = False,
215215 indexer_b_s = 0
216216
217217 pixel_values_merged = self .model_zoo .get ('processor' )(imgs , return_tensors = "pt" ).pixel_values
218- generated_ids_merged = self .model_ocr .generate (
218+ generated_ids_merged = self .model_zoo . get ( 'ocr' ) .generate (
219219 pixel_values_merged .to (self .device ))
220220 generated_text_merged = self .model_zoo .get ('processor' ).batch_decode (
221221 generated_ids_merged , skip_special_tokens = True )
@@ -235,7 +235,7 @@ def run(self, overwrite: bool = False,
235235 indexer_b_s = 0
236236
237237 pixel_values_merged = self .model_zoo .get ('processor' )(imgs , return_tensors = "pt" ).pixel_values
238- generated_ids_merged = self .model_ocr .generate (
238+ generated_ids_merged = self .model_zoo . get ( 'ocr' ) .generate (
239239 pixel_values_merged .to (self .device ))
240240 generated_text_merged = self .model_zoo .get ('processor' ).batch_decode (
241241 generated_ids_merged , skip_special_tokens = True )
@@ -253,7 +253,7 @@ def run(self, overwrite: bool = False,
253253 indexer_b_s = 0
254254
255255 pixel_values_merged = self .model_zoo .get ('processor' )(imgs , return_tensors = "pt" ).pixel_values
256- generated_ids_merged = self .model_ocr .generate (
256+ generated_ids_merged = self .model_zoo . get ( 'ocr' ) .generate (
257257 pixel_values_merged .to (self .device ))
258258 generated_text_merged = self .model_zoo .get ('processor' ).batch_decode (
259259 generated_ids_merged , skip_special_tokens = True )
@@ -270,7 +270,7 @@ def run(self, overwrite: bool = False,
270270 indexer_b_s = 0
271271
272272 pixel_values_merged = self .model_zoo .get ('processor' )(imgs , return_tensors = "pt" ).pixel_values
273- generated_ids_merged = self .model_ocr .generate (pixel_values_merged .to (self .device ))
273+ generated_ids_merged = self .model_zoo . get ( 'ocr' ) .generate (pixel_values_merged .to (self .device ))
274274 generated_text_merged = self .model_zoo .get ('processor' ).batch_decode (generated_ids_merged , skip_special_tokens = True )
275275
276276 extracted_texts = extracted_texts + generated_text_merged
@@ -746,10 +746,10 @@ def run(self, overwrite: bool = False,
746746
747747
748748 self .logger .debug ("processing next %d lines" , len (imgs ))
749- preds = self .prediction_model .predict (imgs , verbose = 0 )
749+ preds = self .model_zoo . get ( 'ocr' ) .predict (imgs , verbose = 0 )
750750
751751 if len (indices_ver )> 0 :
752- preds_flipped = self .prediction_model .predict (imgs_ver_flipped , verbose = 0 )
752+ preds_flipped = self .model_zoo . get ( 'ocr' ) .predict (imgs_ver_flipped , verbose = 0 )
753753 preds_max_fliped = np .max (preds_flipped , axis = 2 )
754754 preds_max_args_flipped = np .argmax (preds_flipped , axis = 2 )
755755 pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped [:,:]!= self .end_character
@@ -779,10 +779,10 @@ def run(self, overwrite: bool = False,
779779 preds [indices_to_be_replaced ,:,:] = \
780780 preds_flipped [indices_where_flipped_conf_value_is_higher , :, :]
781781 if dir_in_bin is not None :
782- preds_bin = self .prediction_model .predict (imgs_bin , verbose = 0 )
782+ preds_bin = self .model_zoo . get ( 'ocr' ) .predict (imgs_bin , verbose = 0 )
783783
784784 if len (indices_ver )> 0 :
785- preds_flipped = self .prediction_model .predict (imgs_bin_ver_flipped , verbose = 0 )
785+ preds_flipped = self .model_zoo . get ( 'ocr' ) .predict (imgs_bin_ver_flipped , verbose = 0 )
786786 preds_max_fliped = np .max (preds_flipped , axis = 2 )
787787 preds_max_args_flipped = np .argmax (preds_flipped , axis = 2 )
788788 pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped [:,:]!= self .end_character
@@ -814,7 +814,7 @@ def run(self, overwrite: bool = False,
814814
815815 preds = (preds + preds_bin ) / 2.
816816
817- pred_texts = decode_batch_predictions (preds , self .num_to_char )
817+ pred_texts = decode_batch_predictions (preds , self .model_zoo . get ( ' num_to_char' ) )
818818
819819 preds_max = np .max (preds , axis = 2 )
820820 preds_max_args = np .argmax (preds , axis = 2 )
0 commit comments