Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 98 additions & 54 deletions tmol/optimization/lbfgs_armijo.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def step(self, closure):
Despite the name, this performs the full LBFGS minimization trajectory.
Stores lots of information in self.state
"""
# lbfgs only works w/ single parameter group
assert len(self.param_groups) == 1

group = self.param_groups[0]
Expand All @@ -226,27 +227,45 @@ def step(self, closure):
state.setdefault("func_evals", 0)
state.setdefault("n_iter", 0)

# Optimization: work directly with the single parameter tensor
assert len(self._params) == 1, "This optimized version requires single tensor"
param = self._params[0]

# evaluate initial f(x)
orig_loss = closure()
loss = float(orig_loss)
current_evals = 1
state["func_evals"] += 1

# ... and df/dx
x = self._gather_flat_x()
flat_grad = self._gather_flat_grad()
# ... and df/dx - direct reference instead of gather
x = param.data.view(-1)
flat_grad = param.grad.data.view(-1)
max_grad = flat_grad.max()

# tensors cached in state
d = state.get("d") # search direction
t = state.get("t") # stepsize

old_dirs = state.get("old_dirs") # history of directions
old_stps = state.get("old_stps") # history of stepsizes

prev_flat_grad = state.get("prev_flat_grad") # previous grad
prev_loss = state.get("prev_loss") # previous energy

# Pre-allocate stacked matrices for L-BFGS (reused each iteration)
L = x.numel()
if "old_dirs_mat" not in state:
state["old_dirs_mat"] = torch.empty(
(history_size, L), device=x.device, dtype=x.dtype
)
state["old_stps_mat"] = torch.empty(
(history_size, L), device=x.device, dtype=x.dtype
)
state["history_start"] = 0 # Circular buffer start index
state["history_count"] = 0 # Number of items in history

old_dirs_mat = state["old_dirs_mat"]
old_stps_mat = state["old_stps_mat"]
history_start = state["history_start"]
history_count = state["history_count"]

n_iter = 0

while n_iter < max_iter:
Expand All @@ -257,48 +276,75 @@ def step(self, closure):
if state["n_iter"] == 1:
# initialize
d = flat_grad.neg()
old_dirs = []
old_stps = []
history_count = 0
else:
# do lbfgs update (update memory)
y = flat_grad.sub(prev_flat_grad)
s = d.mul(t)
ys = y.dot(s) # y*s
if ys > 1e-10:
# updating memory
if len(old_dirs) == history_size:
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)

# store new direction/step
old_dirs.append(y)
old_stps.append(s)
# updating memory - write directly into circular buffer
if history_count < history_size:
# Still filling up the buffer
idx = history_count
history_count += 1
else:
# Buffer full, overwrite oldest entry
idx = history_start
history_start = (history_start + 1) % history_size

old_dirs_mat[idx].copy_(y)
old_stps_mat[idx].copy_(s)

# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
num_old = len(old_dirs)

if "ro" not in state:
state["ro"] = [None] * history_size
state["al"] = [None] * history_size
ro = state["ro"]
al = state["al"]

for i in range(num_old):
ro[i] = 1.0 / old_dirs[i].dot(old_stps[i])

# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
for i in range(num_old - 1, -1, -1):
al[i] = old_stps[i].dot(q) * ro[i]
q.add_(old_dirs[i], alpha=-al[i])

# r/d is the final direction
d = r = q
for i in range(num_old):
be_i = old_dirs[i].dot(r) * ro[i]
r.add_(old_stps[i], alpha=al[i] - be_i)
if history_count == 0:
# No history: use steepest descent direction
d = r = flat_grad.neg()
else:
# Create views old -> new
if history_count < history_size:
old_dirs_view = old_dirs_mat[:history_count]
old_stps_view = old_stps_mat[:history_count]
else:
# Buffer full, need to reorder: [start:end] + [0:start]
indices = torch.cat(
[
torch.arange(
history_start, history_size, device=x.device
),
torch.arange(0, history_start, device=x.device),
]
)
old_dirs_view = old_dirs_mat[indices]
old_stps_view = old_stps_mat[indices]

# Compute all ro values in one batched operation
ro = 1.0 / torch.sum(old_dirs_view * old_stps_view, dim=1)

# First loop: backward pass - fully batched
q = flat_grad.neg()

# Compute all dot products: old_stps_mat @ q
stps_dot_q = torch.mv(old_stps_view, q)
al = stps_dot_q * ro

# Compute cumulative updates in reverse order
al_flipped = torch.flip(al, dims=[0])
old_dirs_flipped = torch.flip(old_dirs_view, dims=[0])

# Apply all updates at once: q -= old_dirs_mat.T @ al_flipped
q.add_(torch.mv(old_dirs_flipped.t(), al_flipped), alpha=-1.0)

# Second loop: forward pass - fully batched
r = q
be = (
torch.mv(old_dirs_view, r) * ro
) # All dot products in one matmul

# Single batched update: r += old_stps_mat.T @ (al - be)
r.add_(torch.mv(old_stps_view.t(), al - be))

d = r

if prev_flat_grad is None:
prev_flat_grad = flat_grad.clone()
Expand Down Expand Up @@ -333,9 +379,13 @@ def step(self, closure):
# we do not need to compute gradients in here
self.ls_func_evals = 0

# Optimization: save original position and work directly with param.data
x_backup = x.clone()

def linefn(alpha_test):
self.ls_func_evals += 1
self._set_x_from_flat(x + alpha_test * d)
# Direct parameter update - eliminates _set_x_from_flat overhead
x.copy_(x_backup).add_(d, alpha=alpha_test)
E = closure()
return E.to(dtype=gtd.dtype)

Expand All @@ -351,11 +401,12 @@ def linefn(alpha_test):
minstep=1e-12,
)

# update
x = x + t * d
self._set_x_from_flat(x)
# update - direct modification
x.copy_(x_backup).add_(d, alpha=t)

closure() # fd: needed for derivatives, but adds an extra func eval...
flat_grad = self._gather_flat_grad()

flat_grad = param.grad.data.view(-1) # Direct reference
max_grad = flat_grad.max()

# update func eval
Expand All @@ -374,17 +425,10 @@ def linefn(alpha_test):
if 2 * abs(loss - prev_loss) <= rtol * (abs(loss) + abs(prev_loss) + 1e-10):
break

# report if we have hit max cycles (mimicing R3)
# if state['n_iter'] == max_iter - 1:
# print(
# "LBFGS_Armijo finished ", max_iter,
# " cycles without converging."
# )

state["d"] = d
state["t"] = t
state["old_dirs"] = old_dirs
state["old_stps"] = old_stps
state["history_start"] = history_start
state["history_count"] = history_count
state["prev_flat_grad"] = prev_flat_grad
state["prev_loss"] = prev_loss

Expand Down