@@ -367,27 +367,15 @@ void hermite_interpolate(void *x, void *y, void *x_table, void *y_table, int tab
367367 else if (data_type == DATA_TYPE_DOUBLE ) { * (double * )y = y_val ; }
368368}
369369
370- // 样条插值系数结构体
371- typedef struct
372- {
373- double * a ;
374- double * b ;
375- double * c ;
376- double * d ;
377- } SplineCoefficients ;
378370// 计算样条插值系数
379- SplineCoefficients * calculate_spline_coefficients (const double * x_table , const double * y_table , int table_size )
371+ void calculate_spline_coefficientsf (SplineCoefficients * coeff , const float * x_table , const float * y_table ,
372+ int table_size )
380373{
381- SplineCoefficients * coeff = (SplineCoefficients * )malloc (sizeof (SplineCoefficients ));
382- coeff -> a = (double * )malloc (table_size * sizeof (double ));
383- coeff -> b = (double * )malloc (table_size * sizeof (double ));
384- coeff -> c = (double * )malloc (table_size * sizeof (double ));
385- coeff -> d = (double * )malloc (table_size * sizeof (double ));
386- double * h = (double * )malloc ((table_size - 1 ) * sizeof (double ));
387- double * alpha = (double * )malloc (table_size * sizeof (double ));
388- double * l = (double * )malloc (table_size * sizeof (double ));
389- double * mu = (double * )malloc (table_size * sizeof (double ));
390- double * z = (double * )malloc (table_size * sizeof (double ));
374+ float * h = (float * )malloc ((table_size - 1 ) * sizeof (float ));
375+ float * alpha = (float * )malloc (table_size * sizeof (float ));
376+ float * l = (float * )malloc (table_size * sizeof (float ));
377+ float * mu = (float * )malloc (table_size * sizeof (float ));
378+ float * z = (float * )malloc (table_size * sizeof (float ));
391379 // 计算h[i] = x[i+1] - x[i]
392380 for (int i = 0 ; i < table_size - 1 ; i ++ ) { h [i ] = x_table [i + 1 ] - x_table [i ]; }
393381 // 计算alpha[i]
@@ -423,34 +411,137 @@ SplineCoefficients *calculate_spline_coefficients(const double *x_table, const d
423411 free (l );
424412 free (mu );
425413 free (z );
426- return coeff ;
427414}
428- // 样条插值函数
429- void spline_interpolate (const void * x , void * y , const void * x_table , const void * y_table , int table_size ,
430- DataType data_type )
415+ void calculate_spline_coefficientsd (SplineCoefficients * coeff , const double * x_table , const double * y_table ,
416+ int table_size )
431417{
432- if (data_type != DATA_TYPE_FLOAT && data_type != DATA_TYPE_DOUBLE )
418+ double * h = (double * )malloc ((table_size - 1 ) * sizeof (double ));
419+ double * alpha = (double * )malloc (table_size * sizeof (double ));
420+ double * l = (double * )malloc (table_size * sizeof (double ));
421+ double * mu = (double * )malloc (table_size * sizeof (double ));
422+ double * z = (double * )malloc (table_size * sizeof (double ));
423+ // 计算h[i] = x[i+1] - x[i]
424+ for (int i = 0 ; i < table_size - 1 ; i ++ ) { h [i ] = x_table [i + 1 ] - x_table [i ]; }
425+ // 计算alpha[i]
426+ for (int i = 1 ; i < table_size - 1 ; i ++ )
433427 {
434- printf ("Unsupported data type for spline interpolation.\n" );
435- return ;
428+ alpha [i ] = (3.0 / h [i ]) * (y_table [i + 1 ] - y_table [i ]) - (3.0 / h [i - 1 ]) * (y_table [i ] - y_table [i - 1 ]);
429+ }
430+ // 自然边界条件:二阶导数为0
431+ l [0 ] = 1.0 ;
432+ mu [0 ] = 0.0 ;
433+ z [0 ] = 0.0 ;
434+ // 前向消元
435+ for (int i = 1 ; i < table_size - 1 ; i ++ )
436+ {
437+ l [i ] = 2.0 * (x_table [i + 1 ] - x_table [i - 1 ]) - h [i - 1 ] * mu [i - 1 ];
438+ mu [i ] = h [i ] / l [i ];
439+ z [i ] = (alpha [i ] - h [i - 1 ] * z [i - 1 ]) / l [i ];
440+ }
441+ // 自然边界条件:二阶导数为0
442+ l [table_size - 1 ] = 1.0 ;
443+ z [table_size - 1 ] = 0.0 ;
444+ coeff -> c [table_size - 1 ] = 0.0 ;
445+ // 后向代入
446+ for (int j = table_size - 2 ; j >= 0 ; j -- )
447+ {
448+ coeff -> c [j ] = z [j ] - mu [j ] * coeff -> c [j + 1 ];
449+ coeff -> b [j ] = (y_table [j + 1 ] - y_table [j ]) / h [j ] - h [j ] * (coeff -> c [j + 1 ] + 2.0 * coeff -> c [j ]) / 3.0 ;
450+ coeff -> d [j ] = (coeff -> c [j + 1 ] - coeff -> c [j ]) / (3.0 * h [j ]);
451+ coeff -> a [j ] = y_table [j ];
436452 }
437- SplineCoefficients * coeff = calculate_spline_coefficients ((const double * )x_table , (const double * )y_table ,
438- table_size );
439- double x_val = (data_type == DATA_TYPE_FLOAT ) ? * (float * )x : * (double * )x ;
440- int i = 0 ;
441- // 找到x所在的区间
442- while (i < table_size - 1 && x_val > ((double * )x_table )[i + 1 ]) { i ++ ; }
443- // 计算插值结果
444- double dx = x_val - ((double * )x_table )[i ];
445- double result = coeff -> a [i ] + coeff -> b [i ] * dx + coeff -> c [i ] * dx * dx + coeff -> d [i ] * dx * dx * dx ;
446- if (data_type == DATA_TYPE_FLOAT ) { * (float * )y = (float )result ; }
447- else { * (double * )y = result ; }
453+ free (h );
454+ free (alpha );
455+ free (l );
456+ free (mu );
457+ free (z );
458+ }
459+ SplineCoefficients * splinecoeff_create (int table_size )
460+ {
461+ SplineCoefficients * coeff = (SplineCoefficients * )malloc (sizeof (SplineCoefficients ));
462+ coeff -> a = (double * )malloc (table_size * sizeof (double ));
463+ coeff -> b = (double * )malloc (table_size * sizeof (double ));
464+ coeff -> c = (double * )malloc (table_size * sizeof (double ));
465+ coeff -> d = (double * )malloc (table_size * sizeof (double ));
466+
467+ return coeff ;
468+ }
469+ void splinecoeff_destory (SplineCoefficients * coeff )
470+ {
448471 free (coeff -> a );
449472 free (coeff -> b );
450473 free (coeff -> c );
451474 free (coeff -> d );
452475 free (coeff );
453476}
477+ void calculate_spline_coefficients (SplineCoefficients * coeff , const void * x_table , const void * y_table , int table_size ,
478+ DataType data_type )
479+ {
480+ if (DATA_TYPE_FLOAT == data_type )
481+ {
482+ calculate_spline_coefficientsf (coeff , (const float * )x_table , (const float * )y_table , table_size );
483+ }
484+ else if (DATA_TYPE_DOUBLE == data_type )
485+ {
486+ calculate_spline_coefficientsd (coeff , (const double * )x_table , (const double * )y_table , table_size );
487+ }
488+ }
489+
490+ /**
491+ * @brief 计算插值结果
492+ *
493+ * @param coeff 样条插值系数
494+ * @param x_val 待插值的自变量
495+ * @param x_table x表
496+ * @param table_size 表大小
497+ * @return
498+ */
499+ double calculate_spline_result (SplineCoefficients * coeff , const void * x_val , const void * x_table , int table_size ,
500+ DataType data_type )
501+ {
502+ double result = 0 ;
503+ double dx ;
504+ int i = 0 ;
505+ if (DATA_TYPE_FLOAT == data_type )
506+ {
507+ float * x_valf = (float * )x_val ;
508+ float * x_tablef = (float * )x_table ;
509+ // 找到x所在的区间
510+ while (i < table_size - 1 && * x_valf > x_tablef [i + 1 ]) { i ++ ; }
511+ // 计算插值结果
512+ dx = * x_valf - x_tablef [i ];
513+ }
514+ else if (DATA_TYPE_DOUBLE == data_type )
515+ {
516+ double * x_vald = (double * )x_val ;
517+ double * x_tabled = (double * )x_table ;
518+ // 找到x所在的区间
519+ while (i < table_size - 1 && * x_vald > x_tabled [i + 1 ]) { i ++ ; }
520+ // 计算插值结果
521+ dx = * x_vald - x_tabled [i ];
522+ }
523+ result = coeff -> a [i ] + coeff -> b [i ] * dx + coeff -> c [i ] * dx * dx + coeff -> d [i ] * dx * dx * dx ;
524+ return result ;
525+ }
526+ // 样条插值函数
527+ bool spline_interpolate (const void * x , void * y , const void * x_table , const void * y_table , int table_size ,
528+ DataType data_type )
529+ {
530+ SplineCoefficients * coeff = splinecoeff_create (table_size );
531+
532+ if (data_type == DATA_TYPE_FLOAT )
533+ {
534+ calculate_spline_coefficientsf (coeff , (const float * )x_table , (const float * )y_table , table_size );
535+ * (float * )y = calculate_spline_result (coeff , x , x_table , table_size , DATA_TYPE_FLOAT );
536+ }
537+ else if (data_type == DATA_TYPE_DOUBLE )
538+ {
539+ calculate_spline_coefficientsd (coeff , (const double * )x_table , (const double * )y_table , table_size );
540+ * (double * )y = calculate_spline_result (coeff , x , x_table , table_size , DATA_TYPE_DOUBLE );
541+ }
542+ splinecoeff_destory (coeff );
543+ return true;
544+ }
454545
455546
456547// 插值函数接口
@@ -474,10 +565,10 @@ void interpolate(void *x, void *y, void *x_table, void *y_table, int table_size,
474565 break ;
475566 case INTERP_HERMITE : /*!< 埃尔米特插值 */
476567 hermite_interpolate (x , y , x_table , y_table , table_size , data_type );
477-
478568 break ;
479569 case INTERP_SPLINE : /*!< B样条插值 */
480570 spline_interpolate (x , y , x_table , y_table , table_size , data_type );
481571 break ;
482572 }
483573}
574+
0 commit comments