From d98f74da34b78c0587b8f9d4ed59f4ed2aa96afd Mon Sep 17 00:00:00 2001 From: Judyxujj Date: Tue, 11 Jun 2024 11:44:18 +0200 Subject: [PATCH] update --- .../LBS-960/Aux-Loss/returnn.config | 788 ++++++++++++++++ .../LBS-960/Iterative_Zero_Out/returnn.config | 850 ++++++++++++++++++ .../LBS-960/Simple_Top_k/returnn.config | 798 ++++++++++++++++ 2024-dynamic-encoder-size/README.md | 15 + 4 files changed, 2451 insertions(+) create mode 100644 2024-dynamic-encoder-size/LBS-960/Aux-Loss/returnn.config create mode 100644 2024-dynamic-encoder-size/LBS-960/Iterative_Zero_Out/returnn.config create mode 100644 2024-dynamic-encoder-size/LBS-960/Simple_Top_k/returnn.config diff --git a/2024-dynamic-encoder-size/LBS-960/Aux-Loss/returnn.config b/2024-dynamic-encoder-size/LBS-960/Aux-Loss/returnn.config new file mode 100644 index 00000000..04a07d09 --- /dev/null +++ b/2024-dynamic-encoder-size/LBS-960/Aux-Loss/returnn.config @@ -0,0 +1,788 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2400000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +debug_print_layer_output_template = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "partition_epoch": 1, + "seq_ordering": "laplace:.1000", + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.KErFrKsP3fTh/output/audio.hdf", + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.Clwnntg2nopq/output/audio.hdf", + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.h7DH1ILPAElF/output/targets.hdf", + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.vESciFN5It1u/output/targets.hdf", + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 4.1e-06, + 5.199999999999999e-06, + 6.299999999999999e-06, + 7.399999999999999e-06, + 8.499999999999998e-06, + 9.6e-06, + 1.07e-05, + 1.1799999999999999e-05, + 1.2899999999999998e-05, + 1.3999999999999998e-05, + 1.5099999999999998e-05, + 1.6199999999999997e-05, + 1.7299999999999997e-05, + 1.8399999999999997e-05, + 1.9499999999999996e-05, + 2.0599999999999996e-05, + 2.1699999999999996e-05, + 2.2799999999999995e-05, + 2.3899999999999995e-05, + 2.4999999999999994e-05, + 2.6099999999999994e-05, + 2.7199999999999994e-05, + 2.8299999999999993e-05, + 2.9399999999999996e-05, + 3.0499999999999996e-05, + 3.1599999999999996e-05, + 3.2699999999999995e-05, + 3.3799999999999995e-05, + 3.4899999999999995e-05, + 3.5999999999999994e-05, + 3.7099999999999994e-05, + 3.8199999999999993e-05, + 3.929999999999999e-05, + 4.039999999999999e-05, + 4.149999999999999e-05, + 4.259999999999999e-05, + 4.369999999999999e-05, + 4.479999999999999e-05, + 4.589999999999999e-05, + 4.699999999999999e-05, + 4.809999999999999e-05, + 4.919999999999999e-05, + 5.029999999999999e-05, + 5.139999999999999e-05, + 5.249999999999999e-05, + 5.359999999999999e-05, + 5.469999999999999e-05, + 5.5799999999999994e-05, + 5.6899999999999994e-05, + 5.7999999999999994e-05, + 5.909999999999999e-05, + 6.019999999999999e-05, + 6.13e-05, + 6.24e-05, + 6.35e-05, + 6.46e-05, + 6.57e-05, + 6.68e-05, + 6.79e-05, + 6.9e-05, + 7.01e-05, + 7.12e-05, + 7.23e-05, + 7.34e-05, + 7.45e-05, + 7.56e-05, + 7.67e-05, + 7.78e-05, + 7.89e-05, + 7.999999999999999e-05, + 8.109999999999999e-05, + 8.219999999999999e-05, + 8.329999999999999e-05, + 8.439999999999999e-05, + 8.549999999999999e-05, + 8.659999999999999e-05, + 8.769999999999999e-05, + 8.879999999999999e-05, + 8.989999999999999e-05, + 9.099999999999999e-05, + 9.209999999999999e-05, + 9.319999999999999e-05, + 9.429999999999999e-05, + 9.539999999999999e-05, + 9.649999999999999e-05, + 9.759999999999999e-05, + 9.869999999999999e-05, + 9.979999999999999e-05, + 0.00010089999999999999, + 0.00010199999999999999, + 0.00010309999999999999, + 0.00010419999999999998, + 0.00010529999999999998, + 0.00010639999999999998, + 0.00010749999999999998, + 0.0001086, + 0.0001097, + 0.0001108, + 0.0001119, + 0.000113, + 0.0001141, + 0.0001152, + 0.0001163, + 0.0001174, + 0.0001185, + 0.0001196, + 0.00012069999999999999, + 0.00012179999999999999, + 0.00012289999999999998, + 0.00012399999999999998, + 0.00012509999999999998, + 0.00012619999999999998, + 0.00012729999999999998, + 0.00012839999999999998, + 0.00012949999999999998, + 0.00013059999999999998, + 0.00013169999999999998, + 0.00013279999999999998, + 0.00013389999999999997, + 0.00013499999999999997, + 0.00013609999999999997, + 0.00013719999999999997, + 0.00013829999999999997, + 0.00013939999999999997, + 0.00014049999999999997, + 0.00014159999999999997, + 0.00014269999999999997, + 0.00014379999999999997, + 0.00014489999999999997, + 0.00014599999999999997, + 0.00014709999999999997, + 0.00014819999999999997, + 0.00014929999999999997, + 0.00015039999999999997, + 0.00015149999999999997, + 0.00015259999999999997, + 0.00015369999999999997, + 0.00015479999999999997, + 0.00015589999999999997, + 0.00015699999999999997, + 0.00015809999999999997, + 0.00015919999999999997, + 0.00016029999999999997, + 0.00016139999999999997, + 0.00016249999999999997, + 0.00016359999999999997, + 0.00016469999999999996, + 0.00016579999999999996, + 0.00016689999999999996, + 0.00016799999999999996, + 0.00016909999999999996, + 0.00017019999999999996, + 0.00017129999999999996, + 0.00017239999999999996, + 0.00017349999999999996, + 0.00017459999999999996, + 0.00017569999999999996, + 0.00017679999999999996, + 0.00017789999999999996, + 0.00017899999999999996, + 0.00018009999999999996, + 0.00018119999999999996, + 0.00018229999999999996, + 0.00018339999999999996, + 0.00018449999999999996, + 0.00018559999999999996, + 0.00018669999999999996, + 0.00018779999999999996, + 0.00018889999999999996, + 0.00018999999999999996, + 0.00019109999999999996, + 0.00019219999999999996, + 0.00019329999999999996, + 0.00019439999999999995, + 0.00019549999999999995, + 0.00019659999999999995, + 0.00019769999999999995, + 0.00019879999999999995, + 0.00019989999999999995, + 0.00020099999999999995, + 0.00020209999999999995, + 0.00020319999999999995, + 0.00020429999999999995, + 0.00020539999999999995, + 0.00020649999999999995, + 0.00020759999999999995, + 0.00020869999999999995, + 0.00020979999999999995, + 0.00021089999999999995, + 0.00021199999999999995, + 0.00021309999999999995, + 0.00021419999999999998, + 0.00021529999999999997, + 0.00021639999999999997, + 0.00021749999999999997, + 0.00021859999999999997, + 0.00021969999999999997, + 0.00022079999999999997, + 0.00022189999999999997, + 0.00022299999999999997, + 0.00022409999999999997, + 0.00022519999999999997, + 0.00022629999999999997, + 0.00022739999999999997, + 0.00022849999999999997, + 0.00022959999999999997, + 0.00023069999999999997, + 0.00023179999999999997, + 0.00023289999999999997, + 0.00023399999999999997, + 0.00023509999999999997, + 0.00023619999999999997, + 0.00023729999999999997, + 0.00023839999999999997, + 0.00023949999999999997, + 0.00024059999999999997, + 0.00024169999999999997, + 0.00024279999999999997, + 0.00024389999999999997, + 0.000245, + 0.00024609999999999996, + 0.0002472, + 0.00024829999999999996, + 0.0002494, + 0.00025049999999999996, + 0.0002516, + 0.00025269999999999996, + 0.0002538, + 0.00025489999999999996, + 0.000256, + 0.00025709999999999996, + 0.0002582, + 0.00025929999999999996, + 0.0002604, + 0.00026149999999999996, + 0.0002626, + 0.00026369999999999996, + 0.0002648, + 0.00026589999999999996, + 0.000267, + 0.00026809999999999996, + 0.0002692, + 0.00027029999999999996, + 0.0002714, + 0.00027249999999999996, + 0.0002736, + 0.00027469999999999996, + 0.0002758, + 0.00027689999999999995, + 0.000278, + 0.00027909999999999995, + 0.0002802, + 0.00028129999999999995, + 0.0002824, + 0.00028349999999999995, + 0.0002846, + 0.0002857, + 0.0002868, + 0.0002879, + 0.000289, + 0.0002901, + 0.0002912, + 0.0002923, + 0.0002934, + 0.0002945, + 0.0002956, + 0.0002967, + 0.0002978, + 0.0002989, + 0.0003, + 0.00029889999999999995, + 0.0002978, + 0.00029669999999999995, + 0.0002956, + 0.00029449999999999995, + 0.0002934, + 0.00029229999999999995, + 0.0002912, + 0.00029009999999999995, + 0.000289, + 0.00028789999999999995, + 0.0002868, + 0.00028569999999999995, + 0.0002846, + 0.00028349999999999995, + 0.0002824, + 0.00028129999999999995, + 0.0002802, + 0.00027909999999999995, + 0.000278, + 0.00027689999999999995, + 0.0002758, + 0.00027469999999999996, + 0.0002736, + 0.00027249999999999996, + 0.0002714, + 0.00027029999999999996, + 0.0002692, + 0.00026809999999999996, + 0.000267, + 0.00026589999999999996, + 0.0002648, + 0.00026369999999999996, + 0.0002626, + 0.00026149999999999996, + 0.0002604, + 0.00025929999999999996, + 0.0002582, + 0.00025709999999999996, + 0.000256, + 0.00025489999999999996, + 0.0002538, + 0.00025269999999999996, + 0.0002516, + 0.00025049999999999996, + 0.0002494, + 0.00024829999999999996, + 0.0002472, + 0.00024609999999999996, + 0.000245, + 0.00024389999999999997, + 0.0002428, + 0.00024169999999999997, + 0.0002406, + 0.00023949999999999997, + 0.0002384, + 0.00023729999999999997, + 0.0002362, + 0.00023509999999999997, + 0.000234, + 0.00023289999999999997, + 0.0002318, + 0.00023069999999999997, + 0.0002296, + 0.00022849999999999997, + 0.0002274, + 0.00022629999999999997, + 0.0002252, + 0.00022409999999999997, + 0.000223, + 0.00022189999999999997, + 0.0002208, + 0.00021969999999999997, + 0.0002186, + 0.00021749999999999997, + 0.0002164, + 0.00021529999999999997, + 0.0002142, + 0.00021309999999999998, + 0.000212, + 0.00021089999999999998, + 0.0002098, + 0.00020869999999999998, + 0.0002076, + 0.00020649999999999998, + 0.0002054, + 0.00020429999999999998, + 0.0002032, + 0.00020209999999999998, + 0.000201, + 0.00019989999999999998, + 0.0001988, + 0.00019769999999999998, + 0.0001966, + 0.00019549999999999998, + 0.00019439999999999998, + 0.00019329999999999998, + 0.00019219999999999998, + 0.00019109999999999998, + 0.00018999999999999998, + 0.00018889999999999998, + 0.00018779999999999998, + 0.00018669999999999998, + 0.00018559999999999998, + 0.00018449999999999999, + 0.00018339999999999999, + 0.00018229999999999999, + 0.00018119999999999999, + 0.0001801, + 0.000179, + 0.0001779, + 0.0001768, + 0.0001757, + 0.0001746, + 0.0001735, + 0.0001724, + 0.0001713, + 0.0001702, + 0.0001691, + 0.000168, + 0.0001669, + 0.0001658, + 0.0001647, + 0.0001636, + 0.0001625, + 0.0001614, + 0.0001603, + 0.0001592, + 0.0001581, + 0.000157, + 0.0001559, + 0.0001548, + 0.0001537, + 0.0001526, + 0.0001515, + 0.0001504, + 0.0001493, + 0.0001482, + 0.0001471, + 0.000146, + 0.0001449, + 0.0001438, + 0.0001427, + 0.0001416, + 0.0001405, + 0.0001394, + 0.0001383, + 0.0001372, + 0.0001361, + 0.000135, + 0.0001339, + 0.0001328, + 0.0001317, + 0.0001306, + 0.0001295, + 0.0001284, + 0.0001273, + 0.0001262, + 0.0001251, + 0.000124, + 0.0001229, + 0.0001218, + 0.0001207, + 0.00011960000000000001, + 0.00011850000000000001, + 0.00011740000000000001, + 0.00011630000000000001, + 0.00011520000000000001, + 0.00011410000000000001, + 0.00011300000000000001, + 0.00011190000000000001, + 0.00011080000000000001, + 0.00010970000000000001, + 0.00010860000000000001, + 0.00010750000000000001, + 0.00010640000000000001, + 0.00010530000000000001, + 0.00010420000000000001, + 0.00010310000000000001, + 0.00010200000000000001, + 0.00010090000000000001, + 9.980000000000001e-05, + 9.870000000000001e-05, + 9.760000000000001e-05, + 9.650000000000001e-05, + 9.540000000000001e-05, + 9.430000000000002e-05, + 9.320000000000002e-05, + 9.210000000000002e-05, + 9.100000000000002e-05, + 8.990000000000002e-05, + 8.879999999999999e-05, + 8.769999999999999e-05, + 8.659999999999999e-05, + 8.549999999999999e-05, + 8.439999999999999e-05, + 8.329999999999999e-05, + 8.219999999999999e-05, + 8.109999999999999e-05, + 7.999999999999999e-05, + 7.89e-05, + 7.78e-05, + 7.67e-05, + 7.56e-05, + 7.45e-05, + 7.34e-05, + 7.23e-05, + 7.12e-05, + 7.01e-05, + 6.9e-05, + 6.79e-05, + 6.68e-05, + 6.57e-05, + 6.46e-05, + 6.35e-05, + 6.24e-05, + 6.13e-05, + 6.02e-05, + 5.91e-05, + 5.8e-05, + 5.69e-05, + 5.58e-05, + 5.470000000000003e-05, + 5.36e-05, + 5.250000000000003e-05, + 5.14e-05, + 5.030000000000003e-05, + 4.92e-05, + 4.810000000000003e-05, + 4.7000000000000004e-05, + 4.590000000000003e-05, + 4.4800000000000005e-05, + 4.370000000000003e-05, + 4.2600000000000005e-05, + 4.150000000000003e-05, + 4.0400000000000006e-05, + 3.9300000000000034e-05, + 3.820000000000001e-05, + 3.7100000000000034e-05, + 3.600000000000001e-05, + 3.4900000000000035e-05, + 3.380000000000001e-05, + 3.2700000000000036e-05, + 3.160000000000001e-05, + 3.0500000000000037e-05, + 2.940000000000001e-05, + 2.8300000000000037e-05, + 2.720000000000001e-05, + 2.6100000000000038e-05, + 2.500000000000001e-05, + 2.390000000000004e-05, + 2.2800000000000012e-05, + 2.170000000000004e-05, + 2.0600000000000013e-05, + 1.950000000000004e-05, + 1.8400000000000014e-05, + 1.7299999999999987e-05, + 1.6200000000000014e-05, + 1.5099999999999988e-05, + 1.4000000000000015e-05, + 1.2899999999999988e-05, + 1.1800000000000016e-05, + 1.0699999999999989e-05, + 9.600000000000017e-06, + 8.49999999999999e-06, + 7.400000000000017e-06, + 6.2999999999999905e-06, + 5.200000000000018e-06, + 4.099999999999991e-06, + 2.9999999999999997e-06, + 2.9501666666666663e-06, + 2.900333333333333e-06, + 2.8504999999999996e-06, + 2.8006666666666663e-06, + 2.750833333333333e-06, + 2.7009999999999996e-06, + 2.6511666666666663e-06, + 2.601333333333333e-06, + 2.5514999999999996e-06, + 2.5016666666666663e-06, + 2.451833333333333e-06, + 2.4019999999999996e-06, + 2.3521666666666663e-06, + 2.302333333333333e-06, + 2.2524999999999996e-06, + 2.2026666666666663e-06, + 2.152833333333333e-06, + 2.1029999999999996e-06, + 2.0531666666666663e-06, + 2.0033333333333334e-06, + 1.9535e-06, + 1.9036666666666665e-06, + 1.8538333333333331e-06, + 1.8039999999999998e-06, + 1.7541666666666665e-06, + 1.7043333333333331e-06, + 1.6544999999999998e-06, + 1.6046666666666665e-06, + 1.5548333333333331e-06, + 1.5049999999999998e-06, + 1.4551666666666664e-06, + 1.4053333333333331e-06, + 1.3554999999999998e-06, + 1.3056666666666664e-06, + 1.255833333333333e-06, + 1.2059999999999998e-06, + 1.1561666666666664e-06, + 1.106333333333333e-06, + 1.0565e-06, + 1.0066666666666666e-06, + 9.568333333333333e-07, + 9.07e-07, + 8.571666666666666e-07, + 8.073333333333333e-07, + 7.575e-07, + 7.076666666666666e-07, + 6.578333333333333e-07, + 6.079999999999999e-07, + 5.581666666666666e-07, + 5.083333333333333e-07, + 4.5849999999999993e-07, + 4.086666666666666e-07, + 3.5883333333333326e-07, + 3.089999999999999e-07, + 2.591666666666666e-07, + 2.0933333333333325e-07, + 1.594999999999999e-07, + 1.0966666666666658e-07, + 5.983333333333324e-08, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 60 +model = "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/training/ReturnnTrainingJob.wdbAMfELLBp2/output/models/epoch" +num_epochs = 600 +num_inputs = 1 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "partition_epoch": 20, + "seq_ordering": "laplace:.1000", + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.VZM5dHZhqlnJ/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.SYt8A5fOy2ta/output/targets.hdf" + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert( + 0, "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/recipe" +) +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.baseline.conformer_ctc_d_model_512_num_layers_12_new_frontend_raw_wave_with_aux_loss import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.baseline.conformer_ctc_d_model_512_num_layers_12_new_frontend_raw_wave_with_aux_loss import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer.conformer_v1 import ConformerEncoderV1Config +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer.conformer_v1 import ConformerBlockV1Config +from i6_models.parts.conformer.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer.convolution import ConformerConvolutionV1Config +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=5, + freq_mask_max_size=8, + ), + conformer_cfg=ConformerEncoderV1Config( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=512, + ), + ), + block_cfg=ConformerBlockV1Config( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, num_att_heads=8, att_weights_dropout=0.1, dropout=0.1 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC(512), + ), + ), + ), + target_size=79, + aux_losses={"4": 0.3, "8": 0.3, "12": 1}, + recog_num_layer=12, +) +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.baseline.conformer_ctc_d_model_512_num_layers_12_new_frontend_raw_wave_with_aux_loss import ( + train_step, +) diff --git a/2024-dynamic-encoder-size/LBS-960/Iterative_Zero_Out/returnn.config b/2024-dynamic-encoder-size/LBS-960/Iterative_Zero_Out/returnn.config new file mode 100644 index 00000000..ea7536b6 --- /dev/null +++ b/2024-dynamic-encoder-size/LBS-960/Iterative_Zero_Out/returnn.config @@ -0,0 +1,850 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2400000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +debug_print_layer_output_template = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "partition_epoch": 1, + "seq_ordering": "laplace:.1000", + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.KErFrKsP3fTh/output/audio.hdf", + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.Clwnntg2nopq/output/audio.hdf", + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.h7DH1ILPAElF/output/targets.hdf", + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.vESciFN5It1u/output/targets.hdf", + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 5.466666666666666e-06, + 6.933333333333334e-06, + 8.400000000000001e-06, + 9.866666666666668e-06, + 1.1333333333333336e-05, + 1.28e-05, + 1.4266666666666667e-05, + 1.5733333333333334e-05, + 1.72e-05, + 1.866666666666667e-05, + 2.0133333333333336e-05, + 2.16e-05, + 2.3066666666666667e-05, + 2.4533333333333334e-05, + 2.6000000000000002e-05, + 2.746666666666667e-05, + 2.8933333333333336e-05, + 3.0400000000000004e-05, + 3.186666666666667e-05, + 3.333333333333334e-05, + 3.4800000000000006e-05, + 3.6266666666666676e-05, + 3.773333333333334e-05, + 3.9200000000000004e-05, + 4.0666666666666675e-05, + 4.213333333333334e-05, + 4.360000000000001e-05, + 4.506666666666667e-05, + 4.6533333333333344e-05, + 4.800000000000001e-05, + 4.946666666666668e-05, + 5.093333333333334e-05, + 5.2400000000000007e-05, + 5.386666666666668e-05, + 5.533333333333334e-05, + 5.680000000000001e-05, + 5.8266666666666676e-05, + 5.9733333333333346e-05, + 6.120000000000001e-05, + 6.266666666666668e-05, + 6.413333333333334e-05, + 6.560000000000001e-05, + 6.706666666666668e-05, + 6.853333333333335e-05, + 7.000000000000001e-05, + 7.146666666666668e-05, + 7.293333333333335e-05, + 7.44e-05, + 7.586666666666668e-05, + 7.733333333333335e-05, + 7.880000000000002e-05, + 8.026666666666668e-05, + 8.173333333333335e-05, + 8.320000000000002e-05, + 8.466666666666669e-05, + 8.613333333333334e-05, + 8.760000000000002e-05, + 8.906666666666669e-05, + 9.053333333333334e-05, + 9.200000000000001e-05, + 9.346666666666668e-05, + 9.493333333333336e-05, + 9.640000000000001e-05, + 9.786666666666668e-05, + 9.933333333333335e-05, + 0.00010080000000000001, + 0.00010226666666666668, + 0.00010373333333333335, + 0.00010520000000000002, + 0.00010666666666666668, + 0.00010813333333333335, + 0.00010960000000000002, + 0.00011106666666666668, + 0.00011253333333333335, + 0.00011400000000000002, + 0.00011546666666666669, + 0.00011693333333333335, + 0.00011840000000000002, + 0.00011986666666666669, + 0.00012133333333333336, + 0.0001228, + 0.0001242666666666667, + 0.00012573333333333334, + 0.0001272, + 0.00012866666666666669, + 0.00013013333333333334, + 0.0001316, + 0.00013306666666666668, + 0.00013453333333333334, + 0.000136, + 0.00013746666666666668, + 0.00013893333333333334, + 0.0001404, + 0.00014186666666666668, + 0.00014333333333333334, + 0.0001448, + 0.00014626666666666668, + 0.00014773333333333334, + 0.00014920000000000002, + 0.00015066666666666668, + 0.00015213333333333334, + 0.00015360000000000002, + 0.00015506666666666668, + 0.00015653333333333333, + 0.00015800000000000002, + 0.00015946666666666668, + 0.00016093333333333333, + 0.00016240000000000002, + 0.00016386666666666667, + 0.00016533333333333336, + 0.00016680000000000002, + 0.00016826666666666667, + 0.00016973333333333336, + 0.00017120000000000001, + 0.00017266666666666667, + 0.00017413333333333336, + 0.0001756, + 0.00017706666666666667, + 0.00017853333333333335, + 0.00018, + 0.00018146666666666667, + 0.00018293333333333335, + 0.0001844, + 0.0001858666666666667, + 0.00018733333333333335, + 0.0001888, + 0.0001902666666666667, + 0.00019173333333333335, + 0.0001932, + 0.0001946666666666667, + 0.00019613333333333335, + 0.0001976, + 0.0001990666666666667, + 0.00020053333333333335, + 0.00020200000000000003, + 0.0002034666666666667, + 0.00020493333333333335, + 0.00020640000000000003, + 0.0002078666666666667, + 0.00020933333333333334, + 0.00021080000000000003, + 0.00021226666666666669, + 0.00021373333333333334, + 0.00021520000000000003, + 0.00021666666666666668, + 0.00021813333333333334, + 0.00021960000000000003, + 0.00022106666666666668, + 0.00022253333333333337, + 0.00022400000000000002, + 0.00022546666666666668, + 0.00022693333333333337, + 0.00022840000000000002, + 0.00022986666666666668, + 0.00023133333333333336, + 0.00023280000000000002, + 0.00023426666666666668, + 0.00023573333333333336, + 0.00023720000000000002, + 0.0002386666666666667, + 0.00024013333333333336, + 0.00024160000000000002, + 0.0002430666666666667, + 0.0002445333333333334, + 0.000246, + 0.0002474666666666667, + 0.0002489333333333334, + 0.0002504, + 0.0002518666666666667, + 0.0002533333333333334, + 0.0002548, + 0.0002562666666666667, + 0.0002577333333333334, + 0.0002592, + 0.0002606666666666667, + 0.0002621333333333334, + 0.0002636, + 0.0002650666666666667, + 0.0002665333333333334, + 0.000268, + 0.0002694666666666667, + 0.0002709333333333334, + 0.0002724, + 0.0002738666666666667, + 0.0002753333333333334, + 0.0002768, + 0.0002782666666666667, + 0.0002797333333333334, + 0.0002812, + 0.0002826666666666667, + 0.0002841333333333334, + 0.0002856, + 0.0002870666666666667, + 0.00028853333333333337, + 0.00029000000000000006, + 0.0002914666666666667, + 0.00029293333333333337, + 0.00029440000000000005, + 0.0002958666666666667, + 0.00029733333333333337, + 0.00029880000000000005, + 0.0003002666666666667, + 0.00030173333333333337, + 0.00030320000000000005, + 0.0003046666666666667, + 0.00030613333333333337, + 0.00030760000000000005, + 0.0003090666666666667, + 0.00031053333333333336, + 0.00031200000000000005, + 0.0003134666666666667, + 0.00031493333333333336, + 0.00031640000000000005, + 0.0003178666666666667, + 0.00031933333333333336, + 0.00032080000000000005, + 0.0003222666666666667, + 0.00032373333333333336, + 0.00032520000000000004, + 0.00032666666666666673, + 0.00032813333333333336, + 0.00032960000000000004, + 0.0003310666666666667, + 0.00033253333333333336, + 0.00033400000000000004, + 0.0003354666666666667, + 0.00033693333333333336, + 0.00033840000000000004, + 0.0003398666666666667, + 0.00034133333333333335, + 0.00034280000000000004, + 0.0003442666666666667, + 0.00034573333333333335, + 0.00034720000000000004, + 0.0003486666666666667, + 0.00035013333333333335, + 0.00035160000000000004, + 0.0003530666666666667, + 0.00035453333333333335, + 0.00035600000000000003, + 0.0003574666666666667, + 0.00035893333333333335, + 0.00036040000000000003, + 0.0003618666666666667, + 0.0003633333333333334, + 0.00036480000000000003, + 0.0003662666666666667, + 0.0003677333333333334, + 0.00036920000000000003, + 0.0003706666666666667, + 0.0003721333333333334, + 0.00037360000000000003, + 0.0003750666666666667, + 0.0003765333333333334, + 0.000378, + 0.0003794666666666667, + 0.0003809333333333334, + 0.0003824, + 0.0003838666666666667, + 0.0003853333333333334, + 0.0003868, + 0.0003882666666666667, + 0.0003897333333333334, + 0.0003912, + 0.0003926666666666667, + 0.0003941333333333334, + 0.0003956, + 0.0003970666666666667, + 0.0003985333333333334, + 0.0004, + 0.00039853333333333333, + 0.0003970666666666667, + 0.0003956, + 0.00039413333333333334, + 0.0003926666666666667, + 0.0003912, + 0.00038973333333333334, + 0.0003882666666666667, + 0.0003868, + 0.00038533333333333334, + 0.0003838666666666667, + 0.0003824, + 0.00038093333333333334, + 0.0003794666666666667, + 0.000378, + 0.00037653333333333334, + 0.00037506666666666666, + 0.00037360000000000003, + 0.00037213333333333334, + 0.00037066666666666666, + 0.00036920000000000003, + 0.00036773333333333335, + 0.00036626666666666666, + 0.00036480000000000003, + 0.00036333333333333335, + 0.00036186666666666666, + 0.00036040000000000003, + 0.00035893333333333335, + 0.00035746666666666666, + 0.00035600000000000003, + 0.00035453333333333335, + 0.00035306666666666667, + 0.00035160000000000004, + 0.00035013333333333335, + 0.00034866666666666667, + 0.0003472, + 0.00034573333333333335, + 0.00034426666666666667, + 0.0003428, + 0.00034133333333333335, + 0.00033986666666666667, + 0.0003384, + 0.00033693333333333336, + 0.00033546666666666667, + 0.000334, + 0.00033253333333333336, + 0.00033106666666666667, + 0.0003296, + 0.00032813333333333336, + 0.0003266666666666667, + 0.0003252, + 0.00032373333333333336, + 0.0003222666666666667, + 0.0003208, + 0.00031933333333333336, + 0.0003178666666666667, + 0.0003164, + 0.00031493333333333336, + 0.0003134666666666667, + 0.000312, + 0.00031053333333333336, + 0.0003090666666666667, + 0.0003076, + 0.00030613333333333337, + 0.0003046666666666667, + 0.0003032, + 0.00030173333333333337, + 0.0003002666666666667, + 0.0002988, + 0.00029733333333333337, + 0.0002958666666666667, + 0.0002944, + 0.00029293333333333337, + 0.0002914666666666667, + 0.00029, + 0.0002885333333333333, + 0.0002870666666666667, + 0.0002856, + 0.0002841333333333333, + 0.00028266666666666663, + 0.0002812, + 0.0002797333333333333, + 0.00027826666666666664, + 0.0002768, + 0.0002753333333333333, + 0.00027386666666666664, + 0.0002724, + 0.0002709333333333333, + 0.00026946666666666664, + 0.000268, + 0.0002665333333333333, + 0.00026506666666666664, + 0.0002636, + 0.0002621333333333333, + 0.00026066666666666664, + 0.0002592, + 0.00025773333333333333, + 0.00025626666666666664, + 0.0002548, + 0.00025333333333333333, + 0.00025186666666666664, + 0.0002504, + 0.00024893333333333333, + 0.00024746666666666665, + 0.000246, + 0.00024453333333333333, + 0.00024306666666666668, + 0.0002416, + 0.00024013333333333333, + 0.00023866666666666665, + 0.0002372, + 0.00023573333333333334, + 0.00023426666666666665, + 0.0002328, + 0.00023133333333333334, + 0.00022986666666666665, + 0.0002284, + 0.00022693333333333334, + 0.00022546666666666665, + 0.000224, + 0.00022253333333333334, + 0.00022106666666666666, + 0.0002196, + 0.00021813333333333331, + 0.00021666666666666666, + 0.0002152, + 0.00021373333333333332, + 0.00021226666666666666, + 0.0002108, + 0.00020933333333333332, + 0.00020786666666666666, + 0.0002064, + 0.00020493333333333332, + 0.00020346666666666666, + 0.00020199999999999998, + 0.00020053333333333332, + 0.00019906666666666666, + 0.00019759999999999998, + 0.00019613333333333332, + 0.00019466666666666666, + 0.00019319999999999998, + 0.00019173333333333332, + 0.00019026666666666667, + 0.00018879999999999998, + 0.00018733333333333332, + 0.00018586666666666667, + 0.00018439999999999998, + 0.00018293333333333333, + 0.00018146666666666664, + 0.00017999999999999998, + 0.00017853333333333333, + 0.00017706666666666664, + 0.00017559999999999999, + 0.00017413333333333333, + 0.00017266666666666664, + 0.0001712, + 0.00016973333333333333, + 0.00016826666666666665, + 0.0001668, + 0.0001653333333333333, + 0.00016386666666666665, + 0.0001624, + 0.0001609333333333333, + 0.00015946666666666665, + 0.000158, + 0.0001565333333333333, + 0.00015506666666666662, + 0.0001536, + 0.0001521333333333333, + 0.00015066666666666662, + 0.0001492, + 0.0001477333333333333, + 0.00014626666666666663, + 0.0001448, + 0.0001433333333333333, + 0.00014186666666666663, + 0.0001404, + 0.0001389333333333333, + 0.00013746666666666663, + 0.000136, + 0.00013453333333333331, + 0.00013306666666666663, + 0.0001316, + 0.00013013333333333332, + 0.00012866666666666663, + 0.0001272, + 0.00012573333333333332, + 0.00012426666666666663, + 0.0001228, + 0.00012133333333333332, + 0.00011986666666666663, + 0.0001184, + 0.00011693333333333332, + 0.00011546666666666664, + 0.00011399999999999995, + 0.00011253333333333332, + 0.00011106666666666664, + 0.00010959999999999995, + 0.00010813333333333332, + 0.00010666666666666664, + 0.00010519999999999996, + 0.00010373333333333332, + 0.00010226666666666664, + 0.00010079999999999996, + 9.933333333333333e-05, + 9.786666666666664e-05, + 9.639999999999996e-05, + 9.493333333333333e-05, + 9.346666666666664e-05, + 9.199999999999996e-05, + 9.053333333333333e-05, + 8.906666666666665e-05, + 8.759999999999996e-05, + 8.613333333333333e-05, + 8.466666666666665e-05, + 8.319999999999996e-05, + 8.173333333333333e-05, + 8.026666666666665e-05, + 7.879999999999996e-05, + 7.733333333333328e-05, + 7.586666666666665e-05, + 7.439999999999997e-05, + 7.293333333333328e-05, + 7.146666666666665e-05, + 6.999999999999997e-05, + 6.853333333333328e-05, + 6.706666666666665e-05, + 6.559999999999997e-05, + 6.413333333333328e-05, + 6.266666666666665e-05, + 6.119999999999997e-05, + 5.9733333333333285e-05, + 5.8266666666666655e-05, + 5.679999999999997e-05, + 5.533333333333329e-05, + 5.386666666666666e-05, + 5.239999999999997e-05, + 5.093333333333329e-05, + 4.946666666666666e-05, + 4.7999999999999974e-05, + 4.653333333333329e-05, + 4.506666666666666e-05, + 4.3599999999999976e-05, + 4.213333333333329e-05, + 4.066666666666661e-05, + 3.919999999999998e-05, + 3.773333333333329e-05, + 3.626666666666661e-05, + 3.479999999999998e-05, + 3.3333333333333294e-05, + 3.186666666666661e-05, + 3.039999999999998e-05, + 2.8933333333333296e-05, + 2.746666666666661e-05, + 2.599999999999998e-05, + 2.4533333333333297e-05, + 2.3066666666666613e-05, + 2.1599999999999983e-05, + 2.01333333333333e-05, + 1.8666666666666614e-05, + 1.7199999999999984e-05, + 1.57333333333333e-05, + 1.4266666666666616e-05, + 1.2799999999999986e-05, + 1.1333333333333302e-05, + 9.866666666666617e-06, + 8.399999999999987e-06, + 6.933333333333303e-06, + 5.466666666666619e-06, + 4e-06, + 3.9335e-06, + 3.8669999999999996e-06, + 3.8005e-06, + 3.7339999999999997e-06, + 3.6675e-06, + 3.601e-06, + 3.5344999999999998e-06, + 3.4679999999999997e-06, + 3.4015e-06, + 3.335e-06, + 3.2685e-06, + 3.202e-06, + 3.1355e-06, + 3.0689999999999998e-06, + 3.0024999999999996e-06, + 2.936e-06, + 2.8695000000000002e-06, + 2.803e-06, + 2.7365e-06, + 2.67e-06, + 2.6034999999999997e-06, + 2.537e-06, + 2.4705e-06, + 2.404e-06, + 2.3375e-06, + 2.271e-06, + 2.2045e-06, + 2.138e-06, + 2.0715e-06, + 2.005e-06, + 1.9385e-06, + 1.872e-06, + 1.8055e-06, + 1.7390000000000002e-06, + 1.6725e-06, + 1.606e-06, + 1.5395000000000003e-06, + 1.4730000000000001e-06, + 1.4065e-06, + 1.3399999999999999e-06, + 1.2735000000000002e-06, + 1.207e-06, + 1.1405e-06, + 1.0740000000000002e-06, + 1.0075e-06, + 9.41e-07, + 8.745000000000003e-07, + 8.080000000000001e-07, + 7.415e-07, + 6.750000000000003e-07, + 6.085000000000002e-07, + 5.420000000000001e-07, + 4.7550000000000036e-07, + 4.0900000000000023e-07, + 3.425000000000001e-07, + 2.760000000000004e-07, + 2.0950000000000028e-07, + 1.4300000000000016e-07, + 7.650000000000003e-08, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 60 +model = "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/training/ReturnnTrainingJob.pEKVEYrMRNj4/output/models/epoch" +num_epochs = 600 +num_inputs = 1 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "partition_epoch": 20, + "seq_ordering": "laplace:.1000", + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.VZM5dHZhqlnJ/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.SYt8A5fOy2ta/output/targets.hdf" + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert( + 0, "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/recipe" +) +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.zeroout.joint_train_three_model_zeroout_modwise import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.zeroout.joint_train_three_model_zeroout_modwise import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer.conformer_v1 import ConformerEncoderV1Config +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer.conformer_v1 import ConformerBlockV1Config +from i6_models.parts.conformer.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer.convolution import ConformerConvolutionV1Config +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=5, + freq_mask_max_size=8, + ), + conformer_cfg=ConformerEncoderV1Config( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=512, + ), + ), + block_cfg=ConformerBlockV1Config( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, num_att_heads=8, att_weights_dropout=0.1, dropout=0.1 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC(512), + ), + ), + ), + target_size=79, + stage_args={ + "stage_1_layer_dropout_on_large": False, + "stage_1_num_steps_per_iter": [ + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + 27450.0, + ], + "stage_1_num_zero_per_iter": [ + 2, + 4, + 6, + 8, + 10, + 12, + 14, + 16, + 18, + 20, + 22, + 24, + 26, + 28, + 30, + 32, + ], + "stage_1_expected_sparsity_per_iter": [ + 0.041666666666666664, + 0.08333333333333333, + 0.125, + 0.16666666666666666, + 0.20833333333333334, + 0.25, + 0.2916666666666667, + 0.3333333333333333, + 0.375, + 0.4166666666666667, + 0.4583333333333333, + 0.5, + 0.5416666666666666, + 0.5833333333333334, + 0.625, + 0.6666666666666666, + ], + "gate_activation": "sigmoid", + "zeroout_val": -5, + }, + small_model_num_mods=16, + medium_model_num_mods=32, + layer_dropout=0.3, + sparsity_loss_scale=5, + recog_num_mods=48, +) +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.zeroout.joint_train_three_model_zeroout_modwise import ( + train_step, +) diff --git a/2024-dynamic-encoder-size/LBS-960/Simple_Top_k/returnn.config b/2024-dynamic-encoder-size/LBS-960/Simple_Top_k/returnn.config new file mode 100644 index 00000000..3be170ea --- /dev/null +++ b/2024-dynamic-encoder-size/LBS-960/Simple_Top_k/returnn.config @@ -0,0 +1,798 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2400000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +debug_print_layer_output_template = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "partition_epoch": 1, + "seq_ordering": "laplace:.1000", + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.KErFrKsP3fTh/output/audio.hdf", + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.Clwnntg2nopq/output/audio.hdf", + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.h7DH1ILPAElF/output/targets.hdf", + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.vESciFN5It1u/output/targets.hdf", + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 5.466666666666666e-06, + 6.933333333333334e-06, + 8.400000000000001e-06, + 9.866666666666668e-06, + 1.1333333333333336e-05, + 1.28e-05, + 1.4266666666666667e-05, + 1.5733333333333334e-05, + 1.72e-05, + 1.866666666666667e-05, + 2.0133333333333336e-05, + 2.16e-05, + 2.3066666666666667e-05, + 2.4533333333333334e-05, + 2.6000000000000002e-05, + 2.746666666666667e-05, + 2.8933333333333336e-05, + 3.0400000000000004e-05, + 3.186666666666667e-05, + 3.333333333333334e-05, + 3.4800000000000006e-05, + 3.6266666666666676e-05, + 3.773333333333334e-05, + 3.9200000000000004e-05, + 4.0666666666666675e-05, + 4.213333333333334e-05, + 4.360000000000001e-05, + 4.506666666666667e-05, + 4.6533333333333344e-05, + 4.800000000000001e-05, + 4.946666666666668e-05, + 5.093333333333334e-05, + 5.2400000000000007e-05, + 5.386666666666668e-05, + 5.533333333333334e-05, + 5.680000000000001e-05, + 5.8266666666666676e-05, + 5.9733333333333346e-05, + 6.120000000000001e-05, + 6.266666666666668e-05, + 6.413333333333334e-05, + 6.560000000000001e-05, + 6.706666666666668e-05, + 6.853333333333335e-05, + 7.000000000000001e-05, + 7.146666666666668e-05, + 7.293333333333335e-05, + 7.44e-05, + 7.586666666666668e-05, + 7.733333333333335e-05, + 7.880000000000002e-05, + 8.026666666666668e-05, + 8.173333333333335e-05, + 8.320000000000002e-05, + 8.466666666666669e-05, + 8.613333333333334e-05, + 8.760000000000002e-05, + 8.906666666666669e-05, + 9.053333333333334e-05, + 9.200000000000001e-05, + 9.346666666666668e-05, + 9.493333333333336e-05, + 9.640000000000001e-05, + 9.786666666666668e-05, + 9.933333333333335e-05, + 0.00010080000000000001, + 0.00010226666666666668, + 0.00010373333333333335, + 0.00010520000000000002, + 0.00010666666666666668, + 0.00010813333333333335, + 0.00010960000000000002, + 0.00011106666666666668, + 0.00011253333333333335, + 0.00011400000000000002, + 0.00011546666666666669, + 0.00011693333333333335, + 0.00011840000000000002, + 0.00011986666666666669, + 0.00012133333333333336, + 0.0001228, + 0.0001242666666666667, + 0.00012573333333333334, + 0.0001272, + 0.00012866666666666669, + 0.00013013333333333334, + 0.0001316, + 0.00013306666666666668, + 0.00013453333333333334, + 0.000136, + 0.00013746666666666668, + 0.00013893333333333334, + 0.0001404, + 0.00014186666666666668, + 0.00014333333333333334, + 0.0001448, + 0.00014626666666666668, + 0.00014773333333333334, + 0.00014920000000000002, + 0.00015066666666666668, + 0.00015213333333333334, + 0.00015360000000000002, + 0.00015506666666666668, + 0.00015653333333333333, + 0.00015800000000000002, + 0.00015946666666666668, + 0.00016093333333333333, + 0.00016240000000000002, + 0.00016386666666666667, + 0.00016533333333333336, + 0.00016680000000000002, + 0.00016826666666666667, + 0.00016973333333333336, + 0.00017120000000000001, + 0.00017266666666666667, + 0.00017413333333333336, + 0.0001756, + 0.00017706666666666667, + 0.00017853333333333335, + 0.00018, + 0.00018146666666666667, + 0.00018293333333333335, + 0.0001844, + 0.0001858666666666667, + 0.00018733333333333335, + 0.0001888, + 0.0001902666666666667, + 0.00019173333333333335, + 0.0001932, + 0.0001946666666666667, + 0.00019613333333333335, + 0.0001976, + 0.0001990666666666667, + 0.00020053333333333335, + 0.00020200000000000003, + 0.0002034666666666667, + 0.00020493333333333335, + 0.00020640000000000003, + 0.0002078666666666667, + 0.00020933333333333334, + 0.00021080000000000003, + 0.00021226666666666669, + 0.00021373333333333334, + 0.00021520000000000003, + 0.00021666666666666668, + 0.00021813333333333334, + 0.00021960000000000003, + 0.00022106666666666668, + 0.00022253333333333337, + 0.00022400000000000002, + 0.00022546666666666668, + 0.00022693333333333337, + 0.00022840000000000002, + 0.00022986666666666668, + 0.00023133333333333336, + 0.00023280000000000002, + 0.00023426666666666668, + 0.00023573333333333336, + 0.00023720000000000002, + 0.0002386666666666667, + 0.00024013333333333336, + 0.00024160000000000002, + 0.0002430666666666667, + 0.0002445333333333334, + 0.000246, + 0.0002474666666666667, + 0.0002489333333333334, + 0.0002504, + 0.0002518666666666667, + 0.0002533333333333334, + 0.0002548, + 0.0002562666666666667, + 0.0002577333333333334, + 0.0002592, + 0.0002606666666666667, + 0.0002621333333333334, + 0.0002636, + 0.0002650666666666667, + 0.0002665333333333334, + 0.000268, + 0.0002694666666666667, + 0.0002709333333333334, + 0.0002724, + 0.0002738666666666667, + 0.0002753333333333334, + 0.0002768, + 0.0002782666666666667, + 0.0002797333333333334, + 0.0002812, + 0.0002826666666666667, + 0.0002841333333333334, + 0.0002856, + 0.0002870666666666667, + 0.00028853333333333337, + 0.00029000000000000006, + 0.0002914666666666667, + 0.00029293333333333337, + 0.00029440000000000005, + 0.0002958666666666667, + 0.00029733333333333337, + 0.00029880000000000005, + 0.0003002666666666667, + 0.00030173333333333337, + 0.00030320000000000005, + 0.0003046666666666667, + 0.00030613333333333337, + 0.00030760000000000005, + 0.0003090666666666667, + 0.00031053333333333336, + 0.00031200000000000005, + 0.0003134666666666667, + 0.00031493333333333336, + 0.00031640000000000005, + 0.0003178666666666667, + 0.00031933333333333336, + 0.00032080000000000005, + 0.0003222666666666667, + 0.00032373333333333336, + 0.00032520000000000004, + 0.00032666666666666673, + 0.00032813333333333336, + 0.00032960000000000004, + 0.0003310666666666667, + 0.00033253333333333336, + 0.00033400000000000004, + 0.0003354666666666667, + 0.00033693333333333336, + 0.00033840000000000004, + 0.0003398666666666667, + 0.00034133333333333335, + 0.00034280000000000004, + 0.0003442666666666667, + 0.00034573333333333335, + 0.00034720000000000004, + 0.0003486666666666667, + 0.00035013333333333335, + 0.00035160000000000004, + 0.0003530666666666667, + 0.00035453333333333335, + 0.00035600000000000003, + 0.0003574666666666667, + 0.00035893333333333335, + 0.00036040000000000003, + 0.0003618666666666667, + 0.0003633333333333334, + 0.00036480000000000003, + 0.0003662666666666667, + 0.0003677333333333334, + 0.00036920000000000003, + 0.0003706666666666667, + 0.0003721333333333334, + 0.00037360000000000003, + 0.0003750666666666667, + 0.0003765333333333334, + 0.000378, + 0.0003794666666666667, + 0.0003809333333333334, + 0.0003824, + 0.0003838666666666667, + 0.0003853333333333334, + 0.0003868, + 0.0003882666666666667, + 0.0003897333333333334, + 0.0003912, + 0.0003926666666666667, + 0.0003941333333333334, + 0.0003956, + 0.0003970666666666667, + 0.0003985333333333334, + 0.0004, + 0.00039853333333333333, + 0.0003970666666666667, + 0.0003956, + 0.00039413333333333334, + 0.0003926666666666667, + 0.0003912, + 0.00038973333333333334, + 0.0003882666666666667, + 0.0003868, + 0.00038533333333333334, + 0.0003838666666666667, + 0.0003824, + 0.00038093333333333334, + 0.0003794666666666667, + 0.000378, + 0.00037653333333333334, + 0.00037506666666666666, + 0.00037360000000000003, + 0.00037213333333333334, + 0.00037066666666666666, + 0.00036920000000000003, + 0.00036773333333333335, + 0.00036626666666666666, + 0.00036480000000000003, + 0.00036333333333333335, + 0.00036186666666666666, + 0.00036040000000000003, + 0.00035893333333333335, + 0.00035746666666666666, + 0.00035600000000000003, + 0.00035453333333333335, + 0.00035306666666666667, + 0.00035160000000000004, + 0.00035013333333333335, + 0.00034866666666666667, + 0.0003472, + 0.00034573333333333335, + 0.00034426666666666667, + 0.0003428, + 0.00034133333333333335, + 0.00033986666666666667, + 0.0003384, + 0.00033693333333333336, + 0.00033546666666666667, + 0.000334, + 0.00033253333333333336, + 0.00033106666666666667, + 0.0003296, + 0.00032813333333333336, + 0.0003266666666666667, + 0.0003252, + 0.00032373333333333336, + 0.0003222666666666667, + 0.0003208, + 0.00031933333333333336, + 0.0003178666666666667, + 0.0003164, + 0.00031493333333333336, + 0.0003134666666666667, + 0.000312, + 0.00031053333333333336, + 0.0003090666666666667, + 0.0003076, + 0.00030613333333333337, + 0.0003046666666666667, + 0.0003032, + 0.00030173333333333337, + 0.0003002666666666667, + 0.0002988, + 0.00029733333333333337, + 0.0002958666666666667, + 0.0002944, + 0.00029293333333333337, + 0.0002914666666666667, + 0.00029, + 0.0002885333333333333, + 0.0002870666666666667, + 0.0002856, + 0.0002841333333333333, + 0.00028266666666666663, + 0.0002812, + 0.0002797333333333333, + 0.00027826666666666664, + 0.0002768, + 0.0002753333333333333, + 0.00027386666666666664, + 0.0002724, + 0.0002709333333333333, + 0.00026946666666666664, + 0.000268, + 0.0002665333333333333, + 0.00026506666666666664, + 0.0002636, + 0.0002621333333333333, + 0.00026066666666666664, + 0.0002592, + 0.00025773333333333333, + 0.00025626666666666664, + 0.0002548, + 0.00025333333333333333, + 0.00025186666666666664, + 0.0002504, + 0.00024893333333333333, + 0.00024746666666666665, + 0.000246, + 0.00024453333333333333, + 0.00024306666666666668, + 0.0002416, + 0.00024013333333333333, + 0.00023866666666666665, + 0.0002372, + 0.00023573333333333334, + 0.00023426666666666665, + 0.0002328, + 0.00023133333333333334, + 0.00022986666666666665, + 0.0002284, + 0.00022693333333333334, + 0.00022546666666666665, + 0.000224, + 0.00022253333333333334, + 0.00022106666666666666, + 0.0002196, + 0.00021813333333333331, + 0.00021666666666666666, + 0.0002152, + 0.00021373333333333332, + 0.00021226666666666666, + 0.0002108, + 0.00020933333333333332, + 0.00020786666666666666, + 0.0002064, + 0.00020493333333333332, + 0.00020346666666666666, + 0.00020199999999999998, + 0.00020053333333333332, + 0.00019906666666666666, + 0.00019759999999999998, + 0.00019613333333333332, + 0.00019466666666666666, + 0.00019319999999999998, + 0.00019173333333333332, + 0.00019026666666666667, + 0.00018879999999999998, + 0.00018733333333333332, + 0.00018586666666666667, + 0.00018439999999999998, + 0.00018293333333333333, + 0.00018146666666666664, + 0.00017999999999999998, + 0.00017853333333333333, + 0.00017706666666666664, + 0.00017559999999999999, + 0.00017413333333333333, + 0.00017266666666666664, + 0.0001712, + 0.00016973333333333333, + 0.00016826666666666665, + 0.0001668, + 0.0001653333333333333, + 0.00016386666666666665, + 0.0001624, + 0.0001609333333333333, + 0.00015946666666666665, + 0.000158, + 0.0001565333333333333, + 0.00015506666666666662, + 0.0001536, + 0.0001521333333333333, + 0.00015066666666666662, + 0.0001492, + 0.0001477333333333333, + 0.00014626666666666663, + 0.0001448, + 0.0001433333333333333, + 0.00014186666666666663, + 0.0001404, + 0.0001389333333333333, + 0.00013746666666666663, + 0.000136, + 0.00013453333333333331, + 0.00013306666666666663, + 0.0001316, + 0.00013013333333333332, + 0.00012866666666666663, + 0.0001272, + 0.00012573333333333332, + 0.00012426666666666663, + 0.0001228, + 0.00012133333333333332, + 0.00011986666666666663, + 0.0001184, + 0.00011693333333333332, + 0.00011546666666666664, + 0.00011399999999999995, + 0.00011253333333333332, + 0.00011106666666666664, + 0.00010959999999999995, + 0.00010813333333333332, + 0.00010666666666666664, + 0.00010519999999999996, + 0.00010373333333333332, + 0.00010226666666666664, + 0.00010079999999999996, + 9.933333333333333e-05, + 9.786666666666664e-05, + 9.639999999999996e-05, + 9.493333333333333e-05, + 9.346666666666664e-05, + 9.199999999999996e-05, + 9.053333333333333e-05, + 8.906666666666665e-05, + 8.759999999999996e-05, + 8.613333333333333e-05, + 8.466666666666665e-05, + 8.319999999999996e-05, + 8.173333333333333e-05, + 8.026666666666665e-05, + 7.879999999999996e-05, + 7.733333333333328e-05, + 7.586666666666665e-05, + 7.439999999999997e-05, + 7.293333333333328e-05, + 7.146666666666665e-05, + 6.999999999999997e-05, + 6.853333333333328e-05, + 6.706666666666665e-05, + 6.559999999999997e-05, + 6.413333333333328e-05, + 6.266666666666665e-05, + 6.119999999999997e-05, + 5.9733333333333285e-05, + 5.8266666666666655e-05, + 5.679999999999997e-05, + 5.533333333333329e-05, + 5.386666666666666e-05, + 5.239999999999997e-05, + 5.093333333333329e-05, + 4.946666666666666e-05, + 4.7999999999999974e-05, + 4.653333333333329e-05, + 4.506666666666666e-05, + 4.3599999999999976e-05, + 4.213333333333329e-05, + 4.066666666666661e-05, + 3.919999999999998e-05, + 3.773333333333329e-05, + 3.626666666666661e-05, + 3.479999999999998e-05, + 3.3333333333333294e-05, + 3.186666666666661e-05, + 3.039999999999998e-05, + 2.8933333333333296e-05, + 2.746666666666661e-05, + 2.599999999999998e-05, + 2.4533333333333297e-05, + 2.3066666666666613e-05, + 2.1599999999999983e-05, + 2.01333333333333e-05, + 1.8666666666666614e-05, + 1.7199999999999984e-05, + 1.57333333333333e-05, + 1.4266666666666616e-05, + 1.2799999999999986e-05, + 1.1333333333333302e-05, + 9.866666666666617e-06, + 8.399999999999987e-06, + 6.933333333333303e-06, + 5.466666666666619e-06, + 4e-06, + 3.9335e-06, + 3.8669999999999996e-06, + 3.8005e-06, + 3.7339999999999997e-06, + 3.6675e-06, + 3.601e-06, + 3.5344999999999998e-06, + 3.4679999999999997e-06, + 3.4015e-06, + 3.335e-06, + 3.2685e-06, + 3.202e-06, + 3.1355e-06, + 3.0689999999999998e-06, + 3.0024999999999996e-06, + 2.936e-06, + 2.8695000000000002e-06, + 2.803e-06, + 2.7365e-06, + 2.67e-06, + 2.6034999999999997e-06, + 2.537e-06, + 2.4705e-06, + 2.404e-06, + 2.3375e-06, + 2.271e-06, + 2.2045e-06, + 2.138e-06, + 2.0715e-06, + 2.005e-06, + 1.9385e-06, + 1.872e-06, + 1.8055e-06, + 1.7390000000000002e-06, + 1.6725e-06, + 1.606e-06, + 1.5395000000000003e-06, + 1.4730000000000001e-06, + 1.4065e-06, + 1.3399999999999999e-06, + 1.2735000000000002e-06, + 1.207e-06, + 1.1405e-06, + 1.0740000000000002e-06, + 1.0075e-06, + 9.41e-07, + 8.745000000000003e-07, + 8.080000000000001e-07, + 7.415e-07, + 6.750000000000003e-07, + 6.085000000000002e-07, + 5.420000000000001e-07, + 4.7550000000000036e-07, + 4.0900000000000023e-07, + 3.425000000000001e-07, + 2.760000000000004e-07, + 2.0950000000000028e-07, + 1.4300000000000016e-07, + 7.650000000000003e-08, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 60 +model = "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/training/ReturnnTrainingJob.730j0wz4gkwk/output/models/epoch" +num_epochs = 600 +num_inputs = 1 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "partition_epoch": 20, + "seq_ordering": "laplace:.1000", + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.VZM5dHZhqlnJ/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.SYt8A5fOy2ta/output/targets.hdf" + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert( + 0, "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/recipe" +) +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.simple_topk.joint_train_three_model_simple_topk_modwise import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.simple_topk.joint_train_three_model_simple_topk_modwise import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer.conformer_v1 import ConformerEncoderV1Config +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer.conformer_v1 import ConformerBlockV1Config +from i6_models.parts.conformer.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer.convolution import ConformerConvolutionV1Config +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=5, + freq_mask_max_size=8, + ), + conformer_cfg=ConformerEncoderV1Config( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=512, + ), + ), + block_cfg=ConformerBlockV1Config( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, num_att_heads=8, att_weights_dropout=0.1, dropout=0.1 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC(512), + ), + ), + ), + target_size=79, + start_select_step=0, + small_model_num_mods=16, + medium_model_num_mods=32, + tau_args={ + "initial_tau": 2, + "annealing": 0.999992, + "min_tau": 0.1, + "gumbel_scale": 0.05, + }, + layer_dropout={"layer_dropout_mod_select": 0, "layer_dropout_fix_mod": 0.3}, + recog_num_mods=48, + k_anneal_args={"k_anneal_num_steps_per_iter": 14640, "k_reduction_per_iter": 1}, +) +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.simple_topk.joint_train_three_model_simple_topk_modwise import ( + train_step, +) diff --git a/2024-dynamic-encoder-size/README.md b/2024-dynamic-encoder-size/README.md index 28f6fd28..5e79bff8 100644 --- a/2024-dynamic-encoder-size/README.md +++ b/2024-dynamic-encoder-size/README.md @@ -15,3 +15,18 @@ ConformerCTCModel, ConformerCTCConfig and train_step in returnn config is define ### TED-LIUM-v2 Iterative-Zero-Out ConformerCTCModel, ConformerCTCConfig and train_step in returnn config is defined in [here](https://github.com/rwth-i6/i6_experiments/blob/main/users/jxu/experiments/ctc/tedlium2/pytorch_networks/dynamic_encoder_size/iterative_zero_out_refactored/jointly_train_iterative_zero_out_layerwise.py) + + +### LBS-960 Simple-Top-K + +ConformerCTCModel, ConformerCTCConfig and train_step in returnn config is defined in [here](https://github.com/rwth-i6/i6_experiments/blob/main/users/jxu/experiments/ctc/lbs_960/pytorch_networks/dynamic_encoder_size/simple_topk/joint_train_three_model_simple_topk_modwise.py) + + +### LBS-960 Iterative-Zero-Out + +ConformerCTCModel, ConformerCTCConfig and train_step in returnn config is defined in [here](https://github.com/rwth-i6/i6_experiments/blob/main/users/jxu/experiments/ctc/lbs_960/pytorch_networks/dynamic_encoder_size/zeroout/joint_train_three_model_zeroout_modwise.py) + + +### LBS-960 Aux-Loss + +ConformerCTCModel, ConformerCTCConfig and train_step in returnn config is defined in [here](https://github.com/rwth-i6/i6_experiments/blob/main/users/jxu/experiments/ctc/lbs_960/pytorch_networks/baseline/conformer_ctc_d_model_512_num_layers_12_new_frontend_raw_wave_with_aux_loss.py) \ No newline at end of file