Spaces:
Configuration error
Configuration error
from transformers import GPT2Config | |
from transformers.models.auto.configuration_auto import CONFIG_MAPPING | |
class CustomGPTConfig(GPT2Config): | |
model_type = "custom_gpt" | |
def __init__(self, vocab_size=50304, n_layer=24, n_head=16, hidden_size=1024, block_size=1024, **kwargs): | |
super().__init__( | |
vocab_size=vocab_size, | |
n_positions=block_size, | |
n_ctx=block_size, | |
n_embd=hidden_size, | |
n_layer=n_layer, | |
n_head=n_head, | |
**kwargs, | |
) | |
# Register the custom configuration | |
CONFIG_MAPPING.register("custom_gpt", CustomGPTConfig) | |