Skip to content

Commit ea516c7

Browse files
Merge pull request #278 from OceanParcels/ckernels
For JITParticles, Kernels can now also be written directly in C (instead of Python translation to C)
2 parents cf758ce + c908f96 commit ea516c7

File tree

6 files changed

+69
-6
lines changed

6 files changed

+69
-6
lines changed

parcels/codegenerator.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,21 @@ def visit_Call(self, node):
300300
"""Generate C code for simple C-style function calls. Please
301301
note that starred and keyword arguments are currently not
302302
supported."""
303+
pointer_args = False
303304
for a in node.args:
304305
self.visit(a)
305-
ccode_args = ", ".join([a.ccode for a in node.args])
306+
if a.ccode == 'pointer_args':
307+
pointer_args = True
308+
continue
309+
if isinstance(a, FieldNode):
310+
a.ccode = a.obj.name
311+
elif isinstance(a, ParticleNode):
312+
continue
313+
elif pointer_args:
314+
a.ccode = "&%s" % a.ccode
315+
ccode_args = ", ".join([a.ccode for a in node.args[pointer_args:]])
306316
try:
317+
self.visit(node.func)
307318
node.ccode = "%s(%s)" % (node.func.ccode, ccode_args)
308319
except:
309320
raise RuntimeError("Error in converting Kernel to C. See http://oceanparcels.org/#writing-parcels-kernels for hints and tips")
@@ -543,7 +554,7 @@ def __init__(self, fieldset, ptype=None):
543554
self.fieldset = fieldset
544555
self.ptype = ptype
545556

546-
def generate(self, funcname, field_args, const_args, kernel_ast):
557+
def generate(self, funcname, field_args, const_args, kernel_ast, c_include):
547558
ccode = []
548559

549560
# Add include for Parcels and math header
@@ -560,6 +571,9 @@ def generate(self, funcname, field_args, const_args, kernel_ast):
560571

561572
ccode += [str(c.Typedef(c.GenerableStruct("", vdecl, declname=self.ptype.name)))]
562573

574+
if c_include:
575+
ccode += [c_include]
576+
563577
# Insert kernel code
564578
ccode += [str(kernel_ast)]
565579

parcels/include/parcels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ typedef struct
5757
CGridIndex *gridIndices;
5858
} CGridIndexSet;
5959

60+
6061
static inline ErrorCode search_indices_vertical_z(float z, int zdim, float *zvals, int *k, double *zeta)
6162
{
6263
if (z < zvals[0] || z > zvals[zdim-1]) {return ERROR_OUT_OF_BOUNDS;}

parcels/kernel.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Kernel(object):
4848
"""
4949

5050
def __init__(self, fieldset, ptype, pyfunc=None, funcname=None,
51-
funccode=None, py_ast=None, funcvars=None):
51+
funccode=None, py_ast=None, funcvars=None, c_include=""):
5252
self.fieldset = fieldset
5353
self.ptype = ptype
5454

@@ -104,8 +104,13 @@ def __init__(self, fieldset, ptype, pyfunc=None, funcname=None,
104104
del self.field_args['UV']
105105
self.const_args = kernelgen.const_args
106106
loopgen = LoopGenerator(fieldset, ptype)
107+
if path.isfile(c_include):
108+
with open(c_include, 'r') as f:
109+
c_include_str = f.read()
110+
else:
111+
c_include_str = c_include
107112
self.ccode = loopgen.generate(self.funcname, self.field_args, self.const_args,
108-
kernel_ccode)
113+
kernel_ccode, c_include_str)
109114

110115
basename = path.join(get_cache_dir(), self._cache_key)
111116
self.src_file = "%s.c" % basename

parcels/particleset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,10 +545,10 @@ def density(self, field=None, particle_val=None, relative=False, area_scale=True
545545

546546
return Density
547547

548-
def Kernel(self, pyfunc):
548+
def Kernel(self, pyfunc, c_include=""):
549549
"""Wrapper method to convert a `pyfunc` into a :class:`parcels.kernel.Kernel` object
550550
based on `fieldset` and `ptype` of the ParticleSet"""
551-
return Kernel(self.fieldset, self.ptype, pyfunc=pyfunc)
551+
return Kernel(self.fieldset, self.ptype, pyfunc=pyfunc, c_include=c_include)
552552

553553
def ParticleFile(self, *args, **kwargs):
554554
"""Wrapper method to initialise a :class:`parcels.particlefile.ParticleFile`

tests/customed_header.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
static inline void func(CField *f, float *lon, float *dt)
3+
{
4+
float (*data)[f->xdim] = (float (*)[f->xdim]) f->data;
5+
float u = data[2][1];
6+
*lon += u * *dt;
7+
}

tests/test_kernel_language.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55
import random as py_random
6+
from os import path
67

78

89
ptype = {'scipy': ScipyParticle, 'jit': JITParticle}
@@ -150,3 +151,38 @@ class TestParticle(ptype[mode]):
150151
'random.%s(%s)' % (rngfunc, ', '.join([str(a) for a in rngargs])))
151152
pset.execute(kernel, endtime=1., dt=1.)
152153
assert np.allclose(np.array([p.p for p in pset]), series, rtol=1e-12)
154+
155+
156+
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
157+
@pytest.mark.parametrize('c_inc', ['str', 'file'])
158+
def test_c_kernel(fieldset, mode, c_inc):
159+
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=[0.5], lat=[0])
160+
161+
def func(U, lon, dt):
162+
u = U.data[0, 2, 1]
163+
return lon + u * dt
164+
165+
if c_inc == 'str':
166+
c_include = """
167+
static inline void func(CField *f, float *lon, float *dt)
168+
{
169+
float (*data)[f->xdim] = (float (*)[f->xdim]) f->data;
170+
float u = data[2][1];
171+
*lon += u * *dt;
172+
}
173+
"""
174+
else:
175+
c_include = path.join(path.dirname(__file__), 'customed_header.h')
176+
177+
def ckernel(particle, fieldset, time, dt):
178+
func('pointer_args', fieldset.U, particle.lon, dt)
179+
180+
def pykernel(particle, fieldset, time, dt):
181+
particle.lon = func(fieldset.U, particle.lon, dt)
182+
183+
if mode == 'scipy':
184+
kernel = pset.Kernel(pykernel)
185+
else:
186+
kernel = pset.Kernel(ckernel, c_include=c_include)
187+
pset.execute(kernel, endtime=3., dt=3.)
188+
assert np.allclose(pset[0].lon, 0.81578948)

0 commit comments

Comments
 (0)