my-stylegan-model / train.py
edemana's picture
Update train.py
73e4fea verified
import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import pickle
# Load pre-trained model
with open('ffhq.pkl', 'rb') as f:
data = pickle.load(f)
G = data['G_ema']
D = data['D']
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = G.to(device)
D = D.to(device)
# Custom dataset class
class CustomDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.png')]
self.transform = transform
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.image_files[idx])
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image
# Data loading
transform = transforms.Compose([
transforms.Resize((G.img_resolution, G.img_resolution)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = CustomDataset("/path/to/your/image_dir", transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# Fine-tuning setup
optimizer_g = Adam(G.parameters(), lr=0.0001, betas=(0, 0.99))
optimizer_d = Adam(D.parameters(), lr=0.0001, betas=(0, 0.99))
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.99)
num_epochs = 100
for epoch in range(num_epochs):
for batch in dataloader:
real_images = batch.to(device)
# Generate fake images
z = torch.randn([batch.size(0), G.z_dim]).to(device)
fake_images = G(z, None)
# Compute losses
g_loss = -torch.mean(torch.log(D(fake_images, None)))
d_loss_real = -torch.mean(torch.log(D(real_images, None)))
d_loss_fake = -torch.mean(torch.log(1 - D(fake_images, None)))
d_loss = d_loss_real + d_loss_fake
# Update models
optimizer_g.zero_grad()
g_loss.backward()
optimizer_g.step()
optimizer_d.zero_grad()
d_loss.backward()
optimizer_d.step()
scheduler_g.step()
scheduler_d.step()
print(f"Epoch {epoch+1}/{num_epochs}, G Loss: {g_loss.item()}, D Loss: {d_loss.item()}")
# Save the fine-tuned model
torch.save(G.state_dict(), 'fine_tuned_stylegan.pth')