Device incorrect when importing functions from Torch

#85
by catwell - opened

For instance:

import spaces
import torch
from torch import ones

@spaces.GPU
def foo():
  print(torch.ones(3, device="cuda"))
  print(ones(3, device="cuda"))

foo()

This prints:

tensor([1., 1., 1.], device='cuda:0')
tensor([1., 1., 1.])

It is pretty confusing and will cause issues with many existing code bases, so I think it would be a good idea to fix this.

If you need a workaround... At the top of the app.py file, after importing spaces and torch, do this:

def my_ones(*args, **kwargs):
    return torch.ones(*args, **kwargs)

torch.ones = my_ones

I know it looks weird, but it works (provided you never call ones outside of the @spaces.GPU decorator).

(I used ones at an example but it is the same thing with all other Torch functions such as arange.)

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment