Skip to content

Commit 13b9234

Browse files
committed
✨ feat(interpolation): 添加样条插值算法,支持浮点数和双精度数据类型,更新相关函数和结构体
1 parent 8e8904a commit 13b9234

File tree

4 files changed

+218
-40
lines changed

4 files changed

+218
-40
lines changed

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2024 蒙蒙plus
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

SConscript

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from building import *
2+
import os
3+
4+
cwd = GetCurrentDir()
5+
src = Glob('src/algorithms/*.c')
6+
CPPPATH = [cwd+"/src"]
7+
8+
group = DefineGroup('third_party', src, depend = [''], CPPPATH = CPPPATH)
9+
10+
list = os.listdir(cwd)
11+
for item in list:
12+
if os.path.isfile(os.path.join(cwd, item, 'SConscript')):
13+
group = group + SConscript(os.path.join(item, 'SConscript'))
14+
15+
Return('group')

src/algorithms/interpolation.c

Lines changed: 129 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -367,27 +367,15 @@ void hermite_interpolate(void *x, void *y, void *x_table, void *y_table, int tab
367367
else if (data_type == DATA_TYPE_DOUBLE) { *(double *)y = y_val; }
368368
}
369369

370-
// 样条插值系数结构体
371-
typedef struct
372-
{
373-
double *a;
374-
double *b;
375-
double *c;
376-
double *d;
377-
} SplineCoefficients;
378370
// 计算样条插值系数
379-
SplineCoefficients *calculate_spline_coefficients(const double *x_table, const double *y_table, int table_size)
371+
void calculate_spline_coefficientsf(SplineCoefficients *coeff, const float *x_table, const float *y_table,
372+
int table_size)
380373
{
381-
SplineCoefficients *coeff = (SplineCoefficients *)malloc(sizeof(SplineCoefficients));
382-
coeff->a = (double *)malloc(table_size * sizeof(double));
383-
coeff->b = (double *)malloc(table_size * sizeof(double));
384-
coeff->c = (double *)malloc(table_size * sizeof(double));
385-
coeff->d = (double *)malloc(table_size * sizeof(double));
386-
double *h = (double *)malloc((table_size - 1) * sizeof(double));
387-
double *alpha = (double *)malloc(table_size * sizeof(double));
388-
double *l = (double *)malloc(table_size * sizeof(double));
389-
double *mu = (double *)malloc(table_size * sizeof(double));
390-
double *z = (double *)malloc(table_size * sizeof(double));
374+
float *h = (float *)malloc((table_size - 1) * sizeof(float));
375+
float *alpha = (float *)malloc(table_size * sizeof(float));
376+
float *l = (float *)malloc(table_size * sizeof(float));
377+
float *mu = (float *)malloc(table_size * sizeof(float));
378+
float *z = (float *)malloc(table_size * sizeof(float));
391379
// 计算h[i] = x[i+1] - x[i]
392380
for (int i = 0; i < table_size - 1; i++) { h[i] = x_table[i + 1] - x_table[i]; }
393381
// 计算alpha[i]
@@ -423,34 +411,137 @@ SplineCoefficients *calculate_spline_coefficients(const double *x_table, const d
423411
free(l);
424412
free(mu);
425413
free(z);
426-
return coeff;
427414
}
428-
// 样条插值函数
429-
void spline_interpolate(const void *x, void *y, const void *x_table, const void *y_table, int table_size,
430-
DataType data_type)
415+
void calculate_spline_coefficientsd(SplineCoefficients *coeff, const double *x_table, const double *y_table,
416+
int table_size)
431417
{
432-
if (data_type != DATA_TYPE_FLOAT && data_type != DATA_TYPE_DOUBLE)
418+
double *h = (double *)malloc((table_size - 1) * sizeof(double));
419+
double *alpha = (double *)malloc(table_size * sizeof(double));
420+
double *l = (double *)malloc(table_size * sizeof(double));
421+
double *mu = (double *)malloc(table_size * sizeof(double));
422+
double *z = (double *)malloc(table_size * sizeof(double));
423+
// 计算h[i] = x[i+1] - x[i]
424+
for (int i = 0; i < table_size - 1; i++) { h[i] = x_table[i + 1] - x_table[i]; }
425+
// 计算alpha[i]
426+
for (int i = 1; i < table_size - 1; i++)
433427
{
434-
printf("Unsupported data type for spline interpolation.\n");
435-
return;
428+
alpha[i] = (3.0 / h[i]) * (y_table[i + 1] - y_table[i]) - (3.0 / h[i - 1]) * (y_table[i] - y_table[i - 1]);
429+
}
430+
// 自然边界条件:二阶导数为0
431+
l[0] = 1.0;
432+
mu[0] = 0.0;
433+
z[0] = 0.0;
434+
// 前向消元
435+
for (int i = 1; i < table_size - 1; i++)
436+
{
437+
l[i] = 2.0 * (x_table[i + 1] - x_table[i - 1]) - h[i - 1] * mu[i - 1];
438+
mu[i] = h[i] / l[i];
439+
z[i] = (alpha[i] - h[i - 1] * z[i - 1]) / l[i];
440+
}
441+
// 自然边界条件:二阶导数为0
442+
l[table_size - 1] = 1.0;
443+
z[table_size - 1] = 0.0;
444+
coeff->c[table_size - 1] = 0.0;
445+
// 后向代入
446+
for (int j = table_size - 2; j >= 0; j--)
447+
{
448+
coeff->c[j] = z[j] - mu[j] * coeff->c[j + 1];
449+
coeff->b[j] = (y_table[j + 1] - y_table[j]) / h[j] - h[j] * (coeff->c[j + 1] + 2.0 * coeff->c[j]) / 3.0;
450+
coeff->d[j] = (coeff->c[j + 1] - coeff->c[j]) / (3.0 * h[j]);
451+
coeff->a[j] = y_table[j];
436452
}
437-
SplineCoefficients *coeff = calculate_spline_coefficients((const double *)x_table, (const double *)y_table,
438-
table_size);
439-
double x_val = (data_type == DATA_TYPE_FLOAT) ? *(float *)x : *(double *)x;
440-
int i = 0;
441-
// 找到x所在的区间
442-
while (i < table_size - 1 && x_val > ((double *)x_table)[i + 1]) { i++; }
443-
// 计算插值结果
444-
double dx = x_val - ((double *)x_table)[i];
445-
double result = coeff->a[i] + coeff->b[i] * dx + coeff->c[i] * dx * dx + coeff->d[i] * dx * dx * dx;
446-
if (data_type == DATA_TYPE_FLOAT) { *(float *)y = (float)result; }
447-
else { *(double *)y = result; }
453+
free(h);
454+
free(alpha);
455+
free(l);
456+
free(mu);
457+
free(z);
458+
}
459+
SplineCoefficients *splinecoeff_create(int table_size)
460+
{
461+
SplineCoefficients *coeff = (SplineCoefficients *)malloc(sizeof(SplineCoefficients));
462+
coeff->a = (double *)malloc(table_size * sizeof(double));
463+
coeff->b = (double *)malloc(table_size * sizeof(double));
464+
coeff->c = (double *)malloc(table_size * sizeof(double));
465+
coeff->d = (double *)malloc(table_size * sizeof(double));
466+
467+
return coeff;
468+
}
469+
void splinecoeff_destory(SplineCoefficients *coeff)
470+
{
448471
free(coeff->a);
449472
free(coeff->b);
450473
free(coeff->c);
451474
free(coeff->d);
452475
free(coeff);
453476
}
477+
void calculate_spline_coefficients(SplineCoefficients *coeff, const void *x_table, const void *y_table, int table_size,
478+
DataType data_type)
479+
{
480+
if (DATA_TYPE_FLOAT == data_type)
481+
{
482+
calculate_spline_coefficientsf(coeff, (const float *)x_table, (const float *)y_table, table_size);
483+
}
484+
else if (DATA_TYPE_DOUBLE == data_type)
485+
{
486+
calculate_spline_coefficientsd(coeff, (const double *)x_table, (const double *)y_table, table_size);
487+
}
488+
}
489+
490+
/**
491+
* @brief 计算插值结果
492+
*
493+
* @param coeff 样条插值系数
494+
* @param x_val 待插值的自变量
495+
* @param x_table x表
496+
* @param table_size 表大小
497+
* @return
498+
*/
499+
double calculate_spline_result(SplineCoefficients *coeff, const void *x_val, const void *x_table, int table_size,
500+
DataType data_type)
501+
{
502+
double result = 0;
503+
double dx;
504+
int i = 0;
505+
if (DATA_TYPE_FLOAT == data_type)
506+
{
507+
float *x_valf = (float *)x_val;
508+
float *x_tablef = (float *)x_table;
509+
// 找到x所在的区间
510+
while (i < table_size - 1 && *x_valf > x_tablef[i + 1]) { i++; }
511+
// 计算插值结果
512+
dx = *x_valf - x_tablef[i];
513+
}
514+
else if (DATA_TYPE_DOUBLE == data_type)
515+
{
516+
double *x_vald = (double *)x_val;
517+
double *x_tabled = (double *)x_table;
518+
// 找到x所在的区间
519+
while (i < table_size - 1 && *x_vald > x_tabled[i + 1]) { i++; }
520+
// 计算插值结果
521+
dx = *x_vald - x_tabled[i];
522+
}
523+
result = coeff->a[i] + coeff->b[i] * dx + coeff->c[i] * dx * dx + coeff->d[i] * dx * dx * dx;
524+
return result;
525+
}
526+
// 样条插值函数
527+
bool spline_interpolate(const void *x, void *y, const void *x_table, const void *y_table, int table_size,
528+
DataType data_type)
529+
{
530+
SplineCoefficients *coeff = splinecoeff_create(table_size);
531+
532+
if (data_type == DATA_TYPE_FLOAT)
533+
{
534+
calculate_spline_coefficientsf(coeff, (const float *)x_table, (const float *)y_table, table_size);
535+
*(float *)y = calculate_spline_result(coeff, x, x_table, table_size, DATA_TYPE_FLOAT);
536+
}
537+
else if (data_type == DATA_TYPE_DOUBLE)
538+
{
539+
calculate_spline_coefficientsd(coeff, (const double *)x_table, (const double *)y_table, table_size);
540+
*(double *)y = calculate_spline_result(coeff, x, x_table, table_size, DATA_TYPE_DOUBLE);
541+
}
542+
splinecoeff_destory(coeff);
543+
return true;
544+
}
454545

455546

456547
// 插值函数接口
@@ -474,10 +565,10 @@ void interpolate(void *x, void *y, void *x_table, void *y_table, int table_size,
474565
break;
475566
case INTERP_HERMITE: /*!< 埃尔米特插值 */
476567
hermite_interpolate(x, y, x_table, y_table, table_size, data_type);
477-
478568
break;
479569
case INTERP_SPLINE: /*!< B样条插值 */
480570
spline_interpolate(x, y, x_table, y_table, table_size, data_type);
481571
break;
482572
}
483573
}
574+

