neon-8-qubits / custom_gpt_config.py
MarkChenX's picture
Upload 9 files
706c147 verified
raw
history blame
623 Bytes
from transformers import GPT2Config
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
class CustomGPTConfig(GPT2Config):
model_type = "custom_gpt"
def __init__(self, vocab_size=50304, n_layer=24, n_head=16, hidden_size=1024, block_size=1024, **kwargs):
super().__init__(
vocab_size=vocab_size,
n_positions=block_size,
n_ctx=block_size,
n_embd=hidden_size,
n_layer=n_layer,
n_head=n_head,
**kwargs,
)
# Register the custom configuration
CONFIG_MAPPING.register("custom_gpt", CustomGPTConfig)