import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from scipy.optimize import minimize
import warnings
warnings.filterwarnings('ignore')

class GaussianMixture:
    def __init__(self, means, stds, weights):
        self.means = np.array(means)
        self.stds = np.array(stds)
        self.weights = np.array(weights)
        self.weights = self.weights / np.sum(self.weights)  # normalize
    
    def pdf(self, x):
        """计算混合高斯的概率密度"""
        result = np.zeros_like(x)
        for mean, std, weight in zip(self.means, self.stds, self.weights):
            result += weight * norm.pdf(x, loc=mean, scale=std)
        return result
    
    def sample(self, n_samples):
        """从混合高斯分布中采样"""
        # 根据权重选择组件
        component_indices = np.random.choice(len(self.means), size=n_samples, p=self.weights)
        
        # 从选中的组件中采样
        samples = np.zeros(n_samples)
        for i in range(len(self.means)):
            mask = (component_indices == i)
            n_from_component = np.sum(mask)
            if n_from_component > 0:
                samples[mask] = np.random.normal(self.means[i], self.stds[i], n_from_component)
        
        return samples

def forward_kld_sampling(p_mixture, q_mixture, n_samples=10000):
    """用采样计算前向KLD: D_KL[p||q] = E_{x~p}[log(p(x)/q(x))]"""
    # 从p分布采样
    samples = p_mixture.sample(n_samples)
    
    # 计算p(x)和q(x)
    p_vals = p_mixture.pdf(samples)
    q_vals = q_mixture.pdf(samples)
    
    # 避免数值问题
    eps = 1e-10
    p_vals = np.maximum(p_vals, eps)
    q_vals = np.maximum(q_vals, eps)
    
    # 计算KLD
    log_ratios = np.log(p_vals) - np.log(q_vals)
    kld = np.mean(log_ratios)
    
    return kld

def reverse_kld_sampling(p_mixture, q_mixture, n_samples=10000):
    """用采样计算反向KLD: D_KL[q||p] = E_{x~q}[log(q(x)/p(x))]"""
    # 从q分布采样
    samples = q_mixture.sample(n_samples)
    
    # 计算p(x)和q(x)
    p_vals = p_mixture.pdf(samples)
    q_vals = q_mixture.pdf(samples)
    
    # 避免数值问题
    eps = 1e-10
    p_vals = np.maximum(p_vals, eps)
    q_vals = np.maximum(q_vals, eps)
    
    # 计算KLD
    log_ratios = np.log(q_vals) - np.log(p_vals)
    kld = np.mean(log_ratios)
    
    return kld

