1- #include < torch/script.h>
2- #include < torch/torch.h>
1+ #include < torch/csrc/inductor/aoti_torch/c/shim.h>
32#include < torch/csrc/stable/library.h>
4- #include < torch/csrc/stable/tensor.h>
53#include < torch/csrc/stable/ops.h>
6- #include < torch/csrc/inductor/aoti_torch/c/shim.h>
4+ #include < torch/csrc/stable/tensor.h>
5+ #include < torch/script.h>
6+ #include < torch/torch.h>
77
88using namespace std ;
99
@@ -81,18 +81,21 @@ void forced_align_impl(
8181 auto curIdxOffset = t % 2 ;
8282 auto prevIdxOffset = (t - 1 ) % 2 ;
8383 for (auto j = 0 ; j < S; ++j) {
84- alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t >::infinity (); // alphas_a[curIdxOffset][j]
84+ alphas_a[curIdxOffset * S + j] = -std::numeric_limits<
85+ scalar_t >::infinity (); // alphas_a[curIdxOffset][j]
8586 }
8687 if (start == 0 ) {
87- alphas_a[curIdxOffset * S] =
88- alphas_a[prevIdxOffset * S] + logProbs_a[batchIndex][t][blank]; // alphas_a[curIdxOffset][0]
88+ alphas_a[curIdxOffset * S] = alphas_a[prevIdxOffset * S] +
89+ logProbs_a[batchIndex][t][blank]; // alphas_a[curIdxOffset][0]
8990 backPtr_a[S * t] = 0 ; // backPtr_a[t][0] = 0
9091 startloop += 1 ;
9192 }
9293
9394 for (auto i = startloop; i < end; i++) {
9495 auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a[prevIdxOffset][i];
95- auto x1 = alphas_a[prevIdxOffset * S + i - 1 ]; // alphas_a[prevIdxOffset][i - 1];
96+ auto x1 =
97+ alphas_a[prevIdxOffset * S + i - 1 ]; // alphas_a[prevIdxOffset][i
98+ // - 1];
9699 auto x2 = -std::numeric_limits<scalar_t >::infinity ();
97100
98101 auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[batchIndex][i / 2 ];
@@ -103,7 +106,8 @@ void forced_align_impl(
103106 // (i != 1) just ensures we don't access targets[i - 2] if its i < 2
104107 if (i % 2 != 0 && i != 1 &&
105108 targets_a[batchIndex][i / 2 ] != targets_a[batchIndex][i / 2 - 1 ]) {
106- x2 = alphas_a[prevIdxOffset * S + i - 2 ]; // alphas_a[prevIdxOffset][i - 2];
109+ x2 = alphas_a[prevIdxOffset * S + i - 2 ]; // alphas_a[prevIdxOffset][i -
110+ // 2];
107111 }
108112 scalar_t result = 0.0 ;
109113 if (x2 > x1 && x2 > x0) {
@@ -116,12 +120,14 @@ void forced_align_impl(
116120 result = x0;
117121 backPtr_a[t * S + i] = 0 ; // backPtr_a[t][i] = 0
118122 }
119- alphas_a[curIdxOffset * S + i] = result + logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]
123+ alphas_a[curIdxOffset * S + i] = result +
124+ logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]
120125 }
121126 }
122127 auto idx1 = (T - 1 ) % 2 ;
123- auto ltrIdx = alphas_a[S * idx1 + S - 1 ] >
124- alphas_a[S * idx1 + S - 2 ] ? S - 1 : S - 2 ; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
128+ auto ltrIdx = alphas_a[S * idx1 + S - 1 ] > alphas_a[S * idx1 + S - 2 ]
129+ ? S - 1
130+ : S - 2 ; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
125131 delete[] alphas_a;
126132 // path stores the token index for each time step after force alignment.
127133 for (auto t = T - 1 ; t > -1 ; t--) {
@@ -194,15 +200,9 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
194200 logProbs, targets, blank, paths);
195201 }
196202 });
197- return std::make_tuple (
198- paths,
199- logProbs
200- );
203+ return std::make_tuple (paths, logProbs);
201204}
202205
203-
204-
205-
206206TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
207207 m.impl (" forced_align" , &compute);
208208}
0 commit comments