Skip to content

Commit abd9c7c

Browse files
committed
Update the code-generation setup
1 parent bab76a2 commit abd9c7c

File tree

7 files changed

+33535
-33500
lines changed

7 files changed

+33535
-33500
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "lapack-sys"]
2+
path = lapack-sys
3+
url = https://github.com/blas-lapack-rs/lapack-sys.git

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ for (one, another) in w.iter().zip(&[2.0, 2.0, 5.0]) {
2626
}
2727
```
2828

29+
## Development
30+
31+
The code is generated via a Python script based on the content the `lapack-sys`
32+
submodule. To re-generate, run the following commands:
33+
34+
```sh
35+
./bin/generate.py > src/lapack-sys.rs
36+
rustfmt src/lapack-sys.rs
37+
```
38+
2939
## Contribution
3040

3141
Your contribution is highly appreciated. Do not hesitate to open an issue or a

bin/function.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
name_re = re.compile('\s*pub fn (?:LAPACKE_)?(\w+[a-z0-9])(_?)')
55
return_re = re.compile('(?:\s*->\s*([^;]+))?')
66

7+
78
class Function(object):
9+
810
def __init__(self, name, args, ret):
911
self.name = name
1012
self.args = args
@@ -16,7 +18,7 @@ def parse(line):
1618
if name is None:
1719
return None
1820

19-
assert(line[0] == '(')
21+
assert line[0] == '('
2022
line = line[1:]
2123
args = []
2224
while True:
@@ -32,23 +34,27 @@ def parse(line):
3234

3335
return Function(name, args, ret)
3436

37+
3538
def pull_argument(s):
3639
match = argument_re.match(s)
3740
if match is None:
3841
return None, None, s
3942
return match.group(1), match.group(2), s[match.end(3):]
4043

44+
4145
def pull_name(s):
4246
match = name_re.match(s)
43-
assert(match is not None)
47+
assert match is not None
4448
return match.group(1), s[match.end(2):]
4549

50+
4651
def pull_return(s):
4752
match = return_re.match(s)
4853
if match is None:
4954
return None, s
5055
return match.group(1), s[match.end(1):]
5156

57+
5258
def read_functions(path):
5359
lines = []
5460
with open(path) as file:

bin/generate.py

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
#!/usr/bin/env python
22

3-
from function import Function, read_functions
43
import argparse
54
import os
65
import re
76

8-
select_re = re.compile('LAPACK_(\w)_SELECT(\d)')
7+
from function import Function
8+
from function import read_functions
99

10-
def is_const(name, cty):
11-
return '*const' in cty
10+
select_re = re.compile('LAPACK_(\w)_SELECT(\d)')
1211

13-
def is_mut(name, cty):
14-
return '*mut' in cty
1512

1613
def is_scalar(name, cty, f):
1714
return (
@@ -69,20 +66,8 @@ def is_scalar(name, cty, f):
6966
name.startswith('vers')
7067
)
7168

72-
def translate_argument(name, cty, f):
73-
if is_const(name, cty):
74-
base = translate_type_base(cty, f)
75-
if is_scalar(name, cty, f):
76-
return base
77-
else:
78-
return '&[{}]'.format(base)
79-
elif is_mut(name, cty):
80-
base = translate_type_base(cty, f)
81-
if is_scalar(name, cty, f):
82-
return '&mut {}'.format(base)
83-
else:
84-
return '&mut [{}]'.format(base)
8569

70+
def translate_argument(name, cty, f):
8671
m = select_re.match(cty)
8772
if m is not None:
8873
if m.group(1) == 'S':
@@ -94,25 +79,47 @@ def translate_argument(name, cty, f):
9479
elif m.group(1) == 'Z':
9580
return 'Select{}C64'.format(m.group(2))
9681

97-
assert False, 'cannot translate `{}: {}`'.format(name, cty)
82+
base = translate_type_base(cty)
83+
if '*const' in cty:
84+
if is_scalar(name, cty, f):
85+
return base
86+
else:
87+
return '&[{}]'.format(base)
88+
elif '*mut' in cty:
89+
if is_scalar(name, cty, f):
90+
return '&mut {}'.format(base)
91+
else:
92+
return '&mut [{}]'.format(base)
93+
94+
return base
95+
96+
97+
def translate_type_base(cty):
98+
cty = cty.replace('__BindgenComplex<f32>', 'lapack_complex_float')
99+
cty = cty.replace('__BindgenComplex<f64>', 'lapack_complex_double')
100+
cty = cty.replace('f32', 'c_float')
101+
cty = cty.replace('f64', 'c_double')
98102

99-
def translate_type_base(cty, f):
100103
if 'c_char' in cty:
101104
return 'u8'
102-
elif 'lapack_int' in cty or 'lapack_logical' in cty:
105+
elif 'c_int' in cty:
103106
return 'i32'
104-
elif 'lapack_complex_double' in cty:
105-
return 'c64'
106-
elif 'lapack_complex_float' in cty:
107-
return 'c32'
108-
elif 'c_double' in cty:
109-
return 'f64'
110107
elif 'c_float' in cty:
111108
return 'f32'
109+
elif 'c_double' in cty:
110+
return 'f64'
111+
elif 'lapack_complex_float' in cty:
112+
return 'c32'
113+
elif 'lapack_complex_double' in cty:
114+
return 'c64'
115+
116+
assert False, 'cannot translate `{}`'.format(cty)
112117

113-
assert False, 'cannot translate `{}` in `{}`'.format(cty, f.name)
114118

115119
def translate_body_argument(name, rty):
120+
if rty.startswith('Select'):
121+
return 'transmute({})'.format(name)
122+
116123
if rty == 'u8':
117124
return '&({} as c_char)'.format(name)
118125
elif rty == '&mut u8':
@@ -130,7 +137,7 @@ def translate_body_argument(name, rty):
130137
elif rty.startswith('f'):
131138
return '&{}'.format(name)
132139
elif rty.startswith('&mut f'):
133-
return '{}'.format(name)
140+
return name
134141
elif rty.startswith('&[f'):
135142
return '{}.as_ptr()'.format(name)
136143
elif rty.startswith('&mut [f'):
@@ -145,57 +152,69 @@ def translate_body_argument(name, rty):
145152
elif rty.startswith('&mut [c'):
146153
return '{}.as_mut_ptr() as *mut _'.format(name)
147154

148-
if rty.startswith('Select'):
149-
return 'transmute({})'.format(name)
150-
151155
assert False, 'cannot translate `{}: {}`'.format(name, rty)
152156

157+
153158
def translate_return_type(cty):
154-
if cty == 'c_float':
159+
cty = cty.replace('lapack_float_return', 'c_float')
160+
cty = cty.replace('f64', 'c_double')
161+
162+
if cty == 'c_int':
163+
return 'i32'
164+
elif cty == 'c_float':
155165
return 'f32'
156166
elif cty == 'c_double':
157167
return 'f64'
158168

159169
assert False, 'cannot translate `{}`'.format(cty)
160170

171+
161172
def format_header(f):
162173
args = format_header_arguments(f)
163174
if f.ret is None:
164175
return 'pub unsafe fn {}({})'.format(f.name, args)
165176
else:
166-
return 'pub unsafe fn {}({}) -> {}'.format(f.name, args, translate_return_type(f.ret))
177+
return 'pub unsafe fn {}({}) -> {}'.format(f.name, args,
178+
translate_return_type(f.ret))
179+
167180

168181
def format_body(f):
169182
return 'ffi::{}_({})'.format(f.name, format_body_arguments(f))
170183

184+
171185
def format_header_arguments(f):
172186
s = []
173187
for arg in f.args:
174188
s.append('{}: {}'.format(arg[0], translate_argument(*arg, f=f)))
175189
return ', '.join(s)
176190

191+
177192
def format_body_arguments(f):
178193
s = []
179194
for arg in f.args:
180195
rty = translate_argument(*arg, f=f)
181196
s.append(translate_body_argument(arg[0], rty))
182197
return ', '.join(s)
183198

199+
184200
def prepare(code):
185-
lines = filter(lambda line: not re.match(r'^\s*//.*', line), code.split('\n'))
201+
lines = filter(lambda line: not re.match(r'^\s*//.*', line),
202+
code.split('\n'))
186203
lines = re.sub(r'\s+', ' ', ''.join(lines)).strip().split(';')
187204
lines = filter(lambda line: not re.match(r'^\s*$', line), lines)
188205
return [Function.parse(line) for line in lines]
189206

207+
190208
def do(functions):
191209
for f in functions:
192210
print('\n#[inline]')
193211
print(format_header(f) + ' {')
194212
print(' ' + format_body(f) + '\n}')
195213

214+
196215
if __name__ == '__main__':
197216
parser = argparse.ArgumentParser()
198-
parser.add_argument('--sys', required=True)
217+
parser.add_argument('--sys', default='lapack-sys')
199218
arguments = parser.parse_args()
200-
path = os.path.join(arguments.sys, 'src', 'lib.rs')
219+
path = os.path.join(arguments.sys, 'src', 'lapack.rs')
201220
do(prepare(read_functions(path)))

lapack-sys

Submodule lapack-sys added at b655927

0 commit comments

Comments
 (0)