Skip to content

Commit

Permalink
Val fix (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
achalddave authored Dec 12, 2023
1 parent 73e7b37 commit 81aeb87
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
11 changes: 6 additions & 5 deletions open_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self):

def reset(self):
self.points = []
self.points_tensor = None

def update(self, val):
self.points.append(val)
Expand All @@ -53,13 +54,13 @@ def compute_bootstrap_ci(self, num_samples=10_000, interval=95):
lower = None
upper = None

points_tensor = torch.cat(self.points)
num_points = self.points.shape[0]
self.points_tensor = torch.cat(self.points)
num_points = self.points_tensor.shape[0]

estimates = []
for _ in range(num_samples):
i = np.random.choice(num_points, size=num_points)
estimate = torch.sum(points_tensor[i]) / num_points
estimate = torch.sum(self.points_tensor[i]) / num_points
estimates.append(estimate.item())

half = (100 - interval) / 2
Expand Down Expand Up @@ -384,8 +385,8 @@ def evaluate(model, data, start_epoch, args, writer):

lower_seq, upper_seq = losses_seq_ci_m.compute_bootstrap_ci()
lower_tok, upper_tok = losses_tok_ci_m.compute_bootstrap_ci()
num_seqs = losses_seq_ci_m.points.shape[0]
num_toks = losses_tok_ci_m.points.shape[0]
num_seqs = losses_seq_ci_m.points_tensor.shape[0]
num_toks = losses_tok_ci_m.points_tensor.shape[0]

# Save eval loss / etc.
log_data = {
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _read_reqs(relpath):

setuptools.setup(
name="open_lm",
version="0.0.21",
version="0.0.22",
author=[
"Suchin Gururangan*",
"Mitchell Wortsman*",
Expand Down

0 comments on commit 81aeb87

Please sign in to comment.