Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting index-base for input/output tensor in convert_tensor #5

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/Genten_TensorIO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,9 @@ queryFile()
template <typename ExecSpace>
TensorWriter<ExecSpace>::
TensorWriter(const std::string& fname,
const bool comp) : filename(fname), compressed(comp) {}
const ttb_indx ib,
const bool comp) :
filename(fname), index_base(ib), compressed(comp) {}

template <typename ExecSpace>
void
Expand Down Expand Up @@ -1119,7 +1121,9 @@ writeText(const SptensorT<ExecSpace>& X) const
{
Sptensor X_host = create_mirror_view(X);
deep_copy(X_host, X);
export_sptensor(filename, X_host, true, 15, true, compressed);
if (index_base != 0 && index_base != 1)
Genten::error("Writing a sparse tensor requires index base of 0 or 1");
export_sptensor(filename, X_host, true, 15, index_base==0, compressed);
}

template <typename ExecSpace>
Expand Down
2 changes: 2 additions & 0 deletions src/Genten_TensorIO.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ template <typename ExecSpace>
class TensorWriter {
public:
TensorWriter(const std::string& filename,
const ttb_indx index_base = 0,
const bool compressed = false);

void writeBinary(const SptensorT<ExecSpace>& X,
Expand All @@ -199,6 +200,7 @@ class TensorWriter {
void writeText(const TensorT<ExecSpace>& X) const;
private:
std::string filename;
ttb_indx index_base;
bool compressed;
};

Expand Down
48 changes: 30 additions & 18 deletions tools/convert_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
template <typename TensorType>
void print_tensor_stats(const TensorType& x)
{
std::cout << " Stats: ";
std::cout << " Stats: ";
const ttb_indx nd = x.ndims();
for (ttb_indx i=0; i<nd; ++i) {
std::cout << x.size(i);
Expand All @@ -65,19 +65,22 @@ void print_tensor_stats(const TensorType& x)

template <typename TensorType>
void save_tensor(const TensorType& x_in, const std::string& filename,
const std::string format, const std::string type, bool gz,
const std::string format, const std::string type,
const ttb_indx index_base, bool gz,
bool header)
{
std::cout << "\nOutput:\n"
<< " File: " << filename << std::endl
<< " Format: " << format << std::endl
<< " Type: " << type;
<< " File: " << filename << std::endl
<< " Index Base: " << index_base << std::endl
<< " Format: " << format << std::endl
<< " Type: " << type;
if (type == "text" && gz)
std::cout << " (compressed)";
if (type == "binary" && !header)
std::cout << " (no header)";
std::cout << std::endl;
Genten::TensorWriter<Genten::DefaultHostExecutionSpace> writer(filename,gz);
Genten::TensorWriter<Genten::DefaultHostExecutionSpace> writer(
filename,index_base,gz);
if (format == "sparse") {
Genten::Sptensor x_out(x_in);
print_tensor_stats(x_out);
Expand All @@ -97,10 +100,12 @@ void save_tensor(const TensorType& x_in, const std::string& filename,
}

void read_tensor_file(const std::string& filename,
const ttb_indx index_base,
std::string& format, std::string& type, bool gz,
Genten::Sptensor& x_sparse, Genten::Tensor& x_dense)
{
Genten::TensorReader<Genten::DefaultHostExecutionSpace> reader(filename,0,gz);
Genten::TensorReader<Genten::DefaultHostExecutionSpace> reader(
filename,index_base,gz);
reader.read();

if (reader.isSparse()) {
Expand All @@ -125,14 +130,16 @@ int main(int argc, char* argv[])
auto args = Genten::build_arg_list(argc,argv);
const bool help =
Genten::parse_ttb_bool(args, "--help", "--no-help", false);
if (argc < 9 || argc > 11 || help) {
if (argc < 9 || argc > 16 || help) {
std::cout << "\nconvert-tensor: a helper utility for converting tensor data between\n"
<< "tensor formats (sparse or dense), and file types (text or binary).\n\n"
<< "Usage: " << argv[0] << " --input-file <string> --output-file <string> --output-format <sparse|dense> --output-type <text|binary> [options] \n"
<< "Options:\n"
<< " --input-gz Input tensor is Gzip compressed (text-only, default: off)\n"
<< " --output-gz Output tensor is Gzip compressed (text-only, default: off)\n"
<< " --output-header Write header to output file (binary-only, default: on)\n";
<< " --input-gz Input tensor is Gzip compressed (text-only, default: off)\n"
<< " --output-gz Output tensor is Gzip compressed (text-only, default: off)\n"
<< " --output-header Write header to output file (binary-only, default: on)\n"
<< " --input-index-base Starting index for input tensor (sparse-only, default: 0)\n"
<< " --output-index-base Starting index for output tensor (sparse-only, default: 0)\n";
return 0;
}

Expand All @@ -153,6 +160,10 @@ int main(int argc, char* argv[])
Genten::parse_ttb_bool(args, "--output-gz", "--no-output-gz", false);
const bool output_header =
Genten::parse_ttb_bool(args, "--output-header", "--no-output-header", true);
const ttb_indx input_index_base =
Genten::parse_ttb_indx(args, "--input-index-base", 0, 0, INT_MAX);
const ttb_indx output_index_base =
Genten::parse_ttb_indx(args, "--output-index-base", 0, 0, INT_MAX);

if (input_filename == "")
Genten::error("input filename must be specified");
Expand All @@ -168,29 +179,30 @@ int main(int argc, char* argv[])
Genten::error("No header option only supported for binary output files");

std::cout << "\nInput:\n"
<< " File: " << input_filename << std::endl;
<< " File: " << input_filename << std::endl
<< " Index base: " << input_index_base << std::endl;

std::string input_format = "unknown";
std::string input_type = "unknown";
Genten::Sptensor x_sparse;
Genten::Tensor x_dense;
read_tensor_file(input_filename, input_format, input_type, input_gz,
x_sparse, x_dense);
read_tensor_file(input_filename, input_index_base, input_format, input_type,
input_gz, x_sparse, x_dense);

std::cout << " Format: " << input_format << std::endl
<< " Type: " << input_type;
std::cout << " Format: " << input_format << std::endl
<< " Type: " << input_type;
if (input_type == "text" && input_gz)
std::cout << " (compressed)";
std::cout << std::endl;
if (input_format == "sparse") {
print_tensor_stats(x_sparse);
save_tensor(x_sparse, output_filename, output_format, output_type,
output_gz, output_header);
output_index_base, output_gz, output_header);
}
else if (input_format == "dense") {
print_tensor_stats(x_dense);
save_tensor(x_dense, output_filename, output_format, output_type,
output_gz, output_header);
output_index_base, output_gz, output_header);
}
else
Genten::error("Invalid input tensor format!");
Expand Down
Loading