You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When launching the WordTagger with device = 'cpu' the class throws an error:
WordsTagger( basepath, device='cpu')
File "C:\Users\MarcoOdore\agilelab\MultiLegalSBD-master\models.py", line 613, in __init__
self.tagger = WordsTagger(
File "C:\Users\MarcoOdore\agilelab\MultiLegalSBD-master\venv\lib\site-packages\bi_lstm_crf\app\predict.py", line 15, in __init__
self.model = build_model(self.args, self.preprocessor, load=True, verbose=False)
File "C:\Users\MarcoOdore\agilelab\MultiLegalSBD-master\venv\lib\site-packages\bi_lstm_crf\app\utils.py", line 24, in build_model
state_dict = torch.load(model_path)
File "C:\Users\MarcoOdore\agilelab\MultiLegalSBD-master\venv\lib\site-packages\torch\serialization.py", line 789, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "C:\Users\MarcoOdore\agilelab\MultiLegalSBD-master\venv\lib\site-packages\torch\serialization.py", line 1131, in _load
result = unpickler.load()
File "C:\Users\MarcoOdore\agilelab\MultiLegalSBD-master\venv\lib\site-packages\torch\serialization.py", line 1101, in persistent_load
load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
File "C:\Users\MarcoOdore\agilelab\MultiLegalSBD-master\venv\lib\site-packages\torch\serialization.py", line 1083, in load_tensor
wrap_storage=restore_location(storage, location),
File "C:\Users\MarcoOdore\agilelab\MultiLegalSBD-master\venv\lib\site-packages\torch\serialization.py", line 215, in default_restore_location
result = fn(storage, location)
File "C:\Users\MarcoOdore\agilelab\MultiLegalSBD-master\venv\lib\site-packages\torch\serialization.py", line 182, in _cuda_deserialize
device = validate_cuda_device(location)
File "C:\Users\MarcoOdore\agilelab\MultiLegalSBD-master\venv\lib\site-packages\torch\serialization.py", line 166, in validate_cuda_device
raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
The reason is because under app.utils.py, the method build_model don't take in account the device passed as input
def build_model(args, processor, load=True, verbose=False):
model = BiRnnCrf(len(processor.vocab), len(processor.tags),
embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, num_rnn_layers=args.num_rnn_layers)
# weights
model_path = model_filepath(args.model_dir)
if exists(model_path) and load:
state_dict = torch.load(model_path) # here
model.load_state_dict(state_dict)
if verbose:
print("load model weights from {}".format(model_path))
return model
I think that the problem could be solved by passing the device also to the build_model method, changing the torch.load method, adding the desired device
def build_model(args, processor, load=True, verbose=False, device='gpu'):
model = BiRnnCrf(len(processor.vocab), len(processor.tags),
embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, num_rnn_layers=args.num_rnn_layers)
# weights
model_path = model_filepath(args.model_dir)
if exists(model_path) and load:
if device == 'cpu':
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
else:
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
if verbose:
print("load model weights from {}".format(model_path))
return model
The text was updated successfully, but these errors were encountered:
When launching the WordTagger with device = 'cpu' the class throws an error:
The reason is because under app.utils.py, the method build_model don't take in account the device passed as input
I think that the problem could be solved by passing the device also to the build_model method, changing the torch.load method, adding the desired device
The text was updated successfully, but these errors were encountered: