File size: 4,959 Bytes
51bc847 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# Data parameters:
# With data_parallel batch_size is split into N jobs.
# With DDP batch_size is multiplied by N jobs.
batch_size: 6
test_batch_size: 2
# We remove utterances longer than 90s in the train/dev/test sets as
# longer sentences certainly correspond to "open microphones".
avoid_if_longer_than: 90.0
avoid_if_smaller_than: 0.0
dataloader_options:
batch_size: 6
num_workers: 6
shuffle: true
test_dataloader_options:
batch_size: 2
num_workers: 3
# Feature parameters:
sample_rate: 16000
feats_dim: 1024
# Training parameters:
number_of_epochs: 80
lr: 1
lr_wav2vec: 0.0001
annealing_factor: 0.8
annealing_factor_wav2vec: 0.9
improvement_threshold: 0.0025
improvement_threshold_wav2vec: 0.0025
patient: 0
patient_wav2vec: 0
sorting: random
# Model parameters:
activation: &id001 !name:torch.nn.LeakyReLU
dropout: 0.15
cnn_blocks: 0
rnn_layers: 0
dnn_blocks: 1
rnn_neurons: 0
dnn_neurons: 1024
# Wav2Vec parameters:
freeze: false
# Decoding parameters:
blank_index: 0
# Outputs:
output_neurons: 113
# ------ Functions and classes
epoch_counter: &id008 !new:speechbrain.utils.epoch_loop.EpochCounter
limit: 80
wav2vec: &id002 !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
source: microsoft/wavlm-large
output_norm: true
freeze: false
save_path: results/TARIC_SLU_wav2vec_wavLM_with_intent_criterion_a100_copie/1212/save/wav2vec.pt
dec: &id003 !new:speechbrain.lobes.models.VanillaNN.VanillaNN
input_shape: [null, null, 1024]
activation: *id001
dnn_blocks: 1
dnn_neurons: 1024
output_lin: &id004 !new:speechbrain.nnet.linear.Linear
input_size: 1024
n_neurons: 113
bias: true
softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: true
ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: 0
modules:
wav2vec: *id002
dec: *id003
output_lin: *id004
model: &id005 !new:torch.nn.ModuleList
- [*id003, *id004]
model_wav2vec: !new:torch.nn.ModuleList
- [*id002]
opt_class: !name:torch.optim.Adadelta
lr: 1
rho: 0.95
eps: 1.e-8
opt_class_wav2vec: !name:torch.optim.Adam
lr: 0.0001
lr_annealing: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: 1
improvement_threshold: 0.0025
annealing_factor: 0.8
patient: 0
lr_annealing_wav2vec: &id007 !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: 0.0001
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 0
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: results/TARIC_SLU_wav2vec_wavLM_with_intent_criterion_a100_copie/1212/save
recoverables:
model: *id005
wav2vec: *id002
lr_annealing: *id006
lr_annealing_wav2vec: *id007
counter: *id008
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: results/TARIC_SLU_wav2vec_wavLM_with_intent_criterion_a100_copie/1212/train_log.txt
ctc_computer: !name:speechbrain.utils.metric_stats.MetricStats
metric: !name:speechbrain.nnet.losses.ctc_loss
blank_index: 0
reduction: batch
error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
merge_tokens: true
coer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
extract_concepts_values: true
keep_values: false
tag_in: <
tag_out: >
cver_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
extract_concepts_values: true
keep_values: true
tag_in: <
tag_out: >
tokenizer: !new:speechbrain.dataio.encoder.CTCTextEncoder
pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
loadables:
model: !ref <model>
wav2vec: !ref <wav2vec>
tokenizer: !ref <tokenizer>
paths:
model: !ref /content/sample_data/SLU/model.cpkt
wav2vec: !ref /content/sample_data/SLU/wav2vec.cpkt
tokenizer: !ref /content/sample_data/SLU/label_encoder.txt
decoding_function: !name:speechbrain.decoders.ctc_greedy_decode
blank_id: 0
# Tag list:
tag_list: <politeness>, <directives_query>, <directives_answer>, <age>, <age_req>,
<age_ticket>, <an>, <answer>, <arrival_time>, <card_price>, <card_type>, <city>,
<city_name_arrival>, <city_name_before>, <city_name_departure>, <city_name_direction>,
<class_number>, <class_type>, <command_task>, <comparatif_age>, <comparatif_distance>,
<comparatif_price>, <comparatif_time>, <coreference_city>, <coreference_departure>,
<date>, <day>, <departure_time>, <discount_gain>, <discount_pourcent>, <duration>,
<duration_req>, <existance>, <existance_req>, <hour_req>, <money_exchange>, <month>,
<negation>, <number>, <number_class>, <number_of_train>, <number_req>, <object>,
<option>, <other_transport>, <part_price>, <part_time>, <period_day>, <period_year>,
<person_name>, <price_req>, <rang>, <ref_object>, <ref_person>, <ref_time>, <relative_day>,
<relative_time>, <state>, <tarif>, <task>, <ticket_number>, <ticket_price>, <ticket_type>,
<time>, <train_type>
|