Skip to content

Commit 767d667

Browse files
committed
Revert "fixed mypy issues" which commited many other irrelevant files
This reverts commit b4538d4.
1 parent b4538d4 commit 767d667

File tree

5 files changed

+16
-54
lines changed

5 files changed

+16
-54
lines changed

examples/compare_enumeration.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,20 @@
1515
ProgramEnumerator,
1616
auto_type,
1717
)
18-
from synth.syntax.grammars.det_grammar import DetGrammar
1918
from synth.syntax.grammars.enumeration.constant_delay import (
2019
enumerate_prob_grammar as cd,
2120
)
2221
import tqdm
2322
import timeout_decorator
2423

2524
SEARCH_ALGOS = {
26-
# "a_star": as_enumerate_prob_grammar,
27-
# "bee_search": bs_enumerate_prob_grammar,
25+
"a_star": as_enumerate_prob_grammar,
26+
"bee_search": bs_enumerate_prob_grammar,
2827
"beap_search": bps_enumerate_prob_grammar,
29-
# "heap_search": hs_enumerate_prob_grammar,
30-
# "cd4": lambda x: cd(x, k=4),
31-
# "cd16": lambda x: cd(x, k=16),
32-
# "cd64": lambda x: cd(x, k=64),
28+
"heap_search": hs_enumerate_prob_grammar,
29+
"cd4": lambda x: cd(x, k=4),
30+
"cd16": lambda x: cd(x, k=16),
31+
"cd64": lambda x: cd(x, k=64),
3332
}
3433

