Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Expected argument target to be an int or long tensor, but got tensor with dtype torch.float32 #46

Open
X02cinnamondirty opened this issue Feb 28, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@X02cinnamondirty
Copy link

X02cinnamondirty commented Feb 28, 2024

train.fit_sequence_module(
     in_memory=True,
     train_var="train_val",
...      model=model,
...      sdata=sd_train,
...      seq_var="ohe_seq",
...      target_vars= "id_x",
...      in_memory=True,
...      train_var="train_val",
...      epochs=25,
...      gpus=1,
...      batch_size=9,
...      num_workers=4,
...      prefetch_factor=2,
...      drop_last=False,
...      name="LTRidentity",
...      version="0.75",
...      transforms={"ohe_seq": lambda x: x.swapaxes(1, 2)}
...  )
Dropping 0 sequences with NaN targets.
Loading ohe_seq and ['id_x'] into memory
No seed set
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name         | Type            | Params
-------------------------------------------------
0 | arch         | SmallCNN        | 2.8 K
1 | train_metric | MulticlassAUROC | 0
2 | val_metric   | MulticlassAUROC | 0
3 | test_metric  | MulticlassAUROC | 0
-------------------------------------------------
2.8 K     Trainable params
0         Non-trainable params
2.8 K     Total params
0.011     Total estimated model params size (MB)
Sanity Checking DataLoader 0:   0%|                                                                                                                                                     | 0/2 [00:00<?, ?it/s]/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/eugene/models/_SequenceModule.py:203: UserWarning: Using a target size (torch.Size([9])) that is different to the input size (torch.Size([9, 9])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  loss = self.loss_fxn(outs, y)  # train
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/eugene/train/_fit.py", line 273, in fit_sequence_module
    trainer = fit(
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/eugene/train/_fit.py", line 123, in fit
    trainer.fit(
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
    results = self._run_stage()
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1033, in _run_stage
    self._run_sanity_check()
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1062, in _run_sanity_check
    val_loop.run()
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 134, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 391, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 403, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/eugene/models/_SequenceModule.py", line 228, in validation_step
    calculate_metric(
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/eugene/models/base/_metrics.py", line 53, in calculate_metric
    metric(outs, y)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torchmetrics/metric.py", line 298, in forward
    self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torchmetrics/metric.py", line 367, in _forward_reduce_state_update
    self.update(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torchmetrics/metric.py", line 460, in wrapped_func
    update(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torchmetrics/classification/precision_recall_curve.py", line 345, in update
    _multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torchmetrics/functional/classification/precision_recall_curve.py", line 394, in _multiclass_precision_recall_curve_tensor_validation
    raise ValueError(
**ValueError: Expected argument `target` to be an int or long tensor, but got tensor with dtype torch.float32**

but

sd_train
<xarray.Dataset>
Dimensions:    (_sequence: 20466, _length: 300, length: 300, _ohe: 4)
Dimensions without coordinates: _sequence, _length, length, _ohe
Data variables:
    qseqid_x   (_sequence) object dask.array<chunksize=(91,), meta=np.ndarray>
    seq        (_sequence, _length) |S1 dask.array<chunksize=(91, 300), meta=np.ndarray>
    set        (_sequence) object dask.array<chunksize=(91,), meta=np.ndarray>
    spe        (_sequence) object dask.array<chunksize=(91,), meta=np.ndarray>
    sseqid     (_sequence) object dask.array<chunksize=(91,), meta=np.ndarray>
    id         (_sequence) <U8 'seq00000' 'seq00001' ... 'seq22738' 'seq22739'
    ohe_seq    (_sequence, length, _ohe) uint8 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0
    id_x       (_sequence) int8 0 1 1 0 1 0 0 0 1 0 0 ... 1 1 1 1 1 1 1 0 1 0 1
    train_val  (_sequence) bool False True True False ... True False True False
    **target     (_sequence) int8 0 1 1 0 1 0 0 0 1 0 0 ... 1 1 1 1 1 1 1 0 1 0 1**
>>> **sd_train['target']
<xarray.DataArray 'target' (_sequence: 20466)>
array([0, 1, 1, ..., 1, 0, 1], dtype=int8)
Dimensions without coordinates: _sequence**

My target var is int ,why this error happen?

@adamklie
Copy link
Collaborator

adamklie commented Mar 4, 2024

By default, the dataloading step casts target_vars to torch.float32.

You can overwrite this using thetransforms argument. Try modifying it to:

transforms={"ohe_seq": lambda x: x.swapaxes(1, 2), "id_x": lambda x: torch.tensor(x, dtype=torch.long))}

@adamklie adamklie self-assigned this Mar 4, 2024
@adamklie adamklie added the bug Something isn't working label Mar 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants