Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Bool data handled as tensors -> can't set batch size with 0-dim data. #1199

Open
3 tasks done
alex-bene opened this issue Feb 1, 2025 · 9 comments
Open
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@alex-bene
Copy link

Describe the bug

When using bool data, those are transformed to 0-dim tensors internally. As a result, auto_batch_size_ can't infer the batch size, and also setting the batch size raises an exception.

To Reproduce

from tensordict import TensorDict

# Passing "string" argument (non tensor)
td = TensorDict(no_bs_arg="True")
td.auto_batch_size_(1)
assert td.batch_size == torch.Size([])
td["tt"] = torch.rand(2, 3)
assert td.batch_size == torch.Size([])
td.auto_batch_size_(1)
assert td.batch_size == torch.Size([2])

# Passing "bool" argument (still a non-tensor but transformed to tensor internally)
td = TensorDict(no_bs_arg=True)
td.auto_batch_size_(1)
assert td.batch_size == torch.Size([])
td["tt"] = torch.rand(2, 3)
assert td.batch_size == torch.Size([])
td.auto_batch_size_(1)
assert td.batch_size == torch.Size([])
try:
    td.batch_size = torch.Size([2])
except RuntimeError as e:
    assert (
        str(e)
        == "the tensor no_bs_arg has shape torch.Size([]) which is incompatible with the batch-size torch.Size([2])."
    )

Expected behavior

Bool should be handled like string and other non-tensor data.

Reason and Possible fixes

The reason is that bool arguments are internally transformed into tensor data. While I understand that using tensors as an internal representation might be more efficient, maybe we should ignore tensors with .ndim == 0 in automatically calculating the batch size and also in _check_new_batch_size.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@alex-bene alex-bene added the bug Something isn't working label Feb 1, 2025
@alex-bene alex-bene changed the title [BUG] Bool data handled as tensors -> can't set batch size with 1-dim data. [BUG] Bool data handled as tensors -> can't set batch size with 0-dim data. Feb 1, 2025
@vmoens
Copy link
Contributor

vmoens commented Feb 6, 2025

This is a design decision in TensorDict that doing things like TensorDict(a=0, b=1.0, c=True) will give you a TensorDict with tensor leaves.

I'm not going to hide from it: it is a "historical" thing in the sense that when we started, NonTensorData were not remotely on the radar and assuming that everything was always going to be a tensor we assumed that automatically casting was the appropriate thing to do.

The official way to do what you want is to pass them as NonTensorData:
TensorDict(a=NonTensorData(0), b=NonTensorData(1.0), c=NonTensorData(True))
(a bit clunky I suppose?).

Now that being said, I'd be open to implement a feature by which these transformation do not occur but then we'd be facing an issue which is to decide what happens down the line:

td = TensorDict(a=0, autocast=False)
td["a"] # returns an integer
td["b"] = 1 # cast or no cast?

ie, is the autocast=False for the constructor only of for the TensorDict as a whole?
Another one:

td = TensorDict(autocast=False)
td["root"] = 0 # no cast
td["a", "b"] = 0 # no cast?
td["a"]["b"] = 0 # no cast?
super_td = TensorDict(td=td, autocast=True)
super_td["td", "a", "b"] = 0 # no cast?

(in tensorclass we have the ability to control this via @tensorclass(autocast=True) which will attempt to map evey input to its typed annotation, @tensorclass(nocast=True) which will do what you suggested - ie avoid any type of casting).

@alex-bene
Copy link
Author

alex-bene commented Feb 6, 2025

Okay, I see what you mean. Not a straightforward change. However, we can probably postpone this change but still address the initial issue by making a small change in _set_max_batch_size.
For example, change this line:

tensor_data = [val for val in source.values() if not is_non_tensor(val)]

to this

tensor_data = [val for val in source.values() if not is_non_tensor(val) and val.ndim > 0]

Which will ignore 0-dim tensors (cast or not) considering that these do not have a batch size.

Does this sound reasonable or does it create another problem?

@vmoens
Copy link
Contributor

vmoens commented Feb 6, 2025

hmmm no bc the concept of the batch-size is "the common leading dim of all tensors"
But I coded a UnbatchedTensor thingy that could prove useful here if you don't care that your True is, tensor or not. #1170

You can run this under #1213

from tensordict import UnbatchedTensor, TensorDict
import torch
td = TensorDict(a=UnbatchedTensor(0), b=torch.randn(10, 11)).auto_batch_size_(batch_dims=2)
assert td.shape == (10, 11)

@alex-bene
Copy link
Author

alex-bene commented Feb 6, 2025

hmmm no bc the concept of the batch-size is "the common leading dim of all tensors"

Well, it would not be unreasonable to interpret this as "the common leading dim of all tensors that have dim"

Still, even using NonTensorData is not enough though in general because the non-tensor-type does not transfer when assigning from one tensordict to another. (which comes back to your comment about casting again) Meaning:

import torch
from tensordict import TensorDict, NonTensorData

