|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.utils import weight_norm |
|
|
|
|
|
def WNConv1d(*args, **kwargs): |
|
return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
|
|
def WNConvTranspose1d(*args, **kwargs): |
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
|
|
|
|
@torch.jit.script |
|
def snake(x, alpha): |
|
shape = x.shape |
|
x = x.reshape(shape[0], shape[1], -1) |
|
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) |
|
x = x.reshape(shape) |
|
return x |
|
|
|
|
|
class Snake1d(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.alpha = nn.Parameter(torch.ones(1, channels, 1)) |
|
|
|
def forward(self, x): |
|
return snake(x, self.alpha) |
|
|
|
|
|
class ResidualUnit(nn.Module): |
|
def __init__(self, dim: int = 16, dilation: int = 1): |
|
super().__init__() |
|
pad = ((7 - 1) * dilation) // 2 |
|
self.block = nn.Sequential( |
|
Snake1d(dim), |
|
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), |
|
Snake1d(dim), |
|
WNConv1d(dim, dim, kernel_size=1), |
|
) |
|
|
|
def forward(self, x): |
|
y = self.block(x) |
|
pad = (x.shape[-1] - y.shape[-1]) // 2 |
|
if pad > 0: |
|
x = x[..., pad:-pad] |
|
return x + y |
|
|
|
|
|
def init_weights(m): |
|
if isinstance(m, nn.Conv1d): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
nn.init.constant_(m.bias, 0) |
|
|