kjysmu commited on
Commit
9870f94
·
verified ·
1 Parent(s): e881d42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -50
app.py CHANGED
@@ -46,6 +46,10 @@ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
46
 
47
  PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
48
 
 
 
 
 
49
  pitch_num_dic = {
50
  'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
51
  'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
@@ -167,6 +171,8 @@ def split_audio(waveform, sample_rate):
167
 
168
 
169
 
 
 
170
  class Music2emo:
171
  def __init__(
172
  self,
@@ -206,6 +212,37 @@ class Music2emo:
206
  self.music2emo_model.to(self.device)
207
  self.music2emo_model.eval()
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  def predict(self, audio, threshold = 0.5):
210
 
211
  feature_dir = Path("./inference/temp_out")
@@ -263,23 +300,11 @@ class Music2emo:
263
  final_embedding_mert.to(self.device)
264
 
265
  # --- Chord feature extract ---
266
- config = HParams.load("./inference/data/run_config.yaml")
267
- config.feature['large_voca'] = True
268
- config.model['num_chords'] = 170
269
- model_file = './inference/data/btc_model_large_voca.pt'
270
- idx_to_chord = idx2voca_chord()
271
- model = BTC_model(config=config.model).to(self.device)
272
-
273
- if os.path.isfile(model_file):
274
- checkpoint = torch.load(model_file, map_location=self.device)
275
- mean = checkpoint['mean']
276
- std = checkpoint['std']
277
- model.load_state_dict(checkpoint['model'])
278
 
279
  audio_path = audio
280
  audio_id = audio_path.split("/")[-1][:-4]
281
  try:
282
- feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, config)
283
  except:
284
  logger.info("audio file failed to load : %s" % audio_path)
285
  assert(False)
@@ -287,9 +312,9 @@ class Music2emo:
287
  logger.info("audio file loaded and feature computation success : %s" % audio_path)
288
 
289
  feature = feature.T
290
- feature = (feature - mean) / std
291
  time_unit = feature_per_second
292
- n_timestep = config.model['timestep']
293
 
294
  num_pad = n_timestep - (feature.shape[0] % n_timestep)
295
  feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
@@ -298,11 +323,11 @@ class Music2emo:
298
  start_time = 0.0
299
  lines = []
300
  with torch.no_grad():
301
- model.eval()
302
  feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(self.device)
303
  for t in range(num_instance):
304
- self_attn_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
305
- prediction, _ = model.output_layer(self_attn_output)
306
  prediction = prediction.squeeze()
307
  for i in range(n_timestep):
308
  if t == 0 and i == 0:
@@ -310,12 +335,12 @@ class Music2emo:
310
  continue
311
  if prediction[i].item() != prev_chord:
312
  lines.append(
313
- '%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
314
  start_time = time_unit * (n_timestep * t + i)
315
  prev_chord = prediction[i].item()
316
  if t == num_instance - 1 and i + num_pad == n_timestep:
317
  if start_time != time_unit * (n_timestep * t + i):
318
- lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
319
  break
320
 
321
  save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
@@ -356,24 +381,9 @@ class Music2emo:
356
  midi.instruments.append(instrument)
357
  midi.write(save_path.replace('.lab', '.midi'))
358
 
359
- tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"]
360
- mode_signatures = ["major", "minor"] # Major and minor modes
361
 
362
- tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)}
363
- mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)}
364
- idx_to_tonic = {idx: tonic for tonic, idx in tonic_to_idx.items()}
365
- idx_to_mode = {idx: mode for mode, idx in mode_to_idx.items()}
366
-
367
- with open('inference/data/chord.json', 'r') as f:
368
- chord_to_idx = json.load(f)
369
- with open('inference/data/chord_inv.json', 'r') as f:
370
- idx_to_chord = json.load(f)
371
- idx_to_chord = {int(k): v for k, v in idx_to_chord.items()} # Ensure keys are ints
372
- with open('inference/data/chord_root.json') as json_file:
373
- chordRootDic = json.load(json_file)
374
- with open('inference/data/chord_attr.json') as json_file:
375
- chordAttrDic = json.load(json_file)
376
 
 
377
  try:
378
  midi_file = converter.parse(save_path.replace('.lab', '.midi'))
379
  key_signature = str(midi_file.analyze('key'))
@@ -390,7 +400,7 @@ class Music2emo:
390
  else:
391
  mode = key_signature.split()[-1]
392
 
393
- encoded_mode = mode_to_idx.get(mode, 0)
394
  mode_tensor = torch.tensor([encoded_mode], dtype=torch.long).to(self.device)
395
 
396
  converted_lines = normalize_chord(save_path, key_signature, key_type)
@@ -419,19 +429,19 @@ class Music2emo:
419
  for start, end, chord in chords:
420
  chord_arr = chord.split(":")
421
  if len(chord_arr) == 1:
422
- chordRootID = chordRootDic[chord_arr[0]]
423
  if chord_arr[0] == "N" or chord_arr[0] == "X":
424
  chordAttrID = 0
425
  else:
426
  chordAttrID = 1
427
  elif len(chord_arr) == 2:
428
- chordRootID = chordRootDic[chord_arr[0]]
429
- chordAttrID = chordAttrDic[chord_arr[1]]
430
  encoded_root.append(chordRootID)
431
  encoded_attr.append(chordAttrID)
432
 
433
- if chord in chord_to_idx:
434
- encoded.append(chord_to_idx[chord])
435
  else:
436
  print(f"Warning: Chord {chord} not found in chord.json. Skipping.")
437
 
@@ -585,14 +595,6 @@ with gr.Blocks(css=css) as demo:
585
  elem_id="output-text"
586
  )
587
 
