Skip to content

Commit 8e8904a

Browse files
committed
✨ feat(r_square): 添加R²计算功能,更新二次拟合函数参数,优化测试用例
1 parent b2d5571 commit 8e8904a

File tree

6 files changed

+133
-47
lines changed

6 files changed

+133
-47
lines changed

src/algorithm.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@
99
#include "algorithms/array_utils.h"
1010
#include "algorithms/kalman_filter.h"
1111
#include "algorithms/interpolation.h"
12+
#include "algorithms/r_square.h"
13+
1214
#endif // ALGORITHM_MODULE_H

src/algorithms/polyfit.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ void linear_curve_fit(float *x, float *y, int size, float *slope, float *interce
144144
*intercept = (sum_y - *slope * sum_x) / size;
145145
}
146146
// 二次拟合函数
147-
void quadratic_fit(float *x, float *y, int n, float *a, float *b, float *c)
147+
void quadratic_fit(float *x, float *y, int n, float *coeff)
148148
{
149149
float sum_x = 0, sum_x2 = 0, sum_x3 = 0, sum_x4 = 0;
150150
float sum_y = 0, sum_xy = 0, sum_x2y = 0;
@@ -181,7 +181,7 @@ void quadratic_fit(float *x, float *y, int n, float *a, float *b, float *c)
181181
}
182182
}
183183

184-
*a = B[2] / A[2][2];
185-
*b = (B[1] - A[1][2] * (*a)) / A[1][1];
186-
*c = (B[0] - A[0][1] * (*b) - A[0][2] * (*a)) / A[0][0];
184+
coeff[2] = B[2] / A[2][2];
185+
coeff[1] = (B[1] - A[1][2] * (coeff[2])) / A[1][1];
186+
coeff[0] = (B[0] - A[0][1] * (coeff[1]) - A[0][2] * coeff[2]) / A[0][0];
187187
}

src/algorithms/polyfit.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,9 @@ void linear_curve_fit(float *x, float *y, int size, float *slope, float *interce
2727
* @param x
2828
* @param y
2929
* @param n
30-
* @param a
31-
* @param b
32-
* @param c
30+
* @param coeff 输出的系数
3331
*/
34-
void quadratic_fit(float *x, float *y, int n, float *a, float *b, float *c);
32+
void quadratic_fit(float *x, float *y, int n, float *coeff);
3533

3634
/**
3735
* @brief 三次拟合函数

src/algorithms/r_square.c

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include "r_square.h"
2+
#include <math.h>
3+
4+
// 计算多项式预测值
5+
static float predict_float(float x, const float *coeff, int n)
6+
{
7+
float result = 0.0f;
8+
for (int i = 0; i <= n; i++) { result += coeff[i] * powf(x, i); }
9+
return result;
10+
}
11+
12+
static double predict_double(double x, const double *coeff, int n)
13+
{
14+
double result = 0.0;
15+
for (int i = 0; i <= n; i++) { result += coeff[i] * pow(x, i); }
16+
return result;
17+
}
18+
19+
// 计算R?的函数实现(float版本)
20+
float r_square_float(const float *d_X, const float *d_Y, int d_N, const float *coeff, int n)
21+
{
22+
float sse = 0.0f, sst = 0.0f, y_mean = 0.0f;
23+
24+
// 计算实际值的平均值
25+
for (int i = 0; i < d_N; i++) { y_mean += d_Y[i]; }
26+
y_mean /= d_N;
27+
28+
// 计算SSE和SST
29+
for (int i = 0; i < d_N; i++)
30+
{
31+
float y_pred = predict_float(d_X[i], coeff, n);
32+
sse += (y_pred - d_Y[i]) * (y_pred - d_Y[i]);
33+
sst += (d_Y[i] - y_mean) * (d_Y[i] - y_mean);
34+
}
35+
36+
// 计算R?
37+
return 1.0f - (sse / sst);
38+
}
39+
40+
// 计算R?的函数实现(double版本)
41+
double r_square_double(const double *d_X, const double *d_Y, int d_N, const double *coeff, int n)
42+
{
43+
double sse = 0.0, sst = 0.0, y_mean = 0.0;
44+
45+
// 计算实际值的平均值
46+
for (int i = 0; i < d_N; i++) { y_mean += d_Y[i]; }
47+
y_mean /= d_N;
48+
49+
// 计算SSE和SST
50+
for (int i = 0; i < d_N; i++)
51+
{
52+
double y_pred = predict_double(d_X[i], coeff, n);
53+
sse += (y_pred - d_Y[i]) * (y_pred - d_Y[i]);
54+
sst += (d_Y[i] - y_mean) * (d_Y[i] - y_mean);
55+
}
56+
57+
// 计算R?
58+
return 1.0 - (sse / sst);
59+
}

src/algorithms/r_square.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/**
2+
* @file r_square.h
3+
* @brief
4+
* 在C语言中,你可以通过以下步骤来计算这些拟合优度参数:
5+
* 计算预测值:使用拟合得到的多项式系数,计算每个数据点的预测值。
6+
* 计算SSE:对于每个数据点,计算预测值与实际值之间的差异,平方后求和。
7+
* 计算R?:
8+
* 计算实际值的平均值。
9+
* 计算总平方和(SST),即每个实际值与平均值的差异的平方和。
10+
* R? = 1 - (SSE / SST)
11+
* 计算Adjusted R?:
12+
* Adjusted R? = 1 - [(1 - R?) * (n - 1) / (n - p - 1)]
13+
* 其中,n是数据点的数量,p是模型中参数的个数。
14+
* 计算RMSE:
15+
* RMSE = sqrt(SSE / n)
16+
* @author mengplus ([email protected])
17+
* @version 0.1
18+
* @date 2025-03-25
19+
* @copyright Copyright (c) 2025 Zhengzhou GL. TECH Co.,Ltd
20+
*
21+
*/
22+
23+
#ifndef R_SQUARE_H
24+
#define R_SQUARE_H
25+
26+
#include <stddef.h>
27+
28+
// 计算R?的函数声明
29+
float r_square_float(const float *d_X, const float *d_Y, int d_N, const float *coeff, int n);
30+
double r_square_double(const double *d_X, const double *d_Y, int d_N, const double *coeff, int n);
31+
32+
#endif // R_SQUARE_H

test/test_algorithms.c

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@ START_TEST(test_meanFilterFloat)
1111
}
1212
END_TEST
1313

