diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml index a8e89e3607..76216debd8 100755 --- a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml +++ b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml @@ -69,7 +69,7 @@ task: loss: mse n_jobs: 20 GPU: 0 - rnn_type: GRU + rnn_type: LSTM dataset: class: TSDatasetH module_path: qlib.data.dataset diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml index 3aa8147fcf..7d5bd792fc 100644 --- a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml +++ b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml @@ -60,7 +60,7 @@ task: metric: loss loss: mse GPU: 0 - rnn_type: GRU + rnn_type: LSTM dataset: class: DatasetH module_path: qlib.data.dataset diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index b0770e2bdd..6dca7622a1 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -52,6 +52,7 @@ def __init__( optimizer="adam", GPU=0, seed=None, + rnn_type="GRU", **kwargs ): # Set logger. @@ -103,6 +104,7 @@ def __init__( self.device, self.use_gpu, seed, + self.rnn_type = rnn_type ) ) @@ -115,6 +117,7 @@ def __init__( hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout, + rnn_type=self.rnn_type, ) self.logger.info("model:\n{:}".format(self.ALSTM_model)) self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model))) diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 3ab8ed8ab5..f46a6fa5c5 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -56,6 +56,7 @@ def __init__( n_jobs=10, GPU=0, seed=None, + rnn_type="GRU", **kwargs ): # Set logger. @@ -77,6 +78,7 @@ def __init__( self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") self.n_jobs = n_jobs self.seed = seed + self.rnn_type = rnn_type self.logger.info( "ALSTM parameters setting:" @@ -122,6 +124,7 @@ def __init__( hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout, + rnn_type=self.rnn_type, ) self.logger.info("model:\n{:}".format(self.ALSTM_model)) self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)))