import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 定义卷积神经网络模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)  # 第一个卷积层 (输入通道3, 输出通道32, 3x3卷积核)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # 第二个卷积层 (输入通道32, 输出通道64, 3x3卷积核)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)  # 第三个卷积层 (输入通道64, 输出通道128, 3x3卷积核)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 池化层
        self.fc1 = nn.Linear(128 * 4 * 4, 512)  # 全连接层1
        self.fc2 = nn.Linear(512, 10)  # 全连接层2 (输出层)
        self.dropout = nn.Dropout(0.5)  # Dropout层
        
    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = self.conv3(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

def load_data(batch_size=128):
    """
    加载CIFAR-10数据集并创建数据加载器
    
    参数:
        batch_size: 批处理大小
        
    返回:
        train_loader: 训练数据加载器
        test_loader: 测试数据加载器
    """
    # 数据预处理和增强
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])

    # 加载CIFAR-10数据集
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', 
        train=True, 
        transform=transform_train,
        download=True
    )

    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', 
        train=False, 
        transform=transform_test,
        download=True
    )

    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=0  # 避免Windows上的多进程问题
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=0  # 避免Windows上的多进程问题
    )
    
    return train_loader, test_loader

def train(model, train_loader, criterion, optimizer, device, epoch, total_epochs):
    """
    训练一个epoch的模型
    
    参数:
        model: 要训练的模型
        train_loader: 训练数据加载器
        criterion: 损失函数
        optimizer: 优化器
        device: 使用的计算设备
        epoch: 当前是第几个epoch
        total_epochs: 总共要训练多少个epoch
        
    返回:
        avg_loss: 该epoch的平均损失
    """
    model.train()
    running_loss = 0.0
    total_step = len(train_loader)
    
    for i, (images, labels) in enumerate(train_loader):
        # 将数据移动到设备
        images = images.to(device)
        labels = labels.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{total_epochs}], Step [{i+1}/{total_step}], Loss: {running_loss/100:.4f}')
            running_loss = 0.0
    
    return running_loss / total_step

def test(model, test_loader, device):
    """
    在测试集上评估模型
    
    参数:
        model: 要评估的模型
        test_loader: 测试数据加载器
        device: 使用的计算设备
        
    返回:
        accuracy: 测试集上的精度
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

def main():
    # 配置训练参数
    batch_size = 128
    learning_rate = 0.001
    num_epochs = 10
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 加载数据
    train_loader, test_loader = load_data(batch_size)
    
    # 实例化模型、损失函数和优化器
    model = CNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # 开始训练循环
    for epoch in range(num_epochs):
        # 训练一个epoch
        train(model, train_loader, criterion, optimizer, device, epoch, num_epochs)
        
        # 测试模型
        accuracy = test(model, test_loader, device)
        print(f'Epoch [{epoch+1}/{num_epochs}] 测试集精度: {accuracy:.2f}%')
    
    print('训练完成!')
    
    # 保存模型
    torch.save(model.state_dict(), 'cnn_cifar10.pth')
    print('模型已保存')

if __name__ == '__main__':
    main()