Fluospark128 commited on
Commit
b5403c0
·
verified ·
1 Parent(s): 01bfd70

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +223 -0
README.md ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision import transforms
5
+ import timm
6
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
7
+ from pathlib import Path
8
+ import pandas as pd
9
+ import numpy as np
10
+ from PIL import Image
11
+ from sklearn.model_selection import train_test_split
12
+ from tqdm.auto import tqdm
13
+ import wandb
14
+
15
+ class PlantDiseaseDataset(Dataset):
16
+     def __init__(self, image_paths, labels, transform=None):
17
+         self.image_paths = image_paths
18
+         self.labels = labels
19
+         self.transform = transform
20
+        
21
+     def __len__(self):
22
+         return len(self.image_paths)
23
+    
24
+     def __getitem__(self, idx):
25
+         image_path = self.image_paths[idx]
26
+         image = Image.open(image_path).convert('RGB')
27
+         label = self.labels[idx]
28
+        
29
+         if self.transform:
30
+             image = self.transform(image)
31
+            
32
+         return image, label
33
+
34
+ class PlantDiseaseClassifier:
35
+     def __init__(self, data_dir, model_name='vit_base_patch16_224', num_classes=38):
36
+         self.data_dir = Path(data_dir)
37
+         self.model_name = model_name
38
+         self.num_classes = num_classes
39
+         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
+        
41
+         # Initialize wandb
42
+         wandb.init(project="plant-disease-classification")
43
+        
44
+     def prepare_data(self):
45
+         """Prepare dataset and create data loaders"""
46
+         # Data augmentation and normalization for training
47
+         train_transform = transforms.Compose([
48
+             transforms.RandomResizedCrop(224),
49
+             transforms.RandomHorizontalFlip(),
50
+             transforms.RandomVerticalFlip(),
51
+             transforms.RandomRotation(20),
52
+             transforms.ColorJitter(brightness=0.2, contrast=0.2),
53
+             transforms.ToTensor(),
54
+             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
55
+         ])
56
+        
57
+         # Just normalization for validation/testing
58
+         val_transform = transforms.Compose([
59
+             transforms.Resize(256),
60
+             transforms.CenterCrop(224),
61
+             transforms.ToTensor(),
62
+             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
63
+         ])
64
+        
65
+         # Collect all image paths and labels
66
+         image_paths = []
67
+         labels = []
68
+         self.class_to_idx = {}
69
+        
70
+         for idx, class_dir in enumerate(sorted(self.data_dir.glob('*'))):
71
+             if class_dir.is_dir():
72
+                 self.class_to_idx[class_dir.name] = idx
73
+                 for img_path in class_dir.glob('*.jpg'):
74
+                     image_paths.append(str(img_path))
75
+                     labels.append(idx)
76
+        
77
+         # Split data
78
+         train_paths, val_paths, train_labels, val_labels = train_test_split(
79
+             image_paths, labels, test_size=0.2, stratify=labels, random_state=42
80
+         )
81
+        
82
+         # Create datasets
83
+         train_dataset = PlantDiseaseDataset(train_paths, train_labels, train_transform)
84
+         val_dataset = PlantDiseaseDataset(val_paths, val_labels, val_transform)
85
+        
86
+         # Create data loaders
87
+         self.train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
88
+         self.val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
89
+        
90
+         return self.train_loader, self.val_loader
91
+    
92
+     def create_model(self):
93
+         """Initialize the Vision Transformer model"""
94
+         self.model = timm.create_model(
95
+             self.model_name,
96
+             pretrained=True,
97
+             num_classes=self.num_classes
98
+         )
99
+         self.model = self.model.to(self.device)
100
+        
101
+         # Loss function and optimizer
102
+         self.criterion = nn.CrossEntropyLoss()
103
+         self.optimizer = torch.optim.AdamW(
104
+             self.model.parameters(),
105
+             lr=2e-5,
106
+             weight_decay=0.01
107
+         )
108
+         self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
109
+             self.optimizer,
110
+             T_max=10
111
+         )
112
+        
113
+         return self.model
114
+    
115
+     def train_epoch(self, epoch):
116
+         """Train for one epoch"""
117
+         self.model.train()
118
+         total_loss = 0
119
+         correct = 0
120
+         total = 0
121
+        
122
+         progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')
123
+        
124
+         for batch_idx, (inputs, targets) in enumerate(progress_bar):
125
+             inputs, targets = inputs.to(self.device), targets.to(self.device)
126
+            
127
+             self.optimizer.zero_grad()
128
+             outputs = self.model(inputs)
129
+             loss = self.criterion(outputs, targets)
130
+            
131
+             loss.backward()
132
+             self.optimizer.step()
133
+            
134
+             total_loss += loss.item()
135
+             _, predicted = outputs.max(1)
136
+             total += targets.size(0)
137
+             correct += predicted.eq(targets).sum().item()
138
+            
139
+             progress_bar.set_postfix({
140
+                 'Loss': total_loss/(batch_idx+1),
141
+                 'Acc': 100.*correct/total
142
+             })
143
+            
144
+             # Log to wandb
145
+             wandb.log({
146
+                 'train_loss': loss.item(),
147
+                 'train_acc': 100.*correct/total
148
+             })
149
+            
150
+         return total_loss/len(self.train_loader), 100.*correct/total
151
+    
152
+     def validate(self):
153
+         """Validate the model"""
154
+         self.model.eval()
155
+         total_loss = 0
156
+         correct = 0
157
+         total = 0
158
+        
159
+         with torch.no_grad():
160
+             for inputs, targets in tqdm(self.val_loader, desc='Validating'):
161
+                 inputs, targets = inputs.to(self.device), targets.to(self.device)
162
+                 outputs = self.model(inputs)
163
+                 loss = self.criterion(outputs, targets)
164
+                
165
+                 total_loss += loss.item()
166
+                 _, predicted = outputs.max(1)
167
+                 total += targets.size(0)
168
+                 correct += predicted.eq(targets).sum().item()
169
+                
170
+         accuracy = 100.*correct/total
171
+         avg_loss = total_loss/len(self.val_loader)
172
+        
173
+         # Log to wandb
174
+         wandb.log({
175
+             'val_loss': avg_loss,
176
+             'val_acc': accuracy
177
+         })
178
+        
179
+         return avg_loss, accuracy
180
+    
181
+     def train(self, epochs=10):
182
+         """Complete training process"""
183
+         best_acc = 0
184
+        
185
+         for epoch in range(epochs):
186
+             train_loss, train_acc = self.train_epoch(epoch)
187
+             val_loss, val_acc = self.validate()
188
+             self.scheduler.step()
189
+            
190
+             print(f'\nEpoch {epoch}:')
191
+             print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
192
+             print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
193
+            
194
+             # Save best model
195
+             if val_acc > best_acc:
196
+                 best_acc = val_acc
197
+                 torch.save({
198
+                     'model_state_dict': self.model.state_dict(),
199
+                     'optimizer_state_dict': self.optimizer.state_dict(),
200
+                     'class_to_idx': self.class_to_idx
201
+                 }, 'best_model.pth')
202
+        
203
+         wandb.finish()
204
+    
205
+     def save_for_huggingface(self):
206
+         """Save model in HF中国镜像站 format"""
207
+         # Load best model
208
+         checkpoint = torch.load('best_model.pth')
209
+         self.model.load_state_dict(checkpoint['model_state_dict'])
210
+        
211
+         # Save model and config
212
+         self.model.save_pretrained('plant_disease_model')
213
+        
214
+         # Save class mapping
215
+         idx_to_class = {v: k for k, v in self.class_to_idx.items()}
216
+         pd.Series(idx_to_class).to_json('class_mapping.json')
217
+
218
+ if __name__ == "__main__":
219
+     classifier = PlantDiseaseClassifier(data_dir="path/to/dataset")
220
+     classifier.prepare_data()
221
+     classifier.create_model()
222
+     classifier.train(epochs=10)
223
+     classifier.save_for_huggingface()