588
- # Add example usage
589
- # gr.Examples(
590
- # examples=["inference/input/test.mp3"],
591
- # inputs=input_audio,
592
- # outputs=output_text,
593
- # fn=lambda x: format_prediction(music2emo.predict(x, 0.5)),
594
- # cache_examples=True
595
- # )
596
 
597
  predict_btn.click(
598
  fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),
 
46
 
47
  PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
48
 
49
+ tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"]
50
+ mode_signatures = ["major", "minor"] # Major and minor modes
51
+
52
+
53
  pitch_num_dic = {
54
  'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
55
  'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
 
171
 
172
 
173
 
174
+
175
+
176
  class Music2emo:
177
  def __init__(
178
  self,
 
212
  self.music2emo_model.to(self.device)
213
  self.music2emo_model.eval()
214
 
215
+ self.config = HParams.load("./inference/data/run_config.yaml")
216
+ self.config.feature['large_voca'] = True
217
+ self.config.model['num_chords'] = 170
218
+ model_file = './inference/data/btc_model_large_voca.pt'
219
+ self.idx_to_voca = idx2voca_chord()
220
+ self.btc_model = BTC_model(config=self.config.model).to(self.device)
221
+
222
+ if os.path.isfile(model_file):
223
+ checkpoint = torch.load(model_file, map_location=self.device)
224
+ self.mean = checkpoint['mean']
225
+ self.std = checkpoint['std']
226
+ self.btc_model.load_state_dict(checkpoint['model'])
227
+
228
+
229
+ self.tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)}
230
+ self.mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)}
231
+ self.idx_to_tonic = {idx: tonic for tonic, idx in self.tonic_to_idx.items()}
232
+ self.idx_to_mode = {idx: mode for mode, idx in self.mode_to_idx.items()}
233
+
234
+ with open('inference/data/chord.json', 'r') as f:
235
+ self.chord_to_idx = json.load(f)
236
+ with open('inference/data/chord_inv.json', 'r') as f:
237
+ self.idx_to_chord = json.load(f)
238
+ self.idx_to_chord = {int(k): v for k, v in self.idx_to_chord.items()} # Ensure keys are ints
239
+ with open('inference/data/chord_root.json') as json_file:
240
+ self.chordRootDic = json.load(json_file)
241
+ with open('inference/data/chord_attr.json') as json_file:
242
+ self.chordAttrDic = json.load(json_file)
243
+
244
+
245
+
246
  def predict(self, audio, threshold = 0.5):
247
 
248
  feature_dir = Path("./inference/temp_out")
 
300
  final_embedding_mert.to(self.device)
301
 
302
  # --- Chord feature extract ---
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  audio_path = audio
305
  audio_id = audio_path.split("/")[-1][:-4]
306
  try:
307
+ feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, self.config)
308
  except:
309
  logger.info("audio file failed to load : %s" % audio_path)
310
  assert(False)
 
312
  logger.info("audio file loaded and feature computation success : %s" % audio_path)
313
 
314
  feature = feature.T
315
+ feature = (feature - self.mean) / self.std
316
  time_unit = feature_per_second
317
+ n_timestep = self.config.model['timestep']
318
 
319
  num_pad = n_timestep - (feature.shape[0] % n_timestep)
320
  feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
 
323
  start_time = 0.0
324
  lines = []
325
  with torch.no_grad():
326
+ self.btc_model.eval()
327
  feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(self.device)
328
  for t in range(num_instance):
329
+ self_attn_output, _ = self.btc_model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
330
+ prediction, _ = self.btc_model.output_layer(self_attn_output)
331
  prediction = prediction.squeeze()
332
  for i in range(n_timestep):
333
  if t == 0 and i == 0:
 
335
  continue
336
  if prediction[i].item() != prev_chord:
337
  lines.append(
338
+ '%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord]))
339
  start_time = time_unit * (n_timestep * t + i)
340
  prev_chord = prediction[i].item()
341
  if t == num_instance - 1 and i + num_pad == n_timestep:
342
  if start_time != time_unit * (n_timestep * t + i):
343
+ lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord]))
344
  break
345
 
346
  save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
 
381
  midi.instruments.append(instrument)
382
  midi.write(save_path.replace('.lab', '.midi'))
383
 
 
 
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
+
387
  try:
388
  midi_file = converter.parse(save_path.replace('.lab', '.midi'))
389
  key_signature = str(midi_file.analyze('key'))
 
400
  else:
401
  mode = key_signature.split()[-1]
402
 
403
+ encoded_mode = self.mode_to_idx.get(mode, 0)
404
  mode_tensor = torch.tensor([encoded_mode], dtype=torch.long).to(self.device)
405
 
406
  converted_lines = normalize_chord(save_path, key_signature, key_type)
 
429
  for start, end, chord in chords:
430
  chord_arr = chord.split(":")
431
  if len(chord_arr) == 1:
432
+ chordRootID = self.chordRootDic[chord_arr[0]]
433
  if chord_arr[0] == "N" or chord_arr[0] == "X":
434
  chordAttrID = 0
435
  else:
436
  chordAttrID = 1
437
  elif len(chord_arr) == 2:
438
+ chordRootID = self.chordRootDic[chord_arr[0]]
439
+ chordAttrID = self.chordAttrDic[chord_arr[1]]
440
  encoded_root.append(chordRootID)
441
  encoded_attr.append(chordAttrID)
442
 
443
+ if chord in self.chord_to_idx:
444
+ encoded.append(self.chord_to_idx[chord])
445
  else:
446
  print(f"Warning: Chord {chord} not found in chord.json. Skipping.")
447
 
 
595
  elem_id="output-text"
596
  )
597
 
 
 
 
 
 
 
 
 
598
 
599
  predict_btn.click(
600
  fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),