Configuring Command-R for long context tasks

#32
by beam-me-up-scotty - opened

Apologies for the duplicate post, but the previous related discussion was unclear to me.

saurabhdash mentions:

"This implementation is based on the Llama implementation which materializes this huge buffer which would not be feasible for 128k context. The model does support 128k context with a better implementation."

and then gives the following line of python:

causal_mask = torch.full( (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool )

What exact steps do we need to follow to implement this?

I've tried editing the max_position_embeddings directly in the config.json, and can only run a 25k prompt with max_position_embeddings=32768 and 8 bit quant using a machine with 2x A100 (approx 160GB VRAM).

Can someone indicate how this default implementation needs to change to use the better implementation mentioned above:

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_id = "CohereForAI/c4ai-command-r-v01"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, device_map="auto", quantization_config=bnb_config)

Hi! Apart from the materialized attention mask, there is another problem -- the logits are up-casted to fp32. If you have a seq length of 128k, the logits themselves would take up 128k * 256k * 4(bytes)= 131GB. If the goal is to use it for generation, one could get rid of this and just do log-softmax over the last token's logits.

Thanks for your answer @saurabhdash ! In terms of implementation:

  • Would the implementation of causal_mask at line 614 of modeling_cohere.py in forward() need to change to your above implementation?
  • Where would you change the implementation of the logits? Any tips about how to do so?
  • What's a reasonable VRAM usage to expect for a 128k task with these optimisations? Am I over-optimistic to think that we can fit a context of that size on 2x A100s?

Apologies if these are silly questions, still a little new to all this

Cohere For AI org

I'd recommend waiting for/ using the vLLM implementation. That should be able to help you scale the context to the maximum.

alexrs changed discussion status to closed
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment