Fix device errors on GPU environments

#1
by iwiwi - opened

Without this fix, we encounter the following errors on GPU environments:

❯ python3 example.py 
Downloading (…)rocessor_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 477/477 [00:00<00:00, 73.9kB/s]
Traceback (most recent call last):
  File "example.py", line 67, in <module>
    image_features = model.get_image_features(**image)
  File ".../.cache/huggingface/modules/transformers_modules/stabilityai/japanese-stable-clip-vit-l-16/815a716c60dd0cf36a5256e6546f35e33077545b/modeling_japanese_stable_clip.py", line 154, in get_image_features
    vision_outputs = self.vision_model(
  File "/.../.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/.../.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/.../.local/lib/python3.8/site-packages/transformers/models/vit/modeling_vit.py", line 573, in forward
    embedding_output = self.embeddings(
  File "/.../.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/.../.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/.../.local/lib/python3.8/site-packages/transformers/models/vit/modeling_vit.py", line 122, in forward
    embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  File "/.../.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/.../.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/.../.local/lib/python3.8/site-packages/transformers/models/vit/modeling_vit.py", line 181, in forward
    embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  File "/.../.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/.../.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/.../.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/.../.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

Super helpful! Thank you!

mkshing changed pull request status to merged
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment