Skip to content

Commit f09232d

Browse files
committed
SLATE: prevent Cofunction reassembly
1 parent 382e63f commit f09232d

File tree

5 files changed

+59
-6
lines changed

5 files changed

+59
-6
lines changed

firedrake/adjoint_utils/variational_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def wrapper(self, *args, **kwargs):
2727
# Try again without expanding derivatives,
2828
# as dFdu might have been simplied to an empty Form
2929
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
30-
except (TypeError, NotImplementedError):
30+
except (ValueError, TypeError, NotImplementedError):
3131
self._ad_adj_F = None
3232
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
3333
self._ad_count_map = {}

firedrake/assemble.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,11 @@ def preprocess_base_form(expr, mat_type=None, form_compiler_parameters=None):
838838
if mat_type != "matfree":
839839
# Don't expand derivatives if `mat_type` is 'matfree'
840840
# For "matfree", Form evaluation is delayed
841-
expr = BaseFormAssembler.expand_derivatives_form(expr, form_compiler_parameters)
841+
try:
842+
expr = BaseFormAssembler.expand_derivatives_form(expr, form_compiler_parameters)
843+
except ValueError:
844+
# BaseForms with SLATE tensors are not fully supported in UFL.
845+
pass
842846
if not isinstance(expr, (ufl.form.Form, slate.TensorBase)):
843847
# => No restructuring needed for Form and slate.TensorBase
844848
expr = BaseFormAssembler.restructure_base_form_preorder(expr)

firedrake/slate/slate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1430,7 +1430,9 @@ def as_slate(F):
14301430
return F
14311431
elif isinstance(F, Form):
14321432
return Tensor(F)
1433-
elif isinstance(F, (Function, Cofunction)):
1433+
elif isinstance(F, Function):
1434+
# Do not implicitly cast Cofunctions as Slate tensors
1435+
# because assemble(AssembledVector(F)) repeats element summation on F
14341436
return AssembledVector(F)
14351437
else:
14361438
raise TypeError(f"Cannot convert {type(F).__name__} into a slate.Tensor")

firedrake/variational_solver.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,22 @@ def dm(self):
132132
return self.u_restrict.function_space().dm
133133

134134
@staticmethod
135-
def compute_bc_lifting(J, u):
135+
def compute_bc_lifting(J, u, L=0):
136136
"""Return the action of the bilinear form J (without bcs) on a Function u."""
137137
if isinstance(J, MatrixBase) and J.has_bcs:
138138
# Extract the full form without bcs
139139
if not isinstance(J.a, (ufl.BaseForm, slate.slate.TensorBase)):
140140
raise TypeError(f"Could not remove bcs from {type(J).__name__}.")
141141
J = J.a
142-
return ufl_expr.action(J, u)
142+
F = ufl_expr.action(J, u)
143+
if L != 0:
144+
try:
145+
F = F - L
146+
except TypeError:
147+
# Slate expressions do not combine with Cofunctions
148+
# because assemble(AssembledVector(L)) repeats element summation on L
149+
F = ufl.FormSum((F, 1), (L, -1))
150+
return F
143151

144152

145153
class NonlinearVariationalSolver(OptionsManager, NonlinearVariationalSolverMixin):
@@ -399,7 +407,7 @@ def __init__(self, a, L, u, bcs=None, aP=None,
399407
raise TypeError("Provided RHS is a '%s', not a Form or Slate Tensor" % type(L).__name__)
400408
if len(L.arguments()) != 1 and not L.empty():
401409
raise ValueError("Provided RHS is not a linear form")
402-
F = self.compute_bc_lifting(a, u) - L
410+
F = self.compute_bc_lifting(a, u, L=L)
403411

404412
super(LinearVariationalProblem, self).__init__(F, u, bcs=bcs, J=a, Jp=aP,
405413
form_compiler_parameters=form_compiler_parameters,

tests/firedrake/slate/test_linear_algebra.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,42 @@ def test_inverse_action(mat_type, rhs_type):
150150
x = Function(V)
151151
assemble(action(Ainv, b), tensor=x)
152152
assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13)
153+
154+
155+
@pytest.mark.parametrize("mat_type, rhs_type", [
156+
("slate", "slate"), ("slate", "form"), ("slate", "cofunction"),
157+
("aij", "cofunction"), ("aij", "form"),
158+
("matfree", "cofunction"), ("matfree", "form")])
159+
def test_solve_interface(mat_type, rhs_type):
160+
mesh = UnitSquareMesh(1, 1)
161+
V = FunctionSpace(mesh, "HDivT", 0)
162+
u = TrialFunction(V)
163+
v = TestFunction(V)
164+
f = Function(V).assign(1.0)
165+
bcs = DirichletBC(V, f, "on_boundary")
166+
167+
a = avg(inner(u, v))*dS + inner(u, v)*ds
168+
A = Tensor(a)
169+
if mat_type != "slate":
170+
A = assemble(A, bcs=bcs, mat_type=mat_type)
171+
bcs = None
172+
173+
L = action(a, f)
174+
if rhs_type == "form":
175+
b = L
176+
elif rhs_type == "cofunction":
177+
b = assemble(L)
178+
elif rhs_type == "slate":
179+
b = Tensor(L)
180+
else:
181+
raise ValueError("Invalid rhs type")
182+
183+
sp = None
184+
if mat_type == "matfree":
185+
sp = {"pc_type": "none"}
186+
187+
x = Function(V)
188+
problem = LinearVariationalProblem(A, b, x, bcs=bcs)
189+
solver = LinearVariationalSolver(problem, solver_parameters=sp)
190+
solver.solve()
191+
assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13)

0 commit comments

Comments
 (0)