from transformers import GPT2Config from transformers.models.auto.configuration_auto import CONFIG_MAPPING class CustomGPTConfig(GPT2Config): model_type = "custom_gpt" def __init__(self, vocab_size=50304, n_layer=24, n_head=16, hidden_size=1024, block_size=1024, **kwargs): super().__init__( vocab_size=vocab_size, n_positions=block_size, n_ctx=block_size, n_embd=hidden_size, n_layer=n_layer, n_head=n_head, **kwargs, ) # Register the custom configuration CONFIG_MAPPING.register("custom_gpt", CustomGPTConfig)