import torch
import torch.nn as nn
from torchvision import datasets, transforms
import time
import matplotlib.pyplot as plt
import numpy as np

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = torch.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

# 标准ResNet50（单GPU）
class StandardResNet50(nn.Module):
    def __init__(self, n_classes=10):
        super(StandardResNet50, self).__init__()
        self.in_planes = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)  # 调整为32x32
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(Bottleneck, 64, 3, stride=1)
        self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
        self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2)
        self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * Bottleneck.expansion, n_classes)
        
        self.to('cuda:0')

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

# GPipe ResNet50（双GPU）
class GPipeResNet50(nn.Module):
    def __init__(self, n_classes=10, split_size=8):
        super(GPipeResNet50, self).__init__()
        self.n_classes = n_classes
        self.split_size = split_size
        self.in_planes = 64

        # Stage 0: 初始层 + layer1 + layer2 (在cuda:0)
        self.stage_0 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),  # 调整为32x32
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            self._make_layer(Bottleneck, 64, 3, stride=1),   # layer1
            self._make_layer(Bottleneck, 128, 4, stride=2)   # layer2
        ).to('cuda:0')
        
        # Stage 1: layer3 + layer4 + 分类器 (在cuda:1)
        self.stage_1 = nn.Sequential(
            self._make_layer(Bottleneck, 256, 6, stride=2),  # layer3
            self._make_layer(Bottleneck, 512, 3, stride=2),  # layer4
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512 * Bottleneck.expansion, n_classes)
        ).to('cuda:1')

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        splits = iter(x.split(self.split_size, dim=0))
        s_prev = self.stage_0(next(splits)).to('cuda:1')
        results = []

        for s_next in splits:
            s_prev = self.stage_1(s_prev)
            results.append(s_prev)
            s_prev = self.stage_0(s_next).to('cuda:1')

        s_prev = self.stage_1(s_prev)
        results.append(s_prev)
        return torch.cat(results, dim=0)

def train_model(model, model_name, epochs=3, batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    
    print(f"Loading CIFAR-10 dataset for {model_name}...")
    train_dataset = datasets.CIFAR10('data', train=True, download=True, transform=transform)
    print(f"Dataset loaded. Total samples: {len(train_dataset)}")
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    if isinstance(model, GPipeResNet50):
        criterion = nn.CrossEntropyLoss().to('cuda:1')
    else:
        criterion = nn.CrossEntropyLoss().to('cuda:0')
    
    model.train()
    epoch_times = []
    epoch_losses = []
    epoch_accuracies = []
    
    total_start_time = time.time()
    
    for epoch in range(epochs):
        epoch_start_time = time.time()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (inputs, labels) in enumerate(train_loader):
            if isinstance(model, GPipeResNet50):
                inputs = inputs.to('cuda:0')
                labels = labels.to('cuda:1')
            else:
                inputs = inputs.to('cuda:0')
                labels = labels.to('cuda:0')
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            if i % 100 == 99:
                print(f'{model_name} - Epoch: {epoch+1}, Batch: {i+1}, Loss: {running_loss/100:.3f}, '
                      f'Acc: {100.*correct/total:.3f}%')
                running_loss = 0.0
        
        epoch_time = time.time() - epoch_start_time
        epoch_times.append(epoch_time)
        epoch_losses.append(running_loss)
        epoch_accuracies.append(100.*correct/total)
        
        print(f'{model_name} - Epoch {epoch+1} completed in {epoch_time:.2f}s, Accuracy: {100.*correct/total:.3f}%')
    
    total_time = time.time() - total_start_time
    print(f'{model_name} - Total training time: {total_time:.2f}s')
    
    return {
        'total_time': total_time,
        'epoch_times': epoch_times,
        'epoch_losses': epoch_losses,
        'epoch_accuracies': epoch_accuracies
    }

def plot_comparison(standard_results, gpipe_results):
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    epochs = list(range(1, len(standard_results['epoch_times']) + 1))
    
    # 每轮训练时间对比
    ax1.bar(['Standard ResNet50', 'GPipe ResNet50'], 
            [standard_results['total_time'], gpipe_results['total_time']], 
            color=['blue', 'orange'])
    ax1.set_ylabel('Total Training Time (s)')
    ax1.set_title('Total Training Time Comparison')
    
    # 每个epoch的时间
    ax2.plot(epochs, standard_results['epoch_times'], 'b-o', label='Standard ResNet50')
    ax2.plot(epochs, gpipe_results['epoch_times'], 'r-o', label='GPipe ResNet50')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Time per Epoch (s)')
    ax2.set_title('Time per Epoch')
    ax2.legend()
    
    # 准确率对比
    ax3.plot(epochs, standard_results['epoch_accuracies'], 'b-o', label='Standard ResNet50')
    ax3.plot(epochs, gpipe_results['epoch_accuracies'], 'r-o', label='GPipe ResNet50')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Accuracy (%)')
    ax3.set_title('Training Accuracy')
    ax3.legend()
    
    # 效率对比（准确率/时间）
    standard_efficiency = [acc/time for acc, time in zip(standard_results['epoch_accuracies'], standard_results['epoch_times'])]
    gpipe_efficiency = [acc/time for acc, time in zip(gpipe_results['epoch_accuracies'], gpipe_results['epoch_times'])]
    
    ax4.plot(epochs, standard_efficiency, 'b-o', label='Standard ResNet50')
    ax4.plot(epochs, gpipe_efficiency, 'r-o', label='GPipe ResNet50')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Accuracy/Time (% per second)')
    ax4.set_title('Training Efficiency')
    ax4.legend()
    
    plt.tight_layout()
    plt.savefig('resnet50_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # 打印详细对比结果
    print("\n" + "="*50)
    print("DETAILED COMPARISON RESULTS")
    print("="*50)
    print(f"Standard ResNet50 total time: {standard_results['total_time']:.2f}s")
    print(f"GPipe ResNet50 total time: {gpipe_results['total_time']:.2f}s")
    print(f"Time difference: {gpipe_results['total_time'] - standard_results['total_time']:.2f}s")
    print(f"GPipe speedup: {standard_results['total_time'] / gpipe_results['total_time']:.2f}x")
    print(f"Standard ResNet50 final accuracy: {standard_results['epoch_accuracies'][-1]:.2f}%")
    print(f"GPipe ResNet50 final accuracy: {gpipe_results['epoch_accuracies'][-1]:.2f}%")

# 运行对比实验
if __name__ == "__main__":
    print("Starting ResNet50 comparison experiment...")
    
    # 训练标准ResNet50
    print("\n" + "="*30)
    print("Training Standard ResNet50")
    print("="*30)
    standard_model = StandardResNet50(n_classes=10)
    standard_results = train_model(standard_model, "Standard ResNet50", epochs=3, batch_size=64)
    
    # 清理GPU内存
    del standard_model
    torch.cuda.empty_cache()
    
    # 训练GPipe ResNet50
    print("\n" + "="*30)
    print("Training GPipe ResNet50")
    print("="*30)
    gpipe_model = GPipeResNet50(n_classes=10, split_size=128)
    gpipe_results = train_model(gpipe_model, "GPipe ResNet50", epochs=3, batch_size=512)
    
    # 绘制对比图
    plot_comparison(standard_results, gpipe_results)