Skip to content

Commit

Permalink
Implement live results updates
Browse files Browse the repository at this point in the history
  • Loading branch information
WeetHet committed Jan 17, 2025
1 parent 27c3195 commit 2c9c252
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 115 deletions.
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
];
shellHook = ''
poetry sync
export PATH=$(pwd)/.venv/bin:$PATH
source .venv/bin/activate
'';
};

Expand Down
275 changes: 161 additions & 114 deletions verified_cogen/several_modes/several_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import logging
import multiprocessing as mp
import pathlib
from typing import Dict, Optional
from dataclasses import dataclass
from multiprocessing.managers import DictProxy, SyncManager
from threading import Lock
from typing import Optional

from verified_cogen.llm.llm import LLM
from verified_cogen.main import construct_rewriter, make_runner_cls
from verified_cogen.runners import Runner, RunnerConfig
from verified_cogen.runners.languages import register_basic_languages
from verified_cogen.runners.rewriters import Rewriter
from verified_cogen.several_modes.args import ProgramArgsMultiple, get_args
from verified_cogen.several_modes.constants import (
MODE_MAPPING,
Expand All @@ -27,161 +29,206 @@
logger = logging.getLogger(__name__)


@dataclass
class ProcessFileConfig:
args: ProgramArgsMultiple
history_dir: pathlib.Path
json_results: pathlib.Path


@dataclass
class SharedState:
lock: Lock
results: "DictProxy[str, int]"
results_avg: "DictProxy[int, float]"


def process_file(
file: pathlib.Path,
args: ProgramArgsMultiple,
history_dir: pathlib.Path,
file_with_name: tuple[pathlib.Path, str],
llm: LLM,
rewriter: Rewriter,
runner: Runner,
config: ProcessFileConfig,
state: SharedState,
) -> Optional[int]:
file, marker_name = file_with_name
try:
mode = Mode(args.insert_conditions_mode)
tries = runner.run_on_file(mode, args.tries, str(file))
mode = Mode(config.args.insert_conditions_mode)
tries = runner.run_on_file(mode, config.args.tries, str(file))
except Exception as e:
print(e)
tries = None

llm.dump_history(history_dir / f"{file.stem}.txt")
display_name = rename_file(file)
with state.lock:
if tries is not None:
state.results[marker_name] = tries
state.results_avg[tries] += 1
logger.info(f"Verified {display_name} in {tries} tries")
else:
state.results[marker_name] = -1
logger.info(f"Failed to verify {display_name}")
with open(config.json_results, "w") as f:
json.dump(dict(state.results), f, indent=2)

llm.dump_history(config.history_dir / f"{file.stem}.txt")

return tries


def main():
args = get_args()
print(args.manual_rewriters)
def run_mode(
manager: SyncManager,
idx: int,
mode: str,
directory: pathlib.Path,
files: list[pathlib.Path],
args: ProgramArgsMultiple,
results_directory: pathlib.Path,
verifier: Verifier,
log_tries_dir: Optional[pathlib.Path],
):
all_removed = MODE_MAPPING[mode]
register_basic_languages(with_removed=all_removed)

assert args.insert_conditions_mode != Mode.REGEX
assert args.dir is not None
logger.info(mode)
log_tries_mode = (
log_tries_dir / f"{idx}_{mode}" if log_tries_dir is not None else None
)

if args.output_logging:
register_output_handler(logger)
if log_tries_mode is not None:
log_tries_mode.mkdir(exist_ok=True)

log_tries_dir = pathlib.Path(args.log_tries) if args.log_tries is not None else None
if log_tries_dir is not None:
log_tries_dir.mkdir(exist_ok=True)
json_avg_results = (
results_directory / f"tries_{directory.name}_{idx}_{mode}_avg.json"
)

directory = pathlib.Path(args.dir)
files = list(directory.glob(ext_glob(args.filter_by_ext)))
assert len(files) > 0, "No files found in the directory"
files.sort()
with open(json_avg_results, "w") as f:
json.dump({}, f)

results_directory = pathlib.Path("results")
results_directory.mkdir(exist_ok=True)
results_avg: "DictProxy[int, float]" = manager.dict([
(i, 0) for i in range(args.tries + 1)
])

verifier = Verifier(args.verifier_command)
lock = manager.Lock()

for idx, mode in enumerate(args.modes):
all_removed = MODE_MAPPING[mode]
register_basic_languages(with_removed=all_removed)
for run in range(args.runs):
logger.info(f"Run {run}")

logger.info(mode)
log_tries_mode = (
log_tries_dir / f"{idx}_{mode}" if log_tries_dir is not None else None
history_dir = results_directory / f"history_{directory.name}_{idx}_{mode}_{run}"
history_dir.mkdir(exist_ok=True)
json_results = (
results_directory / f"tries_{directory.name}_{idx}_{mode}_{run}.json"
)

if log_tries_mode is not None:
log_tries_mode.mkdir(exist_ok=True)
if not json_results.exists():
with open(json_results, "w") as f:
json.dump({}, f)

json_avg_results = (
results_directory / f"tries_{directory.name}_{idx}_{mode}_avg.json"
)
with open(json_results, "r") as f:
results = manager.dict(json.load(f))

with open(json_avg_results, "w") as f:
json.dump({}, f)
results_avg: Dict[int, float] = dict([(i, 0) for i in range(args.tries + 1)])
log_tries = log_tries_mode and (log_tries_mode / f"run_{run}")
if log_tries is not None:
log_tries.mkdir(exist_ok=True)

for run in range(args.runs):
logger.info(f"Run {run}")
config = RunnerConfig(
log_tries=log_tries,
include_text_descriptions=TEXT_DESCRIPTIONS[mode],
remove_implementations=REMOVE_IMPLS_MAPPING[mode],
remove_helpers=(mode == "mode6"),
)

history_dir = (
results_directory / f"history_{directory.name}_{idx}_{mode}_{run}"
)
history_dir.mkdir(exist_ok=True)
json_results = (
results_directory / f"tries_{directory.name}_{idx}_{mode}_{run}.json"
)
files_to_process: list[tuple[pathlib.Path, str, str]] = []
for file in files:
display_name = rename_file(file)
marker_name = str(file.relative_to(directory))
if (
marker_name in results
and isinstance(results[marker_name], int)
and results[marker_name] != -1
):
logger.info(f"Skipping: {display_name} as it has already been verified")
continue
files_to_process.append((file, display_name, marker_name))

llm = LLM(
args.grazie_token,
args.llm_profile,
args.prompts_directory[idx],
args.temperature,
)

if not json_results.exists():
with open(json_results, "w") as f:
json.dump({}, f)
rewriter = construct_rewriter(
extension_from_file_list(files), args.manual_rewriters
)

with open(json_results, "r") as f:
results = json.load(f)
runner = make_runner_cls(
args.bench_type, extension_from_file_list(files), config
)(llm, logger, verifier, rewriter)

log_tries = log_tries_mode and (log_tries_mode / f"run_{run}")
if log_tries is not None:
log_tries.mkdir(exist_ok=True)
state = SharedState(lock, results, results_avg)
with mp.Pool(processes=mp.cpu_count()) as pool:

config = RunnerConfig(
log_tries=log_tries,
include_text_descriptions=TEXT_DESCRIPTIONS[mode],
remove_implementations=REMOVE_IMPLS_MAPPING[mode],
remove_helpers=(mode == "mode6"),
)
def make_arguments(file: pathlib.Path, marker_name: str):
return (
(file, marker_name),
llm,
runner,
ProcessFileConfig(args, history_dir, json_results),
state,
)

files_to_process: list[tuple[pathlib.Path, str, str]] = []
for file in files:
display_name = rename_file(file)
marker_name = str(file.relative_to(directory))
if (
marker_name in results
and isinstance(results[marker_name], int)
and results[marker_name] != -1
):
logger.info(
f"Skipping: {display_name} as it has already been verified"
)
continue
files_to_process.append((file, display_name, marker_name))

llm = LLM(
args.grazie_token,
args.llm_profile,
args.prompts_directory[idx],
args.temperature,
arguments = (
make_arguments(file, marker_name)
for file, _, marker_name in files_to_process
)
pool.starmap(process_file, arguments)

rewriter = construct_rewriter(
extension_from_file_list(files), args.manual_rewriters
)
for key in results_avg.keys():
results_avg[key] = results_avg[key] / args.runs

runner = make_runner_cls(
args.bench_type, extension_from_file_list(files), config
)(llm, logger, verifier, rewriter)
with open(json_avg_results, "w") as f:
json.dump(results_avg, f)

with mp.Pool(processes=mp.cpu_count()) as pool:
logger.info(f"Averaged results for {mode}: {results_avg}")

def make_arguments(file: pathlib.Path):
return file, args, history_dir, llm, rewriter, runner

mp_results = pool.starmap(
process_file,
(make_arguments(file) for file, _, _ in files_to_process),
)
def main():
args = get_args()
print(args.manual_rewriters)

for (file, display_name, marker_name), tries in zip(
files_to_process, mp_results
):
logger.info(f"Processing results for: {display_name}")
assert args.insert_conditions_mode != Mode.REGEX
assert args.dir is not None

if args.output_logging:
register_output_handler(logger)

if tries is not None:
results[marker_name] = tries
results_avg[tries] += 1
logger.info(f"Verified {display_name} in {tries} tries")
else:
results[marker_name] = -1
logger.info(f"Failed to verify {display_name}")
with open(json_results, "w") as f:
json.dump(results, f, indent=2)
log_tries_dir = pathlib.Path(args.log_tries) if args.log_tries is not None else None
if log_tries_dir is not None:
log_tries_dir.mkdir(exist_ok=True)

directory = pathlib.Path(args.dir)
files = list(directory.glob(ext_glob(args.filter_by_ext)))
assert len(files) > 0, "No files found in the directory"
files.sort()

for key in results_avg.keys():
results_avg[key] = results_avg[key] / args.runs
results_directory = pathlib.Path("results")
results_directory.mkdir(exist_ok=True)

with open(json_avg_results, "w") as f:
json.dump(results_avg, f)
verifier = Verifier(args.verifier_command)

logger.info(f"Averaged results for {mode}: {results_avg}")
with mp.Manager() as manager:
for idx, mode in enumerate(args.modes):
run_mode(
manager,
idx,
mode,
directory,
files,
args,
results_directory,
verifier,
log_tries_dir,
)


if __name__ == "__main__":
Expand Down

0 comments on commit 2c9c252

Please sign in to comment.