def fit_mixture_to_mixture(target_mixture, n_components=2, kld_type='forward', n_samples=10000):
    """用采样版本的KLD优化混合高斯去拟合目标混合分布"""
    
    def objective(params):
        # 解析参数: [mean1, log_std1, weight1, mean2, log_std2, weight2, ...]
        means = params[::3]
        stds = np.exp(params[1::3])  # 确保std > 0
        weights = np.exp(params[2::3])  # 确保weight > 0
        weights = weights / np.sum(weights)  # 归一化
        
        q_mixture = GaussianMixture(means, stds, weights)
        
        if kld_type == 'forward':
            return forward_kld_sampling(target_mixture, q_mixture, n_samples)
        else:  # reverse
            return reverse_kld_sampling(target_mixture, q_mixture, n_samples)
    
    # 多次随机初始化，选择最好的结果
    best_result = None
    best_loss = float('inf')
    
    # 增加初始化次数并改进初始化策略
    for trial in range(50):
        # 更智能的初始化：在目标分布峰值附近初始化
        target_peaks = target_mixture.means
        init_params = []
        
        for i in range(n_components):
            # 在目标峰值附近随机选择
            if i < len(target_peaks):
                init_mean = target_peaks[i] + np.random.normal(0, 0.3)
            else:
                init_mean = np.random.choice(target_peaks) + np.random.normal(0, 0.3)
            
            # 更合理的标准差初始化
            init_log_std = np.log(0.4) + np.random.normal(0, 0.2)
            # 更平均的权重初始化
            init_log_weight = np.random.normal(0, 0.2)
            init_params.extend([init_mean, init_log_std, init_log_weight])
        
        # 更紧的边界约束
        bounds = []
        for i in range(n_components):
            bounds.extend([(-4, 4), (-1.5, 1), (-1, 1)])  # mean, log_std, log_weight
        
        try:
            # 设置固定的随机种子以确保可重复性
            np.random.seed(42 + trial)
            
            result = minimize(objective, init_params, 
                             method='L-BFGS-B',
                             bounds=bounds,
                             options={'maxiter': 500, 'ftol': 1e-6})
            
            if result.success and result.fun < best_loss:
                best_loss = result.fun
                best_result = result
                print(f"  Trial {trial+1}: Loss = {result.fun:.6f}")
        except Exception as e:
            print(f"  Trial {trial+1}: Failed - {e}")
            continue
    
    if best_result is None:
        print(f"警告: {kld_type} KLD 优化失败，使用默认参数")
        # 使用默认参数
        means = target_mixture.means[:n_components]
        stds = target_mixture.stds[:n_components] 
        weights = target_mixture.weights[:n_components]
        weights = weights / np.sum(weights)
        return GaussianMixture(means, stds, weights), float('inf')
    
    # 解析最优参数
    params = best_result.x
    means = params[::3]
    stds = np.exp(params[1::3])
    weights = np.exp(params[2::3])
    weights = weights / np.sum(weights)
    
    return GaussianMixture(means, stds, weights), best_loss

# 实验设置
np.random.seed(42)  # 设置全局随机种子

# 创建目标分布：三峰高斯混合
target = GaussianMixture(means=[-2, 0, 2], stds=[0.3, 0.4, 0.3], weights=[0.3, 0.4, 0.3])

n_components = 2
n_samples = 20000  # 增加采样数量以提高精度

print("开始Forward KLD优化...")
# 用前向KLD拟合（用2个高斯拟合3个高斯）
mixture_fwd, loss_fwd = fit_mixture_to_mixture(target, n_components=n_components, 
                                             kld_type='forward', n_samples=n_samples)
print(f"\nForward KLD - Loss: {loss_fwd:.6f}")
print(f"Forward KLD - Means: {mixture_fwd.means}")
print(f"Forward KLD - Stds: {mixture_fwd.stds}")
print(f"Forward KLD - Weights: {mixture_fwd.weights}")

print("\n" + "="*50)
print("开始Reverse KLD优化...")
# 用反向KLD拟合
mixture_rev, loss_rev = fit_mixture_to_mixture(target, n_components=n_components, 
                                             kld_type='reverse', n_samples=n_samples)
print(f"\nReverse KLD - Loss: {loss_rev:.6f}")
print(f"Reverse KLD - Means: {mixture_rev.means}")
print(f"Reverse KLD - Stds: {mixture_rev.stds}")
print(f"Reverse KLD - Weights: {mixture_rev.weights}")

# 可视化部分保持不变
x = np.linspace(-5, 5, 1000)

# 可视化结果
plt.figure(figsize=(15, 10))

# 子图1：PDF比较
plt.subplot(2, 3, 1)
plt.plot(x, target.pdf(x), 'k-', linewidth=3, label='Target (3 Gaussians)')
plt.plot(x, mixture_fwd.pdf(x), 'r--', linewidth=2, label='Forward KLD fit (2 Gaussians)')
plt.plot(x, mixture_rev.pdf(x), 'b--', linewidth=2, label='Reverse KLD fit (2 Gaussians)')
plt.xlabel('x')
plt.ylabel('Probability Density')
plt.title('PDF Comparison')
plt.legend()
plt.grid(True, alpha=0.3)

# 子图2：对数尺度比较
plt.subplot(2, 3, 2)
plt.semilogy(x, target.pdf(x), 'k-', linewidth=3, label='Target')
plt.semilogy(x, mixture_fwd.pdf(x), 'r--', linewidth=2, label='Forward KLD fit')
plt.semilogy(x, mixture_rev.pdf(x), 'b--', linewidth=2, label='Reverse KLD fit')
plt.xlabel('x')
plt.ylabel('Log Probability Density')
plt.title('Log-scale Comparison')
plt.legend()
plt.grid(True, alpha=0.3)

