@@ -17,10 +17,12 @@ def run_classification(
17
17
):
18
18
from util .active_transfer_learning import ATLClassifier
19
19
20
+ print ("progress: 0.05" , flush = True )
20
21
classifier = ATLClassifier ()
21
22
prediction_probabilities = classifier .fit_predict (
22
23
corpus_embeddings , corpus_labels , corpus_ids , training_ids
23
24
)
25
+ print ("progress: 0.8" , flush = True )
24
26
if os .path .exists ("/inference" ):
25
27
pickle_path = os .path .join (
26
28
"/inference" , f"active-learner-{ information_source_id } .pkl"
@@ -36,6 +38,7 @@ def run_classification(
36
38
prediction = classifier .model .classes_ [probas .argmax ()]
37
39
predictions_with_probabilities .append ([proba , prediction ])
38
40
41
+ print ("progress: 0.9" , flush = True )
39
42
ml_results_by_record_id = {}
40
43
for record_id , (probability , prediction ) in zip (
41
44
corpus_ids , predictions_with_probabilities
@@ -48,8 +51,12 @@ def run_classification(
48
51
probability ,
49
52
prediction ,
50
53
)
54
+ print ("progress: 0.95" , flush = True )
51
55
if len (ml_results_by_record_id ) == 0 :
52
- print ("No records were predicted. Try lowering the confidence threshold." )
56
+ print (
57
+ "No records were predicted. Try lowering the confidence threshold." ,
58
+ flush = True ,
59
+ )
53
60
return ml_results_by_record_id
54
61
55
62
@@ -62,10 +69,12 @@ def run_extraction(
62
69
):
63
70
from util .active_transfer_learning import ATLExtractor
64
71
72
+ print ("progress: 0.05" , flush = True )
65
73
extractor = ATLExtractor ()
66
74
predictions , probabilities = extractor .fit_predict (
67
75
corpus_embeddings , corpus_labels , corpus_ids , training_ids
68
76
)
77
+ print ("progress: 0.5" , flush = True )
69
78
if os .path .exists ("/inference" ):
70
79
pickle_path = os .path .join (
71
80
"/inference" , f"active-learner-{ information_source_id } .pkl"
@@ -75,8 +84,9 @@ def run_extraction(
75
84
print ("Saved model to disk" , flush = True )
76
85
77
86
ml_results_by_record_id = {}
78
- for record_id , prediction , probability in zip (
79
- corpus_ids , predictions , probabilities
87
+ amount = len (corpus_ids )
88
+ for idx , (record_id , prediction , probability ) in enumerate (
89
+ zip (corpus_ids , predictions , probabilities )
80
90
):
81
91
df = pd .DataFrame (
82
92
list (zip (prediction , probability )),
@@ -101,14 +111,22 @@ def run_extraction(
101
111
)
102
112
new_start_idx = True
103
113
ml_results_by_record_id [record_id ] = predictions_with_probabilities
114
+ if idx % 100 == 0 :
115
+ progress = round ((idx + 1 ) / amount , 4 ) * 0.5 + 0.5
116
+ print ("progress: " , progress , flush = True )
117
+
118
+ print ("progress: 0.9" , flush = True )
104
119
if len (ml_results_by_record_id ) == 0 :
105
- print ("No records were predicted. Try lowering the confidence threshold." )
120
+ print (
121
+ "No records were predicted. Try lowering the confidence threshold." ,
122
+ flush = True ,
123
+ )
106
124
return ml_results_by_record_id
107
125
108
126
109
127
if __name__ == "__main__" :
110
128
_ , payload_url = sys .argv
111
- print ("Preparing data for machine learning." )
129
+ print ("Preparing data for machine learning." , flush = True )
112
130
113
131
(
114
132
information_source_id ,
@@ -120,7 +138,7 @@ def run_extraction(
120
138
is_extractor = any ([isinstance (val , list ) for val in corpus_labels ["manual" ]])
121
139
122
140
if is_extractor :
123
- print ("Running extractor." )
141
+ print ("Running extractor." , flush = True )
124
142
ml_results_by_record_id = run_extraction (
125
143
information_source_id ,
126
144
corpus_embeddings ,
@@ -129,7 +147,7 @@ def run_extraction(
129
147
training_ids ,
130
148
)
131
149
else :
132
- print ("Running classifier." )
150
+ print ("Running classifier." , flush = True )
133
151
ml_results_by_record_id = run_classification (
134
152
information_source_id ,
135
153
corpus_embeddings ,
@@ -138,5 +156,6 @@ def run_extraction(
138
156
training_ids ,
139
157
)
140
158
141
- print ("Finished execution." )
159
+ print ("progress: 1" , flush = True )
160
+ print ("Finished execution." , flush = True )
142
161
requests .put (payload_url , json = ml_results_by_record_id )
0 commit comments