Skip to content

Commit b3c5620

Browse files
committed
eynollah_ocr: actually replace the model calls
1 parent e510a45 commit b3c5620

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/eynollah/eynollah_ocr.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def run(self, overwrite: bool = False,
192192
indexer_b_s = 0
193193

194194
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
195-
generated_ids_merged = self.model_ocr.generate(
195+
generated_ids_merged = self.model_zoo.get('ocr').generate(
196196
pixel_values_merged.to(self.device))
197197
generated_text_merged = self.model_zoo.get('processor').batch_decode(
198198
generated_ids_merged, skip_special_tokens=True)
@@ -215,7 +215,7 @@ def run(self, overwrite: bool = False,
215215
indexer_b_s = 0
216216

217217
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
218-
generated_ids_merged = self.model_ocr.generate(
218+
generated_ids_merged = self.model_zoo.get('ocr').generate(
219219
pixel_values_merged.to(self.device))
220220
generated_text_merged = self.model_zoo.get('processor').batch_decode(
221221
generated_ids_merged, skip_special_tokens=True)
@@ -235,7 +235,7 @@ def run(self, overwrite: bool = False,
235235
indexer_b_s = 0
236236

237237
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
238-
generated_ids_merged = self.model_ocr.generate(
238+
generated_ids_merged = self.model_zoo.get('ocr').generate(
239239
pixel_values_merged.to(self.device))
240240
generated_text_merged = self.model_zoo.get('processor').batch_decode(
241241
generated_ids_merged, skip_special_tokens=True)
@@ -253,7 +253,7 @@ def run(self, overwrite: bool = False,
253253
indexer_b_s = 0
254254

255255
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
256-
generated_ids_merged = self.model_ocr.generate(
256+
generated_ids_merged = self.model_zoo.get('ocr').generate(
257257
pixel_values_merged.to(self.device))
258258
generated_text_merged = self.model_zoo.get('processor').batch_decode(
259259
generated_ids_merged, skip_special_tokens=True)
@@ -270,7 +270,7 @@ def run(self, overwrite: bool = False,
270270
indexer_b_s = 0
271271

272272
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
273-
generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device))
273+
generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device))
274274
generated_text_merged = self.model_zoo.get('processor').batch_decode(generated_ids_merged, skip_special_tokens=True)
275275

276276
extracted_texts = extracted_texts + generated_text_merged
@@ -746,10 +746,10 @@ def run(self, overwrite: bool = False,
746746

747747

748748
self.logger.debug("processing next %d lines", len(imgs))
749-
preds = self.prediction_model.predict(imgs, verbose=0)
749+
preds = self.model_zoo.get('ocr').predict(imgs, verbose=0)
750750

751751
if len(indices_ver)>0:
752-
preds_flipped = self.prediction_model.predict(imgs_ver_flipped, verbose=0)
752+
preds_flipped = self.model_zoo.get('ocr').predict(imgs_ver_flipped, verbose=0)
753753
preds_max_fliped = np.max(preds_flipped, axis=2 )
754754
preds_max_args_flipped = np.argmax(preds_flipped, axis=2 )
755755
pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character
@@ -779,10 +779,10 @@ def run(self, overwrite: bool = False,
779779
preds[indices_to_be_replaced,:,:] = \
780780
preds_flipped[indices_where_flipped_conf_value_is_higher, :, :]
781781
if dir_in_bin is not None:
782-
preds_bin = self.prediction_model.predict(imgs_bin, verbose=0)
782+
preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0)
783783

784784
if len(indices_ver)>0:
785-
preds_flipped = self.prediction_model.predict(imgs_bin_ver_flipped, verbose=0)
785+
preds_flipped = self.model_zoo.get('ocr').predict(imgs_bin_ver_flipped, verbose=0)
786786
preds_max_fliped = np.max(preds_flipped, axis=2 )
787787
preds_max_args_flipped = np.argmax(preds_flipped, axis=2 )
788788
pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character
@@ -814,7 +814,7 @@ def run(self, overwrite: bool = False,
814814

815815
preds = (preds + preds_bin) / 2.
816816

817-
pred_texts = decode_batch_predictions(preds, self.num_to_char)
817+
pred_texts = decode_batch_predictions(preds, self.model_zoo.get('num_to_char'))
818818

819819
preds_max = np.max(preds, axis=2 )
820820
preds_max_args = np.argmax(preds, axis=2 )

0 commit comments

Comments
 (0)