src/algorithms/interpolation.h

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define INTERPOLATION_H
33

44
#include <stdint.h>
5-
5+
#include <stdbool.h>
66
// 支持的数据类型
77
typedef enum
88
{
@@ -19,7 +19,7 @@ typedef enum
1919
INTERP_CUBIC, /*!< 三次插值 */
2020
INTERP_LAGRANGE, /*!< 拉格朗日插值 */
2121
INTERP_HERMITE, /*!< 埃尔米特插值 */
22-
INTERP_SPLINE /*!< B样条插值 计算结果与matlab 略有出入*/
22+
INTERP_SPLINE /*!< B样条插值 计算结果与matlab 略有出入 ont suppose DATA_TYPE_INT*/
2323
} InterpolationType;
2424

2525
/**
@@ -35,5 +35,56 @@ typedef enum
3535
*/
3636
void interpolate(void *x, void *y, void *x_table, void *y_table, int table_size, DataType data_type,
3737
InterpolationType interp_type);
38+
// 样条插值系数结构体
39+
typedef struct
40+
{
41+
double *a;
42+
double *b;
43+
double *c;
44+
double *d;
45+
} SplineCoefficients;
3846

47+
/**
48+
* @brief 创造系数结构体
49+
*
50+
* @param table_size 待插值数据长度
51+
* @return
52+
*/
53+
SplineCoefficients *splinecoeff_create(int table_size);
54+
void splinecoeff_destory(SplineCoefficients *coeff);
55+
/**
56+
* @brief 样条插值函数 计算插值系数
57+
*
58+
*
59+
* @param coeff 样条插值系数
60+
* @param x_table x表
61+
* @param y_table y表
62+
* @param table_size 表大小
63+
*/
64+
void calculate_spline_coefficients(SplineCoefficients *coeff, const void *x_table, const void *y_table, int table_size,
65+
DataType data_type);
66+
/**
67+
* @brief 计算插值结果
68+
*
69+
* @param coeff 样条插值系数
70+
* @param x_val 待插值的自变量
71+
* @param x_table x表
72+
* @param table_size 表大小
73+
* @return
74+
*/
75+
double calculate_spline_result(SplineCoefficients *coeff, const void *x_val, const void *x_table, int table_size,
76+
DataType data_type);
77+
78+
/**
79+
* @brief 计算样条插值结果
80+
* 是上面操作的整合步骤
81+
* @param x 待插值的自变量
82+
* @param y 插值结果
83+
* @param x_table x表
84+
* @param y_table y表
85+
* @param table_size 表大小
86+
* @return true 插值成功 false 插值失败
87+
*/
88+
bool spline_interpolate(const void *x, void *y, const void *x_table, const void *y_table, int table_size,
89+
DataType data_type);
3990
#endif // INTERPOLATION_H

0 commit comments

Comments
 (0)