td = TensorDict(a=NonTensorData(True))
assert isinstance(td["a"], bool) # when accessing a NonTensorData it returns the underlying type
td["a"] = td["a"] # but setting a boolean -- imagine this for the non trivial operation of setting td["a"] from td2["a"]
assert isinstance(td["a"], torch.Tensor) # ends up with a tensor

While returning the underlying type for non-tensor data sure makes sense (and I like it this way since the only reason I use the NonTensorDict is to avoid the casting etc) it seems weird that running a seemingly no-op, the underlying type changes and creates again the problems with batch size as explained here, etc. I can still do something like:

td["a"] = NonTensorData(td["a"]) if isinstance(td["a"], bool) else td["a"]

But this seems overly convoluted just to be able to have a single boolean metadata inside the tensordict that does not affect the batch size.

Why is it preferred to cast non-tensor (and even more non-sequence) data to tensors in general?

(I haven't yet tested UnbatchedTensor to check if the same thing happens and we end up with a plain tensor on assignment though it seems quite probable)

@vmoens
Copy link
Contributor

vmoens commented Feb 6, 2025

You raise a good point there, that's something I can fix.

Re "the common leading dim of all tensors that have dim" the problem with this is that sometimes we want to do

td = TensorDict(a=0)
td = td.expand(100)
td.memmap_()

Which is useful to preallocate data on disk.
UnbatchedTensor is defo the API to do this!

@alex-bene
Copy link
Author

alex-bene commented Feb 7, 2025

sometimes we want to do

td = TensorDict(a=0)
td = td.expand(100)
td.memmap_()

Which is useful to preallocate data on disk.

I can understand why this would not be possible if the casting of integer to tensor does not happen, however why would this not be possible with the changes in _set_max_batch_size?

Why is it preferred to cast non-tensor (and even more non-sequence) data to tensors in general?

Also, can you share an insight about this? I mean, apart from the historical fact you said, I can understand casting lists, arrays, etc, but intuitively I would expect the casting to ignore the primitive python types. We can always explicitly write torch.tensor(0) if we want a cast like this to happen. Is this something that makes sense?

@alex-bene
Copy link
Author

Hey @vmoens , is there any update about this?

@vmoens
Copy link
Contributor

vmoens commented Feb 19, 2025

Let me maybe outline what I have in mind, see if we're on the same page:

  • RE autocasting any compatible type to tensor: I think casting to a tensor if possible is a nice feature, the fist and foremost responsibility of the class is to carry tensors. If you pass a numpy array, an int or a bool, unless you specifically ask not to cast it (which you can!) we should IMO use a tensor to represent it. If it looks less ambiguous not to do so, when you start to use more advanced features like consolidate, share-memory and such it becomes harder to really understand what is happening. Because we have NonTensorData, there is a way for you not to have that side effect at a low cost. I still see NonTensorData as a nice add-on but less a core feature than TensorDict itself.

  • RE allowing non-empty batch-sizes with empty tensors, my take is that TensorDict should be a class that relies on a very few basic axioms and has the whole API follow these logically. These axioms are: (1) A TensorDict is a collection of tensors or other tensordicts, (2) a TensorDict has an arbitrary batch size which must match the leading shape of all the tensors it contains, (3) The device is not mandatory, if it is set then all set ops will cast tensors to the desired device.
    If we add must match the leading shape of all the tensors it contains except tensors without shape we quickly end up in some complex situations where we need to make opinionated choices (ie, choices that will not be obvious for a large chunk of the community). For instance: What happens during td.expand if td has no shape and the tensor has no shape either? What happens if the td has a shape, the tensor has the same shape, and we vmap over the tensordict. Then the tensor appears not to have a shape - is it a tensor with an empty shape? There are plenty of examples where this refactoring would create situations where we'd need to break the API, introduce multiple shape checks and probably some bugs too!

We can always explicitly write torch.tensor(0) if we want a cast like this to happen

True, and we can always write NonTensorData(0) or UnbatchedTensor(tensor(0)) if we want to use it as a plain int or an unbatched tensor :)

Note that I'm a bit biased, tbh I started working on tensordict thinking that we'd never ever have non-tensor data, we'd never ever use dataclasses, we'd never ever support arithmetic ops and here we are... I suffer from a patholocical psycho-inertia and I'm working on it!

Maybe @shagunsodhani would like to give his two cents here since he's been following the project since its inception!

@alex-bene
Copy link
Author

I think I understand all your points and I do not inherently disagree with any of them.

True, and we can always write NonTensorData(0) or UnbatchedTensor(tensor(0)) if we want to use it as a plain int or an unbatched tensor :)

Maybe just this one a bit, but again this is my own bias on what I would expect to happen, which most probably is not what the intuition of the rest of the community would be in this case. So using NonTensorData(0) explicitly is not a big problem.

However, where there indeed exists a problem is in the assignment of variables that are non-tensor data that "silently" creates tensor data as shown here which seems reasonable to expect not to happen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants