Skip to content

Commit

Permalink
Added merge argument to ali-decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
t13m committed Mar 6, 2018
1 parent c3f83fe commit 6c07909
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 20 deletions.
23 changes: 19 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ execute_process(
RESULT_VARIABLE RESULT_TF_LIB
OUTPUT_STRIP_TRAILING_WHITESPACE
)

execute_process(
COMMAND ${PYTHONBIN} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_compile_flags()))"
OUTPUT_VARIABLE DEFAULT_TF_CXX_FLAGS
ERROR_VARIABLE ERROR_TF_CXX_FLAGS
RESULT_VARIABLE RESULT_TF_CXX_FLAGS
OUTPUT_STRIP_TRAILING_WHITESPACE
)
execute_process(
COMMAND ${PYTHONBIN} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_link_flags()))"
OUTPUT_VARIABLE DEFAULT_TF_LINK_FLAGS
ERROR_VARIABLE ERROR_TF_LINK_FLAGS
RESULT_VARIABLE RESULT_TF_LINK_FLAGS
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${DEFAULT_TF_CXX_FLAGS} -std=c++11" )
set( CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${DEFAULT_TF_LINK_FLAGS}" )
#message("TF_INC is set: ${DEFAULT_TF_INC}")
#message("TF_LIB is set: ${DEFAULT_TF_LIB}")
#set(TF_INC "${DEFAULT_TF_INC}" CACHE PATH "Path of tensorflow including files")
Expand All @@ -47,10 +64,8 @@ set(SOURCE_FILES
)

add_library(kaldi_readers SHARED)
target_compile_definitions(kaldi_readers
PUBLIC
-D_GLIBCXX_USE_CXX11_ABI=0
)


# -fPIC
set_property(TARGET kaldi_readers PROPERTY POSITION_INDEPENDENT_CODE ON)
target_sources(kaldi_readers
Expand Down
4 changes: 2 additions & 2 deletions example/read-ali.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import tensorflow as tf
kaldi_module = tf.load_op_library("../build/libkaldi_readers.so")
kaldi_module = tf.load_op_library("../cmake-build-release/libkaldi_readers.so")

def main():
value_rspecific = "./data/ali.ark:6"
rspec = tf.constant(value_rspecific, tf.string)
ali_raw_value = kaldi_module.read_kaldi_post_and_ali(rspec, is_reading_post=False)
ali_value = kaldi_module.decode_kaldi_ali(ali_raw_value, tf.int32, is_reading_post=False)
ali_value = kaldi_module.decode_kaldi_ali(ali_raw_value, tf.int32, is_reading_post=False, merge=False)
ali_value.set_shape([None])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
Expand Down
4 changes: 2 additions & 2 deletions example/read-compressed-matrix.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import tensorflow as tf
kaldi_module = tf.load_op_library("../build/libkaldi_readers.so")
kaldi_module = tf.load_op_library("../cmake-build-release/libkaldi_readers.so")

def main():
value_rspecific = "./data/matrix.compressed.ark:6"
rspec = tf.constant(value_rspecific, tf.string)
#feats_raw_value = kaldi_module.read_kaldi_matrix(rspec)
#feats_value = kaldi_module.decode_kaldi_matrix(feats_raw_value, tf.float32)
feats_value = kaldi_module.read_and_decode_kaldi_matrix(rspec, left_padding=3, right_padding=4)
feats_value = kaldi_module.read_and_decode_kaldi_matrix(rspec, left_padding=0, right_padding=0)
feats_value.set_shape([None, 4])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
Expand Down
2 changes: 1 addition & 1 deletion example/read-matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def main():
for rspec_value in scplist:
feats = sess.run(feats_value, feed_dict={rspec: rspec_value})
print(rspec_value)
print(feats)
print(feats.shape)


if __name__ == "__main__":
Expand Down
2 changes: 0 additions & 2 deletions example/read-uncompressed-matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
def main():
value_rspecific = "./data/matrix.nocompress.ark:59"
rspec = tf.constant(value_rspecific, tf.string)
#feats_raw_value = kaldi_module.read_kaldi_matrix(rspec)
#feats_value = kaldi_module.decode_kaldi_matrix(feats_raw_value, tf.float32)
feats_value = kaldi_module.read_and_decode_kaldi_matrix(rspec, left_padding=3, right_padding=4)
feats_value.set_shape([None, 4])
sess = tf.Session()
Expand Down
61 changes: 53 additions & 8 deletions kaldi-ali.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ scpline: scalar. /path/to/ark.file:12345
explicit DecodeKaldiAliOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("out_type", &out_type_));
OP_REQUIRES_OK(context, context->GetAttr("is_reading_post", &is_reading_post_));
OP_REQUIRES_OK(context, context->GetAttr("merge", &merge_));
}

void Compute(OpKernelContext* context) override {
Expand All @@ -127,14 +128,35 @@ scpline: scalar. /path/to/ark.file:12345
errors::InvalidArgument(
"DecodeKaldiAliOp requires input string size = 1"
)
)
);
const string& in_str = flat_in(0);
str_size = in_str.size();

const char* in_data = reinterpret_cast<const char*>(flat_in(0).data());
TensorShape out_shape;
int32 num_elem = *reinterpret_cast<const int32*>(in_data + 1);
out_shape.AddDim(num_elem);
if (!merge_) {
out_shape.AddDim(num_elem);
} else {
int32 prev_elem = -1;
int32 count = 0;
const char* p = in_data + 5;
for (int32 frame_idx = 0; frame_idx < num_elem; frame_idx ++) {
int32 curr_elem;
if (is_reading_post_) {
curr_elem = *reinterpret_cast<const int32*>(p + 5 + 1);
p += 15;
} else {
curr_elem = *reinterpret_cast<const int32*>(p + 1);
p += 5;
}
if (curr_elem != prev_elem) {
count ++;
prev_elem = curr_elem;
}
}
out_shape.AddDim(count);
}

if (str_size == -1 || str_size == 0) { // Empty input
Tensor* output_tensor = nullptr;
Expand All @@ -150,21 +172,43 @@ scpline: scalar. /path/to/ark.file:12345

int32* out_data = out.data();
const char* in_bytes = in_data + 5;
if (is_reading_post_) {
for (int32 frame_idx = 0; frame_idx < num_elem; frame_idx++) {
out_data[frame_idx] = *reinterpret_cast<const int32*>(in_bytes + 5 + 1);
in_bytes += 15;
if (!merge_) {
if (is_reading_post_) {
int32 prev_elem = -1;
for (int32 frame_idx = 0; frame_idx < num_elem; frame_idx++) {
out_data[frame_idx] = *reinterpret_cast<const int32 *>(in_bytes + 5 + 1);
in_bytes += 15;
}
} else {
for (int32 frame_idx = 0; frame_idx < num_elem; frame_idx++) {
out_data[frame_idx] = *reinterpret_cast<const int32 *>(in_bytes + 1);
in_bytes += 5;
}
}
} else {
int32 prev_elem = -1;
int32 count = 0;
for (int32 frame_idx = 0; frame_idx < num_elem; frame_idx++) {
out_data[frame_idx] = *reinterpret_cast<const int32*>(in_bytes + 1);
in_bytes += 5;
int32 curr_elem;
if (is_reading_post_) {
curr_elem = *reinterpret_cast<const int32*>(in_bytes + 5 + 1);
in_bytes += 15;
} else {
curr_elem = *reinterpret_cast<const int32*>(in_bytes + 1);
in_bytes += 5;
}
if (curr_elem != prev_elem) {
out_data[count] = curr_elem;
count ++;
prev_elem = curr_elem;
}
}
}
}

private:
bool is_reading_post_;
bool merge_;
DataType out_type_;

};
Expand All @@ -176,6 +220,7 @@ scpline: scalar. /path/to/ark.file:12345
.Output("output: out_type")
.Attr("out_type: {int32}")
.Attr("is_reading_post: bool")
.Attr("merge: bool")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Reinterpret the bytes of a string as a kaldi ali
Expand Down
2 changes: 1 addition & 1 deletion kaldi-matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ scpline: scalar. /path/to/ark.file:12345
errors::InvalidArgument(
"DecodeKaldiArk requires input string size = 1"
)
)
);
const string& in_str = flat_in(0);
str_size = in_str.size();

Expand Down

0 comments on commit 6c07909

Please sign in to comment.