3333# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
3434# POSSIBILITY OF SUCH DAMAGE.
3535#
36+ from typing import Tuple , Optional
37+ import os
38+ import logging
39+
40+ import numpy as np
3641import scipy .sparse as sparse
3742import scipy .linalg as linalg
38- import numpy as np
39- import os
4043import matplotlib .pyplot as plt
41- import logging
4244import matplotlib .colors as colors
45+
46+ import allensdk .internal .brain_observatory .mask_set as mask_set
4347from allensdk .config .manifest import Manifest
4448
45- def demix_time_dep_masks (raw_traces , stack , masks ):
46- '''
4749
48- :param raw_traces: extracted traces
49- :param stack: movie (same length as traces)
50- :param masks: binary roi masks
51- :return: demixed traces
52- '''
50+ def identify_valid_masks (mask_array ):
51+ ms = mask_set .MaskSet (masks = mask_array .astype (bool ))
52+ valid_masks = np .ones (mask_array .shape [0 ]).astype (bool )
53+
54+ # detect duplicates
55+ duplicates = ms .detect_duplicates (overlap_threshold = 0.9 )
56+ if len (duplicates ) > 0 :
57+ valid_masks [duplicates .keys ()] = False
58+
59+ # detect unions, only for remaining valid masks
60+ valid_idxs = np .where (valid_masks )
61+ ms = mask_set .MaskSet (masks = mask_array [valid_idxs ].astype (bool ))
62+ unions = ms .detect_unions ()
63+
64+ if len (unions ) > 0 :
65+ un_idxs = unions .keys ()
66+ valid_masks [valid_idxs [0 ][un_idxs ]] = False
67+
68+ return valid_masks
69+
70+
71+ def _demix_point (source_frame : np .ndarray , mask_traces : np .ndarray ,
72+ flat_masks : sparse ,
73+ pixels_per_mask : np .ndarray ) -> Optional [np .ndarray ]:
74+ """
75+ Helper function to run demixing for single point in time for a
76+ source with overlapping traces.
77+
78+ Parameters
79+ ==========
80+ source_frame: values of movie source at the single time point,
81+ unraveled in the x-y dimension (1d array of length HxW )
82+ flat_masks: 2d-array of binary masks unraveled in the x-y dimension
83+ mask traces: values of mask trace at single time point (1d-array of
84+ length n, where `n` is number of masks)
85+ pixels_per_mask: Number of pixels for each mask associated with
86+ trace (1d-array of length `n`)
87+
88+ Returns
89+ =======
90+ Array of demixed trace values for each mask if all trace data is
91+ nonzero. Otherwise, returns None.
92+ """
93+ mask_weighted_trace = mask_traces * pixels_per_mask
94+
95+ # Skip if there is zero signal anywhere in one of the traces
96+ if (mask_weighted_trace == 0 ).any ():
97+ return None
98+ norm_mat = sparse .diags (pixels_per_mask / mask_weighted_trace , offsets = 0 )
99+ source_mat = sparse .diags (source_frame , offsets = 0 )
100+ source_mask_projection = flat_masks .dot (source_mat )
101+ weighted_masks = norm_mat .dot (source_mask_projection )
102+ # cast to dense numpy array for linear solver because solution is dense
103+ overlap = flat_masks .dot (weighted_masks .T ).toarray ()
104+ try :
105+ demix_traces = linalg .solve (overlap , mask_weighted_trace )
106+ except linalg .LinAlgError :
107+ logging .warning ("Singular matrix, using least squares to solve." )
108+ x , _ , _ , _ = linalg .lstsq (overlap , mask_weighted_trace )
109+ demix_traces = x
110+ return demix_traces
111+
112+
113+ def demix_time_dep_masks (raw_traces : np .ndarray , stack : np .ndarray ,
114+ masks : np .ndarray ) -> Tuple [np .ndarray , list ]:
115+ """
116+ Demix traces of potentially overlapping masks extraced from a single
117+ 2p recording.
118+
119+ :param raw_traces: 2d array of traces for each mask, of dimensions
120+ (t, n), where `t` is the number of time points and `n` is the
121+ number of masks.
122+ :param stack: 3d array representing a 1p recording movie, of
123+ dimensions (t, H, W).
124+ :param masks: 3d array of binary roi masks, of shape (n, H, W),
125+ where `n` is the number of masks, and HW are the dimensions of
126+ an individual frame in the movie `stack`.
127+ :return: Tuple of demixed traces and whether each frame was skipped
128+ in the demixing calculation.
129+ """
53130 N , T = raw_traces .shape
54131 _ , x , y = masks .shape
55132 P = x * y
@@ -58,8 +135,6 @@ def demix_time_dep_masks(raw_traces, stack, masks):
58135 stack = stack .reshape (T , P )
59136
60137 num_pixels_in_mask = np .sum (masks , axis = (1 , 2 ))
61- F = raw_traces .T * num_pixels_in_mask # shape (T,N)
62- F = F .T
63138
64139 flat_masks = masks .reshape (N , P )
65140 flat_masks = sparse .csr_matrix (flat_masks )
@@ -68,31 +143,16 @@ def demix_time_dep_masks(raw_traces, stack, masks):
68143 demix_traces = np .zeros ((N , T ))
69144
70145 for t in range (T ):
71-
72- weighted_mask_sum = F [:, t ]
73- drop_test = (weighted_mask_sum == 0 )
74-
75- if np .sum (drop_test == 0 ):
76- norm_mat = sparse .diags (num_pixels_in_mask / weighted_mask_sum , offsets = 0 )
77- stack_t = sparse .diags (stack [t ], offsets = 0 )
78-
79- flat_weighted_masks = norm_mat .dot (flat_masks .dot (stack_t ))
80-
81- overlap = flat_masks .dot (flat_weighted_masks .T ).toarray () # cast to dense numpy array for linear solver because solution is dense
82- try :
83- demix_traces [:, t ] = linalg .solve (overlap , F [:, t ])
84- except linalg .LinAlgError as e :
85- logging .warning ("singular matrix, using least squares" )
86- x , _ , _ , _ = linalg .lstsq (overlap , F [:, t ])
87- demix_traces [:, t ] = x
88-
146+ demixed_point = _demix_point (
147+ stack [t ], raw_traces [:, t ], flat_masks , num_pixels_in_mask )
148+ if demixed_point is not None :
149+ demix_traces [:, t ] = demixed_point
89150 drop_frames .append (False )
90-
91151 else :
92152 drop_frames .append (True )
93-
94153 return demix_traces , drop_frames
95154
155+
96156def plot_traces (raw_trace , demix_trace , roi_id , roi_ind , save_file ):
97157 fig , ax = plt .subplots ()
98158
@@ -103,12 +163,15 @@ def plot_traces(raw_trace, demix_trace, roi_id, roi_ind, save_file):
103163 plt .savefig (save_file )
104164 plt .close (fig )
105165
166+
106167def find_zero_baselines (traces ):
107168 means = traces .mean (axis = 1 )
108- stds = traces .std (axis = 1 )
169+ stds = traces .std (axis = 1 )
109170 return np .where ((means - stds ) < 0 )
110-
111- def plot_negative_baselines (raw_traces , demix_traces , mask_array , roi_ids_mask , plot_dir , ext = 'png' ):
171+
172+
173+ def plot_negative_baselines (raw_traces , demix_traces , mask_array ,
174+ roi_ids_mask , plot_dir , ext = 'png' ):
112175 N , T = raw_traces .shape
113176 _ , x , y = mask_array .shape
114177
@@ -134,11 +197,10 @@ def plot_negative_baselines(raw_traces, demix_traces, mask_array, roi_ids_mask,
134197 overlap_inds .update (zero_inds )
135198
136199 return list (overlap_inds )
137-
138200
139-
140-
141- def plot_negative_transients ( raw_traces , demix_traces , valid_roi , mask_array , roi_ids_mask , plot_dir , ext = 'png' ):
201+
202+ def plot_negative_transients ( raw_traces , demix_traces , valid_roi , mask_array ,
203+ roi_ids_mask , plot_dir , ext = 'png' ):
142204
143205 N , T = raw_traces .shape
144206 _ , x , y = mask_array .shape
@@ -227,6 +289,7 @@ def find_negative_baselines(trace):
227289 stds = trace .std (axis = 1 )
228290 return np .where ((means + stds ) < 0 )
229291
292+
230293def find_negative_transients_threshold (trace , window = 500 , length = 10 , std_devs = 3 ):
231294 trace = np .pad (trace , pad_width = (window - 1 , 0 ), mode = 'constant' , constant_values = [np .mean (trace [:window ])])
232295 rolling_mean = np .mean (rolling_window (trace , window ), - 1 )
@@ -332,4 +395,3 @@ def plot_transients(roi_ind, t_trans, masks, traces, demix_traces, savefile):
332395
333396 plt .savefig (savefile )
334397 plt .close (fig )
335-
0 commit comments