import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import urllib.request
from pathlib import Path

# 定义U-Net模块
class DoubleConv(nn.Module):
    """(卷积 => BN => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """下采样：最大池化 + 双卷积"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """上采样：转置卷积 + 拼接 + 双卷积"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # 可以选择双线性插值或转置卷积进行上采样
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # 处理输入size不匹配的问题
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # 拼接两个特征图
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    """输出卷积层"""
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

# 完整的U-Net模型
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=3, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        # 编码器部分
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        
        # 解码器部分
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        # 编码路径
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # 解码路径 + 跳跃连接
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

# 自定义数据集类
class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths,
                 img_transform=None, mask_transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.it = img_transform
        self.mt = mask_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        m   = Image.open(self.mask_paths[idx]).convert("L")

        if self.it: img = self.it(img)
        if self.mt:
            m = self.mt(m).squeeze(0).long()  # [H,W], 值为 {1,2,3}
            m = m - 1                         # 变成 {0,1,2}

        return img, m

# 计算IoU指标（交并比）
def iou_score(outputs, labels):
    smooth = 1e-6
    outputs = outputs.view(-1).float()
    labels = labels.view(-1).float()
    
    intersection = (outputs * labels).sum()
    union = outputs.sum() + labels.sum() - intersection
    
    iou = (intersection + smooth) / (union + smooth)
    return iou

# 计算Dice系数
def dice_coef(outputs, labels):
    smooth = 1e-6
    outputs = outputs.view(-1).float()
    labels = labels.view(-1).float()
    
    intersection = (outputs * labels).sum()
    dice = (2. * intersection + smooth) / (outputs.sum() + labels.sum() + smooth)
    return dice

# 训练函数
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    best_val_iou = 0.0
    train_loss_history = []
    val_iou_history = []
    val_dice_history = []
    
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        train_loss = 0.0
        
        with tqdm(train_loader, unit="batch") as tepoch:
            tepoch.set_description(f"Epoch {epoch+1}/{num_epochs}")
            
            for images, masks in tepoch:
                images = images.to(device)
                masks = masks.to(device).squeeze(1).long()
                
                # 前向传播
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                # 反向传播和优化
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                tepoch.set_postfix(loss=loss.item())
        
        avg_train_loss = train_loss / len(train_loader)
        train_loss_history.append(avg_train_loss)
        
        # 验证阶段
        model.eval()
        val_iou = 0.0
        val_dice = 0.0
        
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                pred_mask = torch.max(outputs, dim=1, keepdim=True)[1]
                
                # 计算评估指标
                iou = iou_score(pred_mask, masks)
                dice = dice_coef(pred_mask, masks)
                
                val_iou += iou.item()
                val_dice += dice.item()
        
        avg_val_iou = val_iou / len(val_loader)
        avg_val_dice = val_dice / len(val_loader)
        
        val_iou_history.append(avg_val_iou)
        val_dice_history.append(avg_val_dice)
        
        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {avg_train_loss:.4f}, '
              f'Val IoU: {avg_val_iou:.4f}, '
              f'Val Dice: {avg_val_dice:.4f}')
        
        # 保存最佳模型
        if avg_val_iou > best_val_iou:
            best_val_iou = avg_val_iou
            torch.save(model.state_dict(), 'best_unet_model.pth')
            print(f'Model saved with IoU: {best_val_iou:.4f}')
    
    # 返回训练历史
    history = {
        'train_loss': train_loss_history,
        'val_iou': val_iou_history,
        'val_dice': val_dice_history
    }
    
    return model, history

# 可视化训练结果
def plot_training_history(history):
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'])
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(history['val_iou'], label='IoU')
    plt.plot(history['val_dice'], label='Dice')
    plt.title('Validation Metrics')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()

# 可视化预测结果
def visualize_predictions(model, test_loader, device, num_images=5):
    model.eval()
    with torch.no_grad():
        for i, (images, masks) in enumerate(test_loader):
            if i >= num_images:
                break
                
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            pred_mask = torch.max(outputs, dim=1, keepdim=True)[1]
            
            # 将图像、真实掩码和预测掩码转换为numpy数组以便可视化
            img = images[0].cpu().permute(1, 2, 0).numpy()
            # 标准化图像以便显示
            img = (img - img.min()) / (img.max() - img.min())
            
            true_mask = masks[0].cpu().numpy()    
            pred_mask = pred_mask[0,0].cpu().numpy()
            
            plt.figure(figsize=(12, 4))
            
            plt.subplot(1, 3, 1)
            plt.title('Image')
            plt.imshow(img)
            plt.axis('off')
            
            plt.subplot(1, 3, 2)
            plt.title('True Mask')
            plt.imshow(true_mask, cmap='gray')
            plt.axis('off')
            
            plt.subplot(1, 3, 3)
            plt.title('Predicted Mask')
            plt.imshow(pred_mask, cmap='gray')
            plt.axis('off')
            
            plt.tight_layout()
            plt.savefig(f'prediction_{i}.png')
            plt.show()

