From c290e2a7478429fb45f9362bab41a8754c8b993a Mon Sep 17 00:00:00 2001 From: Eric Phipps Date: Tue, 22 Oct 2024 13:26:03 -0600 Subject: [PATCH] All setting index-base for input/output tensor in convert_tensor --- src/Genten_TensorIO.cpp | 8 +++++-- src/Genten_TensorIO.hpp | 2 ++ tools/convert_tensor.cpp | 48 +++++++++++++++++++++++++--------------- 3 files changed, 38 insertions(+), 20 deletions(-) diff --git a/src/Genten_TensorIO.cpp b/src/Genten_TensorIO.cpp index 79fa5af6e6..99446bf9d5 100644 --- a/src/Genten_TensorIO.cpp +++ b/src/Genten_TensorIO.cpp @@ -1088,7 +1088,9 @@ queryFile() template TensorWriter:: TensorWriter(const std::string& fname, - const bool comp) : filename(fname), compressed(comp) {} + const ttb_indx ib, + const bool comp) : + filename(fname), index_base(ib), compressed(comp) {} template void @@ -1119,7 +1121,9 @@ writeText(const SptensorT& X) const { Sptensor X_host = create_mirror_view(X); deep_copy(X_host, X); - export_sptensor(filename, X_host, true, 15, true, compressed); + if (index_base != 0 && index_base != 1) + Genten::error("Writing a sparse tensor requires index base of 0 or 1"); + export_sptensor(filename, X_host, true, 15, index_base==0, compressed); } template diff --git a/src/Genten_TensorIO.hpp b/src/Genten_TensorIO.hpp index bafa3dc096..5a60e7296e 100644 --- a/src/Genten_TensorIO.hpp +++ b/src/Genten_TensorIO.hpp @@ -188,6 +188,7 @@ template class TensorWriter { public: TensorWriter(const std::string& filename, + const ttb_indx index_base = 0, const bool compressed = false); void writeBinary(const SptensorT& X, @@ -199,6 +200,7 @@ class TensorWriter { void writeText(const TensorT& X) const; private: std::string filename; + ttb_indx index_base; bool compressed; }; diff --git a/tools/convert_tensor.cpp b/tools/convert_tensor.cpp index b881dba4c5..7338c27c97 100644 --- a/tools/convert_tensor.cpp +++ b/tools/convert_tensor.cpp @@ -43,7 +43,7 @@ template void print_tensor_stats(const TensorType& x) { - std::cout << " Stats: "; + std::cout << " Stats: "; const ttb_indx nd = x.ndims(); for (ttb_indx i=0; i void save_tensor(const TensorType& x_in, const std::string& filename, - const std::string format, const std::string type, bool gz, + const std::string format, const std::string type, + const ttb_indx index_base, bool gz, bool header) { std::cout << "\nOutput:\n" - << " File: " << filename << std::endl - << " Format: " << format << std::endl - << " Type: " << type; + << " File: " << filename << std::endl + << " Index Base: " << index_base << std::endl + << " Format: " << format << std::endl + << " Type: " << type; if (type == "text" && gz) std::cout << " (compressed)"; if (type == "binary" && !header) std::cout << " (no header)"; std::cout << std::endl; - Genten::TensorWriter writer(filename,gz); + Genten::TensorWriter writer( + filename,index_base,gz); if (format == "sparse") { Genten::Sptensor x_out(x_in); print_tensor_stats(x_out); @@ -97,10 +100,12 @@ void save_tensor(const TensorType& x_in, const std::string& filename, } void read_tensor_file(const std::string& filename, + const ttb_indx index_base, std::string& format, std::string& type, bool gz, Genten::Sptensor& x_sparse, Genten::Tensor& x_dense) { - Genten::TensorReader reader(filename,0,gz); + Genten::TensorReader reader( + filename,index_base,gz); reader.read(); if (reader.isSparse()) { @@ -125,14 +130,16 @@ int main(int argc, char* argv[]) auto args = Genten::build_arg_list(argc,argv); const bool help = Genten::parse_ttb_bool(args, "--help", "--no-help", false); - if (argc < 9 || argc > 11 || help) { + if (argc < 9 || argc > 16 || help) { std::cout << "\nconvert-tensor: a helper utility for converting tensor data between\n" << "tensor formats (sparse or dense), and file types (text or binary).\n\n" << "Usage: " << argv[0] << " --input-file --output-file --output-format --output-type [options] \n" << "Options:\n" - << " --input-gz Input tensor is Gzip compressed (text-only, default: off)\n" - << " --output-gz Output tensor is Gzip compressed (text-only, default: off)\n" - << " --output-header Write header to output file (binary-only, default: on)\n"; + << " --input-gz Input tensor is Gzip compressed (text-only, default: off)\n" + << " --output-gz Output tensor is Gzip compressed (text-only, default: off)\n" + << " --output-header Write header to output file (binary-only, default: on)\n" + << " --input-index-base Starting index for input tensor (sparse-only, default: 0)\n" + << " --output-index-base Starting index for output tensor (sparse-only, default: 0)\n"; return 0; } @@ -153,6 +160,10 @@ int main(int argc, char* argv[]) Genten::parse_ttb_bool(args, "--output-gz", "--no-output-gz", false); const bool output_header = Genten::parse_ttb_bool(args, "--output-header", "--no-output-header", true); + const ttb_indx input_index_base = + Genten::parse_ttb_indx(args, "--input-index-base", 0, 0, INT_MAX); + const ttb_indx output_index_base = + Genten::parse_ttb_indx(args, "--output-index-base", 0, 0, INT_MAX); if (input_filename == "") Genten::error("input filename must be specified"); @@ -168,29 +179,30 @@ int main(int argc, char* argv[]) Genten::error("No header option only supported for binary output files"); std::cout << "\nInput:\n" - << " File: " << input_filename << std::endl; + << " File: " << input_filename << std::endl + << " Index base: " << input_index_base << std::endl; std::string input_format = "unknown"; std::string input_type = "unknown"; Genten::Sptensor x_sparse; Genten::Tensor x_dense; - read_tensor_file(input_filename, input_format, input_type, input_gz, - x_sparse, x_dense); + read_tensor_file(input_filename, input_index_base, input_format, input_type, + input_gz, x_sparse, x_dense); - std::cout << " Format: " << input_format << std::endl - << " Type: " << input_type; + std::cout << " Format: " << input_format << std::endl + << " Type: " << input_type; if (input_type == "text" && input_gz) std::cout << " (compressed)"; std::cout << std::endl; if (input_format == "sparse") { print_tensor_stats(x_sparse); save_tensor(x_sparse, output_filename, output_format, output_type, - output_gz, output_header); + output_index_base, output_gz, output_header); } else if (input_format == "dense") { print_tensor_stats(x_dense); save_tensor(x_dense, output_filename, output_format, output_type, - output_gz, output_header); + output_index_base, output_gz, output_header); } else Genten::error("Invalid input tensor format!");