diff --git a/src/include/matrix/bit/bitarray.hpp b/src/include/matrix/bit/bitarray.hpp new file mode 100644 index 0000000..7998c4b --- /dev/null +++ b/src/include/matrix/bit/bitarray.hpp @@ -0,0 +1,301 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +template +concept PackTypeConcept = std::is_integral_v && std::is_unsigned_v; + +template +concept ContainerConcept = requires(Container c, size_t idx) { + { c.at(idx) } -> std::same_as; + { c.size() } -> std::convertible_to; + { c.max_size() } -> std::convertible_to; + c.begin(); + c.end(); +}; +template +concept ResizableContainerConcept = requires(Container c, size_t new_size, ValueType val) { + c.resize(new_size); + c.assign(new_size, val); +}; + +template> +requires PackTypeConcept && ContainerConcept +class BitArray { +private: + static constexpr size_t _PACKBITS = sizeof(PackType) * 8; //< 1パックあたりのビット数 + Container _data; + size_t _bits; + + /** + * @brief 指定されたビット数に必要な要素数(パック数)を計算します。 + * @param bits ビット数 + * @return 必要な要素数 + */ + static size_t _get_pack_size(size_t bits) { + return (bits + _PACKBITS - 1) / _PACKBITS; + } + + /** + * @brief 指定された要素数に相当するビット数を計算します。 + * @param packs パック数 + * @return ビット数 + */ + static size_t _get_bits_size(const size_t& packs) { + return packs * _PACKBITS; + } + + /** + * @brief 指定された位置のビットの値を取得します。 + * @param n ビットインデックス + * @return 指定された位置のビットの値 + * @throws std::out_of_range ビットインデックスが範囲外の場合 + */ + bool _get_bit(size_t n) const { + if (n >= _bits) throw std::out_of_range("not enough bits to get"); + const size_t pack_idx = n / _PACKBITS; + const size_t bit_idx = _PACKBITS - 1 - n % _PACKBITS; + return (this->_data[pack_idx] >> bit_idx) & 1; + } + + /** + * @brief 指定された位置のビットを設定します。 + * @param n ビットインデックス + * @param val 設定する値 + * @throws std::out_of_range ビットインデックスが範囲外の場合 + */ + void _set_bit(size_t n, bool val) { + if (n >= this->_bits) throw std::out_of_range("not enough bits to set"); + const size_t pack_idx = n / _PACKBITS; + const size_t bit_idx = _PACKBITS - 1 - n % _PACKBITS; + if (val) { + this->_data[pack_idx] |= (PackType(1) << bit_idx); + } else { + this->_data[pack_idx] &= ~(PackType(1) << bit_idx); + } + } + +public: + // constructors + BitArray(); + explicit BitArray(const Container& a): _data(a), _bits(_get_bits_size(a.size())) + {}; + + /** + * @brief BitArrayを指定されたビット数で初期化します。 + * @param bits ビット数 + */ + explicit BitArray(const size_t& bits): _bits(bits){ + size_t req_size = _get_pack_size(bits); + + if(_data.max_size() < req_size) + throw std::runtime_error("Container max_size is too small for the specified number of bits."); + + if constexpr (ResizableContainerConcept) { + _data.assign(req_size, static_cast(false)); + } else { + std::fill(_data.begin(), _data.end(), static_cast(false)); + } + } + + /** + * @brief BitArrayを指定されたビット数と初期値で初期化します。 + * @param bits ビット数 + * @param init 初期値 + */ + explicit BitArray(const size_t& bits, const bool& init): _bits(bits){ + PackType init_val = static_cast(0); + size_t req_size = _get_pack_size(bits); + + if(init) + init_val = ~init_val; + + if(this->_data.max_size() < req_size) + throw std::runtime_error("Container max_size is too small for the specified number of bits."); + + if constexpr (ResizableContainerConcept) { + _data.assign(req_size, init_val); + } else { + std::fill(_data.begin(), _data.end(), init_val); + } + } + auto operator<=>(const BitArray&) const = default; + + bool operator [](const size_t& bits){ + return this->_get_bit(bits); + } + bool operator ()(const size_t& bits){ + return this->_get_bit(bits); + } + + /** + * @brief BitArrayのサイズを変更します。 + * @param bits 新しいビット数 + * @return 自身の参照 + */ + BitArray& resize(const size_t& bits) + requires ResizableContainerConcept + { + this->_bits = bits; + size_t req_size = _get_pack_size(bits); + this->_data.resize(req_size); + + return *this; + } + + /** + * @brief BitArrayのビット数を取得します。 + * @return ビット数 + */ + size_t size() const { + return this->_bits; + } + + /** + * @brief BitArrayの最大ビット数を取得します。 + * @return 最大ビット数 + */ + size_t max_size() const { + return _get_bits_size(this->_data.max_size()); + } + + /** + * @brief BitArrayのパック数を取得します。 + * @return パック数 + */ + size_t pack_size() const { + return this->_data.size(); + } + + /** + * @brief 指定された位置のビットを取得します。 + * @param n ビットインデックス + * @return ビットの値 + */ + bool at(const size_t& n) const { + return this->_get_bit(n); + } + + /** + * @brief 指定された位置のパックの値を取得します。 + * @param n パックインデックス + * @return パックの値 + */ + PackType pack_at(const size_t& n) const { + return this->_data.at(n); + } + + /** + * @brief 指定された位置のビットを設定します。 + * @param n ビットインデックス + * @param val 設定する値 + */ + BitArray& set(const size_t& n, const bool& val){ + this->_set_bit(n, val); + + return *this; + } + + /** + * @brief BitArrayの各パックに値を設定します。 + * @param func 値を生成する関数 + */ + BitArray& set_pack(const size_t& n, const PackType& val){ + this->_data.at(n) = val; + + return *this; + } + + /** + * @brief BitArrayの各ビットに値を設定します。 + * @param func 値を生成する関数 + */ + template requires std::invocable + BitArray& set_all(const Func& func){ + for(size_t i = 0; i < this->_bits; ++i){ + this->_set_bit(i, func(i)); + } + + return *this; + } + + /** + * @brief BitArrayの各パックに値を設定します。 + * @param func 値を生成する関数 + */ + template requires std::invocable + BitArray& set_all_pack(const Func& func){ + for(size_t i = 0; i < this->pack_size(); ++i){ + this->_data.at(i) = func(i); + } + + return *this; + } + + /** + * @brief ランダムな値でBitArrayを初期化します。 + * @param seed 乱数生成器のシード + */ + BitArray& set_random(uint32_t seed = std::random_device{}()){ + std::default_random_engine engine(seed); + std::uniform_int_distribution dist(0, static_cast(-1)); + + for(size_t i = 0; i < this->pack_size(); ++i){ + this->_data.at(i) = static_cast(dist(engine)); + } + + return *this; + } + + /** + * @brief BitArrayの末尾にビットを追加します。 + * @param val 追加するビットの値 + * @return 自身の参照 + */ + BitArray& emplace_back(const bool& val){ + size_t new_bits = this->_bits + 1; + size_t req_size = _get_pack_size(new_bits); + + if(this->_data.max_size() < req_size) + throw std::runtime_error("Container max_size is too small for the specified number of bits."); + + if constexpr (ResizableContainerConcept) { + if(req_size > this->_data.size()) + this->_data.emplace_back(static_cast(0)); + } else { + if(req_size > this->_data.size()) + throw std::runtime_error("Container size is too small for the specified number of bits."); + } + + this->_bits = new_bits; + this->_set_bit(this->_bits-1, val); + + return *this; + } + + /** + * @brief BitArrayを出力ストリームに書き込みます。 + * @param os 出力ストリーム + * @param ba 書き込むBitArray + * @return 出力ストリーム + */ + friend std::ostream& operator << (std::ostream& os, const BitArray& ba){ + size_t total_bits = ba.size(); + for(size_t i = 0; i < ba.pack_size(); i++){ + if(total_bits < (i+1)*_PACKBITS){ + std::string b = std::bitset(ba.pack_at(i)).to_string(); + size_t remain = (i+1)*_PACKBITS - total_bits; + + os << b.substr(0, _PACKBITS - remain); + }else{ + os << std::bitset(ba.pack_at(i)) << " "; + } + } + return os; + } +}; diff --git a/src/include/matrix/concepts.hpp b/src/include/matrix/concepts.hpp new file mode 100644 index 0000000..0ac1e25 --- /dev/null +++ b/src/include/matrix/concepts.hpp @@ -0,0 +1,42 @@ +#ifndef SANAE_NEURALNETWORK_CONSEPTS_HPP +#define SANAE_NEURALNETWORK_CONSEPTS_HPP + +#include +#include +#include + +// std::executionポリシー判定用の型 +template +concept StdExecPolicy = std::is_execution_policy_v>; + +// std::vectorまたはstd::array判定用の型 +template struct is_vector_or_array : std::false_type {}; +template struct is_vector_or_array> : std::true_type {}; +template struct is_vector_or_array> : std::true_type {}; +template +concept VectorOrArray = is_vector_or_array>::value; + +// std::array判定用の型 +template struct is_std_array : std::false_type {}; +template struct is_std_array> : std::true_type {}; +template +concept StdArray = is_std_array>::value; + +// BLAS使用判定用の型 +template struct can_use_blas : std::false_type {}; +#if defined(USE_OPENBLAS) +// OpenBlas + template<> struct can_use_blas : std::true_type {}; + template<> struct can_use_blas : std::true_type {}; +#elif defined(USE_CUBLAS) +// cuBLAS + template<> struct can_use_blas : std::true_type {}; + template<> struct can_use_blas : std::true_type {}; +#elif defined(USE_CLBLAST) +// clBLAST + template<> struct can_use_blas : std::true_type {}; + template<> struct can_use_blas : std::true_type {}; +#endif +template concept CanUseBlas = can_use_blas::value; + +#endif // SANAE_NEURALNETWORK_CONSEPTS_HPP \ No newline at end of file diff --git a/src/include/matrix/matrix.h b/src/include/matrix/matrix.h index 605b086..fa11cb1 100644 --- a/src/include/matrix/matrix.h +++ b/src/include/matrix/matrix.h @@ -7,42 +7,8 @@ #include #include #include -#include -#include - -// std::executionポリシー判定用の型 -template -concept StdExecPolicy = std::is_execution_policy_v>; - -// std::vectorまたはstd::array判定用の型 -template struct is_vector_or_array : std::false_type {}; -template struct is_vector_or_array> : std::true_type {}; -template struct is_vector_or_array> : std::true_type {}; -template -concept VectorOrArray = is_vector_or_array>::value; - -// std::array判定用の型 -template struct is_std_array : std::false_type {}; -template struct is_std_array> : std::true_type {}; -template -concept StdArray = is_std_array>::value; - -// BLAS使用判定用の型 -template struct can_use_blas : std::false_type {}; -#if defined(USE_OPENBLAS) -// OpenBlas - template<> struct can_use_blas : std::true_type {}; - template<> struct can_use_blas : std::true_type {}; -#elif defined(USE_CUBLAS) -// cuBLAS - template<> struct can_use_blas : std::true_type {}; - template<> struct can_use_blas : std::true_type {}; -#elif defined(USE_CLBLAST) -// clBLAST - template<> struct can_use_blas : std::true_type {}; - template<> struct can_use_blas : std::true_type {}; -#endif -template concept CanUseBlas = can_use_blas::value; +#include +#include "./concepts.hpp" /** * @brief 汎用的な行列クラスを提供します。 diff --git a/src/main.cpp b/src/main.cpp index 26bbcfa..75e2702 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,11 +1,25 @@ #include "matrixtest.hpp" #include "nntest.hpp" #include "matrixbenchmark.hpp" +#include "include/matrix/bit/bitarray.hpp" int main(){ - run_matrix_tests(); - run_benchmarks(); - run_nntest(); + // run_matrix_tests(); + // run_benchmarks(); + // run_nntest(); + + try{ + BitArray<> bits(11, true); + for(size_t i = 0; i < bits.size(); ++i){ + bits.set(i, 0); + std::cout << bits << std::endl; + } + + bits.emplace_back(1); + std::cout << bits << std::endl; + }catch(std::exception e){ + std::cout << e.what() << std::endl; + } return 0; } \ No newline at end of file