Skip to content

Commit

Permalink
Less mallocs
Browse files Browse the repository at this point in the history
  • Loading branch information
baAlex committed May 16, 2022
1 parent 6f600af commit 4ec46ca
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 31 deletions.
18 changes: 9 additions & 9 deletions resources/ans-cdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ template <typename T> class Cdf
}

public:
Cdf(const T* message, size_t len, uint32_t scale_to = 0)
Cdf(const std::vector<T>& message, uint32_t scale_to = 0)
{
// Cumulative distribution function (CDF) following Pasco (1976, p.10)
// (on my own nomenclature):
Expand All @@ -75,8 +75,8 @@ template <typename T> class Cdf
{
// Count symbols
auto hashmap = std::unordered_map<T, uint32_t>();
for (const T* m = message; m != (message + len); m++)
hashmap[*m]++;
for (auto& s : message)
hashmap[s]++;

// Create sorted table
for (const auto& i : hashmap)
Expand All @@ -92,7 +92,7 @@ template <typename T> class Cdf
[](CdfEntry<T> a, CdfEntry<T> b) { return (a.frequency > b.frequency); });
}

const auto h = entropy(table, len);
const auto h = entropy(table, message.size());

// Accumulate frequencies
uint32_t cumulative = 0;
Expand Down Expand Up @@ -193,20 +193,20 @@ template <typename T> class Cdf

std::cout << "Message: \"";

for (const T* i = message; i != message + std::min(len, MESSAGE_MAX_PRINT); i++)
for (size_t i = 0; i < std::min(message.size(), MESSAGE_MAX_PRINT); i++)
{
if (*i != static_cast<T>('\n'))
std::cout << *i;
if (message[i] != static_cast<T>('\n'))
std::cout << message[i];
else
std::cout << " ";
}

std::cout << "\"\n";
std::cout << "Length: " << len << " symbols\n";
std::cout << "Length: " << message.size() << " symbols\n";
std::cout << "Unique symbols: " << table.size() << "\n";
std::cout << "Maximum cumulative: " << max_cumulative_ << ((scale_to != 0) ? " (scaled)\n" : "\n");
std::cout << "Entropy: " << h << " bits per symbol" << ((scale_to != 0) ? " (before scaling)\n" : "\n");
std::cout << "Shannon target: " << (h * static_cast<double>(len)) / 8.0 << " bytes"
std::cout << "Shannon target: " << (h / 8.0) * static_cast<double>(message.size()) << " bytes"
<< ((scale_to != 0) ? " (before scaling)\n" : "\n");

for (size_t i = 0; i < std::min(table.size(), SYMBOLS_MAX_PRINT); i++)
Expand Down
4 changes: 2 additions & 2 deletions resources/ans1-core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ int main()
try
{
const auto message = std::vector<char>(s.begin(), s.end());
const auto cdf = Cdf<char>(message.data(), message.size());
const auto cdf = Cdf<char>(message);
const auto state = AnsEncode(cdf, message);
AnsDecode(cdf, state);
}
Expand All @@ -145,7 +145,7 @@ int main()

try
{
const auto cdf = Cdf<uint16_t>(message.data(), message.size());
const auto cdf = Cdf<uint16_t>(message);
const auto state = AnsEncode(cdf, message);
AnsDecode(cdf, state);
}
Expand Down
45 changes: 25 additions & 20 deletions resources/ans2-normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ USE OR PERFORMANCE OF THIS SOFTWARE.

#include <fstream>
#include <iostream>
#include <memory>
#include <vector>

