diff --git a/.travis.yml b/.travis.yml index 53f7090..5728eb2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,4 +10,5 @@ script: # - cargo clippy -- -D warnings - cargo test - cargo test --features serde + - cargo test --features decimal --tests - cargo package diff --git a/Cargo.toml b/Cargo.toml index 08b4107..6c82be9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,14 +22,18 @@ travis-ci = { repository = "greyblake/ta-rs", branch = "master" } [dependencies] serde = { version = "1.0", features = ["derive"], optional = true } +rust_decimal = { version = "^1.25.0", optional = true, features = ["maths", "rand"] } [dev-dependencies] assert_approx_eq = "1.0.0" csv = "1.1.0" bencher = "0.1.5" -rand = "0.6.5" +rand = "0.8" bincode = "1.3.1" +[features] +decimal = ["dep:rust_decimal"] + [profile.release] lto = true diff --git a/README.md b/README.md index 76e261f..48482cb 100644 --- a/README.md +++ b/README.md @@ -9,17 +9,18 @@ Technical analysis library for Rust. -* [Getting started](#getting-started) -* [Basic ideas](#basic-ideas) -* [List of indicators](#list-of-indicators) -* [Running benchmarks](#running-benchmarks) -* [Donations](#donations) -* [License](#license) -* [Contributors](#contributors) +- [Getting started](#getting-started) +- [Basic ideas](#basic-ideas) +- [List of indicators](#list-of-indicators) +- [Running benchmarks](#running-benchmarks) +- [Donations](#donations) +- [License](#license) +- [Contributors](#contributors) ## Getting started Add to you `Cargo.toml`: + ``` [dependencies] ta = "0.4.0" @@ -47,58 +48,60 @@ Check also the [documentation](https://docs.rs/ta). A data item which represent a stock quote may implement the following traits: -* `Open` -* `High` -* `Low` -* `Close` -* `Volume` +- `Open` +- `High` +- `Low` +- `Close` +- `Volume` It's not necessary to implement all of them, but it must be enough to fulfill requirements for a particular indicator. You probably should prefer using `DataItem` unless you have reasons to implement your own structure. Indicators typically implement the following traits: -* `Next` (often `Next` and `Next<&DataItem>`) - to feed and get the next value -* `Reset` - to reset an indicator -* `Debug` -* `Display` -* `Default` -* `Clone` +- `Next` (often `Next` and `Next<&DataItem>`) - to feed and get the next value +- `Reset` - to reset an indicator +- `Debug` +- `Display` +- `Default` +- `Clone` ## List of indicators So far there are the following indicators available. -* Trend - * Exponential Moving Average (EMA) - * Simple Moving Average (SMA) -* Oscillators - * Relative Strength Index (RSI) - * Fast Stochastic - * Slow Stochastic - * Moving Average Convergence Divergence (MACD) - * Percentage Price Oscillator (PPO) - * Commodity Channel Index (CCI) - * Money Flow Index (MFI) -* Other - * Minimum - * Maximum - * True Range - * Standard Deviation (SD) - * Mean Absolute Deviation (MAD) - * Average True Range (AR) - * Efficiency Ratio (ER) - * Bollinger Bands (BB) - * Chandelier Exit (CE) - * Keltner Channel (KC) - * Rate of Change (ROC) - * On Balance Volume (OBV) - +- Trend + - Exponential Moving Average (EMA) + - Simple Moving Average (SMA) +- Oscillators + - Relative Strength Index (RSI) + - Fast Stochastic + - Slow Stochastic + - Moving Average Convergence Divergence (MACD) + - Percentage Price Oscillator (PPO) + - Commodity Channel Index (CCI) + - Money Flow Index (MFI) +- Other + - Minimum + - Maximum + - True Range + - Standard Deviation (SD) + - Mean Absolute Deviation (MAD) + - Average True Range (AR) + - Efficiency Ratio (ER) + - Bollinger Bands (BB) + - Chandelier Exit (CE) + - Keltner Channel (KC) + - Rate of Change (ROC) + - On Balance Volume (OBV) ## Features -* `serde` - allows to serialize and deserialize indicators. NOTE: the backward compatibility of serialized -data with the future versions of ta is not guaranteed because internal implementation of the indicators is a subject to change. +- `decimal` - when enabled, uses `Decimal` objects from the [`rust_decimal`] crate instead of `f64`. +- `serde` - allows to serialize and deserialize indicators. NOTE: the backward compatibility of serialized + data with the future versions of ta is not guaranteed because internal implementation of the indicators is a subject to change. + +[`rust_decimal`]: https://docs.rs/rust_decimal ## Running benchmarks @@ -112,12 +115,10 @@ You can support the project by donating [NEAR tokens](https://near.org). Our NEAR wallet address is `ta-rs.near` - ## License [MIT](https://github.com/greyblake/ta-rs/blob/master/LICENSE) © [Sergey Potapov](http://greyblake.com/) - ## Contributors - [greyblake](https://github.com/greyblake) Potapov Sergey - creator, maintainer. @@ -129,3 +130,4 @@ Our NEAR wallet address is `ta-rs.near` - [Devin Gunay](https://github.com/dgunay) - serde support - [Youngchan Lee](https://github.com/edwardycl) - bugfix - [tommady](https://github.com/tommady) - get rid of error-chain dependency +- [Luke Sneeringer](https://github.com/lukesneeringer) - Decimal implementation diff --git a/benches/indicators.rs b/benches/indicators.rs index 554bc15..58721a1 100644 --- a/benches/indicators.rs +++ b/benches/indicators.rs @@ -7,18 +7,18 @@ use ta::indicators::{ PercentagePriceOscillator, RateOfChange, RelativeStrengthIndex, SimpleMovingAverage, SlowStochastic, StandardDeviation, TrueRange, WeightedMovingAverage, }; -use ta::{DataItem, Next}; +use ta::{lit, DataItem, Next}; const ITEMS_COUNT: usize = 5_000; fn rand_data_item() -> DataItem { let mut rng = rand::thread_rng(); - let low = rng.gen_range(0.0, 500.0); - let high = rng.gen_range(500.0, 1000.0); - let open = rng.gen_range(low, high); - let close = rng.gen_range(low, high); - let volume = rng.gen_range(0.0, 10_000.0); + let low = rng.gen_range(lit!(0.0)..=lit!(500.0)); + let high = rng.gen_range(lit!(500.0)..=lit!(1000.0)); + let open = rng.gen_range(low..=high); + let close = rng.gen_range(low..=high); + let volume = rng.gen_range(lit!(0.0)..=lit!(10_000.0)); DataItem::builder() .open(open) diff --git a/examples/custom_data_item.rs b/examples/custom_data_item.rs index ac0f84a..edf4745 100644 --- a/examples/custom_data_item.rs +++ b/examples/custom_data_item.rs @@ -1,30 +1,35 @@ use ta::indicators::TrueRange; use ta::{Close, High, Low, Next}; +#[cfg(feature = "decimal")] +type Num = rust_decimal::Decimal; +#[cfg(not(feature = "decimal"))] +type Num = f64; + // You can create your own data items. // You may want it for different purposes, e.g.: // - you data source don't have volume or other fields. // - you want to skip validation to avoid performance penalty. struct Item { - high: f64, - low: f64, - close: f64, + high: Num, + low: Num, + close: Num, } impl Low for Item { - fn low(&self) -> f64 { + fn low(&self) -> Num { self.low } } impl High for Item { - fn high(&self) -> f64 { + fn high(&self) -> Num { self.high } } impl Close for Item { - fn close(&self) -> f64 { + fn close(&self) -> Num { self.close } } @@ -34,7 +39,7 @@ fn main() { let mut reader = csv::Reader::from_path("./examples/data/AMZN.csv").unwrap(); for record in reader.deserialize() { - let (date, _open, high, low, close, _volume): (String, f64, f64, f64, f64, f64) = + let (date, _open, high, low, close, _volume): (String, Num, Num, Num, Num, Num) = record.unwrap(); let item = Item { high, low, close }; let val = tr.next(&item); diff --git a/examples/ema.rs b/examples/ema.rs index e2ae47c..f631269 100644 --- a/examples/ema.rs +++ b/examples/ema.rs @@ -2,12 +2,17 @@ use ta::indicators::ExponentialMovingAverage as Ema; use ta::DataItem; use ta::Next; +#[cfg(feature = "decimal")] +type Num = rust_decimal::Decimal; +#[cfg(not(feature = "decimal"))] +type Num = f64; + fn main() { let mut ema = Ema::new(9).unwrap(); let mut reader = csv::Reader::from_path("./examples/data/AMZN.csv").unwrap(); for record in reader.deserialize() { - let (date, open, high, low, close, volume): (String, f64, f64, f64, f64, f64) = + let (date, open, high, low, close, volume): (String, Num, Num, Num, Num, Num) = record.unwrap(); let dt = DataItem::builder() .open(open) diff --git a/src/data_item.rs b/src/data_item.rs index c40eac4..f3509ee 100644 --- a/src/data_item.rs +++ b/src/data_item.rs @@ -1,5 +1,6 @@ use crate::errors::*; -use crate::traits::{Close, High, Low, Open, Volume}; +use crate::NumberType; +use crate::{lit, Close, High, Low, Open, Volume}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -31,11 +32,11 @@ use serde::{Deserialize, Serialize}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone, PartialEq)] pub struct DataItem { - open: f64, - high: f64, - low: f64, - close: f64, - volume: f64, + open: NumberType, + high: NumberType, + low: NumberType, + close: NumberType, + volume: NumberType, } impl DataItem { @@ -45,75 +46,70 @@ impl DataItem { } impl Open for DataItem { - fn open(&self) -> f64 { + fn open(&self) -> NumberType { self.open } } impl High for DataItem { - fn high(&self) -> f64 { + fn high(&self) -> NumberType { self.high } } impl Low for DataItem { - fn low(&self) -> f64 { + fn low(&self) -> NumberType { self.low } } impl Close for DataItem { - fn close(&self) -> f64 { + fn close(&self) -> NumberType { self.close } } impl Volume for DataItem { - fn volume(&self) -> f64 { + fn volume(&self) -> NumberType { self.volume } } +#[derive(Default)] pub struct DataItemBuilder { - open: Option, - high: Option, - low: Option, - close: Option, - volume: Option, + open: Option, + high: Option, + low: Option, + close: Option, + volume: Option, } impl DataItemBuilder { pub fn new() -> Self { - Self { - open: None, - high: None, - low: None, - close: None, - volume: None, - } + Self::default() } - pub fn open(mut self, val: f64) -> Self { + pub fn open(mut self, val: NumberType) -> Self { self.open = Some(val); self } - pub fn high(mut self, val: f64) -> Self { + pub fn high(mut self, val: NumberType) -> Self { self.high = Some(val); self } - pub fn low(mut self, val: f64) -> Self { + pub fn low(mut self, val: NumberType) -> Self { self.low = Some(val); self } - pub fn close(mut self, val: f64) -> Self { + pub fn close(mut self, val: NumberType) -> Self { self.close = Some(val); self } - pub fn volume(mut self, val: f64) -> Self { + pub fn volume(mut self, val: NumberType) -> Self { self.volume = Some(val); self } @@ -128,7 +124,7 @@ impl DataItemBuilder { && low <= high && high >= open && high >= close - && volume >= 0.0 + && volume >= lit!(0.0) { let item = DataItem { open, @@ -153,7 +149,15 @@ mod tests { #[test] fn test_builder() { - fn assert_valid((open, high, low, close, volume): (f64, f64, f64, f64, f64)) { + fn assert_valid( + (open, high, low, close, volume): ( + NumberType, + NumberType, + NumberType, + NumberType, + NumberType, + ), + ) { let result = DataItem::builder() .open(open) .high(high) @@ -164,8 +168,15 @@ mod tests { assert!(result.is_ok()); } - fn assert_invalid(record: (f64, f64, f64, f64, f64)) { - let (open, high, low, close, volume) = record; + fn assert_invalid( + (open, high, low, close, volume): ( + NumberType, + NumberType, + NumberType, + NumberType, + NumberType, + ), + ) { let result = DataItem::builder() .open(open) .high(high) @@ -178,9 +189,9 @@ mod tests { let valid_records = vec![ // open, high, low , close, volume - (20.0, 25.0, 15.0, 21.0, 7500.0), - (10.0, 10.0, 10.0, 10.0, 10.0), - (0.0, 0.0, 0.0, 0.0, 0.0), + (lit!(20.0), lit!(25.0), lit!(15.0), lit!(21.0), lit!(7500.0)), + (lit!(10.0), lit!(10.0), lit!(10.0), lit!(10.0), lit!(10.0)), + (lit!(0.0), lit!(0.0), lit!(0.0), lit!(0.0), lit!(0.0)), ]; for record in valid_records { assert_valid(record) @@ -188,15 +199,15 @@ mod tests { let invalid_records = vec![ // open, high, low , close, volume - (-1.0, 25.0, 15.0, 21.0, 7500.0), - (20.0, -1.0, 15.0, 21.0, 7500.0), - (20.0, 25.0, 15.0, -1.0, 7500.0), - (20.0, 25.0, 15.0, 21.0, -1.0), - (14.9, 25.0, 15.0, 21.0, 7500.0), - (25.1, 25.0, 15.0, 21.0, 7500.0), - (20.0, 25.0, 15.0, 14.9, 7500.0), - (20.0, 25.0, 15.0, 25.1, 7500.0), - (20.0, 15.0, 25.0, 21.0, 7500.0), + (lit!(-1.0), lit!(25.0), lit!(15.0), lit!(21.0), lit!(7500.0)), + (lit!(20.0), lit!(-1.0), lit!(15.0), lit!(21.0), lit!(7500.0)), + (lit!(20.0), lit!(25.0), lit!(15.0), lit!(-1.0), lit!(7500.0)), + (lit!(20.0), lit!(25.0), lit!(15.0), lit!(21.0), lit!(-1.0)), + (lit!(14.9), lit!(25.0), lit!(15.0), lit!(21.0), lit!(7500.0)), + (lit!(25.1), lit!(25.0), lit!(15.0), lit!(21.0), lit!(7500.0)), + (lit!(20.0), lit!(25.0), lit!(15.0), lit!(14.9), lit!(7500.0)), + (lit!(20.0), lit!(25.0), lit!(15.0), lit!(25.1), lit!(7500.0)), + (lit!(20.0), lit!(15.0), lit!(25.0), lit!(21.0), lit!(7500.0)), ]; for record in invalid_records { assert_invalid(record) diff --git a/src/helpers.rs b/src/helpers.rs index dc3c2ea..71b2b96 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -1,16 +1,64 @@ +#[cfg(not(feature = "decimal"))] +mod generics { + pub(crate) type NumberType = f64; + + #[macro_export] + macro_rules! lit { + ($e:expr) => { + $e + }; + } + + #[macro_export] + macro_rules! int { + ($e:expr) => { + $e as f64 + }; + } + + pub use std::f64::INFINITY; + pub use std::f64::NEG_INFINITY; +} + +#[cfg(feature = "decimal")] +mod generics { + pub(crate) type NumberType = rust_decimal::Decimal; + + #[macro_export] + macro_rules! lit { + ($e:expr) => { + ::rust_decimal::Decimal::from_str_exact(stringify!($e)).unwrap() + }; + } + + #[macro_export] + macro_rules! int { + ($e:expr) => { + ::rust_decimal::Decimal::new($e.try_into().unwrap(), 0) + }; + } + + use rust_decimal::Decimal; + pub const INFINITY: Decimal = Decimal::MAX; + pub const NEG_INFINITY: Decimal = Decimal::MIN; +} + +pub(crate) use generics::*; + /// Returns the largest of 3 given numbers. -pub fn max3(a: f64, b: f64, c: f64) -> f64 { +pub fn max3(a: NumberType, b: NumberType, c: NumberType) -> NumberType { a.max(b).max(c) } #[cfg(test)] mod tests { use super::*; + use crate::lit; #[test] fn test_max3() { - assert_eq!(max3(3.0, 2.0, 1.0), 3.0); - assert_eq!(max3(2.0, 3.0, 1.0), 3.0); - assert_eq!(max3(2.0, 1.0, 3.0), 3.0); + assert_eq!(max3(lit!(3.0), lit!(2.0), lit!(1.0)), lit!(3.0)); + assert_eq!(max3(lit!(2.0), lit!(3.0), lit!(1.0)), lit!(3.0)); + assert_eq!(max3(lit!(2.0), lit!(1.0), lit!(3.0)), lit!(3.0)); } } diff --git a/src/indicators/average_true_range.rs b/src/indicators/average_true_range.rs index 2892ef5..ae32b16 100644 --- a/src/indicators/average_true_range.rs +++ b/src/indicators/average_true_range.rs @@ -2,7 +2,7 @@ use std::fmt; use crate::errors::Result; use crate::indicators::{ExponentialMovingAverage, TrueRange}; -use crate::{Close, High, Low, Next, Period, Reset}; +use crate::{Close, High, Low, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -79,16 +79,16 @@ impl Period for AverageTrueRange { } } -impl Next for AverageTrueRange { - type Output = f64; +impl Next for AverageTrueRange { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { self.ema.next(self.true_range.next(input)) } } impl Next<&T> for AverageTrueRange { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { self.ema.next(self.true_range.next(input)) @@ -117,6 +117,7 @@ impl fmt::Display for AverageTrueRange { #[cfg(test)] mod tests { use super::*; + use crate::lit; use crate::test_helper::*; test_indicator!(AverageTrueRange); @@ -130,28 +131,28 @@ mod tests { fn test_next() { let mut atr = AverageTrueRange::new(3).unwrap(); - let bar1 = Bar::new().high(10).low(7.5).close(9); - let bar2 = Bar::new().high(11).low(9).close(9.5); + let bar1 = Bar::new().high(10).low(lit!(7.5)).close(9); + let bar2 = Bar::new().high(11).low(9).close(lit!(9.5)); let bar3 = Bar::new().high(9).low(5).close(8); - assert_eq!(atr.next(&bar1), 2.5); - assert_eq!(atr.next(&bar2), 2.25); - assert_eq!(atr.next(&bar3), 3.375); + assert_eq!(atr.next(&bar1), lit!(2.5)); + assert_eq!(atr.next(&bar2), lit!(2.25)); + assert_eq!(atr.next(&bar3), lit!(3.375)); } #[test] fn test_reset() { let mut atr = AverageTrueRange::new(9).unwrap(); - let bar1 = Bar::new().high(10).low(7.5).close(9); - let bar2 = Bar::new().high(11).low(9).close(9.5); + let bar1 = Bar::new().high(10).low(lit!(7.5)).close(9); + let bar2 = Bar::new().high(11).low(9).close(lit!(9.5)); atr.next(&bar1); atr.next(&bar2); atr.reset(); let bar3 = Bar::new().high(60).low(15).close(51); - assert_eq!(atr.next(&bar3), 45.0); + assert_eq!(atr.next(&bar3), lit!(45.0)); } #[test] diff --git a/src/indicators/bollinger_bands.rs b/src/indicators/bollinger_bands.rs index 84f2062..c58c40d 100644 --- a/src/indicators/bollinger_bands.rs +++ b/src/indicators/bollinger_bands.rs @@ -2,7 +2,7 @@ use std::fmt; use crate::errors::Result; use crate::indicators::StandardDeviation as Sd; -use crate::{Close, Next, Period, Reset}; +use crate::{lit, Close, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -50,19 +50,19 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone)] pub struct BollingerBands { period: usize, - multiplier: f64, + multiplier: NumberType, sd: Sd, } #[derive(Debug, Clone, PartialEq)] pub struct BollingerBandsOutput { - pub average: f64, - pub upper: f64, - pub lower: f64, + pub average: NumberType, + pub upper: NumberType, + pub lower: NumberType, } impl BollingerBands { - pub fn new(period: usize, multiplier: f64) -> Result { + pub fn new(period: usize, multiplier: NumberType) -> Result { Ok(Self { period, multiplier, @@ -70,7 +70,7 @@ impl BollingerBands { }) } - pub fn multiplier(&self) -> f64 { + pub fn multiplier(&self) -> NumberType { self.multiplier } } @@ -81,10 +81,10 @@ impl Period for BollingerBands { } } -impl Next for BollingerBands { +impl Next for BollingerBands { type Output = BollingerBandsOutput; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { let sd = self.sd.next(input); let mean = self.sd.mean(); @@ -112,7 +112,7 @@ impl Reset for BollingerBands { impl Default for BollingerBands { fn default() -> Self { - Self::new(9, 2_f64).unwrap() + Self::new(9, lit!(2.0)).unwrap() } } @@ -131,61 +131,61 @@ mod tests { #[test] fn test_new() { - assert!(BollingerBands::new(0, 2_f64).is_err()); - assert!(BollingerBands::new(1, 2_f64).is_ok()); - assert!(BollingerBands::new(2, 2_f64).is_ok()); + assert!(BollingerBands::new(0, lit!(2.0)).is_err()); + assert!(BollingerBands::new(1, lit!(2.0)).is_ok()); + assert!(BollingerBands::new(2, lit!(2.0)).is_ok()); } #[test] fn test_next() { - let mut bb = BollingerBands::new(3, 2.0_f64).unwrap(); - - let a = bb.next(2.0); - let b = bb.next(5.0); - let c = bb.next(1.0); - let d = bb.next(6.25); - - assert_eq!(round(a.average), 2.0); - assert_eq!(round(b.average), 3.5); - assert_eq!(round(c.average), 2.667); - assert_eq!(round(d.average), 4.083); - - assert_eq!(round(a.upper), 2.0); - assert_eq!(round(b.upper), 6.5); - assert_eq!(round(c.upper), 6.066); - assert_eq!(round(d.upper), 8.562); - - assert_eq!(round(a.lower), 2.0); - assert_eq!(round(b.lower), 0.5); - assert_eq!(round(c.lower), -0.733); - assert_eq!(round(d.lower), -0.395); + let mut bb = BollingerBands::new(3, lit!(2.0)).unwrap(); + + let a = bb.next(lit!(2.0)); + let b = bb.next(lit!(5.0)); + let c = bb.next(lit!(1.0)); + let d = bb.next(lit!(6.25)); + + assert_eq!(round(a.average), lit!(2.0)); + assert_eq!(round(b.average), lit!(3.5)); + assert_eq!(round(c.average), lit!(2.667)); + assert_eq!(round(d.average), lit!(4.083)); + + assert_eq!(round(a.upper), lit!(2.0)); + assert_eq!(round(b.upper), lit!(6.5)); + assert_eq!(round(c.upper), lit!(6.066)); + assert_eq!(round(d.upper), lit!(8.562)); + + assert_eq!(round(a.lower), lit!(2.0)); + assert_eq!(round(b.lower), lit!(0.5)); + assert_eq!(round(c.lower), lit!(-0.733)); + assert_eq!(round(d.lower), lit!(-0.395)); } #[test] fn test_reset() { - let mut bb = BollingerBands::new(5, 2.0_f64).unwrap(); + let mut bb = BollingerBands::new(5, lit!(2.0)).unwrap(); - let out = bb.next(3.0); + let out = bb.next(lit!(3.0)); - assert_eq!(out.average, 3.0); - assert_eq!(out.upper, 3.0); - assert_eq!(out.lower, 3.0); + assert_eq!(out.average, lit!(3.0)); + assert_eq!(out.upper, lit!(3.0)); + assert_eq!(out.lower, lit!(3.0)); - bb.next(2.5); - bb.next(3.5); - bb.next(4.0); + bb.next(lit!(2.5)); + bb.next(lit!(3.5)); + bb.next(lit!(4.0)); - let out = bb.next(2.0); + let out = bb.next(lit!(2.0)); - assert_eq!(out.average, 3.0); - assert_eq!(round(out.upper), 4.414); - assert_eq!(round(out.lower), 1.586); + assert_eq!(out.average, lit!(3.0)); + assert_eq!(round(out.upper), lit!(4.414)); + assert_eq!(round(out.lower), lit!(1.586)); bb.reset(); - let out = bb.next(3.0); - assert_eq!(out.average, 3.0); - assert_eq!(out.upper, 3.0); - assert_eq!(out.lower, 3.0); + let out = bb.next(lit!(3.0)); + assert_eq!(out.average, lit!(3.0)); + assert_eq!(out.upper, lit!(3.0)); + assert_eq!(out.lower, lit!(3.0)); } #[test] @@ -195,7 +195,7 @@ mod tests { #[test] fn test_display() { - let bb = BollingerBands::new(10, 3.0_f64).unwrap(); + let bb = BollingerBands::new(10, crate::int!(3)).unwrap(); assert_eq!(format!("{}", bb), "BB(10, 3)"); } } diff --git a/src/indicators/chandelier_exit.rs b/src/indicators/chandelier_exit.rs index 7182fd8..2998cef 100644 --- a/src/indicators/chandelier_exit.rs +++ b/src/indicators/chandelier_exit.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::errors::Result; use crate::indicators::{AverageTrueRange, Maximum, Minimum}; -use crate::{Close, High, Low, Next, Period, Reset}; +use crate::{lit, Close, High, Low, Next, NumberType, Period, Reset}; /// Chandelier Exit (CE). /// @@ -57,11 +57,11 @@ pub struct ChandelierExit { atr: AverageTrueRange, min: Minimum, max: Maximum, - multiplier: f64, + multiplier: NumberType, } impl ChandelierExit { - pub fn new(period: usize, multiplier: f64) -> Result { + pub fn new(period: usize, multiplier: NumberType) -> Result { Ok(Self { atr: AverageTrueRange::new(period)?, min: Minimum::new(period)?, @@ -70,18 +70,18 @@ impl ChandelierExit { }) } - pub fn multiplier(&self) -> f64 { + pub fn multiplier(&self) -> NumberType { self.multiplier } } #[derive(Debug, Clone, PartialEq)] pub struct ChandelierExitOutput { - pub long: f64, - pub short: f64, + pub long: NumberType, + pub short: NumberType, } -impl From for (f64, f64) { +impl From for (NumberType, NumberType) { fn from(ce: ChandelierExitOutput) -> Self { (ce.long, ce.short) } @@ -118,7 +118,7 @@ impl Reset for ChandelierExit { impl Default for ChandelierExit { fn default() -> Self { - Self::new(22, 3.0).unwrap() + Self::new(22, lit!(3.0)).unwrap() } } @@ -136,56 +136,56 @@ mod tests { type Ce = ChandelierExit; - fn round(nums: (f64, f64)) -> (f64, f64) { - let n0 = (nums.0 * 100.0).round() / 100.0; - let n1 = (nums.1 * 100.0).round() / 100.0; + fn round(nums: (NumberType, NumberType)) -> (NumberType, NumberType) { + let n0 = (nums.0 * lit!(100.0)).round() / lit!(100.0); + let n1 = (nums.1 * lit!(100.0)).round() / lit!(100.0); (n0, n1) } #[test] fn test_new() { - assert!(Ce::new(0, 0.0).is_err()); - assert!(Ce::new(1, 1.0).is_ok()); - assert!(Ce::new(22, 3.0).is_ok()); + assert!(Ce::new(0, lit!(0.0)).is_err()); + assert!(Ce::new(1, lit!(1.0)).is_ok()); + assert!(Ce::new(22, lit!(3.0)).is_ok()); } #[test] fn test_next_bar() { - let mut ce = Ce::new(5, 2.0).unwrap(); + let mut ce = Ce::new(5, lit!(2.0)).unwrap(); - let bar1 = Bar::new().high(2).low(1).close(1.5); - assert_eq!(round(ce.next(&bar1).into()), (0.0, 3.0)); + let bar1 = Bar::new().high(2).low(1).close(lit!(1.5)); + assert_eq!(round(ce.next(&bar1).into()), (lit!(0.0), lit!(3.0))); let bar2 = Bar::new().high(5).low(3).close(4); - assert_eq!(round(ce.next(&bar2).into()), (1.33, 4.67)); + assert_eq!(round(ce.next(&bar2).into()), (lit!(1.33), lit!(4.67))); let bar3 = Bar::new().high(9).low(7).close(8); - assert_eq!(round(ce.next(&bar3).into()), (3.22, 6.78)); + assert_eq!(round(ce.next(&bar3).into()), (lit!(3.22), lit!(6.78))); let bar4 = Bar::new().high(5).low(3).close(4); - assert_eq!(round(ce.next(&bar4).into()), (1.81, 8.19)); + assert_eq!(round(ce.next(&bar4).into()), (lit!(1.81), lit!(8.19))); let bar5 = Bar::new().high(5).low(3).close(4); - assert_eq!(round(ce.next(&bar5).into()), (2.88, 7.12)); + assert_eq!(round(ce.next(&bar5).into()), (lit!(2.88), lit!(7.12))); - let bar6 = Bar::new().high(2).low(1).close(1.5); - assert_eq!(round(ce.next(&bar6).into()), (2.92, 7.08)); + let bar6 = Bar::new().high(2).low(1).close(lit!(1.5)); + assert_eq!(round(ce.next(&bar6).into()), (lit!(2.92), lit!(7.08))); } #[test] fn test_reset() { - let mut ce = Ce::new(5, 2.0).unwrap(); + let mut ce = Ce::new(5, lit!(2.0)).unwrap(); - let bar1 = Bar::new().high(2).low(1).close(1.5); + let bar1 = Bar::new().high(2).low(1).close(lit!(1.5)); let bar2 = Bar::new().high(5).low(3).close(4); - assert_eq!(round(ce.next(&bar1).into()), (0.0, 3.0)); - assert_eq!(round(ce.next(&bar2).into()), (1.33, 4.67)); + assert_eq!(round(ce.next(&bar1).into()), (lit!(0.0), lit!(3.0))); + assert_eq!(round(ce.next(&bar2).into()), (lit!(1.33), lit!(4.67))); ce.reset(); - assert_eq!(round(ce.next(&bar1).into()), (0.0, 3.0)); - assert_eq!(round(ce.next(&bar2).into()), (1.33, 4.67)); + assert_eq!(round(ce.next(&bar1).into()), (lit!(0.0), lit!(3.0))); + assert_eq!(round(ce.next(&bar2).into()), (lit!(1.33), lit!(4.67))); } #[test] @@ -195,7 +195,7 @@ mod tests { #[test] fn test_display() { - let indicator = Ce::new(10, 5.0).unwrap(); + let indicator = Ce::new(10, crate::int!(5)).unwrap(); assert_eq!(format!("{}", indicator), "CE(10, 5)"); } } diff --git a/src/indicators/commodity_channel_index.rs b/src/indicators/commodity_channel_index.rs index bcaba2f..2281556 100644 --- a/src/indicators/commodity_channel_index.rs +++ b/src/indicators/commodity_channel_index.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::errors::Result; use crate::indicators::{MeanAbsoluteDeviation, SimpleMovingAverage}; -use crate::{Close, High, Low, Next, Period, Reset}; +use crate::{lit, Close, High, Low, Next, NumberType, Period, Reset}; /// Commodity Channel Index (CCI) /// @@ -51,18 +51,18 @@ impl Period for CommodityChannelIndex { } impl Next<&T> for CommodityChannelIndex { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { - let tp = (input.close() + input.high() + input.low()) / 3.0; + let tp = (input.close() + input.high() + input.low()) / lit!(3.0); let sma = self.sma.next(tp); let mad = self.mad.next(input); - if mad == 0.0 { - return 0.0; + if mad == lit!(0.0) { + return lit!(0.0); } - (tp - sma) / (mad * 0.015) + (tp - sma) / (mad * lit!(0.015)) } } @@ -100,39 +100,39 @@ mod tests { fn test_next_bar() { let mut cci = CommodityChannelIndex::new(5).unwrap(); - let bar1 = Bar::new().high(2).low(1).close(1.5); - assert_eq!(round(cci.next(&bar1)), 0.0); + let bar1 = Bar::new().high(2).low(1).close(lit!(1.5)); + assert_eq!(round(cci.next(&bar1)), lit!(0.0)); let bar2 = Bar::new().high(5).low(3).close(4); - assert_eq!(round(cci.next(&bar2)), 66.667); + assert_eq!(round(cci.next(&bar2)), lit!(66.667)); let bar3 = Bar::new().high(9).low(7).close(8); - assert_eq!(round(cci.next(&bar3)), 100.0); + assert_eq!(round(cci.next(&bar3)), lit!(100.0)); let bar4 = Bar::new().high(5).low(3).close(4); - assert_eq!(round(cci.next(&bar4)), -13.793); + assert_eq!(round(cci.next(&bar4)), lit!(-13.793)); let bar5 = Bar::new().high(5).low(3).close(4); - assert_eq!(round(cci.next(&bar5)), -13.514); + assert_eq!(round(cci.next(&bar5)), lit!(-13.514)); - let bar6 = Bar::new().high(2).low(1).close(1.5); - assert_eq!(round(cci.next(&bar6)), -126.126); + let bar6 = Bar::new().high(2).low(1).close(lit!(1.5)); + assert_eq!(round(cci.next(&bar6)), lit!(-126.126)); } #[test] fn test_reset() { let mut cci = CommodityChannelIndex::new(5).unwrap(); - let bar1 = Bar::new().high(2).low(1).close(1.5); + let bar1 = Bar::new().high(2).low(1).close(lit!(1.5)); let bar2 = Bar::new().high(5).low(3).close(4); - assert_eq!(round(cci.next(&bar1)), 0.0); - assert_eq!(round(cci.next(&bar2)), 66.667); + assert_eq!(round(cci.next(&bar1)), lit!(0.0)); + assert_eq!(round(cci.next(&bar2)), lit!(66.667)); cci.reset(); - assert_eq!(round(cci.next(&bar1)), 0.0); - assert_eq!(round(cci.next(&bar2)), 66.667); + assert_eq!(round(cci.next(&bar1)), lit!(0.0)); + assert_eq!(round(cci.next(&bar2)), lit!(66.667)); } #[test] diff --git a/src/indicators/efficiency_ratio.rs b/src/indicators/efficiency_ratio.rs index 8ad5b50..60b8126 100644 --- a/src/indicators/efficiency_ratio.rs +++ b/src/indicators/efficiency_ratio.rs @@ -1,7 +1,7 @@ use std::fmt; use crate::errors::{Result, TaError}; -use crate::traits::{Close, Next, Period, Reset}; +use crate::{lit, Close, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -36,7 +36,7 @@ pub struct EfficiencyRatio { period: usize, index: usize, count: usize, - deque: Box<[f64]>, + deque: Box<[NumberType]>, } impl EfficiencyRatio { @@ -47,7 +47,7 @@ impl EfficiencyRatio { period, index: 0, count: 0, - deque: vec![0.0; period].into_boxed_slice(), + deque: vec![lit!(0.0); period].into_boxed_slice(), }), } } @@ -59,10 +59,10 @@ impl Period for EfficiencyRatio { } } -impl Next for EfficiencyRatio { - type Output = f64; +impl Next for EfficiencyRatio { + type Output = NumberType; - fn next(&mut self, input: f64) -> f64 { + fn next(&mut self, input: NumberType) -> NumberType { let first = if self.count >= self.period { self.deque[self.index] } else { @@ -77,7 +77,7 @@ impl Next for EfficiencyRatio { 0 }; - let mut volatility = 0.0; + let mut volatility = lit!(0.0); let mut previous = first; for n in &self.deque[self.index..self.count] { volatility += (previous - n).abs(); @@ -93,9 +93,9 @@ impl Next for EfficiencyRatio { } impl Next<&T> for EfficiencyRatio { - type Output = f64; + type Output = NumberType; - fn next(&mut self, input: &T) -> f64 { + fn next(&mut self, input: &T) -> NumberType { self.next(input.close()) } } @@ -105,7 +105,7 @@ impl Reset for EfficiencyRatio { self.index = 0; self.count = 0; for i in 0..self.period { - self.deque[i] = 0.0; + self.deque[i] = lit!(0.0); } } } @@ -139,29 +139,29 @@ mod tests { fn test_next() { let mut er = EfficiencyRatio::new(3).unwrap(); - assert_eq!(round(er.next(3.0)), 1.0); - assert_eq!(round(er.next(5.0)), 1.0); - assert_eq!(round(er.next(2.0)), 0.2); - assert_eq!(round(er.next(3.0)), 0.0); - assert_eq!(round(er.next(1.0)), 0.667); - assert_eq!(round(er.next(3.0)), 0.2); - assert_eq!(round(er.next(4.0)), 0.2); - assert_eq!(round(er.next(6.0)), 1.0); + assert_eq!(round(er.next(lit!(3.0))), lit!(1.0)); + assert_eq!(round(er.next(lit!(5.0))), lit!(1.0)); + assert_eq!(round(er.next(lit!(2.0))), lit!(0.2)); + assert_eq!(round(er.next(lit!(3.0))), lit!(0.0)); + assert_eq!(round(er.next(lit!(1.0))), lit!(0.667)); + assert_eq!(round(er.next(lit!(3.0))), lit!(0.2)); + assert_eq!(round(er.next(lit!(4.0))), lit!(0.2)); + assert_eq!(round(er.next(lit!(6.0))), lit!(1.0)); } #[test] fn test_reset() { let mut er = EfficiencyRatio::new(3).unwrap(); - er.next(3.0); - er.next(5.0); + er.next(lit!(3.0)); + er.next(lit!(5.0)); er.reset(); - assert_eq!(round(er.next(3.0)), 1.0); - assert_eq!(round(er.next(5.0)), 1.0); - assert_eq!(round(er.next(2.0)), 0.2); - assert_eq!(round(er.next(3.0)), 0.0); + assert_eq!(round(er.next(lit!(3.0))), lit!(1.0)); + assert_eq!(round(er.next(lit!(5.0))), lit!(1.0)); + assert_eq!(round(er.next(lit!(2.0))), lit!(0.2)); + assert_eq!(round(er.next(lit!(3.0))), lit!(0.0)); } #[test] diff --git a/src/indicators/exponential_moving_average.rs b/src/indicators/exponential_moving_average.rs index 3cfa1a3..c8f74e5 100644 --- a/src/indicators/exponential_moving_average.rs +++ b/src/indicators/exponential_moving_average.rs @@ -1,7 +1,7 @@ use std::fmt; use crate::errors::{Result, TaError}; -use crate::{Close, Next, Period, Reset}; +use crate::{int, lit, Close, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -57,8 +57,8 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone)] pub struct ExponentialMovingAverage { period: usize, - k: f64, - current: f64, + k: NumberType, + current: NumberType, is_new: bool, } @@ -68,8 +68,8 @@ impl ExponentialMovingAverage { 0 => Err(TaError::InvalidParameter), _ => Ok(Self { period, - k: 2.0 / (period + 1) as f64, - current: 0.0, + k: lit!(2.0) / int!(period + 1), + current: NumberType::default(), is_new: true, }), } @@ -82,22 +82,22 @@ impl Period for ExponentialMovingAverage { } } -impl Next for ExponentialMovingAverage { - type Output = f64; +impl Next for ExponentialMovingAverage { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { if self.is_new { self.is_new = false; self.current = input; } else { - self.current = self.k * input + (1.0 - self.k) * self.current; + self.current = self.k * input + (lit!(1.0) - self.k) * self.current; } self.current } } impl Next<&T> for ExponentialMovingAverage { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { self.next(input.close()) @@ -106,7 +106,7 @@ impl Next<&T> for ExponentialMovingAverage { impl Reset for ExponentialMovingAverage { fn reset(&mut self) { - self.current = 0.0; + self.current = NumberType::default(); self.is_new = true; } } @@ -140,30 +140,30 @@ mod tests { fn test_next() { let mut ema = ExponentialMovingAverage::new(3).unwrap(); - assert_eq!(ema.next(2.0), 2.0); - assert_eq!(ema.next(5.0), 3.5); - assert_eq!(ema.next(1.0), 2.25); - assert_eq!(ema.next(6.25), 4.25); + assert_eq!(ema.next(lit!(2.0)), lit!(2.0)); + assert_eq!(ema.next(lit!(5.0)), lit!(3.5)); + assert_eq!(ema.next(lit!(1.0)), lit!(2.25)); + assert_eq!(ema.next(lit!(6.25)), lit!(4.25)); let mut ema = ExponentialMovingAverage::new(3).unwrap(); let bar1 = Bar::new().close(2); let bar2 = Bar::new().close(5); - assert_eq!(ema.next(&bar1), 2.0); - assert_eq!(ema.next(&bar2), 3.5); + assert_eq!(ema.next(&bar1), lit!(2.0)); + assert_eq!(ema.next(&bar2), lit!(3.5)); } #[test] fn test_reset() { let mut ema = ExponentialMovingAverage::new(5).unwrap(); - assert_eq!(ema.next(4.0), 4.0); - ema.next(10.0); - ema.next(15.0); - ema.next(20.0); - assert_ne!(ema.next(4.0), 4.0); + assert_eq!(ema.next(lit!(4.0)), lit!(4.0)); + ema.next(lit!(10.0)); + ema.next(lit!(15.0)); + ema.next(lit!(20.0)); + assert_ne!(ema.next(lit!(4.0)), lit!(4.0)); ema.reset(); - assert_eq!(ema.next(4.0), 4.0); + assert_eq!(ema.next(lit!(4.0)), lit!(4.0)); } #[test] diff --git a/src/indicators/fast_stochastic.rs b/src/indicators/fast_stochastic.rs index 27f7b0f..bd59522 100644 --- a/src/indicators/fast_stochastic.rs +++ b/src/indicators/fast_stochastic.rs @@ -2,7 +2,7 @@ use std::fmt; use crate::errors::Result; use crate::indicators::{Maximum, Minimum}; -use crate::{Close, High, Low, Next, Period, Reset}; +use crate::{lit, Close, High, Low, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -64,25 +64,25 @@ impl Period for FastStochastic { } } -impl Next for FastStochastic { - type Output = f64; +impl Next for FastStochastic { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { let min = self.minimum.next(input); let max = self.maximum.next(input); if min == max { // When only 1 input was given, than min and max are the same, // therefore it makes sense to return 50 - 50.0 + lit!(50.0) } else { - (input - min) / (max - min) * 100.0 + (input - min) / (max - min) * lit!(100.0) } } } impl Next<&T> for FastStochastic { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { let highest = self.maximum.next(input.high()); @@ -91,9 +91,9 @@ impl Next<&T> for FastStochastic { if highest == lowest { // To avoid division by zero, return 50.0 - 50.0 + lit!(50.0) } else { - (close - lowest) / (highest - lowest) * 100.0 + (close - lowest) / (highest - lowest) * lit!(100.0) } } } @@ -133,23 +133,23 @@ mod tests { #[test] fn test_next_with_f64() { let mut stoch = FastStochastic::new(3).unwrap(); - assert_eq!(stoch.next(0.0), 50.0); - assert_eq!(stoch.next(200.0), 100.0); - assert_eq!(stoch.next(100.0), 50.0); - assert_eq!(stoch.next(120.0), 20.0); - assert_eq!(stoch.next(115.0), 75.0); + assert_eq!(stoch.next(lit!(0.0)), lit!(50.0)); + assert_eq!(stoch.next(lit!(200.0)), lit!(100.0)); + assert_eq!(stoch.next(lit!(100.0)), lit!(50.0)); + assert_eq!(stoch.next(lit!(120.0)), lit!(20.0)); + assert_eq!(stoch.next(lit!(115.0)), lit!(75.0)); } #[test] fn test_next_with_bars() { let test_data = vec![ // high, low , close, expected - (20.0, 20.0, 20.0, 50.0), // min = 20, max = 20 - (30.0, 10.0, 25.0, 75.0), // min = 10, max = 30 - (40.0, 20.0, 16.0, 20.0), // min = 10, max = 40 - (35.0, 15.0, 19.0, 30.0), // min = 10, max = 40 - (30.0, 20.0, 25.0, 40.0), // min = 15, max = 40 - (35.0, 25.0, 30.0, 75.0), // min = 15, max = 35 + (lit!(20.0), lit!(20.0), lit!(20.0), lit!(50.0)), // min = 20, max = 20 + (lit!(30.0), lit!(10.0), lit!(25.0), lit!(75.0)), // min = 10, max = 30 + (lit!(40.0), lit!(20.0), lit!(16.0), lit!(20.0)), // min = 10, max = 40 + (lit!(35.0), lit!(15.0), lit!(19.0), lit!(30.0)), // min = 10, max = 40 + (lit!(30.0), lit!(20.0), lit!(25.0), lit!(40.0)), // min = 15, max = 40 + (lit!(35.0), lit!(25.0), lit!(30.0), lit!(75.0)), // min = 15, max = 35 ]; let mut stoch = FastStochastic::new(3).unwrap(); @@ -163,15 +163,15 @@ mod tests { #[test] fn test_reset() { let mut indicator = FastStochastic::new(10).unwrap(); - assert_eq!(indicator.next(10.0), 50.0); - assert_eq!(indicator.next(210.0), 100.0); - assert_eq!(indicator.next(10.0), 0.0); - assert_eq!(indicator.next(60.0), 25.0); + assert_eq!(indicator.next(lit!(10.0)), lit!(50.0)); + assert_eq!(indicator.next(lit!(210.0)), lit!(100.0)); + assert_eq!(indicator.next(lit!(10.0)), lit!(0.0)); + assert_eq!(indicator.next(lit!(60.0)), lit!(25.0)); indicator.reset(); - assert_eq!(indicator.next(10.0), 50.0); - assert_eq!(indicator.next(20.0), 100.0); - assert_eq!(indicator.next(12.5), 25.0); + assert_eq!(indicator.next(lit!(10.0)), lit!(50.0)); + assert_eq!(indicator.next(lit!(20.0)), lit!(100.0)); + assert_eq!(indicator.next(lit!(12.5)), lit!(25.0)); } #[test] diff --git a/src/indicators/keltner_channel.rs b/src/indicators/keltner_channel.rs index a2c8b27..94d69ef 100644 --- a/src/indicators/keltner_channel.rs +++ b/src/indicators/keltner_channel.rs @@ -2,7 +2,7 @@ use std::fmt; use crate::errors::Result; use crate::indicators::{AverageTrueRange, ExponentialMovingAverage}; -use crate::{Close, High, Low, Next, Period, Reset}; +use crate::{int, lit, Close, High, Low, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -51,20 +51,20 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone)] pub struct KeltnerChannel { period: usize, - multiplier: f64, + multiplier: NumberType, atr: AverageTrueRange, ema: ExponentialMovingAverage, } #[derive(Debug, Clone, PartialEq)] pub struct KeltnerChannelOutput { - pub average: f64, - pub upper: f64, - pub lower: f64, + pub average: NumberType, + pub upper: NumberType, + pub lower: NumberType, } impl KeltnerChannel { - pub fn new(period: usize, multiplier: f64) -> Result { + pub fn new(period: usize, multiplier: NumberType) -> Result { Ok(Self { period, multiplier, @@ -73,7 +73,7 @@ impl KeltnerChannel { }) } - pub fn multiplier(&self) -> f64 { + pub fn multiplier(&self) -> NumberType { self.multiplier } } @@ -84,10 +84,10 @@ impl Period for KeltnerChannel { } } -impl Next for KeltnerChannel { +impl Next for KeltnerChannel { type Output = KeltnerChannelOutput; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { let atr = self.atr.next(input); let average = self.ema.next(input); @@ -103,7 +103,7 @@ impl Next<&T> for KeltnerChannel { type Output = KeltnerChannelOutput; fn next(&mut self, input: &T) -> Self::Output { - let typical_price = (input.close() + input.high() + input.low()) / 3.0; + let typical_price = (input.close() + input.high() + input.low()) / lit!(3.0); let average = self.ema.next(typical_price); let atr = self.atr.next(input); @@ -125,7 +125,7 @@ impl Reset for KeltnerChannel { impl Default for KeltnerChannel { fn default() -> Self { - Self::new(10, 2_f64).unwrap() + Self::new(10, int!(2)).unwrap() } } @@ -144,84 +144,84 @@ mod tests { #[test] fn test_new() { - assert!(KeltnerChannel::new(0, 2_f64).is_err()); - assert!(KeltnerChannel::new(1, 2_f64).is_ok()); - assert!(KeltnerChannel::new(2, 2_f64).is_ok()); + assert!(KeltnerChannel::new(0, lit!(2.0)).is_err()); + assert!(KeltnerChannel::new(1, lit!(2.0)).is_ok()); + assert!(KeltnerChannel::new(2, lit!(2.0)).is_ok()); } #[test] fn test_next() { - let mut kc = KeltnerChannel::new(3, 2.0_f64).unwrap(); - - let a = kc.next(2.0); - let b = kc.next(5.0); - let c = kc.next(1.0); - let d = kc.next(6.25); - - assert_eq!(round(a.average), 2.0); - assert_eq!(round(b.average), 3.5); - assert_eq!(round(c.average), 2.25); - assert_eq!(round(d.average), 4.25); - - assert_eq!(round(a.upper), 2.0); - assert_eq!(round(b.upper), 6.5); - assert_eq!(round(c.upper), 7.75); - assert_eq!(round(d.upper), 12.25); - - assert_eq!(round(a.lower), 2.0); - assert_eq!(round(b.lower), 0.5); - assert_eq!(round(c.lower), -3.25); - assert_eq!(round(d.lower), -3.75); + let mut kc = KeltnerChannel::new(3, lit!(2.0)).unwrap(); + + let a = kc.next(lit!(2.0)); + let b = kc.next(lit!(5.0)); + let c = kc.next(lit!(1.0)); + let d = kc.next(lit!(6.25)); + + assert_eq!(round(a.average), lit!(2.0)); + assert_eq!(round(b.average), lit!(3.5)); + assert_eq!(round(c.average), lit!(2.25)); + assert_eq!(round(d.average), lit!(4.25)); + + assert_eq!(round(a.upper), lit!(2.0)); + assert_eq!(round(b.upper), lit!(6.5)); + assert_eq!(round(c.upper), lit!(7.75)); + assert_eq!(round(d.upper), lit!(12.25)); + + assert_eq!(round(a.lower), lit!(2.0)); + assert_eq!(round(b.lower), lit!(0.5)); + assert_eq!(round(c.lower), lit!(-3.25)); + assert_eq!(round(d.lower), lit!(-3.75)); } #[test] fn test_next_with_data_item() { - let mut kc = KeltnerChannel::new(3, 2.0_f64).unwrap(); + let mut kc = KeltnerChannel::new(3, lit!(2.0)).unwrap(); - let dt1 = Bar::new().low(1.2).high(1.7).close(1.3); // typical_price = 1.4 + let dt1 = Bar::new().low(lit!(1.2)).high(lit!(1.7)).close(lit!(1.3)); // typical_price = 1.4 let o1 = kc.next(&dt1); - assert_eq!(round(o1.average), 1.4); - assert_eq!(round(o1.lower), 0.4); - assert_eq!(round(o1.upper), 2.4); + assert_eq!(round(o1.average), lit!(1.4)); + assert_eq!(round(o1.lower), lit!(0.4)); + assert_eq!(round(o1.upper), lit!(2.4)); - let dt2 = Bar::new().low(1.3).high(1.8).close(1.4); // typical_price = 1.5 + let dt2 = Bar::new().low(lit!(1.3)).high(lit!(1.8)).close(lit!(1.4)); // typical_price = 1.5 let o2 = kc.next(&dt2); - assert_eq!(round(o2.average), 1.45); - assert_eq!(round(o2.lower), 0.45); - assert_eq!(round(o2.upper), 2.45); + assert_eq!(round(o2.average), lit!(1.45)); + assert_eq!(round(o2.lower), lit!(0.45)); + assert_eq!(round(o2.upper), lit!(2.45)); - let dt3 = Bar::new().low(1.4).high(1.9).close(1.5); // typical_price = 1.6 + let dt3 = Bar::new().low(lit!(1.4)).high(lit!(1.9)).close(lit!(1.5)); // typical_price = 1.6 let o3 = kc.next(&dt3); - assert_eq!(round(o3.average), 1.525); - assert_eq!(round(o3.lower), 0.525); - assert_eq!(round(o3.upper), 2.525); + assert_eq!(round(o3.average), lit!(1.525)); + assert_eq!(round(o3.lower), lit!(0.525)); + assert_eq!(round(o3.upper), lit!(2.525)); } #[test] fn test_reset() { - let mut kc = KeltnerChannel::new(5, 2.0_f64).unwrap(); + let mut kc = KeltnerChannel::new(5, lit!(2.0)).unwrap(); - let out = kc.next(3.0); + let out = kc.next(lit!(3.0)); - assert_eq!(out.average, 3.0); - assert_eq!(out.upper, 3.0); - assert_eq!(out.lower, 3.0); + assert_eq!(out.average, lit!(3.0)); + assert_eq!(out.upper, lit!(3.0)); + assert_eq!(out.lower, lit!(3.0)); - kc.next(2.5); - kc.next(3.5); - kc.next(4.0); + kc.next(lit!(2.5)); + kc.next(lit!(3.5)); + kc.next(lit!(4.0)); - let out = kc.next(2.0); + let out = kc.next(lit!(2.0)); - assert_eq!(round(out.average), 2.914); - assert_eq!(round(out.upper), 4.864); - assert_eq!(round(out.lower), 0.963); + assert_eq!(round(out.average), lit!(2.914)); + assert_eq!(round(out.upper), lit!(4.864)); + assert_eq!(round(out.lower), lit!(0.963)); kc.reset(); - let out = kc.next(3.0); - assert_eq!(out.average, 3.0); - assert_eq!(out.lower, 3.0); - assert_eq!(out.upper, 3.0); + let out = kc.next(lit!(3.0)); + assert_eq!(out.average, lit!(3.0)); + assert_eq!(out.lower, lit!(3.0)); + assert_eq!(out.upper, lit!(3.0)); } #[test] @@ -231,7 +231,7 @@ mod tests { #[test] fn test_display() { - let kc = KeltnerChannel::new(10, 3.0_f64).unwrap(); + let kc = KeltnerChannel::new(10, int!(3)).unwrap(); assert_eq!(format!("{}", kc), "KC(10, 3)"); } } diff --git a/src/indicators/maximum.rs b/src/indicators/maximum.rs index 55c9428..b4e614c 100644 --- a/src/indicators/maximum.rs +++ b/src/indicators/maximum.rs @@ -1,7 +1,8 @@ use std::fmt; use crate::errors::{Result, TaError}; -use crate::{High, Next, Period, Reset}; +use crate::helpers::NEG_INFINITY; +use crate::{High, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -30,7 +31,7 @@ pub struct Maximum { period: usize, max_index: usize, cur_index: usize, - deque: Box<[f64]>, + deque: Box<[NumberType]>, } impl Maximum { @@ -41,13 +42,13 @@ impl Maximum { period, max_index: 0, cur_index: 0, - deque: vec![f64::NEG_INFINITY; period].into_boxed_slice(), + deque: vec![NEG_INFINITY; period].into_boxed_slice(), }), } } fn find_max_index(&self) -> usize { - let mut max = f64::NEG_INFINITY; + let mut max = NEG_INFINITY; let mut index: usize = 0; for (i, &val) in self.deque.iter().enumerate() { @@ -67,10 +68,10 @@ impl Period for Maximum { } } -impl Next for Maximum { - type Output = f64; +impl Next for Maximum { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { self.deque[self.cur_index] = input; if input > self.deque[self.max_index] { @@ -90,7 +91,7 @@ impl Next for Maximum { } impl Next<&T> for Maximum { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { self.next(input.high()) @@ -100,7 +101,7 @@ impl Next<&T> for Maximum { impl Reset for Maximum { fn reset(&mut self) { for i in 0..self.period { - self.deque[i] = f64::NEG_INFINITY; + self.deque[i] = NEG_INFINITY; } } } @@ -120,6 +121,7 @@ impl fmt::Display for Maximum { #[cfg(test)] mod tests { use super::*; + use crate::lit; use crate::test_helper::*; test_indicator!(Maximum); @@ -134,40 +136,40 @@ mod tests { fn test_next() { let mut max = Maximum::new(3).unwrap(); - assert_eq!(max.next(4.0), 4.0); - assert_eq!(max.next(1.2), 4.0); - assert_eq!(max.next(5.0), 5.0); - assert_eq!(max.next(3.0), 5.0); - assert_eq!(max.next(4.0), 5.0); - assert_eq!(max.next(0.0), 4.0); - assert_eq!(max.next(-1.0), 4.0); - assert_eq!(max.next(-2.0), 0.0); - assert_eq!(max.next(-1.5), -1.0); + assert_eq!(max.next(lit!(4.0)), lit!(4.0)); + assert_eq!(max.next(lit!(1.2)), lit!(4.0)); + assert_eq!(max.next(lit!(5.0)), lit!(5.0)); + assert_eq!(max.next(lit!(3.0)), lit!(5.0)); + assert_eq!(max.next(lit!(4.0)), lit!(5.0)); + assert_eq!(max.next(lit!(0.0)), lit!(4.0)); + assert_eq!(max.next(lit!(-1.0)), lit!(4.0)); + assert_eq!(max.next(lit!(-2.0)), lit!(0.0)); + assert_eq!(max.next(lit!(-1.5)), lit!(-1.0)); } #[test] fn test_next_with_bars() { - fn bar(high: f64) -> Bar { + fn bar(high: NumberType) -> Bar { Bar::new().high(high) } let mut max = Maximum::new(2).unwrap(); - assert_eq!(max.next(&bar(1.1)), 1.1); - assert_eq!(max.next(&bar(4.0)), 4.0); - assert_eq!(max.next(&bar(3.5)), 4.0); - assert_eq!(max.next(&bar(2.0)), 3.5); + assert_eq!(max.next(&bar(lit!(1.1))), lit!(1.1)); + assert_eq!(max.next(&bar(lit!(4.0))), lit!(4.0)); + assert_eq!(max.next(&bar(lit!(3.5))), lit!(4.0)); + assert_eq!(max.next(&bar(lit!(2.0))), lit!(3.5)); } #[test] fn test_reset() { let mut max = Maximum::new(100).unwrap(); - assert_eq!(max.next(4.0), 4.0); - assert_eq!(max.next(10.0), 10.0); - assert_eq!(max.next(4.0), 10.0); + assert_eq!(max.next(lit!(4.0)), lit!(4.0)); + assert_eq!(max.next(lit!(10.0)), lit!(10.0)); + assert_eq!(max.next(lit!(4.0)), lit!(10.0)); max.reset(); - assert_eq!(max.next(4.0), 4.0); + assert_eq!(max.next(lit!(4.0)), lit!(4.0)); } #[test] diff --git a/src/indicators/mean_absolute_deviation.rs b/src/indicators/mean_absolute_deviation.rs index 05f5c05..d335bb5 100644 --- a/src/indicators/mean_absolute_deviation.rs +++ b/src/indicators/mean_absolute_deviation.rs @@ -4,7 +4,7 @@ use std::fmt; use serde::{Deserialize, Serialize}; use crate::errors::{Result, TaError}; -use crate::{Close, Next, Period, Reset}; +use crate::{int, lit, Close, Next, NumberType, Period, Reset}; /// Mean Absolute Deviation (MAD) /// @@ -33,8 +33,8 @@ pub struct MeanAbsoluteDeviation { period: usize, index: usize, count: usize, - sum: f64, - deque: Box<[f64]>, + sum: NumberType, + deque: Box<[NumberType]>, } impl MeanAbsoluteDeviation { @@ -45,8 +45,8 @@ impl MeanAbsoluteDeviation { period, index: 0, count: 0, - sum: 0.0, - deque: vec![0.0; period].into_boxed_slice(), + sum: lit!(0.0), + deque: vec![lit!(0.0); period].into_boxed_slice(), }), } } @@ -58,12 +58,12 @@ impl Period for MeanAbsoluteDeviation { } } -impl Next for MeanAbsoluteDeviation { - type Output = f64; +impl Next for MeanAbsoluteDeviation { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { self.sum = if self.count < self.period { - self.count = self.count + 1; + self.count += 1; self.sum + input } else { self.sum + input - self.deque[self.index] @@ -76,18 +76,18 @@ impl Next for MeanAbsoluteDeviation { 0 }; - let mean = self.sum / self.count as f64; + let mean = self.sum / int!(self.count); - let mut mad = 0.0; + let mut mad = lit!(0.0); for value in &self.deque[..self.count] { mad += (value - mean).abs(); } - mad / self.count as f64 + mad / int!(self.count) } } impl Next<&T> for MeanAbsoluteDeviation { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { self.next(input.close()) @@ -98,9 +98,9 @@ impl Reset for MeanAbsoluteDeviation { fn reset(&mut self) { self.index = 0; self.count = 0; - self.sum = 0.0; + self.sum = lit!(0.0); for i in 0..self.period { - self.deque[i] = 0.0; + self.deque[i] = lit!(0.0); } } } @@ -134,25 +134,25 @@ mod tests { fn test_next() { let mut mad = MeanAbsoluteDeviation::new(5).unwrap(); - assert_eq!(round(mad.next(1.5)), 0.0); - assert_eq!(round(mad.next(4.0)), 1.25); - assert_eq!(round(mad.next(8.0)), 2.333); - assert_eq!(round(mad.next(4.0)), 1.813); - assert_eq!(round(mad.next(4.0)), 1.48); - assert_eq!(round(mad.next(1.5)), 1.48); + assert_eq!(round(mad.next(lit!(1.5))), lit!(0.0)); + assert_eq!(round(mad.next(lit!(4.0))), lit!(1.25)); + assert_eq!(round(mad.next(lit!(8.0))), lit!(2.333)); + assert_eq!(round(mad.next(lit!(4.0))), lit!(1.813)); + assert_eq!(round(mad.next(lit!(4.0))), lit!(1.48)); + assert_eq!(round(mad.next(lit!(1.5))), lit!(1.48)); } #[test] fn test_reset() { let mut mad = MeanAbsoluteDeviation::new(5).unwrap(); - assert_eq!(round(mad.next(1.5)), 0.0); - assert_eq!(round(mad.next(4.0)), 1.25); + assert_eq!(round(mad.next(lit!(1.5))), lit!(0.0)); + assert_eq!(round(mad.next(lit!(4.0))), lit!(1.25)); mad.reset(); - assert_eq!(round(mad.next(1.5)), 0.0); - assert_eq!(round(mad.next(4.0)), 1.25); + assert_eq!(round(mad.next(lit!(1.5))), lit!(0.0)); + assert_eq!(round(mad.next(lit!(4.0))), lit!(1.25)); } #[test] diff --git a/src/indicators/minimum.rs b/src/indicators/minimum.rs index fc1eaad..c6722a4 100644 --- a/src/indicators/minimum.rs +++ b/src/indicators/minimum.rs @@ -1,7 +1,8 @@ use std::fmt; use crate::errors::{Result, TaError}; -use crate::{Low, Next, Period, Reset}; +use crate::helpers::INFINITY; +use crate::{Low, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -29,7 +30,7 @@ pub struct Minimum { period: usize, min_index: usize, cur_index: usize, - deque: Box<[f64]>, + deque: Box<[NumberType]>, } impl Minimum { @@ -40,13 +41,13 @@ impl Minimum { period, min_index: 0, cur_index: 0, - deque: vec![f64::INFINITY; period].into_boxed_slice(), + deque: vec![INFINITY; period].into_boxed_slice(), }), } } fn find_min_index(&self) -> usize { - let mut min = f64::INFINITY; + let mut min = INFINITY; let mut index: usize = 0; for (i, &val) in self.deque.iter().enumerate() { @@ -66,10 +67,10 @@ impl Period for Minimum { } } -impl Next for Minimum { - type Output = f64; +impl Next for Minimum { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { self.deque[self.cur_index] = input; if input < self.deque[self.min_index] { @@ -89,7 +90,7 @@ impl Next for Minimum { } impl Next<&T> for Minimum { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { self.next(input.low()) @@ -99,7 +100,7 @@ impl Next<&T> for Minimum { impl Reset for Minimum { fn reset(&mut self) { for i in 0..self.period { - self.deque[i] = f64::INFINITY; + self.deque[i] = INFINITY; } } } @@ -119,6 +120,7 @@ impl fmt::Display for Minimum { #[cfg(test)] mod tests { use super::*; + use crate::lit; use crate::test_helper::*; test_indicator!(Minimum); @@ -133,41 +135,41 @@ mod tests { fn test_next() { let mut min = Minimum::new(3).unwrap(); - assert_eq!(min.next(4.0), 4.0); - assert_eq!(min.next(1.2), 1.2); - assert_eq!(min.next(5.0), 1.2); - assert_eq!(min.next(3.0), 1.2); - assert_eq!(min.next(4.0), 3.0); - assert_eq!(min.next(6.0), 3.0); - assert_eq!(min.next(7.0), 4.0); - assert_eq!(min.next(8.0), 6.0); - assert_eq!(min.next(-9.0), -9.0); - assert_eq!(min.next(0.0), -9.0); + assert_eq!(min.next(lit!(4.0)), lit!(4.0)); + assert_eq!(min.next(lit!(1.2)), lit!(1.2)); + assert_eq!(min.next(lit!(5.0)), lit!(1.2)); + assert_eq!(min.next(lit!(3.0)), lit!(1.2)); + assert_eq!(min.next(lit!(4.0)), lit!(3.0)); + assert_eq!(min.next(lit!(6.0)), lit!(3.0)); + assert_eq!(min.next(lit!(7.0)), lit!(4.0)); + assert_eq!(min.next(lit!(8.0)), lit!(6.0)); + assert_eq!(min.next(lit!(-9.0)), lit!(-9.0)); + assert_eq!(min.next(lit!(0.0)), lit!(-9.0)); } #[test] fn test_next_with_bars() { - fn bar(low: f64) -> Bar { + fn bar(low: NumberType) -> Bar { Bar::new().low(low) } let mut min = Minimum::new(3).unwrap(); - assert_eq!(min.next(&bar(4.0)), 4.0); - assert_eq!(min.next(&bar(4.0)), 4.0); - assert_eq!(min.next(&bar(1.2)), 1.2); - assert_eq!(min.next(&bar(5.0)), 1.2); + assert_eq!(min.next(&bar(lit!(4.0))), lit!(4.0)); + assert_eq!(min.next(&bar(lit!(4.0))), lit!(4.0)); + assert_eq!(min.next(&bar(lit!(1.2))), lit!(1.2)); + assert_eq!(min.next(&bar(lit!(5.0))), lit!(1.2)); } #[test] fn test_reset() { let mut min = Minimum::new(10).unwrap(); - assert_eq!(min.next(5.0), 5.0); - assert_eq!(min.next(7.0), 5.0); + assert_eq!(min.next(lit!(5.0)), lit!(5.0)); + assert_eq!(min.next(lit!(7.0)), lit!(5.0)); min.reset(); - assert_eq!(min.next(8.0), 8.0); + assert_eq!(min.next(lit!(8.0)), lit!(8.0)); } #[test] diff --git a/src/indicators/money_flow_index.rs b/src/indicators/money_flow_index.rs index 212e9e2..5c1199a 100644 --- a/src/indicators/money_flow_index.rs +++ b/src/indicators/money_flow_index.rs @@ -1,7 +1,7 @@ use std::fmt; use crate::errors::{Result, TaError}; -use crate::{Close, High, Low, Next, Period, Reset, Volume}; +use crate::{lit, Close, High, Low, Next, NumberType, Period, Reset, Volume}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -60,10 +60,10 @@ pub struct MoneyFlowIndex { period: usize, index: usize, count: usize, - previous_typical_price: f64, - total_positive_money_flow: f64, - total_negative_money_flow: f64, - deque: Box<[f64]>, + previous_typical_price: NumberType, + total_positive_money_flow: NumberType, + total_negative_money_flow: NumberType, + deque: Box<[NumberType]>, } impl MoneyFlowIndex { @@ -74,10 +74,10 @@ impl MoneyFlowIndex { period, index: 0, count: 0, - previous_typical_price: 0.0, - total_positive_money_flow: 0.0, - total_negative_money_flow: 0.0, - deque: vec![0.0; period].into_boxed_slice(), + previous_typical_price: lit!(0.0), + total_positive_money_flow: lit!(0.0), + total_negative_money_flow: lit!(0.0), + deque: vec![lit!(0.0); period].into_boxed_slice(), }), } } @@ -90,10 +90,10 @@ impl Period for MoneyFlowIndex { } impl Next<&T> for MoneyFlowIndex { - type Output = f64; + type Output = NumberType; - fn next(&mut self, input: &T) -> f64 { - let tp = (input.close() + input.high() + input.low()) / 3.0; + fn next(&mut self, input: &T) -> NumberType { + let tp = (input.close() + input.high() + input.low()) / lit!(3.0); self.index = if self.index + 1 < self.period { self.index + 1 @@ -102,10 +102,10 @@ impl Next<&T> for MoneyFlowIndex { }; if self.count < self.period { - self.count = self.count + 1; + self.count += 1; if self.count == 1 { self.previous_typical_price = tp; - return 50.0; + return lit!(50.0); } } else { let popped = self.deque[self.index]; @@ -125,13 +125,13 @@ impl Next<&T> for MoneyFlowIndex { self.total_negative_money_flow += raw_money_flow; self.deque[self.index] = -raw_money_flow; } else { - self.deque[self.index] = 0.0; + self.deque[self.index] = lit!(0.0); } self.previous_typical_price = tp; self.total_positive_money_flow / (self.total_positive_money_flow + self.total_negative_money_flow) - * 100.0 + * lit!(100.0) } } @@ -151,11 +151,11 @@ impl Reset for MoneyFlowIndex { fn reset(&mut self) { self.index = 0; self.count = 0; - self.previous_typical_price = 0.0; - self.total_positive_money_flow = 0.0; - self.total_negative_money_flow = 0.0; + self.previous_typical_price = lit!(0.0); + self.total_positive_money_flow = lit!(0.0); + self.total_negative_money_flow = lit!(0.0); for i in 0..self.period { - self.deque[i] = 0.0; + self.deque[i] = lit!(0.0); } } } @@ -175,45 +175,53 @@ mod tests { fn test_next_bar() { let mut mfi = MoneyFlowIndex::new(3).unwrap(); - let bar1 = Bar::new().high(3).low(1).close(2).volume(500.0); - assert_eq!(round(mfi.next(&bar1)), 50.0); + let bar1 = Bar::new().high(3).low(1).close(2).volume(500); + assert_eq!(round(mfi.next(&bar1)), lit!(50.0)); - let bar2 = Bar::new().high(2.3).low(2.0).close(2.3).volume(1000.0); - assert_eq!(round(mfi.next(&bar2)), 100.0); + let bar2 = Bar::new() + .high(lit!(2.3)) + .low(lit!(2.0)) + .close(lit!(2.3)) + .volume(1000); + assert_eq!(round(mfi.next(&bar2)), lit!(100.0)); - let bar3 = Bar::new().high(9).low(7).close(8).volume(200.0); - assert_eq!(round(mfi.next(&bar3)), 100.0); + let bar3 = Bar::new().high(9).low(7).close(8).volume(200); + assert_eq!(round(mfi.next(&bar3)), lit!(100.0)); - let bar4 = Bar::new().high(5).low(3).close(4).volume(500.0); - assert_eq!(round(mfi.next(&bar4)), 65.517); + let bar4 = Bar::new().high(5).low(3).close(4).volume(500); + assert_eq!(round(mfi.next(&bar4)), lit!(65.517)); - let bar5 = Bar::new().high(4).low(2).close(3).volume(5000.0); - assert_eq!(round(mfi.next(&bar5)), 8.602); + let bar5 = Bar::new().high(4).low(2).close(3).volume(5000); + assert_eq!(round(mfi.next(&bar5)), lit!(8.602)); - let bar6 = Bar::new().high(2).low(1).close(1.5).volume(6000.0); - assert_eq!(round(mfi.next(&bar6)), 0.0); + let bar6 = Bar::new().high(2).low(1).close(lit!(1.5)).volume(6000); + assert_eq!(round(mfi.next(&bar6)), lit!(0.0)); - let bar7 = Bar::new().high(2).low(2).close(2).volume(7000.0); - assert_eq!(round(mfi.next(&bar7)), 36.842); + let bar7 = Bar::new().high(2).low(2).close(2).volume(7000); + assert_eq!(round(mfi.next(&bar7)), lit!(36.842)); - let bar8 = Bar::new().high(2).low(2).close(2).volume(7000.0); - assert_eq!(round(mfi.next(&bar8)), 60.87); + let bar8 = Bar::new().high(2).low(2).close(2).volume(7000); + assert_eq!(round(mfi.next(&bar8)), lit!(60.87)); } #[test] fn test_reset() { let mut mfi = MoneyFlowIndex::new(3).unwrap(); - let bar1 = Bar::new().high(3).low(1).close(2).volume(500.0); - let bar2 = Bar::new().high(2.3).low(2.0).close(2.3).volume(1000.0); + let bar1 = Bar::new().high(3).low(1).close(2).volume(500); + let bar2 = Bar::new() + .high(lit!(2.3)) + .low(lit!(2.0)) + .close(lit!(2.3)) + .volume(1000); - assert_eq!(round(mfi.next(&bar1)), 50.0); - assert_eq!(round(mfi.next(&bar2)), 100.0); + assert_eq!(round(mfi.next(&bar1)), lit!(50.0)); + assert_eq!(round(mfi.next(&bar2)), lit!(100.0)); mfi.reset(); - assert_eq!(round(mfi.next(&bar1)), 50.0); - assert_eq!(round(mfi.next(&bar2)), 100.0); + assert_eq!(round(mfi.next(&bar1)), lit!(50.0)); + assert_eq!(round(mfi.next(&bar2)), lit!(100.0)); } #[test] diff --git a/src/indicators/moving_average_convergence_divergence.rs b/src/indicators/moving_average_convergence_divergence.rs index b9bafc8..918d5df 100644 --- a/src/indicators/moving_average_convergence_divergence.rs +++ b/src/indicators/moving_average_convergence_divergence.rs @@ -2,7 +2,7 @@ use std::fmt; use crate::errors::Result; use crate::indicators::ExponentialMovingAverage as Ema; -use crate::{Close, Next, Period, Reset}; +use crate::{Close, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -71,21 +71,21 @@ impl MovingAverageConvergenceDivergence { #[derive(Debug, Clone, PartialEq)] pub struct MovingAverageConvergenceDivergenceOutput { - pub macd: f64, - pub signal: f64, - pub histogram: f64, + pub macd: NumberType, + pub signal: NumberType, + pub histogram: NumberType, } -impl From for (f64, f64, f64) { +impl From for (NumberType, NumberType, NumberType) { fn from(mo: MovingAverageConvergenceDivergenceOutput) -> Self { (mo.macd, mo.signal, mo.histogram) } } -impl Next for MovingAverageConvergenceDivergence { +impl Next for MovingAverageConvergenceDivergence { type Output = MovingAverageConvergenceDivergenceOutput; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { let fast_val = self.fast_ema.next(input); let slow_val = self.slow_ema.next(input); @@ -138,15 +138,16 @@ impl fmt::Display for MovingAverageConvergenceDivergence { #[cfg(test)] mod tests { use super::*; + use crate::lit; use crate::test_helper::*; type Macd = MovingAverageConvergenceDivergence; test_indicator!(Macd); - fn round(nums: (f64, f64, f64)) -> (f64, f64, f64) { - let n0 = (nums.0 * 100.0).round() / 100.0; - let n1 = (nums.1 * 100.0).round() / 100.0; - let n2 = (nums.2 * 100.0).round() / 100.0; + fn round(nums: (NumberType, NumberType, NumberType)) -> (NumberType, NumberType, NumberType) { + let n0 = (nums.0 * lit!(100.0)).round() / lit!(100.0); + let n1 = (nums.1 * lit!(100.0)).round() / lit!(100.0); + let n2 = (nums.2 * lit!(100.0)).round() / lit!(100.0); (n0, n1, n2) } @@ -162,25 +163,55 @@ mod tests { fn test_macd() { let mut macd = Macd::new(3, 6, 4).unwrap(); - assert_eq!(round(macd.next(2.0).into()), (0.0, 0.0, 0.0)); - assert_eq!(round(macd.next(3.0).into()), (0.21, 0.09, 0.13)); - assert_eq!(round(macd.next(4.2).into()), (0.52, 0.26, 0.26)); - assert_eq!(round(macd.next(7.0).into()), (1.15, 0.62, 0.54)); - assert_eq!(round(macd.next(6.7).into()), (1.15, 0.83, 0.32)); - assert_eq!(round(macd.next(6.5).into()), (0.94, 0.87, 0.07)); + assert_eq!( + round(macd.next(lit!(2.0)).into()), + (lit!(0.0), lit!(0.0), lit!(0.0)) + ); + assert_eq!( + round(macd.next(lit!(3.0)).into()), + (lit!(0.21), lit!(0.09), lit!(0.13)) + ); + assert_eq!( + round(macd.next(lit!(4.2)).into()), + (lit!(0.52), lit!(0.26), lit!(0.26)) + ); + assert_eq!( + round(macd.next(lit!(7.0)).into()), + (lit!(1.15), lit!(0.62), lit!(0.54)) + ); + assert_eq!( + round(macd.next(lit!(6.7)).into()), + (lit!(1.15), lit!(0.83), lit!(0.32)) + ); + assert_eq!( + round(macd.next(lit!(6.5)).into()), + (lit!(0.94), lit!(0.87), lit!(0.07)) + ); } #[test] fn test_reset() { let mut macd = Macd::new(3, 6, 4).unwrap(); - assert_eq!(round(macd.next(2.0).into()), (0.0, 0.0, 0.0)); - assert_eq!(round(macd.next(3.0).into()), (0.21, 0.09, 0.13)); + assert_eq!( + round(macd.next(lit!(2.0)).into()), + (lit!(0.0), lit!(0.0), lit!(0.0)) + ); + assert_eq!( + round(macd.next(lit!(3.0)).into()), + (lit!(0.21), lit!(0.09), lit!(0.13)) + ); macd.reset(); - assert_eq!(round(macd.next(2.0).into()), (0.0, 0.0, 0.0)); - assert_eq!(round(macd.next(3.0).into()), (0.21, 0.09, 0.13)); + assert_eq!( + round(macd.next(lit!(2.0)).into()), + (lit!(0.0), lit!(0.0), lit!(0.0)) + ); + assert_eq!( + round(macd.next(lit!(3.0)).into()), + (lit!(0.21), lit!(0.09), lit!(0.13)) + ); } #[test] diff --git a/src/indicators/on_balance_volume.rs b/src/indicators/on_balance_volume.rs index b9ebf25..e513809 100644 --- a/src/indicators/on_balance_volume.rs +++ b/src/indicators/on_balance_volume.rs @@ -1,6 +1,6 @@ use std::fmt; -use crate::{Close, Next, Reset, Volume}; +use crate::{lit, Close, Next, NumberType, Reset, Volume}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -62,27 +62,27 @@ use serde::{Deserialize, Serialize}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct OnBalanceVolume { - obv: f64, - prev_close: f64, + obv: NumberType, + prev_close: NumberType, } impl OnBalanceVolume { pub fn new() -> Self { Self { - obv: 0.0, - prev_close: 0.0, + obv: lit!(0.0), + prev_close: lit!(0.0), } } } impl Next<&T> for OnBalanceVolume { - type Output = f64; + type Output = NumberType; - fn next(&mut self, input: &T) -> f64 { + fn next(&mut self, input: &T) -> NumberType { if input.close() > self.prev_close { - self.obv = self.obv + input.volume(); + self.obv += input.volume(); } else if input.close() < self.prev_close { - self.obv = self.obv - input.volume(); + self.obv -= input.volume(); } self.prev_close = input.close(); self.obv @@ -103,8 +103,8 @@ impl fmt::Display for OnBalanceVolume { impl Reset for OnBalanceVolume { fn reset(&mut self) { - self.obv = 0.0; - self.prev_close = 0.0; + self.obv = lit!(0.0); + self.prev_close = lit!(0.0); } } @@ -117,40 +117,40 @@ mod tests { fn test_next_bar() { let mut obv = OnBalanceVolume::new(); - let bar1 = Bar::new().close(1.5).volume(1000.0); - let bar2 = Bar::new().close(5).volume(5000.0); - let bar3 = Bar::new().close(4).volume(9000.0); - let bar4 = Bar::new().close(4).volume(4000.0); + let bar1 = Bar::new().close(lit!(1.5)).volume(1000); + let bar2 = Bar::new().close(5).volume(5000); + let bar3 = Bar::new().close(4).volume(9000); + let bar4 = Bar::new().close(4).volume(4000); - assert_eq!(obv.next(&bar1), 1000.0); + assert_eq!(obv.next(&bar1), lit!(1000.0)); //close > prev_close - assert_eq!(obv.next(&bar2), 6000.0); + assert_eq!(obv.next(&bar2), lit!(6000.0)); // close < prev_close - assert_eq!(obv.next(&bar3), -3000.0); + assert_eq!(obv.next(&bar3), lit!(-3000.0)); // close == prev_close - assert_eq!(obv.next(&bar4), -3000.0); + assert_eq!(obv.next(&bar4), lit!(-3000.0)); } #[test] fn test_reset() { let mut obv = OnBalanceVolume::new(); - let bar1 = Bar::new().close(1.5).volume(1000.0); - let bar2 = Bar::new().close(4).volume(2000.0); - let bar3 = Bar::new().close(8).volume(3000.0); + let bar1 = Bar::new().close(lit!(1.5)).volume(1000); + let bar2 = Bar::new().close(4).volume(2000); + let bar3 = Bar::new().close(8).volume(3000); - assert_eq!(obv.next(&bar1), 1000.0); - assert_eq!(obv.next(&bar2), 3000.0); - assert_eq!(obv.next(&bar3), 6000.0); + assert_eq!(obv.next(&bar1), lit!(1000.0)); + assert_eq!(obv.next(&bar2), lit!(3000.0)); + assert_eq!(obv.next(&bar3), lit!(6000.0)); obv.reset(); - assert_eq!(obv.next(&bar1), 1000.0); - assert_eq!(obv.next(&bar2), 3000.0); - assert_eq!(obv.next(&bar3), 6000.0); + assert_eq!(obv.next(&bar1), lit!(1000.0)); + assert_eq!(obv.next(&bar2), lit!(3000.0)); + assert_eq!(obv.next(&bar3), lit!(6000.0)); } #[test] diff --git a/src/indicators/percentage_price_oscillator.rs b/src/indicators/percentage_price_oscillator.rs index bbcd24d..565a5a4 100644 --- a/src/indicators/percentage_price_oscillator.rs +++ b/src/indicators/percentage_price_oscillator.rs @@ -2,7 +2,7 @@ use std::fmt; use crate::errors::Result; use crate::indicators::ExponentialMovingAverage as Ema; -use crate::{Close, Next, Period, Reset}; +use crate::{lit, Close, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -71,25 +71,25 @@ impl PercentagePriceOscillator { #[derive(Debug, Clone, PartialEq)] pub struct PercentagePriceOscillatorOutput { - pub ppo: f64, - pub signal: f64, - pub histogram: f64, + pub ppo: NumberType, + pub signal: NumberType, + pub histogram: NumberType, } -impl From for (f64, f64, f64) { +impl From for (NumberType, NumberType, NumberType) { fn from(po: PercentagePriceOscillatorOutput) -> Self { (po.ppo, po.signal, po.histogram) } } -impl Next for PercentagePriceOscillator { +impl Next for PercentagePriceOscillator { type Output = PercentagePriceOscillatorOutput; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { let fast_val = self.fast_ema.next(input); let slow_val = self.slow_ema.next(input); - let ppo = (fast_val - slow_val) / slow_val * 100.0; + let ppo = (fast_val - slow_val) / slow_val * lit!(100.0); let signal = self.signal_ema.next(ppo); let histogram = ppo - signal; @@ -140,15 +140,27 @@ mod tests { use super::*; use crate::test_helper::*; type Ppo = PercentagePriceOscillator; + #[cfg(feature = "decimal")] + use rust_decimal::Decimal; test_indicator!(Ppo); + #[cfg(not(feature = "decimal"))] fn round(nums: (f64, f64, f64)) -> (f64, f64, f64) { let n0 = (nums.0 * 100.0).round() / 100.0; let n1 = (nums.1 * 100.0).round() / 100.0; let n2 = (nums.2 * 100.0).round() / 100.0; (n0, n1, n2) } + #[cfg(feature = "decimal")] + fn round(nums: (Decimal, Decimal, Decimal)) -> (Decimal, Decimal, Decimal) { + use rust_decimal::prelude::RoundingStrategy::MidpointAwayFromZero; + ( + nums.0.round_dp_with_strategy(2, MidpointAwayFromZero), + nums.1.round_dp_with_strategy(2, MidpointAwayFromZero), + nums.2.round_dp_with_strategy(2, MidpointAwayFromZero), + ) + } #[test] fn test_new() { @@ -162,25 +174,55 @@ mod tests { fn test_next() { let mut ppo = Ppo::new(3, 6, 4).unwrap(); - assert_eq!(round(ppo.next(2.0).into()), (0.0, 0.0, 0.0)); - assert_eq!(round(ppo.next(3.0).into()), (9.38, 3.75, 5.63)); - assert_eq!(round(ppo.next(4.2).into()), (18.26, 9.56, 8.71)); - assert_eq!(round(ppo.next(7.0).into()), (28.62, 17.18, 11.44)); - assert_eq!(round(ppo.next(6.7).into()), (24.01, 19.91, 4.09)); - assert_eq!(round(ppo.next(6.5).into()), (17.84, 19.08, -1.24)); + assert_eq!( + round(ppo.next(lit!(2.0)).into()), + (lit!(0.0), lit!(0.0), lit!(0.0)) + ); + assert_eq!( + round(ppo.next(lit!(3.0)).into()), + (lit!(9.38), lit!(3.75), lit!(5.63)) + ); + assert_eq!( + round(ppo.next(lit!(4.2)).into()), + (lit!(18.26), lit!(9.56), lit!(8.71)) + ); + assert_eq!( + round(ppo.next(lit!(8.0)).into()), + (lit!(31.70), lit!(18.41), lit!(13.29)) + ); + assert_eq!( + round(ppo.next(lit!(6.7)).into()), + (lit!(23.94), lit!(20.63), lit!(3.32)) + ); + assert_eq!( + round(ppo.next(lit!(6.5)).into()), + (lit!(16.98), lit!(19.17), lit!(-2.19)) + ); } #[test] fn test_reset() { let mut ppo = Ppo::new(3, 6, 4).unwrap(); - assert_eq!(round(ppo.next(2.0).into()), (0.0, 0.0, 0.0)); - assert_eq!(round(ppo.next(3.0).into()), (9.38, 3.75, 5.63)); + assert_eq!( + round(ppo.next(lit!(2.0)).into()), + (lit!(0.0), lit!(0.0), lit!(0.0)) + ); + assert_eq!( + round(ppo.next(lit!(3.0)).into()), + (lit!(9.38), lit!(3.75), lit!(5.63)) + ); ppo.reset(); - assert_eq!(round(ppo.next(2.0).into()), (0.0, 0.0, 0.0)); - assert_eq!(round(ppo.next(3.0).into()), (9.38, 3.75, 5.63)); + assert_eq!( + round(ppo.next(lit!(2.0)).into()), + (lit!(0.0), lit!(0.0), lit!(0.0)) + ); + assert_eq!( + round(ppo.next(lit!(3.0)).into()), + (lit!(9.38), lit!(3.75), lit!(5.63)) + ); } #[test] diff --git a/src/indicators/rate_of_change.rs b/src/indicators/rate_of_change.rs index 969076a..beea169 100644 --- a/src/indicators/rate_of_change.rs +++ b/src/indicators/rate_of_change.rs @@ -1,7 +1,7 @@ use std::fmt; use crate::errors::{Result, TaError}; -use crate::traits::{Close, Next, Period, Reset}; +use crate::{lit, Close, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -45,7 +45,7 @@ pub struct RateOfChange { period: usize, index: usize, count: usize, - deque: Box<[f64]>, + deque: Box<[NumberType]>, } impl RateOfChange { @@ -56,7 +56,7 @@ impl RateOfChange { period, index: 0, count: 0, - deque: vec![0.0; period].into_boxed_slice(), + deque: vec![lit!(0.0); period].into_boxed_slice(), }), } } @@ -68,10 +68,10 @@ impl Period for RateOfChange { } } -impl Next for RateOfChange { - type Output = f64; +impl Next for RateOfChange { + type Output = NumberType; - fn next(&mut self, input: f64) -> f64 { + fn next(&mut self, input: NumberType) -> NumberType { let previous = if self.count > self.period { self.deque[self.index] } else { @@ -90,14 +90,14 @@ impl Next for RateOfChange { 0 }; - (input - previous) / previous * 100.0 + (input - previous) / previous * lit!(100.0) } } impl Next<&T> for RateOfChange { - type Output = f64; + type Output = NumberType; - fn next(&mut self, input: &T) -> f64 { + fn next(&mut self, input: &T) -> NumberType { self.next(input.close()) } } @@ -119,7 +119,7 @@ impl Reset for RateOfChange { self.index = 0; self.count = 0; for i in 0..self.period { - self.deque[i] = 0.0; + self.deque[i] = lit!(0.0); } } } @@ -142,38 +142,38 @@ mod tests { fn test_next_f64() { let mut roc = RateOfChange::new(3).unwrap(); - assert_eq!(round(roc.next(10.0)), 0.0); - assert_eq!(round(roc.next(10.4)), 4.0); - assert_eq!(round(roc.next(10.57)), 5.7); - assert_eq!(round(roc.next(10.8)), 8.0); - assert_eq!(round(roc.next(10.9)), 4.808); - assert_eq!(round(roc.next(10.0)), -5.393); + assert_eq!(round(roc.next(lit!(10.0))), lit!(0.0)); + assert_eq!(round(roc.next(lit!(10.4))), lit!(4.0)); + assert_eq!(round(roc.next(lit!(10.57))), lit!(5.7)); + assert_eq!(round(roc.next(lit!(10.8))), lit!(8.0)); + assert_eq!(round(roc.next(lit!(10.9))), lit!(4.808)); + assert_eq!(round(roc.next(lit!(10.0))), lit!(-5.393)); } #[test] fn test_next_bar() { - fn bar(close: f64) -> Bar { + fn bar(close: NumberType) -> Bar { Bar::new().close(close) } let mut roc = RateOfChange::new(3).unwrap(); - assert_eq!(round(roc.next(&bar(10.0))), 0.0); - assert_eq!(round(roc.next(&bar(10.4))), 4.0); - assert_eq!(round(roc.next(&bar(10.57))), 5.7); + assert_eq!(round(roc.next(&bar(lit!(10.0)))), lit!(0.0)); + assert_eq!(round(roc.next(&bar(lit!(10.4)))), lit!(4.0)); + assert_eq!(round(roc.next(&bar(lit!(10.57)))), lit!(5.7)); } #[test] fn test_reset() { let mut roc = RateOfChange::new(3).unwrap(); - roc.next(12.3); - roc.next(15.0); + roc.next(lit!(12.3)); + roc.next(lit!(15.0)); roc.reset(); - assert_eq!(round(roc.next(10.0)), 0.0); - assert_eq!(round(roc.next(10.4)), 4.0); - assert_eq!(round(roc.next(10.57)), 5.7); + assert_eq!(round(roc.next(lit!(10.0))), lit!(0.0)); + assert_eq!(round(roc.next(lit!(10.4))), lit!(4.0)); + assert_eq!(round(roc.next(lit!(10.57))), lit!(5.7)); } } diff --git a/src/indicators/relative_strength_index.rs b/src/indicators/relative_strength_index.rs index 5da5d1f..59843ed 100644 --- a/src/indicators/relative_strength_index.rs +++ b/src/indicators/relative_strength_index.rs @@ -2,7 +2,7 @@ use std::fmt; use crate::errors::Result; use crate::indicators::ExponentialMovingAverage as Ema; -use crate::{Close, Next, Period, Reset}; +use crate::{lit, Close, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -75,7 +75,7 @@ pub struct RelativeStrengthIndex { period: usize, up_ema_indicator: Ema, down_ema_indicator: Ema, - prev_val: f64, + prev_val: NumberType, is_new: bool, } @@ -85,7 +85,7 @@ impl RelativeStrengthIndex { period, up_ema_indicator: Ema::new(period)?, down_ema_indicator: Ema::new(period)?, - prev_val: 0.0, + prev_val: lit!(0.0), is_new: true, }) } @@ -97,35 +97,33 @@ impl Period for RelativeStrengthIndex { } } -impl Next for RelativeStrengthIndex { - type Output = f64; +impl Next for RelativeStrengthIndex { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { - let mut up = 0.0; - let mut down = 0.0; + fn next(&mut self, input: NumberType) -> Self::Output { + let mut up = lit!(0.0); + let mut down = lit!(0.0); if self.is_new { self.is_new = false; // Initialize with some small seed numbers to avoid division by zero - up = 0.1; - down = 0.1; + up = lit!(0.1); + down = lit!(0.1); + } else if input > self.prev_val { + up = input - self.prev_val; } else { - if input > self.prev_val { - up = input - self.prev_val; - } else { - down = self.prev_val - input; - } + down = self.prev_val - input; } self.prev_val = input; let up_ema = self.up_ema_indicator.next(up); let down_ema = self.down_ema_indicator.next(down); - 100.0 * up_ema / (up_ema + down_ema) + lit!(100.0) * up_ema / (up_ema + down_ema) } } impl Next<&T> for RelativeStrengthIndex { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { self.next(input.close()) @@ -135,7 +133,7 @@ impl Next<&T> for RelativeStrengthIndex { impl Reset for RelativeStrengthIndex { fn reset(&mut self) { self.is_new = true; - self.prev_val = 0.0; + self.prev_val = lit!(0.0); self.up_ema_indicator.reset(); self.down_ema_indicator.reset(); } @@ -169,21 +167,21 @@ mod tests { #[test] fn test_next() { let mut rsi = RelativeStrengthIndex::new(3).unwrap(); - assert_eq!(rsi.next(10.0), 50.0); - assert_eq!(rsi.next(10.5).round(), 86.0); - assert_eq!(rsi.next(10.0).round(), 35.0); - assert_eq!(rsi.next(9.5).round(), 16.0); + assert_eq!(rsi.next(lit!(10.0)), lit!(50.0)); + assert_eq!(rsi.next(lit!(10.5)).round(), lit!(86.0)); + assert_eq!(rsi.next(lit!(10.0)).round(), lit!(35.0)); + assert_eq!(rsi.next(lit!(9.5)).round(), lit!(16.0)); } #[test] fn test_reset() { let mut rsi = RelativeStrengthIndex::new(3).unwrap(); - assert_eq!(rsi.next(10.0), 50.0); - assert_eq!(rsi.next(10.5).round(), 86.0); + assert_eq!(rsi.next(lit!(10.0)), lit!(50.0)); + assert_eq!(rsi.next(lit!(10.5)).round(), lit!(86.0)); rsi.reset(); - assert_eq!(rsi.next(10.0).round(), 50.0); - assert_eq!(rsi.next(10.5).round(), 86.0); + assert_eq!(rsi.next(lit!(10.0)).round(), lit!(50.0)); + assert_eq!(rsi.next(lit!(10.5)).round(), lit!(86.0)); } #[test] diff --git a/src/indicators/simple_moving_average.rs b/src/indicators/simple_moving_average.rs index 9edcb50..fe52fc8 100644 --- a/src/indicators/simple_moving_average.rs +++ b/src/indicators/simple_moving_average.rs @@ -1,7 +1,7 @@ use std::fmt; use crate::errors::{Result, TaError}; -use crate::{Close, Next, Period, Reset}; +use crate::{int, lit, Close, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -45,8 +45,8 @@ pub struct SimpleMovingAverage { period: usize, index: usize, count: usize, - sum: f64, - deque: Box<[f64]>, + sum: NumberType, + deque: Box<[NumberType]>, } impl SimpleMovingAverage { @@ -57,8 +57,8 @@ impl SimpleMovingAverage { period, index: 0, count: 0, - sum: 0.0, - deque: vec![0.0; period].into_boxed_slice(), + sum: lit!(0.0), + deque: vec![lit!(0.0); period].into_boxed_slice(), }), } } @@ -70,10 +70,10 @@ impl Period for SimpleMovingAverage { } } -impl Next for SimpleMovingAverage { - type Output = f64; +impl Next for SimpleMovingAverage { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { let old_val = self.deque[self.index]; self.deque[self.index] = input; @@ -88,12 +88,12 @@ impl Next for SimpleMovingAverage { } self.sum = self.sum - old_val + input; - self.sum / (self.count as f64) + self.sum / int!(self.count) } } impl Next<&T> for SimpleMovingAverage { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { self.next(input.close()) @@ -104,9 +104,9 @@ impl Reset for SimpleMovingAverage { fn reset(&mut self) { self.index = 0; self.count = 0; - self.sum = 0.0; + self.sum = lit!(0.0); for i in 0..self.period { - self.deque[i] = 0.0; + self.deque[i] = lit!(0.0); } } } @@ -139,37 +139,37 @@ mod tests { #[test] fn test_next() { let mut sma = SimpleMovingAverage::new(4).unwrap(); - assert_eq!(sma.next(4.0), 4.0); - assert_eq!(sma.next(5.0), 4.5); - assert_eq!(sma.next(6.0), 5.0); - assert_eq!(sma.next(6.0), 5.25); - assert_eq!(sma.next(6.0), 5.75); - assert_eq!(sma.next(6.0), 6.0); - assert_eq!(sma.next(2.0), 5.0); + assert_eq!(sma.next(lit!(4.0)), lit!(4.0)); + assert_eq!(sma.next(lit!(5.0)), lit!(4.5)); + assert_eq!(sma.next(lit!(6.0)), lit!(5.0)); + assert_eq!(sma.next(lit!(6.0)), lit!(5.25)); + assert_eq!(sma.next(lit!(6.0)), lit!(5.75)); + assert_eq!(sma.next(lit!(6.0)), lit!(6.0)); + assert_eq!(sma.next(lit!(2.0)), lit!(5.0)); } #[test] fn test_next_with_bars() { - fn bar(close: f64) -> Bar { + fn bar(close: NumberType) -> Bar { Bar::new().close(close) } let mut sma = SimpleMovingAverage::new(3).unwrap(); - assert_eq!(sma.next(&bar(4.0)), 4.0); - assert_eq!(sma.next(&bar(4.0)), 4.0); - assert_eq!(sma.next(&bar(7.0)), 5.0); - assert_eq!(sma.next(&bar(1.0)), 4.0); + assert_eq!(sma.next(&bar(lit!(4.0))), lit!(4.0)); + assert_eq!(sma.next(&bar(lit!(4.0))), lit!(4.0)); + assert_eq!(sma.next(&bar(lit!(7.0))), lit!(5.0)); + assert_eq!(sma.next(&bar(lit!(1.0))), lit!(4.0)); } #[test] fn test_reset() { let mut sma = SimpleMovingAverage::new(4).unwrap(); - assert_eq!(sma.next(4.0), 4.0); - assert_eq!(sma.next(5.0), 4.5); - assert_eq!(sma.next(6.0), 5.0); + assert_eq!(sma.next(lit!(4.0)), lit!(4.0)); + assert_eq!(sma.next(lit!(5.0)), lit!(4.5)); + assert_eq!(sma.next(lit!(6.0)), lit!(5.0)); sma.reset(); - assert_eq!(sma.next(99.0), 99.0); + assert_eq!(sma.next(lit!(99.0)), lit!(99.0)); } #[test] diff --git a/src/indicators/slow_stochastic.rs b/src/indicators/slow_stochastic.rs index c84e3a6..dbf5ba7 100644 --- a/src/indicators/slow_stochastic.rs +++ b/src/indicators/slow_stochastic.rs @@ -2,7 +2,7 @@ use std::fmt; use crate::errors::Result; use crate::indicators::{ExponentialMovingAverage, FastStochastic}; -use crate::{Close, High, Low, Next, Period, Reset}; +use crate::{Close, High, Low, Next, NumberType, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -44,16 +44,16 @@ impl SlowStochastic { } } -impl Next for SlowStochastic { - type Output = f64; +impl Next for SlowStochastic { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { self.ema.next(self.fast_stochastic.next(input)) } } impl Next<&T> for SlowStochastic { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { self.ema.next(self.fast_stochastic.next(input)) @@ -87,6 +87,7 @@ impl fmt::Display for SlowStochastic { #[cfg(test)] mod tests { use super::*; + use crate::lit; use crate::test_helper::*; test_indicator!(SlowStochastic); @@ -101,23 +102,23 @@ mod tests { #[test] fn test_next_with_f64() { let mut stoch = SlowStochastic::new(3, 2).unwrap(); - assert_eq!(stoch.next(10.0), 50.0); - assert_eq!(stoch.next(50.0).round(), 83.0); - assert_eq!(stoch.next(50.0).round(), 94.0); - assert_eq!(stoch.next(30.0).round(), 31.0); - assert_eq!(stoch.next(55.0).round(), 77.0); + assert_eq!(stoch.next(lit!(10.0)), lit!(50.0)); + assert_eq!(stoch.next(lit!(50.0)).round(), lit!(83.0)); + assert_eq!(stoch.next(lit!(50.0)).round(), lit!(94.0)); + assert_eq!(stoch.next(lit!(30.0)).round(), lit!(31.0)); + assert_eq!(stoch.next(lit!(55.0)).round(), lit!(77.0)); } #[test] fn test_next_with_bars() { let test_data = vec![ // high, low , close, expected - (30.0, 10.0, 25.0, 75.0), - (20.0, 20.0, 20.0, 58.0), - (40.0, 20.0, 16.0, 33.0), - (35.0, 15.0, 19.0, 22.0), - (30.0, 20.0, 25.0, 34.0), - (35.0, 25.0, 30.0, 61.0), + (lit!(30.0), lit!(10.0), lit!(25.0), lit!(75.0)), + (lit!(20.0), lit!(20.0), lit!(20.0), lit!(58.0)), + (lit!(40.0), lit!(20.0), lit!(16.0), lit!(33.0)), + (lit!(35.0), lit!(15.0), lit!(19.0), lit!(22.0)), + (lit!(30.0), lit!(20.0), lit!(25.0), lit!(34.0)), + (lit!(35.0), lit!(25.0), lit!(30.0), lit!(61.0)), ]; let mut stoch = SlowStochastic::new(3, 2).unwrap(); @@ -131,12 +132,12 @@ mod tests { #[test] fn test_reset() { let mut stoch = SlowStochastic::new(3, 2).unwrap(); - assert_eq!(stoch.next(10.0), 50.0); - assert_eq!(stoch.next(50.0).round(), 83.0); - assert_eq!(stoch.next(50.0).round(), 94.0); + assert_eq!(stoch.next(lit!(10.0)), lit!(50.0)); + assert_eq!(stoch.next(lit!(50.0)).round(), lit!(83.0)); + assert_eq!(stoch.next(lit!(50.0)).round(), lit!(94.0)); stoch.reset(); - assert_eq!(stoch.next(10.0), 50.0); + assert_eq!(stoch.next(lit!(10.0)), lit!(50.0)); } #[test] diff --git a/src/indicators/standard_deviation.rs b/src/indicators/standard_deviation.rs index 115e35a..4832faf 100644 --- a/src/indicators/standard_deviation.rs +++ b/src/indicators/standard_deviation.rs @@ -1,7 +1,9 @@ use std::fmt; use crate::errors::{Result, TaError}; -use crate::{Close, Next, Period, Reset}; +use crate::{int, lit, Close, Next, NumberType, Period, Reset}; +#[cfg(feature = "decimal")] +use rust_decimal::MathematicalOps; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -45,9 +47,9 @@ pub struct StandardDeviation { period: usize, index: usize, count: usize, - m: f64, - m2: f64, - deque: Box<[f64]>, + m: NumberType, + m2: NumberType, + deque: Box<[NumberType]>, } impl StandardDeviation { @@ -58,14 +60,14 @@ impl StandardDeviation { period, index: 0, count: 0, - m: 0.0, - m2: 0.0, - deque: vec![0.0; period].into_boxed_slice(), + m: lit!(0.0), + m2: lit!(0.0), + deque: vec![lit!(0.0); period].into_boxed_slice(), }), } } - pub(super) fn mean(&self) -> f64 { + pub(super) fn mean(&self) -> NumberType { self.m } } @@ -76,10 +78,10 @@ impl Period for StandardDeviation { } } -impl Next for StandardDeviation { - type Output = f64; +impl Next for StandardDeviation { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { let old_val = self.deque[self.index]; self.deque[self.index] = input; @@ -92,26 +94,31 @@ impl Next for StandardDeviation { if self.count < self.period { self.count += 1; let delta = input - self.m; - self.m += delta / self.count as f64; + self.m += delta / int!(self.count); let delta2 = input - self.m; self.m2 += delta * delta2; } else { let delta = input - old_val; let old_m = self.m; - self.m += delta / self.period as f64; + self.m += delta / int!(self.period); let delta2 = input - self.m + old_val - old_m; self.m2 += delta * delta2; } - if self.m2 < 0.0 { - self.m2 = 0.0; + if self.m2 < lit!(0.0) { + self.m2 = lit!(0.0); } - (self.m2 / self.count as f64).sqrt() + #[cfg(not(feature = "decimal"))] + return (self.m2 / int!(self.count)).sqrt(); + #[cfg(feature = "decimal")] + return (self.m2 / int!(self.count)) + .sqrt() + .expect("Invalid (probably negative) number sent."); } } impl Next<&T> for StandardDeviation { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { self.next(input.close()) @@ -122,10 +129,10 @@ impl Reset for StandardDeviation { fn reset(&mut self) { self.index = 0; self.count = 0; - self.m = 0.0; - self.m2 = 0.0; + self.m = lit!(0.0); + self.m2 = lit!(0.0); for i in 0..self.period { - self.deque[i] = 0.0; + self.deque[i] = lit!(0.0); } } } @@ -158,59 +165,59 @@ mod tests { #[test] fn test_next() { let mut sd = StandardDeviation::new(4).unwrap(); - assert_eq!(sd.next(10.0), 0.0); - assert_eq!(sd.next(20.0), 5.0); - assert_eq!(round(sd.next(30.0)), 8.165); - assert_eq!(round(sd.next(20.0)), 7.071); - assert_eq!(round(sd.next(10.0)), 7.071); - assert_eq!(round(sd.next(100.0)), 35.355); + assert_eq!(sd.next(lit!(10.0)), lit!(0.0)); + assert_eq!(sd.next(lit!(20.0)), lit!(5.0)); + assert_eq!(round(sd.next(lit!(30.0))), lit!(8.165)); + assert_eq!(round(sd.next(lit!(20.0))), lit!(7.071)); + assert_eq!(round(sd.next(lit!(10.0))), lit!(7.071)); + assert_eq!(round(sd.next(lit!(100.0))), lit!(35.355)); } #[test] fn test_next_floating_point_error() { let mut sd = StandardDeviation::new(6).unwrap(); - assert_eq!(sd.next(1.872), 0.0); - assert_eq!(round(sd.next(1.0)), 0.436); - assert_eq!(round(sd.next(1.0)), 0.411); - assert_eq!(round(sd.next(1.0)), 0.378); - assert_eq!(round(sd.next(1.0)), 0.349); - assert_eq!(round(sd.next(1.0)), 0.325); - assert_eq!(round(sd.next(1.0)), 0.0); + assert_eq!(sd.next(lit!(1.872)), lit!(0.0)); + assert_eq!(round(sd.next(lit!(1.0))), lit!(0.436)); + assert_eq!(round(sd.next(lit!(1.0))), lit!(0.411)); + assert_eq!(round(sd.next(lit!(1.0))), lit!(0.378)); + assert_eq!(round(sd.next(lit!(1.0))), lit!(0.349)); + assert_eq!(round(sd.next(lit!(1.0))), lit!(0.325)); + assert_eq!(round(sd.next(lit!(1.0))), lit!(0.0)); } #[test] fn test_next_with_bars() { - fn bar(close: f64) -> Bar { + fn bar(close: NumberType) -> Bar { Bar::new().close(close) } let mut sd = StandardDeviation::new(4).unwrap(); - assert_eq!(sd.next(&bar(10.0)), 0.0); - assert_eq!(sd.next(&bar(20.0)), 5.0); - assert_eq!(round(sd.next(&bar(30.0))), 8.165); - assert_eq!(round(sd.next(&bar(20.0))), 7.071); - assert_eq!(round(sd.next(&bar(10.0))), 7.071); - assert_eq!(round(sd.next(&bar(100.0))), 35.355); + assert_eq!(sd.next(&bar(lit!(10.0))), lit!(0.0)); + assert_eq!(sd.next(&bar(lit!(20.0))), lit!(5.0)); + assert_eq!(round(sd.next(&bar(lit!(30.0)))), lit!(8.165)); + assert_eq!(round(sd.next(&bar(lit!(20.0)))), lit!(7.071)); + assert_eq!(round(sd.next(&bar(lit!(10.0)))), lit!(7.071)); + assert_eq!(round(sd.next(&bar(lit!(100.0)))), lit!(35.355)); } #[test] fn test_next_same_values() { let mut sd = StandardDeviation::new(3).unwrap(); - assert_eq!(sd.next(4.2), 0.0); - assert_eq!(sd.next(4.2), 0.0); - assert_eq!(sd.next(4.2), 0.0); - assert_eq!(sd.next(4.2), 0.0); + assert_eq!(sd.next(lit!(4.2)), lit!(0.0)); + assert_eq!(sd.next(lit!(4.2)), lit!(0.0)); + assert_eq!(sd.next(lit!(4.2)), lit!(0.0)); + assert_eq!(sd.next(lit!(4.2)), lit!(0.0)); } #[test] fn test_reset() { let mut sd = StandardDeviation::new(4).unwrap(); - assert_eq!(sd.next(10.0), 0.0); - assert_eq!(sd.next(20.0), 5.0); - assert_eq!(round(sd.next(30.0)), 8.165); + assert_eq!(sd.next(lit!(10.0)), lit!(0.0)); + assert_eq!(sd.next(lit!(20.0)), lit!(5.0)); + assert_eq!(round(sd.next(lit!(30.0))), lit!(8.165)); sd.reset(); - assert_eq!(sd.next(20.0), 0.0); + assert_eq!(sd.next(lit!(20.0)), lit!(0.0)); } #[test] diff --git a/src/indicators/true_range.rs b/src/indicators/true_range.rs index 49cf876..6741cae 100644 --- a/src/indicators/true_range.rs +++ b/src/indicators/true_range.rs @@ -1,7 +1,7 @@ use std::fmt; use crate::helpers::max3; -use crate::{Close, High, Low, Next, Reset}; +use crate::{lit, Close, High, Low, Next, NumberType, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -52,7 +52,7 @@ use serde::{Deserialize, Serialize}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct TrueRange { - prev_close: Option, + prev_close: Option, } impl TrueRange { @@ -73,13 +73,13 @@ impl fmt::Display for TrueRange { } } -impl Next for TrueRange { - type Output = f64; +impl Next for TrueRange { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { + fn next(&mut self, input: NumberType) -> Self::Output { let distance = match self.prev_close { Some(prev) => (input - prev).abs(), - None => 0.0, + None => lit!(0.0), }; self.prev_close = Some(input); distance @@ -87,7 +87,7 @@ impl Next for TrueRange { } impl Next<&T> for TrueRange { - type Output = f64; + type Output = NumberType; fn next(&mut self, bar: &T) -> Self::Output { let max_dist = match self.prev_close { @@ -120,37 +120,37 @@ mod tests { #[test] fn test_next_f64() { let mut tr = TrueRange::new(); - assert_eq!(round(tr.next(2.5)), 0.0); - assert_eq!(round(tr.next(3.6)), 1.1); - assert_eq!(round(tr.next(3.3)), 0.3); + assert_eq!(round(tr.next(lit!(2.5))), lit!(0.0)); + assert_eq!(round(tr.next(lit!(3.6))), lit!(1.1)); + assert_eq!(round(tr.next(lit!(3.3))), lit!(0.3)); } #[test] fn test_next_bar() { let mut tr = TrueRange::new(); - let bar1 = Bar::new().high(10).low(7.5).close(9); - let bar2 = Bar::new().high(11).low(9).close(9.5); + let bar1 = Bar::new().high(10).low(lit!(7.5)).close(9); + let bar2 = Bar::new().high(11).low(9).close(lit!(9.5)); let bar3 = Bar::new().high(9).low(5).close(8); - assert_eq!(tr.next(&bar1), 2.5); - assert_eq!(tr.next(&bar2), 2.0); - assert_eq!(tr.next(&bar3), 4.5); + assert_eq!(tr.next(&bar1), lit!(2.5)); + assert_eq!(tr.next(&bar2), lit!(2.0)); + assert_eq!(tr.next(&bar3), lit!(4.5)); } #[test] fn test_reset() { let mut tr = TrueRange::new(); - let bar1 = Bar::new().high(10).low(7.5).close(9); - let bar2 = Bar::new().high(11).low(9).close(9.5); + let bar1 = Bar::new().high(10).low(lit!(7.5)).close(9); + let bar2 = Bar::new().high(11).low(9).close(lit!(9.5)); tr.next(&bar1); tr.next(&bar2); tr.reset(); let bar3 = Bar::new().high(60).low(15).close(51); - assert_eq!(tr.next(&bar3), 45.0); + assert_eq!(tr.next(&bar3), lit!(45.0)); } #[test] diff --git a/src/indicators/weighted_moving_average.rs b/src/indicators/weighted_moving_average.rs index 5da1f83..33eb850 100644 --- a/src/indicators/weighted_moving_average.rs +++ b/src/indicators/weighted_moving_average.rs @@ -1,6 +1,7 @@ use std::fmt; use crate::errors::{Result, TaError}; +use crate::{int, lit, NumberType}; use crate::{Close, Next, Period, Reset}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -45,10 +46,10 @@ pub struct WeightedMovingAverage { period: usize, index: usize, count: usize, - weight: f64, - sum: f64, - sum_flat: f64, - deque: Box<[f64]>, + weight: NumberType, + sum: NumberType, + sum_flat: NumberType, + deque: Box<[NumberType]>, } impl WeightedMovingAverage { @@ -59,10 +60,10 @@ impl WeightedMovingAverage { period, index: 0, count: 0, - weight: 0.0, - sum: 0.0, - sum_flat: 0.0, - deque: vec![0.0; period].into_boxed_slice(), + weight: lit!(0.0), + sum: lit!(0.0), + sum_flat: lit!(0.0), + deque: vec![lit!(0.0); period].into_boxed_slice(), }), } } @@ -74,11 +75,11 @@ impl Period for WeightedMovingAverage { } } -impl Next for WeightedMovingAverage { - type Output = f64; +impl Next for WeightedMovingAverage { + type Output = NumberType; - fn next(&mut self, input: f64) -> Self::Output { - let old_val: f64 = self.deque[self.index]; + fn next(&mut self, input: NumberType) -> Self::Output { + let old_val: NumberType = self.deque[self.index]; self.deque[self.index] = input; self.index = if self.index + 1 < self.period { @@ -89,18 +90,18 @@ impl Next for WeightedMovingAverage { if self.count < self.period { self.count += 1; - self.weight = self.count as f64; + self.weight = int!(self.count); self.sum += input * self.weight } else { self.sum = self.sum - self.sum_flat + (input * self.weight); } self.sum_flat = self.sum_flat - old_val + input; - self.sum / (self.weight * (self.weight + 1.0) / 2.0) + self.sum / (self.weight * (self.weight + lit!(1.0)) / lit!(2.0)) } } impl Next<&T> for WeightedMovingAverage { - type Output = f64; + type Output = NumberType; fn next(&mut self, input: &T) -> Self::Output { self.next(input.close()) @@ -111,11 +112,11 @@ impl Reset for WeightedMovingAverage { fn reset(&mut self) { self.index = 0; self.count = 0; - self.weight = 0.0; - self.sum = 0.0; - self.sum_flat = 0.0; + self.weight = lit!(0.0); + self.sum = lit!(0.0); + self.sum_flat = lit!(0.0); for i in 0..self.period { - self.deque[i] = 0.0; + self.deque[i] = lit!(0.0); } } } @@ -149,30 +150,30 @@ mod tests { fn test_next() { let mut wma = WeightedMovingAverage::new(3).unwrap(); - assert_eq!(wma.next(12.0), 12.0); - assert_eq!(wma.next(3.0), 6.0); // (1*12 + 2*3) / 3 = 6.0 - assert_eq!(wma.next(3.0), 4.5); // (1*12 + 2*3 + 3*3) / 6 = 4.5 - assert_eq!(wma.next(5.0), 4.0); // (1*3 + 2*3 + 3*5) / 6 = 4.0 + assert_eq!(wma.next(lit!(12.0)), lit!(12.0)); + assert_eq!(wma.next(lit!(3.0)), lit!(6.0)); // (1*12 + 2*3) / 3 = 6.0 + assert_eq!(wma.next(lit!(3.0)), lit!(4.5)); // (1*12 + 2*3 + 3*3) / 6 = 4.5 + assert_eq!(wma.next(lit!(5.0)), lit!(4.0)); // (1*3 + 2*3 + 3*5) / 6 = 4.0 let mut wma = WeightedMovingAverage::new(3).unwrap(); let bar1 = Bar::new().close(2); let bar2 = Bar::new().close(5); - assert_eq!(wma.next(&bar1), 2.0); - assert_eq!(wma.next(&bar2), 4.0); + assert_eq!(wma.next(&bar1), lit!(2.0)); + assert_eq!(wma.next(&bar2), lit!(4.0)); } #[test] fn test_reset() { let mut wma = WeightedMovingAverage::new(5).unwrap(); - assert_eq!(wma.next(4.0), 4.0); - wma.next(10.0); - wma.next(15.0); - wma.next(20.0); - assert_ne!(wma.next(4.0), 4.0); + assert_eq!(wma.next(lit!(4.0)), lit!(4.0)); + wma.next(lit!(10.0)); + wma.next(lit!(15.0)); + wma.next(lit!(20.0)); + assert_ne!(wma.next(lit!(4.0)), lit!(4.0)); wma.reset(); - assert_eq!(wma.next(4.0), 4.0); + assert_eq!(wma.next(lit!(4.0)), lit!(4.0)); } #[test] diff --git a/src/lib.rs b/src/lib.rs index b42ed9e..2bf11f2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,12 +52,14 @@ //! * [Rate of Change (ROC)](indicators/struct.RateOfChange.html) //! * [On Balance Volume (OBV)](indicators/struct.OnBalanceVolume.html) //! +#[macro_use] +mod helpers; +pub(crate) use helpers::NumberType; + #[cfg(test)] #[macro_use] mod test_helper; -mod helpers; - pub mod errors; pub mod indicators; diff --git a/src/test_helper.rs b/src/test_helper.rs index 96dbb1d..72b788e 100644 --- a/src/test_helper.rs +++ b/src/test_helper.rs @@ -1,103 +1,111 @@ -use super::{Close, High, Low, Open, Volume}; +use super::{lit, Close, High, Low, NumberType, Open, Volume}; #[derive(Debug, PartialEq)] pub struct Bar { - open: f64, - high: f64, - low: f64, - close: f64, - volume: f64, + open: NumberType, + high: NumberType, + low: NumberType, + close: NumberType, + volume: NumberType, } impl Bar { pub fn new() -> Self { Self { - open: 0.0, - close: 0.0, - low: 0.0, - high: 0.0, - volume: 0.0, + open: lit!(0.0), + close: lit!(0.0), + low: lit!(0.0), + high: lit!(0.0), + volume: lit!(0.0), } } - //pub fn open>(mut self, val :T ) -> Self { + //pub fn open>(mut self, val :T ) -> Self { // self.open = val.into(); // self //} - pub fn high>(mut self, val: T) -> Self { + pub fn high>(mut self, val: T) -> Self { self.high = val.into(); self } - pub fn low>(mut self, val: T) -> Self { + pub fn low>(mut self, val: T) -> Self { self.low = val.into(); self } - pub fn close>(mut self, val: T) -> Self { + pub fn close>(mut self, val: T) -> Self { self.close = val.into(); self } - pub fn volume(mut self, val: f64) -> Self { - self.volume = val; + pub fn volume>(mut self, val: T) -> Self { + self.volume = val.into(); self } } impl Open for Bar { - fn open(&self) -> f64 { + fn open(&self) -> NumberType { self.open } } impl Close for Bar { - fn close(&self) -> f64 { + fn close(&self) -> NumberType { self.close } } impl Low for Bar { - fn low(&self) -> f64 { + fn low(&self) -> NumberType { self.low } } impl High for Bar { - fn high(&self) -> f64 { + fn high(&self) -> NumberType { self.high } } impl Volume for Bar { - fn volume(&self) -> f64 { + fn volume(&self) -> NumberType { self.volume } } -pub fn round(num: f64) -> f64 { +#[cfg(not(feature = "decimal"))] +pub fn round(num: NumberType) -> NumberType { (num * 1000.0).round() / 1000.00 } +#[cfg(feature = "decimal")] +pub fn round(num: NumberType) -> NumberType { + use rust_decimal::prelude::RoundingStrategy; + num.round_dp_with_strategy(3, RoundingStrategy::MidpointAwayFromZero) +} + macro_rules! test_indicator { ($i:tt) => { #[test] fn test_indicator() { + use crate::lit; let bar = Bar::new(); // ensure Default trait is implemented let mut indicator = $i::default(); - // ensure Next is implemented - let first_output = indicator.next(12.3); + // ensure Next is implemented + let first_output = indicator.next(lit!(12.3)); // ensure next accepts &DataItem as well indicator.next(&bar); // ensure Reset is implemented and works correctly indicator.reset(); - assert_eq!(indicator.next(12.3), first_output); + assert_eq!(indicator.next(lit!(12.3)), first_output); // ensure Display is implemented format!("{}", indicator); diff --git a/src/traits.rs b/src/traits.rs index 520e383..ed9543a 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,5 +1,6 @@ +use crate::NumberType; + // Indicator traits -// /// Resets an indicator to the initial state. pub trait Reset { @@ -27,25 +28,25 @@ pub trait Next { /// Open price of a particular period. pub trait Open { - fn open(&self) -> f64; + fn open(&self) -> NumberType; } /// Close price of a particular period. pub trait Close { - fn close(&self) -> f64; + fn close(&self) -> NumberType; } /// Lowest price of a particular period. pub trait Low { - fn low(&self) -> f64; + fn low(&self) -> NumberType; } /// Highest price of a particular period. pub trait High { - fn high(&self) -> f64; + fn high(&self) -> NumberType; } /// Trading volume of a particular trading period. pub trait Volume { - fn volume(&self) -> f64; + fn volume(&self) -> NumberType; }