1111from  urllib  import  request 
1212
1313import  numpy  as  np 
14+ 
15+ try :
16+     import  cupy  as  cp 
17+ 
18+     cupy_available  =  True 
19+ except  ImportError :
20+     cupy_available  =  False 
21+ 
1422from  surrogate_model_definitions  import  surrogate_model 
1523
1624try :
2028    sys .exit (0 )
2129
2230from  impactx  import  (
31+     Config ,
32+     CoordSystem ,
2333    ImpactX ,
2434    ImpactXParIter ,
25-     TransformationDirection ,
2635    coordinate_transformation ,
2736    distribution ,
2837    elements ,
2938)
3039
40+ # CPU/GPU logic 
41+ if  Config .have_gpu :
42+     if  cupy_available :
43+         array  =  cp .array 
44+         stack  =  cp .stack 
45+         device  =  torch .device ("cuda" )
46+     else :
47+         print ("Warning: GPU found but cupy not available! Try managed..." )
48+         array  =  np .array 
49+         stack  =  np .stack 
50+         device  =  torch .device ("cpu" )
51+     if  Config .gpu_backend  ==  "SYCL" :
52+         print ("Warning: SYCL GPU backend not yet implemented for Python" )
53+ 
54+ else :
55+     array  =  np .array 
56+     stack  =  np .stack 
57+     device  =  torch .device ("cpu" )
58+ 
3159
3260def  download_and_unzip (url , data_dir ):
3361    request .urlretrieve (url , data_dir )
@@ -50,6 +78,7 @@ def download_and_unzip(url, data_dir):
5078    surrogate_model (
5179        dataset_dir  +  f"dataset_beam_stage_{ i }  .pt" ,
5280        model_dir  +  f"beam_stage_{ i }  _model.pt" ,
81+         device = device ,
5382    )
5483    for  i  in  range (N_stage )
5584]
@@ -78,47 +107,62 @@ def __init__(self, stage_i, surrogate_model, surrogate_length, stage_start):
78107        self .ds  =  surrogate_length 
79108
80109    def  surrogate_push (self , pc , step ):
81-         array  =  np .array 
82- 
83110        ref_part  =  pc .ref_particle ()
84111        ref_z_i  =  ref_part .z 
85112        ref_z_i_LPA  =  ref_z_i  -  self .stage_start 
86113        ref_z_f  =  ref_z_i  +  self .surrogate_length 
87114
88115        ref_part_tensor  =  torch .tensor (
89-             [ref_part .x , ref_part .y , ref_z_i_LPA , ref_part .px , ref_part .py , ref_part .pz ]
116+             [
117+                 ref_part .x ,
118+                 ref_part .y ,
119+                 ref_z_i_LPA ,
120+                 ref_part .px ,
121+                 ref_part .py ,
122+                 ref_part .pz ,
123+             ],
124+             dtype = torch .float64 ,
125+             device = device ,
90126        )
91-         ref_beta_gamma  =  np .sqrt (torch .sum (ref_part_tensor [3 :] **  2 ))
127+         ref_beta_gamma  =  torch .sqrt (torch .sum (ref_part_tensor [3 :] **  2 ))
92128
93129        with  torch .no_grad ():
94-             ref_part_model_final  =  self .surrogate_model (ref_part_tensor . float () )
130+             ref_part_model_final  =  self .surrogate_model (ref_part_tensor )
95131        ref_uz_f  =  ref_part_model_final [5 ]
96132        ref_beta_gamma_final  =  (
97133            ref_uz_f   # NOT np.sqrt(torch.sum(ref_part_model_final[3:]**2)) 
98134        )
99-         ref_part_final  =  torch .tensor ([0 , 0 , ref_z_f , 0 , 0 , ref_uz_f ])
135+         ref_part_final  =  torch .tensor (
136+             [0 , 0 , ref_z_f , 0 , 0 , ref_uz_f ], dtype = torch .float64 , device = device 
137+         )
100138
101139        # transform 
102-         coordinate_transformation (pc , TransformationDirection . to_fixed_t )
140+         coordinate_transformation (pc , direction = CoordSystem . t )
103141
104142        for  lvl  in  range (pc .finest_level  +  1 ):
105143            for  pti  in  ImpactXParIter (pc , level = lvl ):
106-                 aos  =  pti .aos ()
107-                 aos_arr  =  array (aos , copy = False )
108- 
109144                soa  =  pti .soa ()
110-                 real_arrays  =  soa .GetRealData ()
111-                 px  =  array (real_arrays [0 ], copy = False )
112-                 py  =  array (real_arrays [1 ], copy = False )
113-                 pt  =  array (real_arrays [2 ], copy = False )
114-                 data_arr  =  (
115-                     torch .tensor (
116-                         np .vstack (
117-                             [aos_arr ["x" ], aos_arr ["y" ], aos_arr ["z" ], real_arrays [:3 ]]
118-                         )
119-                     )
120-                     .float ()
121-                     .T 
145+                 real_arrays  =  soa .get_real_data ()
146+                 x  =  array (real_arrays [0 ], copy = False )
147+                 y  =  array (real_arrays [1 ], copy = False )
148+                 t  =  array (real_arrays [2 ], copy = False )
149+                 px  =  array (real_arrays [3 ], copy = False )
150+                 py  =  array (real_arrays [4 ], copy = False )
151+                 pt  =  array (real_arrays [5 ], copy = False )
152+                 data_arr  =  torch .tensor (
153+                     stack (
154+                         [
155+                             x ,
156+                             y ,
157+                             t ,
158+                             px ,
159+                             py ,
160+                             py ,
161+                         ],
162+                         axis = 1 ,
163+                     ),
164+                     dtype = torch .float64 ,
165+                     device = device ,
122166                )
123167
124168                data_arr [:, 0 ] +=  ref_part .x 
@@ -135,7 +179,7 @@ def surrogate_push(self, pc, step):
135179                #     # assume for now it is 
136180
137181                with  torch .no_grad ():
138-                     data_arr_post_model  =  self .surrogate_model (data_arr . float () )
182+                     data_arr_post_model  =  self .surrogate_model (data_arr )
139183
140184                #  need to add stage start to z 
141185                data_arr_post_model [:, 2 ] +=  self .stage_start 
@@ -146,9 +190,9 @@ def surrogate_push(self, pc, step):
146190                    data_arr_post_model [:, 3  +  ii ] -=  ref_part_final [3  +  ii ]
147191                    data_arr_post_model [:, 3  +  ii ] /=  ref_beta_gamma_final 
148192
149-                 aos_arr [ "x" ] =  data_arr_post_model [:, 0 ]
150-                 aos_arr [ "y" ] =  data_arr_post_model [:, 1 ]
151-                 aos_arr [ "z" ] =  data_arr_post_model [:, 2 ]
193+                 x [: ] =  data_arr_post_model [:, 0 ]
194+                 y [: ] =  data_arr_post_model [:, 1 ]
195+                 t [: ] =  data_arr_post_model [:, 2 ]
152196                px [:] =  data_arr_post_model [:, 3 ]
153197                py [:] =  data_arr_post_model [:, 4 ]
154198                pt [:] =  data_arr_post_model [:, 5 ]
@@ -160,7 +204,7 @@ def surrogate_push(self, pc, step):
160204        ref_part .x  =  ref_part_final [0 ]
161205        ref_part .y  =  ref_part_final [1 ]
162206        ref_part .z  =  ref_part_final [2 ]
163-         ref_gamma  =  np .sqrt (1  +  ref_beta_gamma_final ** 2 )
207+         ref_gamma  =  torch .sqrt (1  +  ref_beta_gamma_final ** 2 )
164208        ref_part .px  =  ref_part_final [3 ]
165209        ref_part .py  =  ref_part_final [4 ]
166210        ref_part .pz  =  ref_part_final [5 ]
@@ -173,7 +217,7 @@ def surrogate_push(self, pc, step):
173217        # ref_part.s += pge1.ds 
174218        # ref_part.t += pge1.ds / ref_beta 
175219
176-         coordinate_transformation (pc , TransformationDirection . to_fixed_s )
220+         coordinate_transformation (pc , direction = CoordSystem . s )
177221        ## Done! 
178222
179223
0 commit comments