File size: 623 Bytes
706c147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers import GPT2Config
from transformers.models.auto.configuration_auto import CONFIG_MAPPING

class CustomGPTConfig(GPT2Config):
    model_type = "custom_gpt"

    def __init__(self, vocab_size=50304, n_layer=24, n_head=16, hidden_size=1024, block_size=1024, **kwargs):
        super().__init__(
            vocab_size=vocab_size,
            n_positions=block_size,
            n_ctx=block_size,
            n_embd=hidden_size,
            n_layer=n_layer,
            n_head=n_head,
            **kwargs,
        )

# Register the custom configuration
CONFIG_MAPPING.register("custom_gpt", CustomGPTConfig)