forked from hxim/paq8px
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLMS.hpp
More file actions
82 lines (72 loc) · 2.48 KB
/
LMS.hpp
File metadata and controls
82 lines (72 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#pragma once
#include <memory>
#include <cmath> // sqrt
#include "Array.hpp"
#include "Shared.hpp"
/**
* Least Mean Squares predictor with RMSProp-style adaptive learning rates
*
* Implements an adaptive filter for audio prediction using:
* - Two components with separate learning rates and update patterns
* - RMSProp gradient normalization for stable learning
* - Dual circular buffers for different history types
*
* Buffer structure:
* - Component 's': Same-channel recent history (updated every sample)
* - Component 'd': Other-channel history (updated less frequently)
*
* Use case: Stereo audio where 's' tracks same channel and 'd'
* tracks the other channel for cross-channel correlation
*/
class LMS
{
protected:
Array<float, 32> weights;
Array<float, 32> eg; // RMSProp gradient accumulator
Array<float, 32> buffer; // Sample history buffer [s same-channel | d other-channel]
float sameChannelRate; // Learning rate for same-channel weights
float otherChannelRate; // Learning rate for other-channel weights
float rho; // RMSProp decay rate
float eps; // Numerical stability constant
float prediction; // Last prediction value
int s; // Same-channel buffer size (updated every sample)
int d; // Other-channel buffer size
// Protected constructor for subclasses
/**
* Construct an LMS predictor
* @param s Same-channel buffer size (updated every sample)
* @param d Other-channel buffer size
* @param sameChannelRate Learning rate for same-channel weights
* @param otherChannelRate Learning rate for other-channel weights
*/
LMS(int s, int d, float sameChannelRate, float otherChannelRate);
public:
virtual ~LMS() = default;
LMS(const LMS&) = delete;
LMS& operator=(const LMS&) = delete;
// Static factory method
static std::unique_ptr<LMS> create(
SIMDType simd,
int s,
int d,
float sameChannelRate,
float otherChannelRate
);
/**
* Generate prediction for the next sample
* Updates the other-channel buffer (d component)
* @param sample Input sample (from the other channel)
* @return Predicted value
*/
virtual float predict(int sample);
/**
* Update weights based on prediction error
* Updates the same-channel buffer (s component)
* @param sample Input sample (from this channel)
*/
virtual void update(int sample);
/**
* Reset all weights, gradients, and buffer to zero
*/
void reset();
};