Skip to content

Commit 823df82

Browse files
Generate notes test 2
1 parent 0f6fd46 commit 823df82

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

src/model_api.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def generate_notes_test(self, midi_file):
5454
music = abstractor.abstract()
5555

5656
# Generate notes to the right input size for the model, which is (64, 88)
57-
assert music.shape[1] == 88
57+
print(music.shape)
58+
assert music.shape[1] == 88, "Music representation has wrong number of features"
5859
num_quantums = music.shape[0] # Much bigger than 64
5960

6061
for i in range(10):
@@ -64,3 +65,8 @@ def generate_notes_test(self, midi_file):
6465
# Predict the next 16 notes
6566
predictions = self.model.predict(input_seq)
6667
print(f"Prediction {i+1}: {predictions}")
68+
assert predictions.shape == (1, 16, 88), "Predictions have wrong shape"
69+
70+
# Save the predictions to a numpy array
71+
# Reshape the predictions to (16, 88)
72+
predictions = tf.reshape(predictions, (16, 88))

test_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
api = GimbopAPI()
44
api.generate_notes_test(
5-
"data/maestro-v3.0.0/2018/MIDI-Unprocessed_Chamber1_MID--AUDIO_07_R3_2018_wav--1.midi"
5+
"data/maestro-v3.0.0/2018/MIDI-Unprocessed_Chamber2_MID--AUDIO_09_R3_2018_wav--1.midi"
66
)

0 commit comments

Comments
 (0)