import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 定义多层感知机模型
class MLP(nn.Module):
    def __init__(self, input_size=784, hidden_sizes=[512, 256], output_size=10):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])        # 隐藏层1
        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])   # 隐藏层2
        self.output_layer = nn.Linear(hidden_sizes[1], output_size)  # 输出层
        
    def forward(self, x):
        x = x.view(x.size(0), -1)  # 将图像展平为向量
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        x = torch.relu(x)
        x = self.output_layer(x)
        return x

def load_data(batch_size=100):
    """
    加载MNIST数据集并创建数据加载器
    
    参数:
        batch_size: 批处理大小
        
    返回:
        train_loader: 训练数据加载器
        test_loader: 测试数据加载器
    """
    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
    ])

    # 加载MNIST数据集
    train_dataset = torchvision.datasets.MNIST(
        root='./data', 
        train=True, 
        transform=transform,
        download=True
    )

    test_dataset = torchvision.datasets.MNIST(
        root='./data', 
        train=False, 
        transform=transform,
        download=True
    )

    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=0  # 避免Windows上的多进程问题
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=0  # 避免Windows上的多进程问题
    )
    
    return train_loader, test_loader

def train(model, train_loader, criterion, optimizer, device, epoch, total_epochs):
    """
    训练一个epoch的模型
    
    参数:
        model: 要训练的模型
        train_loader: 训练数据加载器
        criterion: 损失函数
        optimizer: 优化器
        device: 使用的计算设备
        epoch: 当前是第几个epoch
        total_epochs: 总共要训练多少个epoch
    """
    model.train()
    total_step = len(train_loader)
    for i, (images, labels) in enumerate(train_loader):
        # 将数据移动到设备
        images = images.to(device)
        labels = labels.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{total_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')

def test(model, test_loader, device):
    """
    在测试集上评估模型
    
    参数:
        model: 要评估的模型
        test_loader: 测试数据加载器
        device: 使用的计算设备
        
    返回:
        accuracy: 测试集上的精度
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'测试集精度: {accuracy:.2f}%')
    return accuracy

def main():
    """主函数，协调整个训练和测试过程"""
    # 配置训练参数
    batch_size = 100
    learning_rate = 0.001
    num_epochs = 5
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 加载数据
    train_loader, test_loader = load_data(batch_size)
    
    # 实例化模型、损失函数和优化器
    model = MLP().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # 训练循环
    for epoch in range(num_epochs):
        train(model, train_loader, criterion, optimizer, device, epoch, num_epochs)
    
    # 测试模型
    test(model, test_loader, device)
    
    # 保存模型
    torch.save(model.state_dict(), 'mlp_mnist.pth')
    print('模型已保存')

if __name__ == '__main__':
    main()