Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -46,6 +46,10 @@ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
|
46 |
|
47 |
PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
48 |
|
|
|
|
|
|
|
|
|
49 |
pitch_num_dic = {
|
50 |
'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
|
51 |
'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
|
@@ -167,6 +171,8 @@ def split_audio(waveform, sample_rate):
|
|
167 |
|
168 |
|
169 |
|
|
|
|
|
170 |
class Music2emo:
|
171 |
def __init__(
|
172 |
self,
|
@@ -206,6 +212,37 @@ class Music2emo:
|
|
206 |
self.music2emo_model.to(self.device)
|
207 |
self.music2emo_model.eval()
|
208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
def predict(self, audio, threshold = 0.5):
|
210 |
|
211 |
feature_dir = Path("./inference/temp_out")
|
@@ -263,23 +300,11 @@ class Music2emo:
|
|
263 |
final_embedding_mert.to(self.device)
|
264 |
|
265 |
# --- Chord feature extract ---
|
266 |
-
config = HParams.load("./inference/data/run_config.yaml")
|
267 |
-
config.feature['large_voca'] = True
|
268 |
-
config.model['num_chords'] = 170
|
269 |
-
model_file = './inference/data/btc_model_large_voca.pt'
|
270 |
-
idx_to_chord = idx2voca_chord()
|
271 |
-
model = BTC_model(config=config.model).to(self.device)
|
272 |
-
|
273 |
-
if os.path.isfile(model_file):
|
274 |
-
checkpoint = torch.load(model_file, map_location=self.device)
|
275 |
-
mean = checkpoint['mean']
|
276 |
-
std = checkpoint['std']
|
277 |
-
model.load_state_dict(checkpoint['model'])
|
278 |
|
279 |
audio_path = audio
|
280 |
audio_id = audio_path.split("/")[-1][:-4]
|
281 |
try:
|
282 |
-
feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, config)
|
283 |
except:
|
284 |
logger.info("audio file failed to load : %s" % audio_path)
|
285 |
assert(False)
|
@@ -287,9 +312,9 @@ class Music2emo:
|
|
287 |
logger.info("audio file loaded and feature computation success : %s" % audio_path)
|
288 |
|
289 |
feature = feature.T
|
290 |
-
feature = (feature - mean) / std
|
291 |
time_unit = feature_per_second
|
292 |
-
n_timestep = config.model['timestep']
|
293 |
|
294 |
num_pad = n_timestep - (feature.shape[0] % n_timestep)
|
295 |
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
|
@@ -298,11 +323,11 @@ class Music2emo:
|
|
298 |
start_time = 0.0
|
299 |
lines = []
|
300 |
with torch.no_grad():
|
301 |
-
|
302 |
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(self.device)
|
303 |
for t in range(num_instance):
|
304 |
-
self_attn_output, _ =
|
305 |
-
prediction, _ =
|
306 |
prediction = prediction.squeeze()
|
307 |
for i in range(n_timestep):
|
308 |
if t == 0 and i == 0:
|
@@ -310,12 +335,12 @@ class Music2emo:
|
|
310 |
continue
|
311 |
if prediction[i].item() != prev_chord:
|
312 |
lines.append(
|
313 |
-
'%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i),
|
314 |
start_time = time_unit * (n_timestep * t + i)
|
315 |
prev_chord = prediction[i].item()
|
316 |
if t == num_instance - 1 and i + num_pad == n_timestep:
|
317 |
if start_time != time_unit * (n_timestep * t + i):
|
318 |
-
lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i),
|
319 |
break
|
320 |
|
321 |
save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
|
@@ -356,24 +381,9 @@ class Music2emo:
|
|
356 |
midi.instruments.append(instrument)
|
357 |
midi.write(save_path.replace('.lab', '.midi'))
|
358 |
|
359 |
-
tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"]
|
360 |
-
mode_signatures = ["major", "minor"] # Major and minor modes
|
361 |
|
362 |
-
tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)}
|
363 |
-
mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)}
|
364 |
-
idx_to_tonic = {idx: tonic for tonic, idx in tonic_to_idx.items()}
|
365 |
-
idx_to_mode = {idx: mode for mode, idx in mode_to_idx.items()}
|
366 |
-
|
367 |
-
with open('inference/data/chord.json', 'r') as f:
|
368 |
-
chord_to_idx = json.load(f)
|
369 |
-
with open('inference/data/chord_inv.json', 'r') as f:
|
370 |
-
idx_to_chord = json.load(f)
|
371 |
-
idx_to_chord = {int(k): v for k, v in idx_to_chord.items()} # Ensure keys are ints
|
372 |
-
with open('inference/data/chord_root.json') as json_file:
|
373 |
-
chordRootDic = json.load(json_file)
|
374 |
-
with open('inference/data/chord_attr.json') as json_file:
|
375 |
-
chordAttrDic = json.load(json_file)
|
376 |
|
|
|
377 |
try:
|
378 |
midi_file = converter.parse(save_path.replace('.lab', '.midi'))
|
379 |
key_signature = str(midi_file.analyze('key'))
|
@@ -390,7 +400,7 @@ class Music2emo:
|
|
390 |
else:
|
391 |
mode = key_signature.split()[-1]
|
392 |
|
393 |
-
encoded_mode = mode_to_idx.get(mode, 0)
|
394 |
mode_tensor = torch.tensor([encoded_mode], dtype=torch.long).to(self.device)
|
395 |
|
396 |
converted_lines = normalize_chord(save_path, key_signature, key_type)
|
@@ -419,19 +429,19 @@ class Music2emo:
|
|
419 |
for start, end, chord in chords:
|
420 |
chord_arr = chord.split(":")
|
421 |
if len(chord_arr) == 1:
|
422 |
-
chordRootID = chordRootDic[chord_arr[0]]
|
423 |
if chord_arr[0] == "N" or chord_arr[0] == "X":
|
424 |
chordAttrID = 0
|
425 |
else:
|
426 |
chordAttrID = 1
|
427 |
elif len(chord_arr) == 2:
|
428 |
-
chordRootID = chordRootDic[chord_arr[0]]
|
429 |
-
chordAttrID = chordAttrDic[chord_arr[1]]
|
430 |
encoded_root.append(chordRootID)
|
431 |
encoded_attr.append(chordAttrID)
|
432 |
|
433 |
-
if chord in chord_to_idx:
|
434 |
-
encoded.append(chord_to_idx[chord])
|
435 |
else:
|
436 |
print(f"Warning: Chord {chord} not found in chord.json. Skipping.")
|
437 |
|
@@ -585,14 +595,6 @@ with gr.Blocks(css=css) as demo:
|
|
585 |
elem_id="output-text"
|
586 |
)
|
587 |
|
588 |
-
# Add example usage
|
589 |
-
# gr.Examples(
|
590 |
-
# examples=["inference/input/test.mp3"],
|
591 |
-
# inputs=input_audio,
|
592 |
-
# outputs=output_text,
|
593 |
-
# fn=lambda x: format_prediction(music2emo.predict(x, 0.5)),
|
594 |
-
# cache_examples=True
|
595 |
-
# )
|
596 |
|
597 |
predict_btn.click(
|
598 |
fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),
|
|
|
46 |
|
47 |
PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
48 |
|
49 |
+
tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"]
|
50 |
+
mode_signatures = ["major", "minor"] # Major and minor modes
|
51 |
+
|
52 |
+
|
53 |
pitch_num_dic = {
|
54 |
'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
|
55 |
'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
|
|
|
171 |
|
172 |
|
173 |
|
174 |
+
|
175 |
+
|
176 |
class Music2emo:
|
177 |
def __init__(
|
178 |
self,
|
|
|
212 |
self.music2emo_model.to(self.device)
|
213 |
self.music2emo_model.eval()
|
214 |
|
215 |
+
self.config = HParams.load("./inference/data/run_config.yaml")
|
216 |
+
self.config.feature['large_voca'] = True
|
217 |
+
self.config.model['num_chords'] = 170
|
218 |
+
model_file = './inference/data/btc_model_large_voca.pt'
|
219 |
+
self.idx_to_voca = idx2voca_chord()
|
220 |
+
self.btc_model = BTC_model(config=self.config.model).to(self.device)
|
221 |
+
|
222 |
+
if os.path.isfile(model_file):
|
223 |
+
checkpoint = torch.load(model_file, map_location=self.device)
|
224 |
+
self.mean = checkpoint['mean']
|
225 |
+
self.std = checkpoint['std']
|
226 |
+
self.btc_model.load_state_dict(checkpoint['model'])
|
227 |
+
|
228 |
+
|
229 |
+
self.tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)}
|
230 |
+
self.mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)}
|
231 |
+
self.idx_to_tonic = {idx: tonic for tonic, idx in self.tonic_to_idx.items()}
|
232 |
+
self.idx_to_mode = {idx: mode for mode, idx in self.mode_to_idx.items()}
|
233 |
+
|
234 |
+
with open('inference/data/chord.json', 'r') as f:
|
235 |
+
self.chord_to_idx = json.load(f)
|
236 |
+
with open('inference/data/chord_inv.json', 'r') as f:
|
237 |
+
self.idx_to_chord = json.load(f)
|
238 |
+
self.idx_to_chord = {int(k): v for k, v in self.idx_to_chord.items()} # Ensure keys are ints
|
239 |
+
with open('inference/data/chord_root.json') as json_file:
|
240 |
+
self.chordRootDic = json.load(json_file)
|
241 |
+
with open('inference/data/chord_attr.json') as json_file:
|
242 |
+
self.chordAttrDic = json.load(json_file)
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
def predict(self, audio, threshold = 0.5):
|
247 |
|
248 |
feature_dir = Path("./inference/temp_out")
|
|
|
300 |
final_embedding_mert.to(self.device)
|
301 |
|
302 |
# --- Chord feature extract ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
audio_path = audio
|
305 |
audio_id = audio_path.split("/")[-1][:-4]
|
306 |
try:
|
307 |
+
feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, self.config)
|
308 |
except:
|
309 |
logger.info("audio file failed to load : %s" % audio_path)
|
310 |
assert(False)
|
|
|
312 |
logger.info("audio file loaded and feature computation success : %s" % audio_path)
|
313 |
|
314 |
feature = feature.T
|
315 |
+
feature = (feature - self.mean) / self.std
|
316 |
time_unit = feature_per_second
|
317 |
+
n_timestep = self.config.model['timestep']
|
318 |
|
319 |
num_pad = n_timestep - (feature.shape[0] % n_timestep)
|
320 |
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
|
|
|
323 |
start_time = 0.0
|
324 |
lines = []
|
325 |
with torch.no_grad():
|
326 |
+
self.btc_model.eval()
|
327 |
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(self.device)
|
328 |
for t in range(num_instance):
|
329 |
+
self_attn_output, _ = self.btc_model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
|
330 |
+
prediction, _ = self.btc_model.output_layer(self_attn_output)
|
331 |
prediction = prediction.squeeze()
|
332 |
for i in range(n_timestep):
|
333 |
if t == 0 and i == 0:
|
|
|
335 |
continue
|
336 |
if prediction[i].item() != prev_chord:
|
337 |
lines.append(
|
338 |
+
'%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord]))
|
339 |
start_time = time_unit * (n_timestep * t + i)
|
340 |
prev_chord = prediction[i].item()
|
341 |
if t == num_instance - 1 and i + num_pad == n_timestep:
|
342 |
if start_time != time_unit * (n_timestep * t + i):
|
343 |
+
lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord]))
|
344 |
break
|
345 |
|
346 |
save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
|
|
|
381 |
midi.instruments.append(instrument)
|
382 |
midi.write(save_path.replace('.lab', '.midi'))
|
383 |
|
|
|
|
|
384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
+
|
387 |
try:
|
388 |
midi_file = converter.parse(save_path.replace('.lab', '.midi'))
|
389 |
key_signature = str(midi_file.analyze('key'))
|
|
|
400 |
else:
|
401 |
mode = key_signature.split()[-1]
|
402 |
|
403 |
+
encoded_mode = self.mode_to_idx.get(mode, 0)
|
404 |
mode_tensor = torch.tensor([encoded_mode], dtype=torch.long).to(self.device)
|
405 |
|
406 |
converted_lines = normalize_chord(save_path, key_signature, key_type)
|
|
|
429 |
for start, end, chord in chords:
|
430 |
chord_arr = chord.split(":")
|
431 |
if len(chord_arr) == 1:
|
432 |
+
chordRootID = self.chordRootDic[chord_arr[0]]
|
433 |
if chord_arr[0] == "N" or chord_arr[0] == "X":
|
434 |
chordAttrID = 0
|
435 |
else:
|
436 |
chordAttrID = 1
|
437 |
elif len(chord_arr) == 2:
|
438 |
+
chordRootID = self.chordRootDic[chord_arr[0]]
|
439 |
+
chordAttrID = self.chordAttrDic[chord_arr[1]]
|
440 |
encoded_root.append(chordRootID)
|
441 |
encoded_attr.append(chordAttrID)
|
442 |
|
443 |
+
if chord in self.chord_to_idx:
|
444 |
+
encoded.append(self.chord_to_idx[chord])
|
445 |
else:
|
446 |
print(f"Warning: Chord {chord} not found in chord.json. Skipping.")
|
447 |
|
|
|
595 |
elem_id="output-text"
|
596 |
)
|
597 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
598 |
|
599 |
predict_btn.click(
|
600 |
fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),
|