-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2b92b5a
Showing
21 changed files
with
692 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
cmake_minimum_required(VERSION 3.8) | ||
project(kaldi_reader_standalone) | ||
|
||
set(CMAKE_CXX_STANDARD 11) | ||
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 5.1) | ||
# c++ regex is used in the code, so the version of gcc must be greater than 4.9 | ||
message(FATAL_ERROR "VERSION OF GCC MUST BE GREATER THAN 5.1") | ||
endif() | ||
|
||
set(PYTHONBIN "python" CACHE PATH "Path of python with tensorflow installed") | ||
|
||
execute_process( | ||
COMMAND ${PYTHONBIN} -c "import tensorflow as tf; print(tf.sysconfig.get_include())" | ||
OUTPUT_VARIABLE DEFAULT_TF_INC | ||
ERROR_VARIABLE ERROR_TF_INC | ||
RESULT_VARIABLE RESULT_TF_INC | ||
OUTPUT_STRIP_TRAILING_WHITESPACE | ||
) | ||
execute_process( | ||
COMMAND ${PYTHONBIN} -c "import tensorflow as tf; print(tf.sysconfig.get_lib())" | ||
OUTPUT_VARIABLE DEFAULT_TF_LIB | ||
ERROR_VARIABLE ERROR_TF_LIB | ||
RESULT_VARIABLE RESULT_TF_LIB | ||
OUTPUT_STRIP_TRAILING_WHITESPACE | ||
) | ||
#message("TF_INC is set: ${DEFAULT_TF_INC}") | ||
#message("TF_LIB is set: ${DEFAULT_TF_LIB}") | ||
#set(TF_INC "${DEFAULT_TF_INC}" CACHE PATH "Path of tensorflow including files") | ||
#set(TF_LIB "${DEFAULT_TF_LIB}" CACHE PATH "Path of tensorflow linking libraries") | ||
|
||
set(TF_INC "${DEFAULT_TF_INC}") | ||
set(TF_LIB "${DEFAULT_TF_LIB}") | ||
|
||
if ("${TF_INC}" STREQUAL "" OR "${TF_LIB}" STREQUAL "") | ||
message(FATAL_ERROR "TF_INC and TF_LIB not set. Please set both variable manually, or set correct PYTHONBIN var.") | ||
endif() | ||
|
||
|
||
message("TF_INC is set: ${TF_INC}") | ||
message("TF_LIB is set: ${TF_LIB}") | ||
|
||
set(SOURCE_FILES | ||
kaldi-matrix.cc | ||
kaldi-ali.cc | ||
shape-funcs.cc | ||
) | ||
|
||
add_library(kaldi_readers SHARED) | ||
target_compile_definitions(kaldi_readers | ||
PUBLIC | ||
-D_GLIBCXX_USE_CXX11_ABI=0 | ||
) | ||
# -fPIC | ||
set_property(TARGET kaldi_readers PROPERTY POSITION_INDEPENDENT_CODE ON) | ||
target_sources(kaldi_readers | ||
PRIVATE | ||
${SOURCE_FILES} | ||
) | ||
|
||
target_include_directories(kaldi_readers | ||
PRIVATE | ||
${TF_INC} | ||
${TF_INC}/external/nsync/public | ||
) | ||
target_link_libraries(kaldi_readers | ||
PRIVATE | ||
${TF_LIB}/libtensorflow_framework.so | ||
) | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
ali-1 25 70 3013 4 0 222 444 111 | ||
ali-2 1 2 3 4 5 6 7 8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
ali-1 ali.ark:6 | ||
ali-2 ali.ark:59 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
mat-1 [ | ||
1.0 2.0 3.0 4.0 | ||
2.0 3.0 4.0 5.0 ] | ||
mat-2 [ | ||
7.0 7.0 7.0 7.0 | ||
2.0 3.0 4.0 5.0 ] |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
mat-1 matrix.compressed.ark:6 | ||
mat-2 matrix.compressed.ark:73 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
mat-1 matrix.nocompress.ark:6 | ||
mat-2 matrix.nocompress.ark:59 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
post-1 [ 25 1 ] [ 70 1 ] [ 3013 1 ] [ 4 1 ] [ 0 1 ] [ 222 1 ] [ 444 1 ] [ 111 1 ] | ||
post-2 [ 1 1 ] [ 2 1 ] [ 3 1 ] [ 4 1 ] [ 5 1 ] [ 6 1 ] [ 7 1 ] [ 8 1 ] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
post-1 post.ark:7 | ||
post-2 post.ark:141 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import tensorflow as tf | ||
kaldi_module = tf.load_op_library("../build/libkaldi_readers.so") | ||
|
||
def main(): | ||
value_rspecific = "./data/ali.ark:6" | ||
rspec = tf.constant(value_rspecific, tf.string) | ||
ali_raw_value = kaldi_module.read_kaldi_post_and_ali(rspec, is_reading_post=False) | ||
ali_value = kaldi_module.decode_kaldi_ali(ali_raw_value, tf.int32, is_reading_post=False) | ||
ali_value.set_shape([None]) | ||
sess = tf.Session() | ||
sess.run(tf.global_variables_initializer()) | ||
ali = sess.run(ali_value) | ||
print(ali.shape) | ||
print(ali) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import tensorflow as tf | ||
kaldi_module = tf.load_op_library("../build/libkaldi_readers.so") | ||
|
||
def main(): | ||
value_rspecific = "./data/matrix.compressed.ark:6" | ||
rspec = tf.constant(value_rspecific, tf.string) | ||
feats_raw_value = kaldi_module.read_kaldi_matrix(rspec) | ||
feats_value = kaldi_module.decode_kaldi_matrix(feats_raw_value, tf.float32) | ||
feats_value.set_shape([None, 4]) | ||
sess = tf.Session() | ||
sess.run(tf.global_variables_initializer()) | ||
feats = sess.run(feats_value) | ||
print(feats.shape) | ||
print(feats) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import tensorflow as tf | ||
kaldi_module = tf.load_op_library("../build/libkaldi_readers.so") | ||
|
||
def main(): | ||
value_rspecific = "./data/post.ark:141" | ||
rspec = tf.constant(value_rspecific, tf.string) | ||
ali_raw_value = kaldi_module.read_kaldi_post_and_ali(rspec, is_reading_post=True) | ||
ali_value = kaldi_module.decode_kaldi_ali(ali_raw_value, tf.int32, is_reading_post=True) | ||
ali_value.set_shape([None]) | ||
sess = tf.Session() | ||
sess.run(tf.global_variables_initializer()) | ||
ali = sess.run(ali_value) | ||
print(ali.shape) | ||
print(ali) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import tensorflow as tf | ||
kaldi_module = tf.load_op_library("../build/libkaldi_readers.so") | ||
|
||
def main(): | ||
value_rspecific = "./data/matrix.nocompress.ark:59" | ||
rspec = tf.constant(value_rspecific, tf.string) | ||
feats_raw_value = kaldi_module.read_kaldi_matrix(rspec) | ||
feats_value = kaldi_module.decode_kaldi_matrix(feats_raw_value, tf.float32) | ||
feats_value.set_shape([None, 4]) | ||
sess = tf.Session() | ||
sess.run(tf.global_variables_initializer()) | ||
feats = sess.run(feats_value) | ||
print(feats.shape) | ||
print(feats) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
#include <memory> | ||
#include <regex> | ||
#include "tensorflow/core/framework/reader_base.h" | ||
#include "tensorflow/core/framework/reader_op_kernel.h" | ||
#include "tensorflow/core/lib/core/errors.h" | ||
#include "tensorflow/core/lib/io/buffered_inputstream.h" | ||
#include "tensorflow/core/lib/io/random_inputstream.h" | ||
#include "tensorflow/core/lib/io/zlib_compression_options.h" | ||
#include "tensorflow/core/lib/io/zlib_inputstream.h" | ||
#include "tensorflow/core/lib/strings/strcat.h" | ||
#include "tensorflow/core/platform/env.h" | ||
|
||
#include "shape-funcs.hh" | ||
|
||
namespace tensorflow { | ||
using shape_util::ScalarInputsAndOutputs; | ||
using shape_util::TwoElementOutput; | ||
|
||
static Status ReadKaldiPostAndAli(Env* env, const string& ark_path, uint64 ark_offset, bool is_reading_post, string* contents) { | ||
enum { kBufferSize = 256 << 10 /* 256 kB */ }; | ||
|
||
std::unique_ptr<RandomAccessFile> file_; | ||
std::unique_ptr<io::InputStreamInterface> buffered_inputstream_; | ||
|
||
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(ark_path, &file_)); | ||
buffered_inputstream_.reset( | ||
new io::BufferedInputStream(file_.get(), kBufferSize)); | ||
TF_RETURN_IF_ERROR(buffered_inputstream_->SkipNBytes(ark_offset)); | ||
|
||
// Actural reading start from here | ||
string binary; | ||
TF_RETURN_IF_ERROR(buffered_inputstream_->ReadNBytes(2, &binary)); | ||
CHECK_EQ(binary[0], '\0'); | ||
CHECK_EQ(binary[1], 'B'); | ||
string header_buffer; | ||
TF_RETURN_IF_ERROR(buffered_inputstream_->ReadNBytes(1, &header_buffer)); | ||
if (header_buffer[0] == '\4') { | ||
// This is a vector of int | ||
string size_str; | ||
buffered_inputstream_->ReadNBytes(4, &size_str); | ||
int32 size = *reinterpret_cast<const int32*>(size_str.data()); | ||
string data; | ||
if (is_reading_post) { | ||
for (int32 outer_vec_idx = 0; outer_vec_idx < size; outer_vec_idx++) { | ||
// <1> <4> [<1> <4> <1> <4>] [<1> <4> <1> <4>] | ||
string inner_size_str; | ||
buffered_inputstream_->ReadNBytes(5, &inner_size_str); | ||
int32 inner_size = *reinterpret_cast<const int32 *>(inner_size_str.data() + 1); | ||
string inner_vec_data; | ||
buffered_inputstream_->ReadNBytes(inner_size * 10, &inner_vec_data); | ||
data += inner_size_str + inner_vec_data; | ||
} | ||
} else { | ||
TF_RETURN_IF_ERROR(buffered_inputstream_->ReadNBytes(size * 5, &data)); | ||
} | ||
*contents = header_buffer + size_str + data; | ||
} else { | ||
return Status(error::UNAVAILABLE, "Unknown Kaldi Post or Ali: " + header_buffer); | ||
} | ||
} | ||
|
||
class ReadKaldiPostAndAliOp : public OpKernel { | ||
public: | ||
using OpKernel::OpKernel; | ||
explicit ReadKaldiPostAndAliOp(OpKernelConstruction *context) | ||
:OpKernel(context), | ||
id_pat_("^(\\S+):(\\d+)") | ||
{ | ||
OP_REQUIRES_OK(context, context->GetAttr("is_reading_post", &is_reading_post_)); | ||
} | ||
void Compute(OpKernelContext* context) override { | ||
|
||
const Tensor* input; | ||
|
||
OP_REQUIRES_OK(context, context->input("scpline", &input)); | ||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(input->shape()), | ||
errors::InvalidArgument( | ||
"Input filename tensor must be scalar, but had shape: ", | ||
input->shape().DebugString())); | ||
|
||
Tensor* output = nullptr; | ||
OP_REQUIRES_OK(context, context->allocate_output("contents", | ||
TensorShape({}), &output)); | ||
const std::regex id_pat("^(\\S+):(\\d+)"); | ||
std::smatch m; | ||
string half_scp_line = input->scalar<string>()(); | ||
bool matched = std::regex_search(half_scp_line, m, id_pat); | ||
OP_REQUIRES(context, matched, Status(error::INVALID_ARGUMENT, "Script line is " + half_scp_line)); | ||
string ark_path = m[1]; | ||
string ark_offset_str = m[2]; | ||
uint64 ark_offset = std::stoull(ark_offset_str); | ||
|
||
OP_REQUIRES_OK(context, | ||
ReadKaldiPostAndAli(context->env(), ark_path, ark_offset, is_reading_post_, | ||
&output->scalar<string>()())); | ||
} | ||
private: | ||
bool is_reading_post_; | ||
const std::regex id_pat_; | ||
}; | ||
REGISTER_KERNEL_BUILDER(Name("ReadKaldiPostAndAli").Device(DEVICE_CPU), ReadKaldiPostAndAliOp); | ||
|
||
REGISTER_OP("ReadKaldiPostAndAli") | ||
.Attr("is_reading_post: bool") | ||
.Input("scpline: string") | ||
.Output("contents: string") | ||
.SetShapeFn(ScalarInputsAndOutputs) | ||
.Doc(R"doc( | ||
Reads and outputs the entire contents of the input kaldi post or ali ark filename. | ||
scpline: scalar. /path/to/ark.file:12345 | ||
)doc"); | ||
|
||
class DecodeKaldiAliOp : public OpKernel { | ||
public: | ||
explicit DecodeKaldiAliOp(OpKernelConstruction* context) : OpKernel(context) { | ||
OP_REQUIRES_OK(context, context->GetAttr("out_type", &out_type_)); | ||
OP_REQUIRES_OK(context, context->GetAttr("is_reading_post", &is_reading_post_)); | ||
} | ||
|
||
void Compute(OpKernelContext* context) override { | ||
const auto& input = context->input(0); | ||
int64 str_size = -1; | ||
auto flat_in = input.flat<string>(); | ||
OP_REQUIRES(context, flat_in.size() == 1, | ||
errors::InvalidArgument( | ||
"DecodeKaldiAliOp requires input string size = 1" | ||
) | ||
) | ||
const string& in_str = flat_in(0); | ||
str_size = in_str.size(); | ||
|
||
const char* in_data = reinterpret_cast<const char*>(flat_in(0).data()); | ||
TensorShape out_shape; | ||
int32 num_elem = *reinterpret_cast<const int32*>(in_data + 1); | ||
out_shape.AddDim(num_elem); | ||
|
||
if (str_size == -1 || str_size == 0) { // Empty input | ||
Tensor* output_tensor = nullptr; | ||
OP_REQUIRES_OK(context, context->allocate_output("output", out_shape, | ||
&output_tensor)); | ||
return; | ||
} | ||
|
||
Tensor* output_tensor = nullptr; | ||
OP_REQUIRES_OK( | ||
context, context->allocate_output("output", out_shape, &output_tensor)); | ||
auto out = output_tensor->flat<int32>(); | ||
|
||
int32* out_data = out.data(); | ||
const char* in_bytes = in_data + 5; | ||
if (is_reading_post_) { | ||
for (int32 frame_idx = 0; frame_idx < num_elem; frame_idx++) { | ||
out_data[frame_idx] = *reinterpret_cast<const int32*>(in_bytes + 5 + 1); | ||
in_bytes += 15; | ||
} | ||
} else { | ||
for (int32 frame_idx = 0; frame_idx < num_elem; frame_idx++) { | ||
out_data[frame_idx] = *reinterpret_cast<const int32*>(in_bytes + 1); | ||
in_bytes += 5; | ||
} | ||
} | ||
} | ||
|
||
private: | ||
bool is_reading_post_; | ||
DataType out_type_; | ||
|
||
}; | ||
|
||
REGISTER_KERNEL_BUILDER(Name("DecodeKaldiAli").Device(DEVICE_CPU), DecodeKaldiAliOp); | ||
|
||
REGISTER_OP("DecodeKaldiAli") | ||
.Input("bytes: string") | ||
.Output("output: out_type") | ||
.Attr("out_type: {int32}") | ||
.Attr("is_reading_post: bool") | ||
.SetShapeFn(shape_inference::UnknownShape) | ||
.Doc(R"doc( | ||
Reinterpret the bytes of a string as a kaldi ali | ||
)doc"); | ||
|
||
|
||
} // namespace tensorflow |
Oops, something went wrong.