Skip to content

Commit 7186173

Browse files
JWittmeyerJWittmeyer
andauthored
Adds exmpale implementaiton for progress (#22)
* Adds exmpale implementaiton for progress * Flush prints * Change percentage jump --------- Co-authored-by: JWittmeyer <[email protected]>
1 parent 5eaf369 commit 7186173

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

run_ml.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ def run_classification(
1717
):
1818
from util.active_transfer_learning import ATLClassifier
1919

20+
print("progress: 0.05", flush=True)
2021
classifier = ATLClassifier()
2122
prediction_probabilities = classifier.fit_predict(
2223
corpus_embeddings, corpus_labels, corpus_ids, training_ids
2324
)
25+
print("progress: 0.8", flush=True)
2426
if os.path.exists("/inference"):
2527
pickle_path = os.path.join(
2628
"/inference", f"active-learner-{information_source_id}.pkl"
@@ -36,6 +38,7 @@ def run_classification(
3638
prediction = classifier.model.classes_[probas.argmax()]
3739
predictions_with_probabilities.append([proba, prediction])
3840

41+
print("progress: 0.9", flush=True)
3942
ml_results_by_record_id = {}
4043
for record_id, (probability, prediction) in zip(
4144
corpus_ids, predictions_with_probabilities
@@ -48,8 +51,12 @@ def run_classification(
4851
probability,
4952
prediction,
5053
)
54+
print("progress: 0.95", flush=True)
5155
if len(ml_results_by_record_id) == 0:
52-
print("No records were predicted. Try lowering the confidence threshold.")
56+
print(
57+
"No records were predicted. Try lowering the confidence threshold.",
58+
flush=True,
59+
)
5360
return ml_results_by_record_id
5461

5562

@@ -62,10 +69,12 @@ def run_extraction(
6269
):
6370
from util.active_transfer_learning import ATLExtractor
6471

72+
print("progress: 0.05", flush=True)
6573
extractor = ATLExtractor()
6674
predictions, probabilities = extractor.fit_predict(
6775
corpus_embeddings, corpus_labels, corpus_ids, training_ids
6876
)
77+
print("progress: 0.5", flush=True)
6978
if os.path.exists("/inference"):
7079
pickle_path = os.path.join(
7180
"/inference", f"active-learner-{information_source_id}.pkl"
@@ -75,8 +84,9 @@ def run_extraction(
7584
print("Saved model to disk", flush=True)
7685

7786
ml_results_by_record_id = {}
78-
for record_id, prediction, probability in zip(
79-
corpus_ids, predictions, probabilities
87+
amount = len(corpus_ids)
88+
for idx, (record_id, prediction, probability) in enumerate(
89+
zip(corpus_ids, predictions, probabilities)
8090
):
8191
df = pd.DataFrame(
8292
list(zip(prediction, probability)),
@@ -101,14 +111,22 @@ def run_extraction(
101111
)
102112
new_start_idx = True
103113
ml_results_by_record_id[record_id] = predictions_with_probabilities
114+
if idx % 100 == 0:
115+
progress = round((idx + 1) / amount, 4) * 0.5 + 0.5
116+
print("progress: ", progress, flush=True)
117+
118+
print("progress: 0.9", flush=True)
104119
if len(ml_results_by_record_id) == 0:
105-
print("No records were predicted. Try lowering the confidence threshold.")
120+
print(
121+
"No records were predicted. Try lowering the confidence threshold.",
122+
flush=True,
123+
)
106124
return ml_results_by_record_id
107125

108126

109127
if __name__ == "__main__":
110128
_, payload_url = sys.argv
111-
print("Preparing data for machine learning.")
129+
print("Preparing data for machine learning.", flush=True)
112130

113131
(
114132
information_source_id,
@@ -120,7 +138,7 @@ def run_extraction(
120138
is_extractor = any([isinstance(val, list) for val in corpus_labels["manual"]])
121139

122140
if is_extractor:
123-
print("Running extractor.")
141+
print("Running extractor.", flush=True)
124142
ml_results_by_record_id = run_extraction(
125143
information_source_id,
126144
corpus_embeddings,
@@ -129,7 +147,7 @@ def run_extraction(
129147
training_ids,
130148
)
131149
else:
132-
print("Running classifier.")
150+
print("Running classifier.", flush=True)
133151
ml_results_by_record_id = run_classification(
134152
information_source_id,
135153
corpus_embeddings,
@@ -138,5 +156,6 @@ def run_extraction(
138156
training_ids,
139157
)
140158

141-
print("Finished execution.")
159+
print("progress: 1", flush=True)
160+
print("Finished execution.", flush=True)
142161
requests.put(payload_url, json=ml_results_by_record_id)

0 commit comments

Comments
 (0)