|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
import timm |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
from pathlib import Path |
|
import pandas as pd |
|
import numpy as np |
|
from PIL import Image |
|
from sklearn.model_selection import train_test_split |
|
from tqdm.auto import tqdm |
|
import wandb |
|
|
|
class PlantDiseaseDataset(Dataset): |
|
def __init__(self, image_paths, labels, transform=None): |
|
self.image_paths = image_paths |
|
self.labels = labels |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
image_path = self.image_paths[idx] |
|
image = Image.open(image_path).convert('RGB') |
|
label = self.labels[idx] |
|
|
|
if self.transform: |
|
image = self.transform(image) |
|
|
|
return image, label |
|
|
|
class PlantDiseaseClassifier: |
|
def __init__(self, data_dir, model_name='vit_base_patch16_224', num_classes=38): |
|
self.data_dir = Path(data_dir) |
|
self.model_name = model_name |
|
self.num_classes = num_classes |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
# Initialize wandb |
|
wandb.init(project="plant-disease-classification") |
|
|
|
def prepare_data(self): |
|
"""Prepare dataset and create data loaders""" |
|
# Data augmentation and normalization for training |
|
train_transform = transforms.Compose([ |
|
transforms.RandomResizedCrop(224), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.RandomVerticalFlip(), |
|
transforms.RandomRotation(20), |
|
transforms.ColorJitter(brightness=0.2, contrast=0.2), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
# Just normalization for validation/testing |
|
val_transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
# Collect all image paths and labels |
|
image_paths = [] |
|
labels = [] |
|
self.class_to_idx = {} |
|
|
|
for idx, class_dir in enumerate(sorted(self.data_dir.glob('*'))): |
|
if class_dir.is_dir(): |
|
self.class_to_idx[class_dir.name] = idx |
|
for img_path in class_dir.glob('*.jpg'): |
|
image_paths.append(str(img_path)) |
|
labels.append(idx) |
|
|
|
# Split data |
|
train_paths, val_paths, train_labels, val_labels = train_test_split( |
|
image_paths, labels, test_size=0.2, stratify=labels, random_state=42 |
|
) |
|
|
|
# Create datasets |
|
train_dataset = PlantDiseaseDataset(train_paths, train_labels, train_transform) |
|
val_dataset = PlantDiseaseDataset(val_paths, val_labels, val_transform) |
|
|
|
# Create data loaders |
|
self.train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4) |
|
self.val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4) |
|
|
|
return self.train_loader, self.val_loader |
|
|
|
def create_model(self): |
|
"""Initialize the Vision Transformer model""" |
|
self.model = timm.create_model( |
|
self.model_name, |
|
pretrained=True, |
|
num_classes=self.num_classes |
|
) |
|
self.model = self.model.to(self.device) |
|
|
|
# Loss function and optimizer |
|
self.criterion = nn.CrossEntropyLoss() |
|
self.optimizer = torch.optim.AdamW( |
|
self.model.parameters(), |
|
lr=2e-5, |
|
weight_decay=0.01 |
|
) |
|
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
self.optimizer, |
|
T_max=10 |
|
) |
|
|
|
return self.model |
|
|
|
def train_epoch(self, epoch): |
|
"""Train for one epoch""" |
|
self.model.train() |
|
total_loss = 0 |
|
correct = 0 |
|
total = 0 |
|
|
|
progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}') |
|
|
|
for batch_idx, (inputs, targets) in enumerate(progress_bar): |
|
inputs, targets = inputs.to(self.device), targets.to(self.device) |
|
|
|
self.optimizer.zero_grad() |
|
outputs = self.model(inputs) |
|
loss = self.criterion(outputs, targets) |
|
|
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
total_loss += loss.item() |
|
_, predicted = outputs.max(1) |
|
total += targets.size(0) |
|
correct += predicted.eq(targets).sum().item() |
|
|
|
progress_bar.set_postfix({ |
|
'Loss': total_loss/(batch_idx+1), |
|
'Acc': 100.*correct/total |
|
}) |
|
|
|
# Log to wandb |
|
wandb.log({ |
|
'train_loss': loss.item(), |
|
'train_acc': 100.*correct/total |
|
}) |
|
|
|
return total_loss/len(self.train_loader), 100.*correct/total |
|
|
|
def validate(self): |
|
"""Validate the model""" |
|
self.model.eval() |
|
total_loss = 0 |
|
correct = 0 |
|
total = 0 |
|
|
|
with torch.no_grad(): |
|
for inputs, targets in tqdm(self.val_loader, desc='Validating'): |
|
inputs, targets = inputs.to(self.device), targets.to(self.device) |
|
outputs = self.model(inputs) |
|
loss = self.criterion(outputs, targets) |
|
|
|
total_loss += loss.item() |
|
_, predicted = outputs.max(1) |
|
total += targets.size(0) |
|
correct += predicted.eq(targets).sum().item() |
|
|
|
accuracy = 100.*correct/total |
|
avg_loss = total_loss/len(self.val_loader) |
|
|
|
# Log to wandb |
|
wandb.log({ |
|
'val_loss': avg_loss, |
|
'val_acc': accuracy |
|
}) |
|
|
|
return avg_loss, accuracy |
|
|
|
def train(self, epochs=10): |
|
"""Complete training process""" |
|
best_acc = 0 |
|
|
|
for epoch in range(epochs): |
|
train_loss, train_acc = self.train_epoch(epoch) |
|
val_loss, val_acc = self.validate() |
|
self.scheduler.step() |
|
|
|
print(f'\nEpoch {epoch}:') |
|
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%') |
|
print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%') |
|
|
|
# Save best model |
|
if val_acc > best_acc: |
|
best_acc = val_acc |
|
torch.save({ |
|
'model_state_dict': self.model.state_dict(), |
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
'class_to_idx': self.class_to_idx |
|
}, 'best_model.pth') |
|
|
|
wandb.finish() |
|
|
|
def save_for_huggingface(self): |
|
"""Save model in HF中国镜像站 format""" |
|
# Load best model |
|
checkpoint = torch.load('best_model.pth') |
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
# Save model and config |
|
self.model.save_pretrained('plant_disease_model') |
|
|
|
# Save class mapping |
|
idx_to_class = {v: k for k, v in self.class_to_idx.items()} |
|
pd.Series(idx_to_class).to_json('class_mapping.json') |
|
|
|
if __name__ == "__main__": |
|
classifier = PlantDiseaseClassifier(data_dir="path/to/dataset") |
|
classifier.prepare_data() |
|
classifier.create_model() |
|
classifier.train(epochs=10) |
|
classifier.save_for_huggingface() |