Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions simpeg/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import simpeg.dask.potential_fields.base
import simpeg.dask.potential_fields.gravity.simulation
import simpeg.dask.potential_fields.magnetics.simulation
import simpeg.dask.potential_fields.magnetics.simulation_pde
import simpeg.dask.simulation
import simpeg.dask.inverse_problem
import simpeg.dask.objective_function
Expand Down
82 changes: 82 additions & 0 deletions simpeg/dask/potential_fields/magnetics/simulation_pde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from dask import array, compute, delayed
import numpy as np
from ....potential_fields.magnetics import Simulation3DDifferential as Sim
from ....utils import sdiag, mkvc


def distance_weights(locations, cell_centers, cell_volumes, exponent=3, threshold=1e-2):
distance_weights = np.zeros(len(cell_centers))
for loc in locations:
distance = np.linalg.norm(cell_centers - loc, axis=1)
distance_weights += cell_volumes**2.0 * (distance + threshold) ** (
-2 * exponent
)

return distance_weights
Comment thread
domfournier marked this conversation as resolved.
Outdated


def dask_getJtJdiag(self, m, W=None, f=None):
"""
Return the diagonal of JtJ
"""

self.model = m

self.model = m
if W is None:
W = np.ones(self.Jmatrix.shape[0])
else:
W = W.diagonal()

client, worker = self._get_client_worker()

n_threads = self.n_threads(client=client, worker=worker)

chunks = np.array_split(self.survey.receiver_locations, n_threads)
cell_centers = self.mesh.cell_centers.copy()
cell_volumes = self.mesh.cell_volumes.copy()

if client:
cell_centers = client.scatter(cell_centers, workers=worker)
cell_volumes = client.scatter(cell_volumes, workers=worker)
else:
delayed_distance_weights = delayed(distance_weights)

futures = []
for block in chunks:
if client:
futures.append(
client.submit(
distance_weights,
block,
cell_centers,
cell_volumes,
workers=worker,
)
)
else:
futures.append(
array.from_delayed(
delayed_distance_weights(
block,
cell_centers,
cell_volumes,
),
dtype=np.float32,
shape=(
len(block),
len(cell_centers),
),
Comment thread
domfournier marked this conversation as resolved.
)
)

if client:
diag = client.gather(futures)
else:
diag = compute(futures)
Comment thread
domfournier marked this conversation as resolved.
Outdated

diag = np.tile(np.vstack(diag).sum(axis=0), 3)
return mkvc((sdiag(np.sqrt(diag)) @ self.remDeriv).power(2).sum(axis=0))


Sim.getJtJdiag = dask_getJtJdiag
2 changes: 1 addition & 1 deletion simpeg/directives/_save_geoh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def write(self, iteration: int, **_):
if (channel_name in child.name and isinstance(child, FloatData))
]

if children[0] is not None:
if children:
properties += children

if len(properties) == 0:
Expand Down
4 changes: 2 additions & 2 deletions simpeg/directives/_vector_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class VectorInversion(InversionDirective):
chifact_target = 1.0
reference_model = None
mode = "cartesian"
inversion_type = "mvis"
inversion_type = "magnetic vector"
norms = []
alphas = []
cartesian_model = None
Expand Down Expand Up @@ -162,7 +162,7 @@ def endIter(self):

if (
self.invProb.phi_d < self.target
) and self.mode == "cartesian": # and self.inversion_type == 'mvis':
) and self.mode == "cartesian" and self.inversion_type == "magnetic vector":
print("Switching MVI to spherical coordinates")
Comment thread
domfournier marked this conversation as resolved.
self.mode = "spherical"
self.cartesian_model = model
Expand Down
4 changes: 4 additions & 0 deletions simpeg/maps/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,10 @@ def __init__(self, *args):
for arg in args:

if isinstance(arg[1], (int, np.integer)):

if not getattr(self, "_nP", None):
self._nP = int(np.sum([w[1] for w in args]))

Comment thread
domfournier marked this conversation as resolved.
wire = Projection(self.nP, slice(start, start + arg[1]))
start += arg[1]
else:
Expand Down
20 changes: 16 additions & 4 deletions simpeg/potential_fields/magnetics/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ def _getRHS(self, m):
).diagonal()
)

return rhs
return rhs.astype(self.solver_dtype)
Comment thread
domfournier marked this conversation as resolved.

def _getA(self):
A = self._Div * self.MfMuiI * self._DivT
Expand Down Expand Up @@ -1991,13 +1991,23 @@ def _Jtvec(self, m, v, f):
if v is None:
v = np.eye(Q.shape[0])
divt_solve_q = (
self._DivT * (self._Ainv * ((Q * self.MfMuiI * -self._DivT).T * v))
self._DivT
* (
self._Ainv
* ((Q * self.MfMuiI * -self._DivT).T * v).astype(self.solver_dtype)
)
+ Q.T * v
)
del v
else:
divt_solve_q = (
self._DivT * (self._Ainv * ((-self._Div * (self.MfMuiI.T * (Q.T * v)))))
self._DivT
* (
self._Ainv
* ((-self._Div * (self.MfMuiI.T * (Q.T * v)))).astype(
self.solver_dtype
)
)
+ Q.T * v
)

Expand Down Expand Up @@ -2071,7 +2081,9 @@ def _Jvec(self, m, v, f):
self.MfMuiI * Mf_r_mui_deriv * v
)

Ainv_Ddm = self._Ainv * (self._Div * (-dCmu_dm + db_dm))
Ainv_Ddm = self._Ainv * (self._Div * (-dCmu_dm + db_dm)).astype(
self.solver_dtype
)

Jv = Q * (C * Ainv_Ddm + (-dCmu_dm + db_dm))

Expand Down
Loading