File tree Expand file tree Collapse file tree 1 file changed +7
-4
lines changed
src/ezmsg/learn/dim_reduce Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -146,15 +146,18 @@ def _process(self, message: AxisArray) -> AxisArray:
146146 d2 = np .prod (in_dat .shape [len (off_targ_axes ) + 1 :])
147147 in_dat = in_dat .reshape ((- 1 , d2 ))
148148
149+ replace_kwargs = {
150+ "axes" : {** self ._state .template .axes , iter_axis : message .axes [iter_axis ]},
151+ }
152+
149153 # Transform data
150154 if hasattr (self ._state .estimator , "components_" ):
151155 decomp_dat = self ._state .estimator .transform (in_dat ).reshape (
152156 (- 1 ,) + self ._state .template .data .shape [1 :]
153157 )
154- return replace (self ._state .template , data = decomp_dat )
155- else :
156- # No components yet, return empty template
157- return self ._state .template
158+ replace_kwargs ["data" ] = decomp_dat
159+
160+ return replace (self ._state .template , ** replace_kwargs )
158161
159162 def partial_fit (self , message : AxisArray ) -> None :
160163 # Check if we need to reset state
You can’t perform that action at this time.
0 commit comments