Spaces:
Runtime error
Runtime error
import torch | |
from torch.optim import Adam | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
import os | |
from PIL import Image | |
import pickle | |
# Load pre-trained model | |
with open('ffhq.pkl', 'rb') as f: | |
data = pickle.load(f) | |
G = data['G_ema'] | |
D = data['D'] | |
# Check if CUDA is available | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
G = G.to(device) | |
D = D.to(device) | |
# Custom dataset class | |
class CustomDataset(Dataset): | |
def __init__(self, image_dir, transform=None): | |
self.image_dir = image_dir | |
self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.png')] | |
self.transform = transform | |
def __len__(self): | |
return len(self.image_files) | |
def __getitem__(self, idx): | |
img_path = os.path.join(self.image_dir, self.image_files[idx]) | |
image = Image.open(img_path).convert('RGB') | |
if self.transform: | |
image = self.transform(image) | |
return image | |
# Data loading | |
transform = transforms.Compose([ | |
transforms.Resize((G.img_resolution, G.img_resolution)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
dataset = CustomDataset("/path/to/your/image_dir", transform=transform) | |
dataloader = DataLoader(dataset, batch_size=4, shuffle=True) | |
# Fine-tuning setup | |
optimizer_g = Adam(G.parameters(), lr=0.0001, betas=(0, 0.99)) | |
optimizer_d = Adam(D.parameters(), lr=0.0001, betas=(0, 0.99)) | |
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99) | |
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.99) | |
num_epochs = 100 | |
for epoch in range(num_epochs): | |
for batch in dataloader: | |
real_images = batch.to(device) | |
# Generate fake images | |
z = torch.randn([batch.size(0), G.z_dim]).to(device) | |
fake_images = G(z, None) | |
# Compute losses | |
g_loss = -torch.mean(torch.log(D(fake_images, None))) | |
d_loss_real = -torch.mean(torch.log(D(real_images, None))) | |
d_loss_fake = -torch.mean(torch.log(1 - D(fake_images, None))) | |
d_loss = d_loss_real + d_loss_fake | |
# Update models | |
optimizer_g.zero_grad() | |
g_loss.backward() | |
optimizer_g.step() | |
optimizer_d.zero_grad() | |
d_loss.backward() | |
optimizer_d.step() | |
scheduler_g.step() | |
scheduler_d.step() | |
print(f"Epoch {epoch+1}/{num_epochs}, G Loss: {g_loss.item()}, D Loss: {d_loss.item()}") | |
# Save the fine-tuned model | |
torch.save(G.state_dict(), 'fine_tuned_stylegan.pth') | |