|
3 | 3 | import numpy as np |
4 | 4 | import pytest |
5 | 5 | import random as py_random |
| 6 | +from os import path |
6 | 7 |
|
7 | 8 |
|
8 | 9 | ptype = {'scipy': ScipyParticle, 'jit': JITParticle} |
@@ -150,3 +151,38 @@ class TestParticle(ptype[mode]): |
150 | 151 | 'random.%s(%s)' % (rngfunc, ', '.join([str(a) for a in rngargs]))) |
151 | 152 | pset.execute(kernel, endtime=1., dt=1.) |
152 | 153 | assert np.allclose(np.array([p.p for p in pset]), series, rtol=1e-12) |
| 154 | + |
| 155 | + |
| 156 | +@pytest.mark.parametrize('mode', ['scipy', 'jit']) |
| 157 | +@pytest.mark.parametrize('c_inc', ['str', 'file']) |
| 158 | +def test_c_kernel(fieldset, mode, c_inc): |
| 159 | + pset = ParticleSet(fieldset, pclass=ptype[mode], lon=[0.5], lat=[0]) |
| 160 | + |
| 161 | + def func(U, lon, dt): |
| 162 | + u = U.data[0, 2, 1] |
| 163 | + return lon + u * dt |
| 164 | + |
| 165 | + if c_inc == 'str': |
| 166 | + c_include = """ |
| 167 | + static inline void func(CField *f, float *lon, float *dt) |
| 168 | + { |
| 169 | + float (*data)[f->xdim] = (float (*)[f->xdim]) f->data; |
| 170 | + float u = data[2][1]; |
| 171 | + *lon += u * *dt; |
| 172 | + } |
| 173 | + """ |
| 174 | + else: |
| 175 | + c_include = path.join(path.dirname(__file__), 'customed_header.h') |
| 176 | + |
| 177 | + def ckernel(particle, fieldset, time, dt): |
| 178 | + func('pointer_args', fieldset.U, particle.lon, dt) |
| 179 | + |
| 180 | + def pykernel(particle, fieldset, time, dt): |
| 181 | + particle.lon = func(fieldset.U, particle.lon, dt) |
| 182 | + |
| 183 | + if mode == 'scipy': |
| 184 | + kernel = pset.Kernel(pykernel) |
| 185 | + else: |
| 186 | + kernel = pset.Kernel(ckernel, c_include=c_include) |
| 187 | + pset.execute(kernel, endtime=3., dt=3.) |
| 188 | + assert np.allclose(pset[0].lon, 0.81578948) |
0 commit comments