#!/usr/bin/env python3 """Recipe for training a Transformer ASR system with CommonVoice The system employs an encoder, a decoder, and an attention mechanism between them. Decoding is performed with (CTC/Att joint) beamsearch. To run this recipe, do the following: > python train.py hparams/conformer_large.yaml Authors * Titouan Parcollet 2021, 2024 * Jianyuan Zhong 2020 * Pooneh Mousavi 2023 """ import os import sys import torch import torchaudio from hyperpyyaml import load_hyperpyyaml import speechbrain as sb from speechbrain.tokenizers.SentencePiece import SentencePiece # Dataset preparation (parsing CommonVoice) import time from speechbrain.utils.fetching import fetch, LocalStrategy from speechbrain.utils.epoch_loop import EpochCounter from speechbrain.utils.data_utils import split_path from speechbrain.dataio.preprocess import AudioNormalizer from speechbrain.core import AMPConfig from speechbrain.core import Stage from typing import List from speechbrain.core import get_logger from speechbrain.dataio.dataloader import DataLoader, LoopedLoader from speechbrain.utils.data_utils import undo_padding import structlog from speechbrain.utils.distributed import if_main_process from tqdm import tqdm from collections import defaultdict logger = structlog.get_logger(__name__) # Initialize file logger def log_step_to_file(message): with open("training_logs_file.log", "a") as f: f.write(message + "\n") def timer(func): """Decorator to measure the execution time of a function.""" def wrapper(*args, **kwargs): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() elapsed_time = end_time - start_time # Print in red print( f"\033[91mFunction '{func.__name__}' took {elapsed_time:.4f} seconds to complete.\033[0m" ) return result return wrapper def print_red(*args): output = " ".join([str(x) for x in args]) print(f"\033[91m { output} \033[0m") class TimerContext: """Context manager to measure execution time of a code block.""" def __init__(self, block_name="", print_=False): self.block_name = block_name self.print_ = print_ def __enter__(self): self.start_time = time.time() return self def __exit__(self, exc_type, exc_val, exc_tb): end_time = time.time() elapsed_time = end_time - self.start_time if self.print_: if exc_type is None: print( f"\033[91mBlock '{self.block_name}' took {elapsed_time:.4f} seconds to complete.\033[0m" ) else: print( f"\033[91mBlock '{self.block_name}' failed after {elapsed_time:.4f} seconds with {exc_type.__name__}: {str(exc_val)}\033[0m" ) logger.exception( f"\033[91mBlock '{self.block_name}' failed after {elapsed_time:.4f} seconds with {exc_type.__name__}: {str(exc_val)}\033[0m" ) # Return False to propagate exceptions, True to suppress them else: return False class ASR(sb.core.Brain): @timer def audio_normalizer(self, signal, sr): normalizer = AudioNormalizer() return normalizer(signal, sr) @timer def load_audio(self, path, savedir=None): """Optimized audio loading function""" source, fl = split_path(path) path = fetch( fl, source=source, savedir=savedir, local_strategy=LocalStrategy.SYMLINK, ) signal, sr = torchaudio.load(str(path), channels_first=False) return self.audio_normalizer(signal, sr) @timer def transcribe_file(self, path, **kwargs): """Optimized transcription function""" waveform = self.load_audio(path, **kwargs) audio_length = waveform.shape[0] / self.hparams.sample_rate print(f"\033[94mAudio length: {audio_length:.2f} seconds\033[0m") batch = waveform.unsqueeze(0).to(self.device) # Move batch to the same device rel_length = torch.tensor([1.0]).to(self.device) predictions = self.forward_audio(batch, rel_length) return self.compute_objective_single(predictions) @timer def compute_objective_single(self, predictions): """Decodes predictions""" p_ctc, p_seq, wav_lens, predicted_tokens = predictions predicted_words = self.tokenizer(predicted_tokens, task="decode_from_list") return predicted_words[0] @timer def forward_audio(self, wavs, wav_lens): """Optimized forward pass""" feats = self.hparams.compute_features(wavs) wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) tokens_bos = torch.tensor([[1, 105, 242, 242]], dtype=torch.int64).to( self.device ) feats = self.hparams.normalize( feats, wav_lens, epoch=self.hparams.epoch_counter.current ) # Forward modules with TimerContext(block_name="CNN", print_=False): src = self.modules.CNN(feats) with TimerContext("Transformer", print_=False): enc_out, pred = self.modules.Transformer( src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index ) # Compute outputs with TimerContext("ctc_lin --- seq_lin", print_=False): logits = self.modules.ctc_lin(enc_out) p_ctc = self.hparams.log_softmax(logits) pred = self.modules.seq_lin(pred) p_seq = self.hparams.log_softmax(pred) hyps = None current_epoch = self.hparams.epoch_counter.current stage = sb.Stage.TEST with TimerContext("SEARCH", print_=False): if stage == sb.Stage.TEST: hyps, _, _, _ = self.hparams.valid_search(enc_out.detach(), wav_lens) return p_ctc, p_seq, wav_lens, hyps def compute_forward(self, batch, stage): """Forward computations from the waveform batches to the output probabilities.""" batch = batch.to(self.device) wavs, wav_lens = batch.sig wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) tokens_bos, _ = batch.tokens_bos # Add waveform augmentation if specified. if ( stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment") and self.optimizer_step > self.hparams.augment_warmup ): wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens) tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos) # compute features feats = self.hparams.compute_features(wavs) current_epoch = self.hparams.epoch_counter.current feats = self.hparams.normalize(feats, wav_lens, epoch=current_epoch) # Add feature augmentation if specified. if ( stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment") and self.optimizer_step > self.hparams.augment_warmup ): feats, fea_lens = self.hparams.fea_augment(feats, wav_lens) tokens_bos = self.hparams.fea_augment.replicate_labels(tokens_bos) # forward modules src = self.modules.CNN(feats) enc_out, pred = self.modules.Transformer( src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index ) # output layer for ctc log-probabilities logits = self.modules.ctc_lin(enc_out) p_ctc = self.hparams.log_softmax(logits) # output layer for seq2seq log-probabilities pred = self.modules.seq_lin(pred) p_seq = self.hparams.log_softmax(pred) # Compute outputs hyps = None current_epoch = self.hparams.epoch_counter.current is_valid_search = ( stage == sb.Stage.VALID and current_epoch % self.hparams.valid_search_interval == 0 ) is_test_search = stage == sb.Stage.TEST if is_valid_search: hyps, _, _, _ = self.hparams.valid_search(enc_out.detach(), wav_lens) elif is_test_search: hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens) return p_ctc, p_seq, wav_lens, hyps, logits def compute_objectives_unlabeled(self, predictions, batch, stage): """Computes the loss (CTC+NLL) given predictions and targets.""" (p_ctc, p_seq, wav_lens, predicted_tokens, logits) = predictions return logits def compute_objectives(self, predictions, batch, stage): """Computes the loss (CTC+NLL) given predictions and targets.""" (p_ctc, p_seq, wav_lens, predicted_tokens, logits) = predictions ids = batch.id tokens_eos, tokens_eos_lens = batch.tokens_eos tokens, tokens_lens = batch.tokens # Augment Labels if stage == sb.Stage.TRAIN: # Labels must be extended if parallel augmentation or concatenated # augmentation was performed on the input (increasing the time dimension) if ( hasattr(self.hparams, "wav_augment") and self.optimizer_step > self.hparams.augment_warmup ): ( tokens, tokens_lens, tokens_eos, tokens_eos_lens, ) = self.hparams.wav_augment.replicate_multiple_labels( tokens, tokens_lens, tokens_eos, tokens_eos_lens ) if ( hasattr(self.hparams, "fea_augment") and self.optimizer_step > self.hparams.augment_warmup ): ( tokens, tokens_lens, tokens_eos, tokens_eos_lens, ) = self.hparams.fea_augment.replicate_multiple_labels( tokens, tokens_lens, tokens_eos, tokens_eos_lens ) loss_seq = self.hparams.seq_cost(p_seq, tokens_eos, length=tokens_eos_lens) loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) loss = ( self.hparams.ctc_weight * loss_ctc + (1 - self.hparams.ctc_weight) * loss_seq ) if stage != sb.Stage.TRAIN: current_epoch = self.hparams.epoch_counter.current valid_search_interval = self.hparams.valid_search_interval if current_epoch % valid_search_interval == 0 or (stage == sb.Stage.TEST): # Decode token terms to words predicted_words = self.tokenizer( predicted_tokens, task="decode_from_list" ) # Convert indices to words target_words = undo_padding(tokens, tokens_lens) target_words = self.tokenizer(target_words, task="decode_from_list") self.wer_metric.append(ids, predicted_words, target_words) self.cer_metric.append(ids, predicted_words, target_words) # compute the accuracy of the one-step-forward prediction self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens) return loss def on_fit_batch_end(self, batch, outputs, loss, should_step): """At the end of the optimizer step, apply noam annealing.""" if should_step: self.hparams.noam_annealing(self.optimizer) def on_stage_start(self, stage, epoch): """Gets called at the beginning of each epoch""" if stage != sb.Stage.TRAIN: self.acc_metric = self.hparams.acc_computer() self.cer_metric = self.hparams.cer_computer() self.wer_metric = self.hparams.error_rate_computer() def on_stage_end(self, stage, stage_loss, epoch): """Gets called at the end of a epoch.""" # Compute/store important stats stage_stats = {"loss": stage_loss} if stage == sb.Stage.TRAIN: self.train_stats = stage_stats else: stage_stats["ACC"] = self.acc_metric.summarize() current_epoch = self.hparams.epoch_counter.current valid_search_interval = self.hparams.valid_search_interval if current_epoch % valid_search_interval == 0 or stage == sb.Stage.TEST: stage_stats["WER"] = self.wer_metric.summarize("error_rate") stage_stats["CER"] = self.cer_metric.summarize("error_rate") log_step_to_file( f"Epoch {epoch}, WER: {stage_stats['WER']}, CER: {stage_stats['CER']}, ACC: {stage_stats['ACC']}" ) print( f"Epoch {epoch}, WER: {stage_stats['WER']}, CER: {stage_stats['CER']}, ACC: {stage_stats['ACC']}" ) # log stats and save checkpoint at end-of-epoch if stage == sb.Stage.VALID: # report different epoch stages according current stage current_epoch = self.hparams.epoch_counter.current lr = self.hparams.noam_annealing.current_lr steps = self.hparams.noam_annealing.n_steps epoch_stats = { "epoch": epoch, "lr": lr, "steps": steps, } self.hparams.train_logger.log_stats( stats_meta=epoch_stats, train_stats=self.train_stats, valid_stats=stage_stats, ) self.checkpointer.save_and_keep_only( meta={"ACC": stage_stats["ACC"], "epoch": epoch}, max_keys=["ACC"], ) elif stage == sb.Stage.TEST: self.hparams.train_logger.log_stats( stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, test_stats=stage_stats, ) if if_main_process(): with open(self.hparams.test_wer_file, "w", encoding="utf-8") as w: self.wer_metric.write_stats(w) import json outputs = [] class TransformerPretrainedASR(ASR): def __init__(self, hparams_path: str): # CLI: hparams_file, run_opts, overrides = sb.parse_arguments([hparams_path]) with open(hparams_file, encoding="utf-8") as fin: hparams = load_hyperpyyaml(fin, overrides) # create ddp_group with the right communication protocol sb.utils.distributed.ddp_init_group(run_opts) # hparams = load_hyperpyyaml(hparams_path) # Initialize ASR Brain with TimerContext("MODEL LOADING:"): tokenizer = SentencePiece( model_dir=hparams["save_folder"], vocab_size=hparams["output_neurons"], annotation_train=hparams["train_csv"], annotation_read="wrd", model_type=hparams["token_type"], character_coverage=hparams["character_coverage"], bos_id=hparams["bos_index"], eos_id=hparams["eos_index"], ) super().__init__( modules=hparams["modules"], opt_class=hparams["Adam"], hparams=hparams, run_opts=run_opts, checkpointer=hparams["checkpointer"], ) self.tokenizer = tokenizer # Initialize missing class attributes self.num_classes = hparams.get("output_neurons", 10) # Example default value self.classwise_acc = [0.0] * self.num_classes self.base_threshold = 0.95 # self.token_to_idx = token_to_idx self.lambda_u = 1.0 self.classwise_threshold = ( torch.ones(self.num_classes).to(self.device) * self.base_threshold ) # Example default threshold self.classwise_acc = defaultdict(lambda: 0.0) # Example base threshold def update_threshold(self, pseudo_labels, confidences): """Dynamically adjust thresholds based on confidence distribution""" for cls in range(self.num_classes): # Iterate over each sample for i in range(pseudo_labels.size(0)): cls_mask = pseudo_labels[i] == cls if cls_mask.sum() > 0: # Calculate average confidence for the current class in the current sample avg_confidence = confidences[i].item() self.classwise_acc[cls] = ( 0.9 * self.classwise_acc[cls] + 0.1 * avg_confidence ) self.classwise_threshold[cls] = min( 0.99, self.base_threshold * (1 - self.classwise_acc[cls]) ) return self.classwise_threshold def fit( self, epoch_counter: EpochCounter, train_set, unlabeled_train_set, valid_set=None, progressbar=None, train_loader_kwargs={}, valid_loader_kwargs={}, unlabeled_train_loader_kwargs={}, ): """Iterate epochs and datasets to improve objective. Relies on the existence of multiple functions that can (or should) be overridden. The following methods are used and expected to have a certain behavior: * ``fit_batch()`` * ``evaluate_batch()`` * ``update_average()`` If the initialization was done with distributed_count > 0 and the distributed_backend is ddp, this will generally handle multiprocess logic, like splitting the training data into subsets for each device and only saving a checkpoint on the main process. Arguments --------- epoch_counter : iterable Each call should return an integer indicating the epoch count. train_set : Dataset, DataLoader A set of data to use for training. If a Dataset is given, a DataLoader is automatically created. If a DataLoader is given, it is used directly. valid_set : Dataset, DataLoader A set of data to use for validation. If a Dataset is given, a DataLoader is automatically created. If a DataLoader is given, it is used directly. progressbar : bool Whether to display the progress of each epoch in a progressbar. train_loader_kwargs : dict Kwargs passed to `make_dataloader()` for making the train_loader (if train_set is a Dataset, not DataLoader). E.G. batch_size, num_workers. DataLoader kwargs are all valid. valid_loader_kwargs : dict Kwargs passed to `make_dataloader()` for making the valid_loader (if valid_set is a Dataset, not DataLoader). E.g., batch_size, num_workers. DataLoader kwargs are all valid. Returns ------- None """ if self.test_only: logger.info("Test only mode, skipping training and validation stages.") return # # train set conversion # if not ( # isinstance(train_set, DataLoader) or isinstance(train_set, LoopedLoader) # ): # train_set = self.make_dataloader( # train_set, stage=sb.Stage.TRAIN, **train_loader_kwargs # ) # print("Made data loader train") # # valid set conversion # if valid_set is not None and not ( # isinstance(valid_set, DataLoader) or isinstance(valid_set, LoopedLoader) # ): # valid_set = self.make_dataloader( # valid_set, # stage=sb.Stage.VALID, # ckpt_prefix=None, # **valid_loader_kwargs, # ) # print("Made data loader valid ") # # unlabeled_train_set conversion # if unlabeled_train_set is not None and not ( # isinstance(valid_set, DataLoader) or isinstance(valid_set, LoopedLoader) # ): # unlabeled_train_set = self.make_dataloader( # unlabeled_train_set, # stage=sb.Stage.TRAIN, # ckpt_prefix=None, # **unlabeled_train_loader_kwargs, # ) # print("Made data loader unlabeled") self.on_fit_start() if progressbar is None: progressbar = not self.noprogressbar # Only show progressbar if requested and main_process enable = progressbar and sb.utils.distributed.if_main_process() # Iterate epochs print( f"Going to epoch counter ----- epochs limit: {epoch_counter.limit} ---- current: {epoch_counter.current}" ) for epoch in epoch_counter: self._fit_train( train_set=train_set, unlabeled_train_set=unlabeled_train_set, epoch=epoch, enable=enable, ) self._fit_valid(valid_set=valid_set, epoch=epoch, enable=enable) # Debug mode only runs a few epochs # if ( # self.debug # and epoch == self.debug_epochs # or self._optimizer_step_limit_exceeded # ): # break def log_step_to_file(self, message): with open("training_logs_file.log", "a") as f: f.write(message) def _fit_train(self, train_set, unlabeled_train_set, epoch, enable): # Training stage self.on_stage_start(Stage.TRAIN, epoch) self.modules.train() self.zero_grad() # Reset nonfinite count to 0 each epoch self.nonfinite_count = 0 if self.train_sampler is not None and hasattr(self.train_sampler, "set_epoch"): self.train_sampler.set_epoch(epoch) # Time since last intra-epoch checkpoint last_ckpt_time = time.time() steps_since_ckpt = 0 with tqdm( zip(unlabeled_train_set, train_set), initial=self.step, dynamic_ncols=True, disable=not enable, colour=self.tqdm_barcolor["train"], total=min(len(unlabeled_train_set), len(train_set)), ) as t: if self.profiler is not None: self.profiler.start() for batch_unlabeled, batch in t: # print(batch.tokens) # print(batch_unlabeled.tokens) # if self._optimizer_step_limit_exceeded: # logger.info("Train iteration limit exceeded") self.step += 1 steps_since_ckpt += 1 loss = self.fit_batch(batch, batch_unlabeled) self.avg_train_loss = self.update_average(loss, self.avg_train_loss) t.set_postfix(train_loss=self.avg_train_loss) log_step_to_file( f"Epoch {epoch}, Train Step {self.step}, Train Loss: {self.avg_train_loss}" ) if self.profiler is not None: self.profiler.step() if self.profiler.step_num > self.tot_prof_steps: logger.info("The profiler finished, training is stopped.") self.profiler.stop() quit() # Debug mode only runs a few batches if self.debug and self.step == self.debug_batches: break if self._should_save_intra_epoch_ckpt(last_ckpt_time, steps_since_ckpt): # Checkpointer class will handle running this on main only self._save_intra_epoch_ckpt() last_ckpt_time = time.time() steps_since_ckpt = 0 # Run train "on_stage_end" on all processes self.zero_grad(set_to_none=True) # flush gradients self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch) self.avg_train_loss = 0.0 self.step = 0 def fit_batch(self, batch, batch_unlabeled=None): """Fit one batch, override to do multiple updates. The default implementation depends on a few methods being defined with a particular behavior: * ``compute_forward()`` * ``compute_objectives()`` * ``optimizers_step()`` Also depends on having optimizers passed at initialization. Arguments --------- batch : list of torch.Tensors Batch of data to use for training. Default implementation assumes this batch has two elements: inputs and targets. Returns ------- detached loss """ amp = AMPConfig.from_name(self.precision) should_step = (self.step % self.grad_accumulation_factor) == 0 self.on_fit_batch_start(batch, should_step) with self.no_sync(not should_step): if self.use_amp: with torch.autocast( dtype=amp.dtype, device_type=torch.device(self.device).type ): outputs = self.compute_forward(batch, sb.Stage.TRAIN) loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN) # compute unlabalebed batch if batch_unlabeled: outputs_unlabeled = self.compute_forward( batch_unlabeled, sb.Stage.TRAIN ) loss_unlabeled = self.compute_loss_unlabeled( outputs_unlabeled, batch_unlabeled, sb.Stage.TRAIN ) total_loss_scaled = loss + loss_unlabeled else: total_loss_scaled = loss else: outputs = self.compute_forward(batch, sb.Stage.TRAIN) loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN) if batch_unlabeled: outputs_unlabeled = self.compute_forward( batch_unlabeled, sb.Stage.TRAIN ) loss_unlabeled = self.compute_loss_unlabeled( outputs_unlabeled, batch_unlabeled, sb.Stage.TRAIN ) total_loss_scaled = loss + loss_unlabeled else: total_loss_scaled = loss scaled_loss = self.scaler.scale( total_loss_scaled / self.grad_accumulation_factor ) self.check_loss_isfinite(scaled_loss) scaled_loss.backward() if should_step: self.optimizers_step() self.on_fit_batch_end(batch, outputs, loss, should_step) return loss.detach().cpu() def compute_loss_unlabeled(self, outputs, batch, stage): # fmt: off p_ctc_u, p_seq_u, wav_lens_u, hyps_u, logits_u = outputs # **Unsupervised Loss (Pseudo-labeling)** with TimerContext("Unsupervised Loss"): probs_u = torch.softmax(logits_u, dim=-1) max_probs, pseudo_labels = torch.max( probs_u, dim=-1 ) # Highest confidence predictions with TimerContext("Update Threshold"): avg_confidence_per_class = max_probs.mean(dim=1) thresholds = self.update_threshold(pseudo_labels, avg_confidence_per_class) with TimerContext("Dynamic Thresholding"): # Create mask on the same device as the input tensors mask = torch.zeros( batch.batchsize, dtype=torch.bool, device=self.device ) for i in range(batch.batchsize): mask[i] = all( avg_confidence_per_class[i] >= thresholds[idx.item()] for idx in pseudo_labels[i] ) with TimerContext("Compute Loss for Unlabeled Data"): # Generate pseudo-labels using argmax decoding pseudo_tokens_ctc = torch.argmax(p_ctc_u, dim=-1) # (batch, seq_len) pseudo_tokens_seq = torch.argmax(p_seq_u, dim=-1) # (batch, seq_len) # Initialize loss_u loss_u = torch.tensor(0.0, device=self.device, requires_grad=True) # Compute loss only for high-confidence samples if mask.sum() > 0: loss_seq_u = self.hparams.seq_cost( p_seq_u[mask], pseudo_tokens_seq[mask], length=[len(seq) for seq in pseudo_tokens_seq[mask]], ) target_lens = torch.tensor( [len(seq) for seq in pseudo_tokens_ctc[mask]], device=self.device, ) target_lens = target_lens.float() / target_lens.max() loss_ctc_u = self.hparams.ctc_cost( p_ctc_u[mask], pseudo_tokens_ctc[mask], wav_lens_u[mask], target_lens, ) loss_u = ( self.hparams.ctc_weight * loss_ctc_u + (1 - self.hparams.ctc_weight) * loss_seq_u ) # Scale losses # loss_l_scaled = self.scaler.scale( # loss_l / self.grad_accumulation_factor # ) return loss_u def transcribe(self, files): self.on_evaluate_start() self.modules.eval() with torch.no_grad(): if self.use_amp: amp = AMPConfig.from_name(self.eval_precision) with torch.autocast( dtype=amp.dtype, device_type=torch.device(self.device).type ): for file in files: output = self.transcribe_file(file) print(output) else: for file in files: output = self.transcribe_file(file) # yield self.merge_predictions(output) # if __name__ == "__main__": # transcriber = TransformerPretrainedASR( # "/home/hamidovaslon1_gmail_com/stt/stt_workers/uz-transformer/hyperparams.yaml" # ) # files = json.load(open("test/files.json")) # for output in transcriber.transcribe([f"test/{file}" for file in files]): # print(output)