@@ -375,7 +375,6 @@ def dist(cls, mus, covs, logp, **kwargs):
375375 def rv_op (cls , mus , covs , logp , size = None ):
376376 # Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
377377 mus_ , covs_ = mus .type (), covs .type ()
378- print (f"mus_.type.shape: { mus_ .type .shape } , covs_.type.shape: { covs_ .type .shape } " )
379378
380379 logp_ = logp .type ()
381380 rng = pytensor .shared (np .random .default_rng ())
@@ -385,7 +384,6 @@ def recursion(mus, covs, rng):
385384 mus = pt .moveaxis (mus , - 2 , 0 )
386385 if covs .ndim > 3 :
387386 covs = pt .moveaxis (covs , - 3 , 0 )
388- print (f"mus.type.shape: { mus .type .shape } , covs.type.shape: { covs .type .shape } " )
389387
390388 def step (mu , cov , rng ):
391389 new_rng , mvn = pm .MvNormal .dist (mu = mu , cov = cov , rng = rng , method = "svd" ).owner .outputs
@@ -394,32 +392,26 @@ def step(mu, cov, rng):
394392 mvn_seq , updates = pytensor .scan (
395393 step , sequences = [mus , covs ], non_sequences = [rng ], strict = True , n_steps = mus .shape [0 ]
396394 )
397- print (f"mvn_seq.type.shape: { mvn_seq .type .shape } " )
398395 mvn_seq = pt .specify_shape (mvn_seq , mus .type .shape )
399396
400397 # Move time axis back to position -2 so batches are on the left
401398 if mvn_seq .ndim > 2 :
402399 mvn_seq = pt .moveaxis (mvn_seq , 0 , - 2 )
403- print (f"mvn_seq.type.shape: { mvn_seq .type .shape } " )
404400
405401 (seq_mvn_rng ,) = tuple (updates .values ())
406402
407- print (f"mvn_seq.type.shape: { mvn_seq .type .shape } " )
408-
409403 return [seq_mvn_rng , mvn_seq ]
410404
411405 mvn_seq_op = KalmanFilterRV (
412406 inputs = [mus_ , covs_ , logp_ , rng ], outputs = recursion (mus_ , covs_ , rng ), ndim_supp = 2
413407 )
414408
415409 mvn_seq = mvn_seq_op (mus , covs , logp , rng )
416- print (f"mvn_seq.type.shape: { mvn_seq .type .shape } " )
417410 return mvn_seq
418411
419412
420413@_logprob .register (KalmanFilterRV )
421414def sequence_mvnormal_logp (op , values , mus , covs , logp , rng , ** kwargs ):
422- print (values [0 ].type .shape , mus .type .shape , covs .type .shape )
423415 return check_parameters (
424416 logp ,
425417 pt .eq (values [0 ].shape [- 2 ], mus .shape [- 2 ]),
0 commit comments