# 下载并准备Oxford-IIIT Pet数据集
def download_pet_dataset():
    base_dir = Path("./data/oxford-pet")
    base_dir.mkdir(parents=True, exist_ok=True)
    
    # 下载图像和分割掩码
    image_url = "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz"
    mask_url = "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz"
    
    image_path = base_dir / "images.tar.gz"
    mask_path = base_dir / "annotations.tar.gz"
    
    # 下载文件
    if not image_path.exists():
        print("正在下载图像数据...")
        urllib.request.urlretrieve(image_url, image_path)
    
    if not mask_path.exists():
        print("正在下载掩码数据...")
        urllib.request.urlretrieve(mask_url, mask_path)
    
    # 解压文件
    images_dir = base_dir / "images"
    masks_dir = base_dir / "annotations"
    
    if not images_dir.exists():
        print("正在解压图像文件...")
        import tarfile
        with tarfile.open(image_path) as tar:
            tar.extractall(base_dir)
    
    if not masks_dir.exists():
        print("正在解压掩码文件...")
        import tarfile
        with tarfile.open(mask_path) as tar:
            tar.extractall(base_dir)
    
    # 创建图像和掩码路径列表
    image_paths = list(images_dir.glob("*.jpg"))
    mask_paths = list((masks_dir / "trimaps").glob("*.png"))
    
    # 确保文件名匹配
    image_paths = sorted(image_paths)
    mask_paths = sorted(mask_paths)
    
    # 验证文件数量
    print(f"发现 {len(image_paths)} 张图像和 {len(mask_paths)} 个掩码")
    
    # 确保掩码和图像文件匹配
    # 只选择有对应掩码的图像
    valid_images = []
    valid_masks = []
    
    for img_path in image_paths:
        mask_name = img_path.stem + ".png"
        mask_path = masks_dir / "trimaps" / mask_name
        if mask_path.exists():
            valid_images.append(str(img_path))
            valid_masks.append(str(mask_path))
    
    print(f"有效的图像-掩码对: {len(valid_images)}")
    
    return valid_images, valid_masks  


def main():
    # 1. 设置参数 
    batch_size = 4
    num_epochs = 10
    learning_rate = 1e-3
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # 2. 准备数据 - 使用Oxford Pet数据集
    print("正在准备数据集...")
    image_paths, mask_paths = download_pet_dataset()
    
    # 划分训练集和验证集
    train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = train_test_split(
        image_paths, mask_paths, test_size=0.2, random_state=42)
    
    # 数据变换
    img_transform = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor(),
    ])
    mask_transform = transforms.Compose([
        transforms.Resize((256,256), interpolation=Image.NEAREST),
        transforms.PILToTensor(),
    ])
    
    # 创建数据集和数据加载器
    train_dataset = SegmentationDataset(train_img_paths, train_mask_paths,
                                        img_transform=img_transform,
                                        mask_transform=mask_transform)
    val_dataset = SegmentationDataset(val_img_paths, val_mask_paths,
                                      img_transform=img_transform,
                                      mask_transform=mask_transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) 
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    # 3. 创建模型
    print("正在初始化模型...")
    model = UNet(n_channels=3, n_classes=3).to(device)
    
    # 4. 定义损失函数和优化器
    # 手动为三类指定权重：增大前景（黑色）和边界（白色）的影响，减小背景（灰色）的影响
    class_weights = torch.tensor([2.0, 2.0, 0.5], device=device)  # [黑色, 白色, 灰色]
    print(f"Using class weights: {class_weights.tolist()}")

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # 5. 训练模型
    print("开始训练模型...")
    trained_model, history = train_model(
        model, train_loader, val_loader, criterion, optimizer, num_epochs, device
    )
    
    # 6. 可视化训练历史
    plot_training_history(history)
    
    # 7. 可视化一些预测结果
    print("生成预测可视化...")
    visualize_predictions(trained_model, val_loader, device, 3)
    
    print("训练和评估完成!")

if __name__ == "__main__":
    main()