import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import random
from torch.multiprocessing import Process
from torchvision import datasets, transforms


class ImageClassifier(nn.Module):
    """深度卷积神经网络用于图像分类任务"""
    
    def __init__(self, input_channels=1, num_classes=10):
        super(ImageClassifier, self).__init__()
        # 特征提取层
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(input_channels, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((4, 4))
        )
        
        # 分类头
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(64 * 4 * 4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)
        logits = self.classifier(features)
        return F.log_softmax(logits, dim=1)


class DatasetSplitter:
    """将数据集按指定比例分割为多个子集"""
    
    def __init__(self, dataset, split_ratios=None, random_seed=2024):
        self.dataset = dataset
        self.subsets = []
        
        if split_ratios is None:
            split_ratios = [0.8, 0.2]  # 默认8:2分割
            
        random_generator = random.Random(random_seed)
        total_samples = len(dataset)
        sample_indices = list(range(total_samples))
        random_generator.shuffle(sample_indices)
        
        start_idx = 0
        for ratio in split_ratios:
            subset_size = int(ratio * total_samples)
            end_idx = start_idx + subset_size
            self.subsets.append(sample_indices[start_idx:end_idx])
            start_idx = end_idx
    
    def get_subset(self, subset_id):
        return DatasetSubset(self.dataset, self.subsets[subset_id])


class DatasetSubset:
    """数据集的子集包装器"""
    
    def __init__(self, parent_dataset, indices):
        self.parent_dataset = parent_dataset
        self.indices = indices
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        return self.parent_dataset[actual_idx]


def setup_distributed_data(world_size, rank):
    """配置分布式数据加载器"""
    # 数据预处理管道
    data_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.1307], std=[0.3081])
    ])
    
    # 加载MNIST数据集
    mnist_dataset = datasets.MNIST(
        root='./mnist_data',
        train=True,
        download=True,
        transform=data_transforms
    )
    
    # 计算每个进程的批次大小
    global_batch_size = 256
    local_batch_size = global_batch_size // world_size
    
    # 为每个进程创建数据分区
    partition_ratios = [1.0 / world_size] * world_size
    data_splitter = DatasetSplitter(mnist_dataset, partition_ratios)
    local_dataset = data_splitter.get_subset(rank)
    
    # 创建数据加载器
    data_loader = torch.utils.data.DataLoader(
        local_dataset,
        batch_size=local_batch_size,
        shuffle=True,
        num_workers=2
    )
    
    return data_loader, local_batch_size


def synchronize_gradients(model):
    """在所有进程间同步梯度"""
    world_size = float(dist.get_world_size())
    
    for param in model.parameters():
        if param.grad is not None:
            # 执行全局梯度求和
            dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
            # 计算平均梯度
            param.grad.data.div_(world_size)


def train_worker(process_rank, world_size):
    """分布式训练工作进程"""
    
    # 设置随机种子确保可重现性
    torch.manual_seed(42)
    
    # 准备数据
    train_loader, batch_size = setup_distributed_data(world_size, process_rank)
    
    # 初始化模型
    model = ImageClassifier(input_channels=1, num_classes=10)
    
    # 配置优化器
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    
    # 计算训练步数
    steps_per_epoch = math.ceil(len(train_loader.dataset) / batch_size)
    
    # 训练循环
    num_epochs = 5
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        
        for batch_data, batch_labels in train_loader:
            # 前向传播
            optimizer.zero_grad()
            predictions = model(batch_data)
            
            # 计算损失
            loss = F.nll_loss(predictions, batch_labels)
            total_loss += loss.item()
            
            # 反向传播
            loss.backward()
            
            # 同步梯度
            synchronize_gradients(model)
            
            # 更新参数
            optimizer.step()
        
        # 计算平均损失
        avg_loss = total_loss / steps_per_epoch
        
        print(f'Worker {process_rank} | Epoch {epoch + 1}/{num_epochs} | Loss: {avg_loss:.4f}')


def initialize_distributed_training(rank, size, worker_fn, backend='gloo'):
    """初始化分布式训练环境"""
    
    # 配置主节点地址和端口
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '12355'
    
    # 初始化进程组
    dist.init_process_group(
        backend=backend,
        rank=rank,
        world_size=size
    )
    
    # 执行训练函数
    worker_fn(rank, size)


def main():
    """主函数：启动分布式训练"""
    
    world_size = 4  # 总进程数
    processes = []
    
    print(f"启动 {world_size} 个训练进程...")
    
    # 为每个进程创建并启动训练任务
    for rank in range(world_size):
        process = Process(
            target=initialize_distributed_training,
            args=(rank, world_size, train_worker)
        )
        process.start()
        processes.append(process)
    
    # 等待所有进程完成
    for process in processes:
        process.join()
    
    print("分布式训练完成!")


if __name__ == "__main__":
    main()