Skip to content

Commit 130ba9c

Browse files
committed
resolve merge conflicts
1 parent adc8f2b commit 130ba9c

File tree

7 files changed

+51
-82
lines changed

7 files changed

+51
-82
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ JoeyS2T is built on [PyTorch](https://pytorch.org/). Please make sure you have a
3131
We tested JoeyS2T with
3232
- python 3.10
3333
- torch 1.12.1
34+
- torchaudio 0.12.1
3435
- cuda 11.6
3536

3637
Clone this repository and install via pip:

joeynmt/datasets.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,16 @@ def _is_valid(s, t, has_trg):
162162
trg, trg_length = None, None
163163

164164
return Batch(
165-
src=torch.tensor(src).long(),
165+
src=(torch.tensor(src).long()
166+
if self.task == "MT" else torch.tensor(src).float()),
166167
src_length=torch.tensor(src_length).long(),
167168
trg=torch.tensor(trg).long() if trg else None,
168169
trg_length=torch.tensor(trg_length).long() if trg_length else None,
169170
device=device,
170171
pad_index=pad_index,
171172
has_trg=self.has_trg,
172173
is_train=self.split == "train",
174+
task=self.task,
173175
)
174176

175177
def make_iter(

joeynmt/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def loss_function(self, cfg: Tuple):
9292
self.decoder.ctc_output_layer = None
9393
self._loss_function = loss_function
9494

95-
@torch.autocast(device_type=DEVICE_TYPE)
9695
def forward(self,
9796
return_type: str = None,
9897
**kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor]:

joeynmt/tokenizers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,6 @@ def build_tokenizer(data_cfg: Dict) -> Dict[str, BasicTokenizer]:
556556
src_lang: _build_tokenizer(data_cfg["src"]),
557557
trg_lang: _build_tokenizer(data_cfg["trg"]),
558558
}
559-
log_str = "Tokenizer" if task == "MT" else "SpeechProcessor"
560-
logger.info("%s %s: %s", src_lang, log_str, tokenizer[src_lang])
559+
logger.info("%s Tokenizer: %s", src_lang, tokenizer[src_lang])
561560
logger.info("%s Tokenizer: %s", trg_lang, tokenizer[trg_lang])
562561
return tokenizer

scripts/discord_joey.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
- Slash Commands:
2222
https://guide.pycord.dev/interactions/application-commands/slash-commands
2323
"""
24-
import re
2524
from functools import partial
2625
from pathlib import Path
2726

@@ -40,7 +39,6 @@
4039
from joeynmt.tokenizers import build_tokenizer
4140
from joeynmt.vocabulary import build_vocab
4241

43-
4442
TOKEN = "your-bot-token-here" # replace with your bot token
4543
guild = 123456789 # replace with your guild ID
4644

test/unit/test_data.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ def testIteratorBatchShape(self):
117117

118118
# make train batches (filtered by length)
119119
train_iter = iter(
120-
make_data_iter(
121-
train_data,
120+
train_data.make_iter(
122121
batch_size=2,
123122
batch_type="sentence",
124123
shuffle=True,
@@ -136,8 +135,7 @@ def testIteratorBatchShape(self):
136135

137136
# make test batches (not filtered by length)
138137
test_iter = iter(
139-
make_data_iter(
140-
test_data,
138+
test_data.make_iter(
141139
batch_size=2,
142140
batch_type="sentence",
143141
shuffle=False,

test/unit/test_transformer_encoder.py

Lines changed: 44 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -59,50 +59,34 @@ def test_transformer_encoder_forward(self):
5959
torch.Size([batch_size, time_dim, self.hidden_size]))
6060
self.assertEqual(hidden, None)
6161

62+
# yapf: disable
6263
output_target = torch.Tensor([
63-
[[
64-
1.9728e-01, -1.2042e-01, 8.0998e-02, 1.3411e-03, -3.5960e-01,
65-
-5.2988e-01, -5.6056e-01, -3.5297e-01, 2.6680e-01, 2.8343e-01,
66-
-3.7342e-01, -5.9113e-03
67-
],
68-
[
69-
8.9687e-02, -1.2491e-01, 7.7809e-02, -1.3499e-03, -2.7002e-01,
70-
-4.7312e-01, -5.7981e-01, -4.1998e-01, 1.0457e-01, 2.9726e-01,
71-
-3.9461e-01, 8.1598e-02
72-
],
73-
[
74-
3.4988e-02, -1.3020e-01, 6.0043e-02, 2.7782e-02, -3.1483e-01,
75-
-3.8940e-01, -5.5557e-01, -5.9540e-01, -2.9808e-02, 3.1468e-01,
76-
-4.5809e-01, 4.3312e-03
77-
],
78-
[
79-
1.2234e-01, -1.3285e-01, 6.3068e-02, -2.3343e-02, -2.3519e-01,
80-
-4.0794e-01, -5.6063e-01,
81-
-5.5484e-01, -1.1272e-01,
82-
3.0103e-01, -4.0983e-01, 3.3038e-02
83-
]],
84-
[[
85-
9.8597e-02, -1.2121e-01, 1.0718e-01, -2.2644e-02, -4.0282e-01,
86-
-4.2646e-01, -5.9981e-01,
87-
-3.7200e-01, 1.9538e-01, 2.7036e-01, -3.4072e-01, -1.7965e-03
88-
],
89-
[
90-
8.8470e-02, -1.2618e-01, 5.3351e-02, -1.8531e-02, -3.3834e-01,
91-
-4.9047e-01, -5.7063e-01, -4.9790e-01, 2.2070e-01, 3.3964e-01,
92-
-4.1604e-01, 2.3519e-02
93-
],
94-
[
95-
5.8373e-02, -1.2706e-01, 1.0598e-01, 9.3256e-05, -3.0493e-01,
96-
-4.4406e-01, -5.4723e-01, -5.2214e-01, 8.0374e-02, 2.6307e-01,
97-
-4.4571e-01, 8.7052e-02
98-
],
99-
[
100-
7.9567e-02, -1.2977e-01, 1.1731e-01, 2.6198e-02, -2.4024e-01,
101-
-4.2161e-01, -5.7604e-01, -7.3298e-01, 1.6698e-01, 3.1454e-01,
102-
-4.9189e-01, 2.4027e-02
103-
]]
64+
[[1.9728e-01, -1.2042e-01, 8.0998e-02, 1.3411e-03, -3.5960e-01,
65+
-5.2988e-01, -5.6056e-01, -3.5297e-01, 2.6680e-01, 2.8343e-01,
66+
-3.7342e-01, -5.9112e-03],
67+
[8.9687e-02, -1.2491e-01, 7.7809e-02, -1.3500e-03, -2.7002e-01,
68+
-4.7312e-01, -5.7981e-01, -4.1998e-01, 1.0457e-01, 2.9726e-01,
69+
-3.9461e-01, 8.1598e-02],
70+
[3.4988e-02, -1.3020e-01, 6.0043e-02, 2.7782e-02, -3.1483e-01,
71+
-3.8940e-01, -5.5557e-01, -5.9540e-01, -2.9808e-02, 3.1468e-01,
72+
-4.5809e-01, 4.3313e-03],
73+
[1.2234e-01, -1.3285e-01, 6.3068e-02, -2.3343e-02, -2.3519e-01,
74+
-4.0794e-01, -5.6063e-01, -5.5484e-01, -1.1272e-01, 3.0103e-01,
75+
-4.0983e-01, 3.3038e-02]],
76+
[[9.8597e-02, -1.2121e-01, 1.0718e-01, -2.2644e-02, -4.0282e-01,
77+
- 4.2646e-01, -5.9981e-01, -3.7200e-01, 1.9538e-01, 2.7036e-01,
78+
-3.4072e-01, -1.7965e-03],
79+
[8.8470e-02, -1.2618e-01, 5.3351e-02, -1.8531e-02, -3.3834e-01,
80+
-4.9047e-01, -5.7063e-01, -4.9790e-01, 2.2070e-01, 3.3964e-01,
81+
-4.1604e-01, 2.3519e-02],
82+
[5.8373e-02, -1.2706e-01, 1.0598e-01, 9.3255e-05, -3.0493e-01,
83+
-4.4406e-01, -5.4723e-01, -5.2214e-01, 8.0374e-02, 2.6307e-01,
84+
-4.4571e-01, 8.7052e-02],
85+
[7.9567e-02, -1.2977e-01, 1.1731e-01, 2.6198e-02, -2.4024e-01,
86+
-4.2161e-01, -5.7604e-01, -7.3298e-01, 1.6698e-01, 3.1454e-01,
87+
-4.9189e-01, 2.4027e-02]],
10488
])
105-
torch.testing.assert_close(output, output_target)
89+
torch.testing.assert_close(output, output_target, rtol=1e-4, atol=1e-4)
10690

10791
for layer in encoder.layers:
10892
self.assertTrue(isinstance(layer, TransformerEncoderLayer))
@@ -118,7 +102,7 @@ def test_transformer_encoder_forward(self):
118102
self.assertEqual(layer._layer_norm_position, self.layer_norm)
119103

120104

121-
class TestSubsampler(TensorTestCase):
105+
class TestSubsampler(unittest.TestCase):
122106

123107
def setUp(self):
124108
self.hidden_size = 12
@@ -149,32 +133,20 @@ def test_subsampler_forward(self):
149133
# x shape [batch_size, seq_len, emb_dim]: [2, 9, 10] -> [2, 3, 12]
150134
self.assertEqual(x.size(), torch.Size([batch_size, 3, self.hidden_size]))
151135

152-
x_target = torch.tensor([[[
153-
-0.4831, -0.0188, -0.0643, 0.2323, 0.1843, -0.0599, 0.0333, -0.0295, 0.0926,
154-
0.0629, 0.4416, -0.3737
155-
],
156-
[
157-
-0.0230, 0.0513, -0.2007, -0.2211, 0.7072, 0.0523,
158-
-0.0546, 0.0382, -0.0606, -0.8240, -0.3379,
159-
-0.7052
160-
],
161-
[
162-
0.0229, 0.1770, -0.2644, -0.5954, 0.8251, -0.0118,
163-
-0.0228, -0.2697, 0.1242, 0.1570, -0.2263, -0.9022
164-
]],
165-
[[
166-
-0.4647, 0.0986, -0.1160, 0.0453, 0.2717, -0.0112,
167-
0.0018, 0.0935, 0.2077, -0.2647, 0.3621, -0.4435
168-
],
169-
[
170-
0.0116, -0.1874, -0.0305, -0.5209, 0.7063,
171-
-0.0522, 0.0577, 0.4307, 0.1027, -0.1947, 0.0964,
172-
-0.8076
173-
],
174-
[
175-
-0.2909, -0.0827, -0.1345, -0.4011, 0.4482,
176-
0.4247, 0.2187, -0.2467, 0.0096, -0.2841, 0.0799,
177-
-1.2243
178-
]]])
179-
self.assertTensorAlmostEqual(x, x_target)
180-
self.assertTensorAlmostEqual(x_length, torch.tensor([3, 3]))
136+
# yapf: disable
137+
x_target = torch.tensor([
138+
[[-0.4831, -0.0188, -0.0643, 0.2323, 0.1843, -0.0599, 0.0333,
139+
-0.0295, 0.0926, 0.0629, 0.4416, -0.3737],
140+
[-0.0230, 0.0513, -0.2007, -0.2211, 0.7072, 0.0523, -0.0546,
141+
0.0382, -0.0606, -0.8240, -0.3379, -0.7052],
142+
[0.0229, 0.1770, -0.2644, -0.5954, 0.8251, -0.0118, -0.0228,
143+
-0.2697, 0.1242, 0.1570, -0.2263, -0.9022]],
144+
[[-0.4647, 0.0986, -0.1160, 0.0453, 0.2717, -0.0112, 0.0018,
145+
0.0935, 0.2077, -0.2647, 0.3621, -0.4435],
146+
[0.0116, -0.1874, -0.0305, -0.5209, 0.7063, -0.0522, 0.0577,
147+
0.4307, 0.1027, -0.1947, 0.0964, -0.8076],
148+
[-0.2909, -0.0827, -0.1345, -0.4011, 0.4482, 0.4247, 0.2187,
149+
-0.2467, 0.0096, -0.2841, 0.0799, -1.2243]],
150+
])
151+
torch.testing.assert_close(x, x_target, rtol=1e-4, atol=1e-4)
152+
torch.testing.assert_close(x_length, torch.tensor([3, 3]))

0 commit comments

Comments
 (0)