#include "ans-cdf.hpp"
Expand Down Expand Up @@ -70,11 +71,13 @@ acc_decode_t D(acc_decode_t accumulator, uint32_t frequency, uint32_t cumulative


template <typename InputT, typename OutputT>
std::vector<OutputT> AnsEncode(const Cdf<InputT>& cdf, const std::vector<InputT>& message)
std::unique_ptr<std::vector<OutputT>> AnsEncode(const Cdf<InputT>& cdf, const std::vector<InputT>& message)
{
std::cout << "\nEncode:\n";

auto output = std::vector<OutputT>();
auto output = std::make_unique<std::vector<OutputT>>();
output->reserve(message.size());

acc_encode_t acc = ACC_INITIAL_STATE;

size_t normalizations = 0; // For stats
Expand All @@ -89,7 +92,7 @@ std::vector<OutputT> AnsEncode(const Cdf<InputT>& cdf, const std::vector<InputT>
normalizations = 0;
while (C(acc, e.frequency, e.cumulative, cdf.m()) > (L * B) - 1) // [2] see above
{
output.insert(output.begin(), 1, static_cast<OutputT>(acc % static_cast<acc_encode_t>(B)));
output->insert(output->begin(), 1, static_cast<OutputT>(acc % static_cast<acc_encode_t>(B)));
acc = acc / static_cast<acc_encode_t>(B);

normalizations++;
Expand All @@ -109,7 +112,7 @@ std::vector<OutputT> AnsEncode(const Cdf<InputT>& cdf, const std::vector<InputT>
{
if (normalizations == 1)
std::cout << " - '" << e.symbol << "' (f: " << e.frequency << ", c: " << e.cumulative << ")\t->\t"
<< acc << " (1 normalization, " << output[0] << ")\n";
<< acc << " (1 normalization, " << output->at(0) << ")\n";
else
std::cout << " - '" << e.symbol << "' (f: " << e.frequency << ", c: " << e.cumulative << ")\t->\t"
<< acc << " (" << normalizations << " normalizations)\n";
Expand All @@ -123,11 +126,11 @@ std::vector<OutputT> AnsEncode(const Cdf<InputT>& cdf, const std::vector<InputT>
// Accumulator remainder (output)
while (acc != 0)
{
output.insert(output.begin(), 1, static_cast<OutputT>(acc % static_cast<acc_encode_t>(B)));
output->insert(output->begin(), 1, static_cast<OutputT>(acc % static_cast<acc_encode_t>(B)));
acc = acc / static_cast<acc_encode_t>(B);

total_normalizations++;
std::cout << " - acc remainder\t->\t" << acc << " (" << output[0] << ")\n";
std::cout << " - acc remainder\t->\t" << acc << " (" << output->at(0) << ")\n";
}

// Bye!
Expand All @@ -138,11 +141,13 @@ std::vector<OutputT> AnsEncode(const Cdf<InputT>& cdf, const std::vector<InputT>


template <typename OutputT, typename InputT>
std::vector<OutputT> AnsDecode(const Cdf<OutputT>& cdf, const std::vector<InputT>& bitcode, size_t message_len)
std::unique_ptr<std::vector<OutputT>> AnsDecode(const Cdf<OutputT>& cdf, const std::vector<InputT>& bitcode,
size_t message_len)
{
std::cout << "\nDecode:\n";

auto output = std::vector<OutputT>();
auto output = std::make_unique<std::vector<OutputT>>();
output->reserve(message_len);

acc_decode_t acc = 0;
size_t input_i = 0;
Expand All @@ -169,7 +174,7 @@ std::vector<OutputT> AnsDecode(const Cdf<OutputT>& cdf, const std::vector<InputT
const CdfEntry<OutputT>& e = cdf.of_point(modulo_point);
acc = D(acc, e.frequency, e.cumulative, cdf.m(), modulo_point);

output.emplace_back(e.symbol);
output->emplace_back(e.symbol);

// Normalize (input)
normalizations = 0;
Expand Down Expand Up @@ -214,27 +219,27 @@ std::vector<OutputT> AnsDecode(const Cdf<OutputT>& cdf, const std::vector<InputT
}


template <typename OutputT> std::vector<OutputT> ReadFile(const std::string& filename)
template <typename OutputT> std::unique_ptr<std::vector<OutputT>> ReadFile(const std::string& filename)
{
const size_t READ_BLOCK_SIZE = 512 * 1024;

auto data = std::vector<OutputT>();
auto data = std::make_unique<std::vector<OutputT>>();
auto file = std::fstream(filename, std::ios::in | std::ios::binary);

if (file.is_open() == false)
throw std::runtime_error("Input error.");

for (auto blocks = 0;; blocks++)
{
const auto append_at = data.size();
const auto append_at = data->size();

data.resize(data.size() + READ_BLOCK_SIZE / sizeof(OutputT));
file.read(reinterpret_cast<char*>(data.data()) + append_at * sizeof(OutputT), READ_BLOCK_SIZE);
data->resize(data->size() + READ_BLOCK_SIZE / sizeof(OutputT));
file.read(reinterpret_cast<char*>(data->data()) + append_at * sizeof(OutputT), READ_BLOCK_SIZE);

if (file.eof() == true)
{
data.resize(data.size() - (READ_BLOCK_SIZE - static_cast<size_t>(file.gcount())) / sizeof(OutputT));
std::cout << "Input: '" << filename << "', " << data.size() << " symbols, read in " << blocks + 1
data->resize(data->size() - (READ_BLOCK_SIZE - static_cast<size_t>(file.gcount())) / sizeof(OutputT));
std::cout << "Input: '" << filename << "', " << data->size() << " symbols, read in " << blocks + 1
<< " blocks\n";
break;
}
Expand Down Expand Up @@ -266,12 +271,12 @@ int main(int argc, const char* argv[])
try
{
const auto data = ReadFile<uint8_t>(argv[1]);
const auto cdf = Cdf<uint8_t>(data.data(), data.size(), 1 << 16); // [1]
const auto bitstream = AnsEncode<uint8_t, uint16_t>(cdf, data);
const auto decoded_data = AnsDecode<uint8_t, uint16_t>(cdf, bitstream, data.size());
const auto cdf = Cdf<uint8_t>(*data, 1 << 16); // [1]
const auto bitstream = AnsEncode<uint8_t, uint16_t>(cdf, *data);
const auto decoded_data = AnsDecode<uint8_t, uint16_t>(cdf, *bitstream, data->size());

if (argc > 2)
WriteFile(decoded_data, argv[2]);
WriteFile(*decoded_data, argv[2]);
}
catch (const std::exception& e)
{
Expand Down

0 comments on commit 4ec46ca

Please sign in to comment.