Skip to content

Commit

Permalink
Add option to silence manage_jobs stats, only view logger output
Browse files Browse the repository at this point in the history
  • Loading branch information
araistrick authored and pvl-bot committed Oct 28, 2024
1 parent 4e7b9cb commit 9e7bb70
Showing 1 changed file with 43 additions and 24 deletions.
67 changes: 43 additions & 24 deletions infinigen/datagen/manage_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def inflight(s):
)

if max_stuck_at_task is not None and stuck_at_next >= max_stuck_at_task:
logging.info(
logging.debug(
f"{seed} - Not launching due to {stuck_at_next=} >"
f" {max_stuck_at_task} for {started_if_launch=}"
)
Expand All @@ -602,12 +602,12 @@ def inflight(s):
queued_key = (JobState.Queued, taskname.split("_")[0])
queued = state_counts.get(queued_key, 0)
if max_queued_task is not None and queued >= max_queued_task:
logging.info(
logging.debug(
f"{seed} - Not launching due to {queued=} > {max_queued_task} for {taskname}"
)
continue
if max_queued_total is not None and total_queued >= max_queued_total:
logging.info(
logging.debug(
f"{seed} - Not launching due to {total_queued=} > {max_queued_total} for {taskname}"
)
continue
Expand Down Expand Up @@ -643,21 +643,13 @@ def compute_control_state(args, totals, elapsed, num_concurrent):
return control_state


def record_states(stats, totals, control_state):
pretty_stats = copy(stats)
pretty_stats.update({f"control_state/{k}": v for k, v in control_state.items()})
pretty_stats.update({f"{k}/total": v for k, v in totals.items()})

if wandb is not None:
wandb.log(pretty_stats)
print("=" * 60)
for k, v in sorted(pretty_stats.items()):
print(f"{k.ljust(30)} : {v}")
print("-" * 60)


@gin.configurable
def manage_datagen_jobs(all_scenes, elapsed, num_concurrent, disk_sleep_threshold=0.95):
def manage_datagen_jobs(
all_scenes: list[dict],
elapsed: float,
num_concurrent: int,
disk_sleep_threshold=0.95,
):
if LocalScheduleHandler._inst is not None:
sys.path = ORIG_SYS_PATH # hacky workaround because bpy module breaks with multiprocessing
LocalScheduleHandler.instance().poll()
Expand All @@ -674,7 +666,6 @@ def manage_datagen_jobs(all_scenes, elapsed, num_concurrent, disk_sleep_threshol
) # may be less due to jobs_to_launch optional kwargs, or running out of num_jobs

pd.DataFrame.from_records(all_scenes).to_csv(args.output_folder / "scenes_db.csv")
record_states(stats, totals, control_state)

# Dont launch new scenes if disk is getting full
if control_state["disk_usage"] > disk_sleep_threshold:
Expand All @@ -687,12 +678,18 @@ def manage_datagen_jobs(all_scenes, elapsed, num_concurrent, disk_sleep_threshol
wait_duration=3 * 60 * 60,
)
time.sleep(60)
return
return {}

for scene, taskname, queue_func in new_jobs:
logger.info(f"{scene['seed']} - running {taskname}")
run_task(queue_func, args.output_folder / str(scene["seed"]), scene, taskname)

log_stats = copy(stats)
log_stats.update({f"control_state/{k}": v for k, v in control_state.items()})
log_stats.update({f"{k}/total": v for k, v in totals.items()})

return log_stats


@gin.configurable
def main(args, shuffle=True, wandb_project="render", upload_commandfile_method=None):
Expand Down Expand Up @@ -721,10 +718,16 @@ def main(args, shuffle=True, wandb_project="render", upload_commandfile_method=N
mode=args.wandb_mode,
)

filehandler = logging.FileHandler(str(args.output_folder / "jobs.log"))
filehandler.setLevel(logging.INFO)

streamhandler = logging.StreamHandler()
streamhandler.setLevel(args.loglevel)

logging.basicConfig(
filename=str(args.output_folder / "jobs.log"),
level=args.loglevel,
format="[%(asctime)s]: %(message)s",
handlers=[filehandler, streamhandler],
)

print(f"Using {get_slurm_banned_nodes()=}")
Expand All @@ -737,10 +740,25 @@ def main(args, shuffle=True, wandb_project="render", upload_commandfile_method=N
start_time = datetime.now()
while any(j["all_done"] == SceneState.NotDone for j in all_scenes):
now = datetime.now()
print(
f'{args.output_folder} {start_time.strftime("%m/%d %I:%M%p")} -> {now.strftime("%m/%d %I:%M%p")}'

if args.print_stats:
print(
f'{args.output_folder} {start_time.strftime("%m/%d %I:%M%p")} -> {now.strftime("%m/%d %I:%M%p")}'
)

log_stats = manage_datagen_jobs(
all_scenes, elapsed=(now - start_time).total_seconds()
)
manage_datagen_jobs(all_scenes, elapsed=(now - start_time).total_seconds())

if wandb is not None:
wandb.log(log_stats)

if args.print_stats:
print("=" * 60)
for k, v in sorted(log_stats.items()):
print(f"{k.ljust(30)} : {v}")
print("-" * 60)

time.sleep(2)

any_crashed = any(j.get("any_fatal_crash", False) for j in all_scenes)
Expand Down Expand Up @@ -842,11 +860,12 @@ def main(args, shuffle=True, wandb_project="render", upload_commandfile_method=N
action="store_const",
dest="loglevel",
const=logging.DEBUG,
default=logging.INFO,
default=logging.WARNING,
)
parser.add_argument(
"-v", "--verbose", action="store_const", dest="loglevel", const=logging.INFO
)
parser.add_argument("--print_stats", type=int, default=1)
args = parser.parse_args()

using_upload = any("upload" in x for x in args.pipeline_configs)
Expand Down

0 comments on commit 9e7bb70

Please sign in to comment.