14-
START_TEST(test_meanFilterint32)
15-
{
16-
int data[] = {1, 2, 3, 4, 5};
17-
ck_assert_int_eq(meanFilterint32(data, 5), 3.0);
18-
}
19-
END_TEST
2014
START_TEST(test_medianFilter)
2115
{
16+
{
17+
int data[] = {1, 2, 3, 4, 5};
18+
ck_assert_int_eq(meanFilterint32(data, 5), 3.0);
19+
}
2220
float data[] = {21, 30, 15, 25, 35};
2321
float median = 0;
2422
medianFilter(data, 5, sizeof(float), &median, compareFloat);
@@ -99,12 +97,6 @@ START_TEST(test_kalman_filter_multiple_updates)
9997
ck_assert(final_error < 0.2);
10098
}
10199
END_TEST
102-
103-
// 辅助函数:比较浮点数是否相等(允许一定的误差)
104-
int float_equal(float a, float b, float epsilon)
105-
{
106-
return fabs(a - b) < epsilon;
107-
}
108100
// 测试用例 1:简单的线性数据
109101
START_TEST(test_linear_curve_fit_simple)
110102
{
@@ -114,8 +106,8 @@ START_TEST(test_linear_curve_fit_simple)
114106
float slope, intercept;
115107
linear_curve_fit(x, y, size, &slope, &intercept);
116108
// 预期结果:斜率为 2,截距为 0
117-
ck_assert(float_equal(slope, 2.0f, 1e-6));
118-
ck_assert(float_equal(intercept, 0.0f, 1e-6));
109+
ck_assert_float_eq_tol(slope, 2.0f, 1e-6);
110+
ck_assert_float_eq_tol(intercept, 0.0f, 1e-6);
119111
}
120112
END_TEST
121113
// 测试用例 2:带截距的线性数据
@@ -127,8 +119,8 @@ START_TEST(test_linear_curve_fit_intercept)
127119
float slope, intercept;
128120
linear_curve_fit(x, y, size, &slope, &intercept);
129121
// 预期结果:斜率为 2,截距为 1
130-
ck_assert(float_equal(slope, 2.0f, 1e-6));
131-
ck_assert(float_equal(intercept, 1.0f, 1e-6));
122+
ck_assert_float_eq_tol(slope, 2.0f, 1e-6);
123+
ck_assert_float_eq_tol(intercept, 1.0f, 1e-6);
132124
}
133125
END_TEST
134126
// 测试用例 3:随机数据
@@ -140,8 +132,8 @@ START_TEST(test_linear_curve_fit_random)
140132
float slope, intercept;
141133
linear_curve_fit(x, y, size, &slope, &intercept);
142134
// 预期结果:斜率接近 2,截距接近 1
143-
ck_assert(float_equal(slope, 1.65f, 1e-2));
144-
ck_assert(float_equal(intercept, 1.1f, 1e-2));
135+
ck_assert_float_eq_tol(slope, 1.65f, 1e-2);
136+
ck_assert_float_eq_tol(intercept, 1.1f, 1e-2);
145137
}
146138
END_TEST
147139
// 测试用例 1:简单的二次拟合
@@ -150,12 +142,15 @@ START_TEST(test_quadratic_fit_simple)
150142
float x[] = {0, 1, 2, 3, 4};
151143
float y[] = {1, 4, 9, 16, 25};
152144
int size = sizeof(x) / sizeof(x[0]);
153-
float a, b, c;
154-
quadratic_fit(x, y, size, &a, &b, &c);
155-
// 预期结果:a = 1, b = 2, c = 1 (y = 1*x^2 + 2*x + 1)
156-
ck_assert(float_equal(a, 1.0f, 1e-6));
157-
ck_assert(float_equal(b, 2.0f, 1e-6));
158-
ck_assert(float_equal(c, 1.0f, 1e-6));
145+
float coeff[3];
146+
quadratic_fit(x, y, size, coeff);
147+
// 预期结果:y = p1+p2*x+p3*x^2
148+
ck_assert_float_eq_tol(coeff[2], 1.0f, 1e-6);
149+
ck_assert_float_eq_tol(coeff[1], 2.0f, 1e-6);
150+
ck_assert_float_eq_tol(coeff[0], 1.0f, 1e-6);
151+
// 计算R²
152+
float r2_float = r_square_float(x, y, 5, coeff, 2);
153+
ck_assert_float_eq_tol(r2_float, 1.0f, 1e-6);
159154
}
160155
END_TEST
161156
// 测试用例 2:带线性项和常数的二次拟合
@@ -164,12 +159,12 @@ START_TEST(test_quadratic_fit_with_terms)
164159
float x[] = {0, 1, 2, 3, 4};
165160
float y[] = {2, 5, 12, 23, 38};
166161
int size = sizeof(x) / sizeof(x[0]);
167-
float a, b, c;
168-
quadratic_fit(x, y, size, &a, &b, &c);
169-
// 预期结果:a = 2, b = 1, c = 2 (y = 2x^2 + 1x + 2)
170-
ck_assert(float_equal(a, 2.0f, 1e-6));
171-
ck_assert(float_equal(b, 1.0f, 1e-6));
172-
ck_assert(float_equal(c, 2.0f, 1e-6));
162+
float coeff[3];
163+
quadratic_fit(x, y, size, coeff);
164+
// 预期结果:y = p1+p2*x+p3*x^2
165+
ck_assert_float_eq_tol(coeff[2], 2.0f, 1e-6);
166+
ck_assert_float_eq_tol(coeff[1], 1.0f, 1e-6);
167+
ck_assert_float_eq_tol(coeff[0], 2.0f, 1e-6);
173168
}
174169
END_TEST
175170
// 测试用例 3:随机数据的二次拟合
@@ -178,12 +173,12 @@ START_TEST(test_quadratic_fit_random)
178173
float x[] = {-2, -1, 0, 1, 2};
179174
float y[] = {8, -1, -6, -1, 8};
180175
int size = sizeof(x) / sizeof(x[0]);
181-
float a, b, c;
182-
quadratic_fit(x, y, size, &a, &b, &c);
183-
// 预期结果:a =3.285714, b = 0.0, c = -4.971429 (y = ax^2 + bx - c)
184-
ck_assert(float_equal(a, 3.285714f, 1e-6));
185-
ck_assert(float_equal(b, 0.0f, 1e-6));
186-
ck_assert(float_equal(c, -4.971429f, 1e-6));
176+
float coeff[3];
177+
quadratic_fit(x, y, size, coeff);
178+
// 预期结果:y = p1+p2*x+p3*x^2
179+
ck_assert_double_eq_tol(coeff[2], 3.285714f, 1e-6);
180+
ck_assert_double_eq_tol(coeff[1], 0.0f, 1e-6);
181+
ck_assert_double_eq_tol(coeff[0], -4.971429f, 1e-6);
187182
}
188183
END_TEST
189184

@@ -203,7 +198,9 @@ START_TEST(test_cubic_fit)
203198
double coef[4];
204199

205200
polyfit(x, y, n, 3, coef);
206-
201+
// 计算R²
202+
float r2_float = r_square_double(x, y, n, coef, 3);
203+
ck_assert_double_eq_tol(r2_float, 0.961957f, 1e-6);
207204
// printf("拟合的三次多项式为: y = %.12fx^3 + %.12fx^2 + %.12fx + %.12f\n", coef[3], coef[2], coef[1], coef[0]);
208205

209206
double x_max_y, max_y;
@@ -456,8 +453,6 @@ Suite *algorithms_suite(void)
456453

457454
/* Core test case */
458455
tc_mean = tcase_create("mean");
459-
460-
tcase_add_test(tc_mean, test_meanFilterint32);
461456
tcase_add_test(tc_mean, test_meanFilterFloat);
462457
tcase_add_test(tc_mean, test_medianFilter);
463458
tcase_add_test(tc_mean, test_kalman_filter_init);

0 commit comments

Comments
 (0)