1515 ProgramEnumerator ,
1616 auto_type ,
1717)
18- from synth .syntax .grammars .det_grammar import DetGrammar
1918from synth .syntax .grammars .enumeration .constant_delay import (
2019 enumerate_prob_grammar as cd ,
2120)
2221import tqdm
2322import timeout_decorator
2423
2524SEARCH_ALGOS = {
26- # "a_star": as_enumerate_prob_grammar,
27- # "bee_search": bs_enumerate_prob_grammar,
25+ "a_star" : as_enumerate_prob_grammar ,
26+ "bee_search" : bs_enumerate_prob_grammar ,
2827 "beap_search" : bps_enumerate_prob_grammar ,
29- # "heap_search": hs_enumerate_prob_grammar,
30- # "cd4": lambda x: cd(x, k=4),
31- # "cd16": lambda x: cd(x, k=16),
32- # "cd64": lambda x: cd(x, k=64),
28+ "heap_search" : hs_enumerate_prob_grammar ,
29+ "cd4" : lambda x : cd (x , k = 4 ),
30+ "cd16" : lambda x : cd (x , k = 16 ),
31+ "cd64" : lambda x : cd (x , k = 64 ),
3332}
3433
3534parser = argparse .ArgumentParser (
@@ -164,8 +163,6 @@ def enumerative_search(
164163 Tuple [str , int , int , float , int , int , int , int ],
165164 List [Tuple [str , int , int , float , int , int , int , int ]],
166165]:
167- import numpy as np
168-
169166 n = 0
170167 non_terminals = len (pcfg .rules )
171168 derivation_rules = sum (len (pcfg .rules [S ]) for S in pcfg .rules )
@@ -174,11 +171,9 @@ def enumerative_search(
174171 pbar = tqdm .tqdm (total = programs , desc = title or name , smoothing = 0 )
175172 enumerator = custom_enumerate (pcfg )
176173 gen = enumerator .generator ()
177- det_g = pcfg .grammar
178- assert isinstance (det_g , DetGrammar )
179174 program = 1
180- datum_each = 10000
181- target_generation_speed = 10000
175+ datum_each = 100000
176+ target_generation_speed = 1000000
182177 start = 0
183178 detailed = []
184179 try :
@@ -189,15 +184,9 @@ def fun():
189184 get_next = timeout_decorator .timeout (timeout , timeout_exception = StopIteration )(
190185 fun
191186 )
192- last_multiset = None
193- max_dist = - 1
194187 start = time .perf_counter_ns ()
195188 while program is not None :
196189 program = get_next ()
197- new_m = det_g .to_multiset (program )
198- if last_multiset is not None :
199- max_dist = max (np .sum (np .abs (new_m - last_multiset )), max_dist )
200- last_multiset = new_m
201190 n += 1
202191 if n % datum_each == 0 or n >= programs :
203192 used_time = time .perf_counter_ns () - start
@@ -211,7 +200,7 @@ def fun():
211200 (
212201 name ,
213202 non_terminals ,
214- max_dist ,
203+ derivation_rules ,
215204 used_time / 1e9 ,
216205 n ,
217206 enumerator .programs_in_queues (),
0 commit comments