1111
1212
1313def  plotTrajectoriesFile (filename , mode = '2d' , tracerfile = None , tracerfield = 'P' ,
14-                          tracerlon = 'x' , tracerlat = 'y' , recordedvar = None ):
14+                          tracerlon = 'x' , tracerlat = 'y' , recordedvar = None ,  show_plt = True ):
1515    """Quick and simple plotting of Parcels trajectories 
1616
1717    :param filename: Name of Parcels-generated NetCDF file with particle positions 
@@ -24,6 +24,7 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
2424    :param tracerlat: Name of latitude dimension of variable to show as background 
2525    :param recordedvar: Name of variable used to color particles in scatter-plot. 
2626                Only works in 'movie2d' or 'movie2d_notebook' mode. 
27+     :param show_plt: Boolean whether plot should directly be show (for py.test) 
2728    """ 
2829
2930    if  plt  is  None :
@@ -34,10 +35,10 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
3435    lon  =  pfile .variables ['lon' ]
3536    lat  =  pfile .variables ['lat' ]
3637    z  =  pfile .variables ['z' ]
38+     time  =  pfile .variables ['time' ][:]
3739    if  len (lon .shape ) ==  1 :
3840        type  =  'indexed' 
3941        id  =  pfile .variables ['trajectory' ][:]
40-         time  =  pfile .variables ['time' ][:]
4142    else :
4243        type  =  'array' 
4344
@@ -59,7 +60,7 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
5960            for  p  in  range (len (lon )):
6061                ax .plot (lon [p , :], lat [p , :], z [p , :], '.-' )
6162        elif  type  ==  'indexed' :
62-             for  t  in  range ( max ( id ) + 1 ):
63+             for  t  in  np . unique ( id ):
6364                ax .plot (lon [id  ==  t ], lat [id  ==  t ],
6465                        z [id  ==  t ], '.-' )
6566        ax .set_xlabel ('Longitude' )
@@ -69,11 +70,18 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
6970        if  type  ==  'array' :
7071            plt .plot (np .transpose (lon ), np .transpose (lat ), '.-' )
7172        elif  type  ==  'indexed' :
72-             for  t  in  range ( max ( id ) + 1 ):
73+             for  t  in  np . unique ( id ):
7374                plt .plot (lon [id  ==  t ], lat [id  ==  t ], '.-' )
7475        plt .xlabel ('Longitude' )
7576        plt .ylabel ('Latitude' )
7677    elif  mode  ==  'movie2d'  or  'movie2d_notebook' :
78+         if  type  ==  'array'  and  any (time [:, 0 ] !=  time [0 , 0 ]):
79+             # since particles don't start at the same time, treat as indexed 
80+             type  =  'indexed' 
81+             id  =  pfile .variables ['trajectory' ][:].flatten ()
82+             lon  =  lon [:].flatten ()
83+             lat  =  lat [:].flatten ()
84+             time  =  time .flatten ()
7785
7886        fig  =  plt .figure ()
7987        ax  =  plt .axes (xlim = (np .amin (lon ), np .amax (lon )), ylim = (np .amin (lat ), np .amax (lat )))
@@ -84,7 +92,7 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
8492            mintime  =  min (time )
8593            scat  =  ax .scatter (lon [time  ==  mintime ], lat [time  ==  mintime ],
8694                              s = 60 , cmap = plt .get_cmap ('autumn' ))
87-             frames  =  np .unique (time )
95+             frames  =  np .unique (time [ ~ np . isnan ( time )] )
8896
8997        def  animate (t ):
9098            if  type  ==  'array' :
@@ -102,7 +110,9 @@ def animate(t):
102110        plt .close ()
103111        return  anim 
104112    else :
105-         plt .show ()
113+         if  show_plt :
114+             plt .show ()
115+         return  plt 
106116
107117
108118if  __name__  ==  "__main__" :
@@ -125,4 +135,5 @@ def animate(t):
125135
126136    plotTrajectoriesFile (args .particlefile , mode = args .mode , tracerfile = args .tracerfile ,
127137                         tracerfield = args .tracerfilefield , tracerlon = args .tracerfilelon ,
128-                          tracerlat = args .tracerfilelat , recordedvar = args .recordedvar )
138+                          tracerlat = args .tracerfilelat , recordedvar = args .recordedvar ,
139+                          show_plt = True )
0 commit comments