# 子图3：误差分析
plt.subplot(2, 3, 3)
error_fwd = np.abs(target.pdf(x) - mixture_fwd.pdf(x))
error_rev = np.abs(target.pdf(x) - mixture_rev.pdf(x))
plt.plot(x, error_fwd, 'r-', linewidth=2, label='Forward KLD error')
plt.plot(x, error_rev, 'b-', linewidth=2, label='Reverse KLD error')
plt.xlabel('x')
plt.ylabel('Absolute Error')
plt.title('Fitting Errors')
plt.legend()
plt.grid(True, alpha=0.3)

# 子图4：分量分解 - Forward KLD
plt.subplot(2, 3, 4)
plt.plot(x, target.pdf(x), 'k-', linewidth=2, label='Target')
plt.plot(x, mixture_fwd.pdf(x), 'r--', linewidth=2, label='Forward fit (total)')
for i, (mean, std, weight) in enumerate(zip(mixture_fwd.means, mixture_fwd.stds, mixture_fwd.weights)):
    component = weight * norm.pdf(x, loc=mean, scale=std)
    plt.plot(x, component, 'r:', alpha=0.7, label=f'Component {i+1}' if i < 2 else None)
plt.xlabel('x')
plt.ylabel('Probability Density')
plt.title('Forward KLD - Component Breakdown')
plt.legend()
plt.grid(True, alpha=0.3)

# 子图5：分量分解 - Reverse KLD
plt.subplot(2, 3, 5)
plt.plot(x, target.pdf(x), 'k-', linewidth=2, label='Target')
plt.plot(x, mixture_rev.pdf(x), 'b--', linewidth=2, label='Reverse fit (total)')
for i, (mean, std, weight) in enumerate(zip(mixture_rev.means, mixture_rev.stds, mixture_rev.weights)):
    component = weight * norm.pdf(x, loc=mean, scale=std)
    plt.plot(x, component, 'b:', alpha=0.7, label=f'Component {i+1}' if i < 2 else None)
plt.xlabel('x')
plt.ylabel('Probability Density')
plt.title('Reverse KLD - Component Breakdown')
plt.legend()
plt.grid(True, alpha=0.3)

# 子图6：权重和位置对比
plt.subplot(2, 3, 6)
# 目标分布的峰值
target_positions = target.means
target_weights = target.weights
plt.scatter(target_positions, target_weights, s=200, c='black', marker='o', 
           label='Target peaks', alpha=0.8)

# Forward KLD结果
fwd_positions = mixture_fwd.means
fwd_weights = mixture_fwd.weights
plt.scatter(fwd_positions, fwd_weights, s=150, c='red', marker='^', 
           label='Forward KLD fit', alpha=0.8)

# Reverse KLD结果
rev_positions = mixture_rev.means
rev_weights = mixture_rev.weights
plt.scatter(rev_positions, rev_weights, s=150, c='blue', marker='s', 
           label='Reverse KLD fit', alpha=0.8)

plt.xlabel('Peak Position')
plt.ylabel('Weight')
plt.title('Peak Positions vs Weights')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('KLD_sampling.png', dpi=300, bbox_inches='tight')
plt.show()

# 额外分析
print("\n=== 拟合行为分析 ===")
print(f"目标分布峰值位置: {target.means}")
print(f"目标分布权重: {target.weights}")
print(f"\n前向KLD结果分析:")
print(f"  - 峰值位置: {mixture_fwd.means}")
print(f"  - 是否覆盖所有模式: {len(set(np.round(mixture_fwd.means, 1))) > 1}")
print(f"\n反向KLD结果分析:")
print(f"  - 峰值位置: {mixture_rev.means}")
print(f"  - 是否选择性拟合: {np.any(np.abs(mixture_rev.means - target.means[1]) < 0.5)}")