forked from piyawatm/oculus-draconis
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
72 lines (61 loc) · 2.2 KB
/
utils.py
File metadata and controls
72 lines (61 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import random
import numpy as np
import torch
import yaml
from argparse import Namespace
from torchvision.utils import save_image as tv_save_image
def set_seed(seed: int = 42):
"""Set RNG seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def count_parameters(model):
"""Return number of trainable parameters (in millions)."""
total = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total / 1e6
def save_checkpoint(model, path: str):
"""Create dir and save model state."""
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(model.state_dict(), path)
print(f"[checkpoint] Saved → {path}")
def load_checkpoint(model, path: str, device="cuda"):
"""Load weights into model."""
state = torch.load(path, map_location=device)
model.load_state_dict(state)
print(f"[checkpoint] Loaded ← {path}")
return model
def load_config(path):
"""
Load YAML config and return it as a flat Namespace.
Handles nested 'model', 'data', 'training' keys by flattening them.
"""
with open(path, 'r') as f:
raw_config = yaml.safe_load(f)
# Flatten the config so we can access attributes directly (e.g. conf.batch_size)
flat_config = {}
for key, value in raw_config.items():
if isinstance(value, dict):
flat_config.update(value)
else:
flat_config[key] = value
return Namespace(**flat_config)
def save_image(tensor, path, **kwargs):
"""Wrapper around torchvision save_image to ensure dir exists."""
os.makedirs(os.path.dirname(path), exist_ok=True)
tv_save_image(tensor, path, **kwargs)
class Logger:
"""Tiny text logger; writes both to stdout and file."""
def __init__(self, path):
# Ensure directory exists
os.makedirs(os.path.dirname(path), exist_ok=True)
self.f = open(path, "w")
def log(self, msg):
print(msg)
self.f.write(str(msg) + "\n")
self.f.flush()
def close(self):
self.f.close()