3534
parser = argparse.ArgumentParser(
@@ -164,8 +163,6 @@ def enumerative_search(
164163
Tuple[str, int, int, float, int, int, int, int],
165164
List[Tuple[str, int, int, float, int, int, int, int]],
166165
]:
167-
import numpy as np
168-
169166
n = 0
170167
non_terminals = len(pcfg.rules)
171168
derivation_rules = sum(len(pcfg.rules[S]) for S in pcfg.rules)
@@ -174,11 +171,9 @@ def enumerative_search(
174171
pbar = tqdm.tqdm(total=programs, desc=title or name, smoothing=0)
175172
enumerator = custom_enumerate(pcfg)
176173
gen = enumerator.generator()
177-
det_g = pcfg.grammar
178-
assert isinstance(det_g, DetGrammar)
179174
program = 1
180-
datum_each = 10000
181-
target_generation_speed = 10000
175+
datum_each = 100000
176+
target_generation_speed = 1000000
182177
start = 0
183178
detailed = []
184179
try:
@@ -189,15 +184,9 @@ def fun():
189184
get_next = timeout_decorator.timeout(timeout, timeout_exception=StopIteration)(
190185
fun
191186
)
192-
last_multiset = None
193-
max_dist = -1
194187
start = time.perf_counter_ns()
195188
while program is not None:
196189
program = get_next()
197-
new_m = det_g.to_multiset(program)
198-
if last_multiset is not None:
199-
max_dist = max(np.sum(np.abs(new_m - last_multiset)), max_dist)
200-
last_multiset = new_m
201190
n += 1
202191
if n % datum_each == 0 or n >= programs:
203192
used_time = time.perf_counter_ns() - start
@@ -211,7 +200,7 @@ def fun():
211200
(
212201
name,
213202
non_terminals,
214-
max_dist,
203+
derivation_rules,
215204
used_time / 1e9,
216205
n,
217206
enumerator.programs_in_queues(),

examples/plot_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,16 +246,16 @@ def plot_dist(
246246
methods: Dict[str, Dict[int, List]],
247247
y_data: Tuple[int, str],
248248
x_axis_name: str,
249-
nbins: int = 5,
250249
) -> None:
251250
width = 1.0
252251
data_length = 0
253252
a_index, a_name = y_data
254253
max_a = max(
255-
max(max(y[a_index] for y in x) for x in seed_dico.values())
254+
max(max([y[a_index] for y in x]) for x in seed_dico.values())
256255
for seed_dico in methods.values()
257256
)
258257
bottom = None
258+
nbins = 5
259259
bins = [max_a]
260260
while len(bins) <= nbins:
261261
bins.insert(0, np.sqrt(bins[0] + 1))

synth/pbe/solvers/pbe_solver.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _close_task_solving_(
7272

7373
def solve(
7474
self, task: Task[PBE], enumerator: ProgramEnumerator[None], timeout: float = 60
75-
) -> Generator[Program, None, bool]:
75+
) -> Generator[Program, bool, None]:
7676
"""
7777
Solve the given task by enumerating programs with the given enumerator.
7878
When the timeout is reached, this function returns.
@@ -101,7 +101,6 @@ def solve(
101101
except StopIteration as e:
102102
self._close_task_solving_(task, enumerator, time, False, program)
103103
raise e
104-
return False
105104

106105
def _test_(self, task: Task[PBE], program: Program) -> bool:
107106
"""

synth/pbe/solvers/restart_pbe_solver.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _close_task_solving_(
6161

6262
def solve(
6363
self, task: Task[PBE], enumerator: ProgramEnumerator[None], timeout: float = 60
64-
) -> Generator[Program, None, bool]:
64+
) -> Generator[Program, bool, None]:
6565
with chrono.clock(f"solve.{self.name()}.{self.subsolver.name()}") as c: # type: ignore
6666
self._enumerator = enumerator
6767
self._init_task_solving_(task, self._enumerator, timeout)
@@ -73,15 +73,15 @@ def solve(
7373
self._close_task_solving_(
7474
task, self._enumerator, time, False, program
7575
)
76-
return False
76+
return
7777
self._programs += 1
7878
if self._test_(task, program):
7979
should_stop = yield program
8080
if should_stop:
8181
self._close_task_solving_(
8282
task, self._enumerator, time, True, program
8383
)
84-
return True
84+
return
8585
self._score = self.subsolver._score
8686
# Saves data
8787
if self._score > 0:
@@ -92,7 +92,6 @@ def solve(
9292
self._enumerator = self._restart_(self._enumerator)
9393
gen = self._enumerator.generator()
9494
program = next(gen)
95-
return False
9695

9796
def _should_restart_(self) -> bool:
9897
return self.restart_criterion(self)

synth/syntax/grammars/det_grammar.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Generic,
1212
)
1313
from functools import lru_cache
14-
import numpy as np
14+
import copy
1515

1616
from synth.syntax.grammars.grammar import DerivableProgram, Grammar
1717
from synth.syntax.program import Constant, Function, Primitive, Program, Variable
@@ -23,13 +23,6 @@
2323
T = TypeVar("T")
2424

2525

26-
def __tuplify__(element: Any) -> Any:
27-
if isinstance(element, (List, Tuple)):
28-
return tuple(__tuplify__(x) for x in element)
29-
else:
30-
return element
31-
32-
3326
class DetGrammar(Grammar, ABC, Generic[U, V, W]):
3427
"""
3528
Represents a deterministic grammar.
@@ -61,14 +54,6 @@ def __init__(
6154
self.type_request = self._guess_type_request_()
6255
if clean:
6356
self.clean()
64-
self._derivation2index = {}
65-
self._index2derivation = []
66-
67-
for S in self.rules:
68-
for P, args in self.rules[S].items():
69-
elem = __tuplify__((S, P, args))
70-
self._derivation2index[elem] = len(self._index2derivation)
71-
self._index2derivation.append(elem)
7257

7358
@lru_cache()
7459
def primitives_used(self) -> Set[Primitive]:
@@ -98,16 +83,6 @@ def __str__(self) -> str:
9883
s += " {}\n".format(self.__rule_to_str__(P, out))
9984
return s
10085

101-
def to_multiset(self, program: Program) -> np.ndarray:
102-
out = np.zeros((len(self._derivation2index)))
103-
104-
def reduce(acc, S, P, args):
105-
elem = __tuplify__((S, P, args))
106-
out[self._derivation2index[elem]] += 1
107-
108-
self.reduce_derivations(reduce, None, program, None)
109-
return out
110-
11186
def __repr__(self) -> str:
11287
return self.__str__()
11388

0 commit comments

Comments
 (0)