# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
    roc_auc_score, average_precision_score, confusion_matrix
)
import random
import copy
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

# 中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# ================== 1. 设备设置 ==================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ================== 2. 加载数据 ==================
HERE = os.path.dirname(os.path.abspath(__file__))
DATA_DIR_CANDIDATES = [
    os.path.join(HERE, '..', 'data'),  # 原项目结构
    HERE,                              # 数据与脚本同目录
]

def _resolve_data_file(name):
    for d in DATA_DIR_CANDIDATES:
        p = os.path.join(d, name)
        if os.path.exists(p):
            return p
    raise FileNotFoundError(f"找不到数据文件: {name}")

X_df = pd.read_csv(_resolve_data_file('1and23456_0.10_train.csv'))
y_df = pd.read_csv(_resolve_data_file('1and23456_0.10_label.csv'))

OUTPUT_DIR = os.path.join(HERE, 'cm_results')
os.makedirs(OUTPUT_DIR, exist_ok=True)

X = X_df.iloc[:, 1:].values
y = y_df['label'].values.astype(np.int64)

print("X shape:", X.shape)
print("y shape:", y.shape)

# ================== 3. 参数 ==================
FEATURE_DIM = X.shape[1]
SEQ_LEN = 1
INPUT_DIM = FEATURE_DIM // SEQ_LEN
assert SEQ_LEN * INPUT_DIM == FEATURE_DIM, "SEQ_LEN * INPUT_DIM 必须等于特征维度"

NUM_CLASSES = len(np.unique(y))
BATCH_SIZE = 32
EPOCHS = 50
LR = 1e-3
K_NEIGHBORS = 5  # GNN 图邻居数量

# ================== 4. GCN Layer ==================
class GCNLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x, adj):
        x = torch.matmul(adj, x)
        x = self.linear(x)
        return x

# ================== 5. kNN 图构建 ==================
def build_knn_graph(x, k=5):
    with torch.no_grad():
        x_norm = F.normalize(x, dim=1)
        sim = torch.matmul(x_norm, x_norm.t())
        _, idx = sim.topk(k=k+1, dim=-1)
        N = x.size(0)
        adj = torch.zeros((N, N), device=x.device)
        for i in range(N):
            adj[i, idx[i]] = 1.0
        adj = adj + torch.eye(N, device=x.device)
        deg = adj.sum(dim=1)
        deg_inv_sqrt = torch.pow(deg, -0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0.0
        D_inv_sqrt = torch.diag(deg_inv_sqrt)
        adj = D_inv_sqrt @ adj @ D_inv_sqrt
    return adj

# ================== 6. LSTM + GNN 模型 ==================
class LSTM_GNN_Classifier(nn.Module):
    def __init__(self, input_dim, lstm_hidden_dim, lstm_layers, gnn_hidden_dim,
                 num_classes, bidirectional=True, dropout=0.3, k=5):
        super().__init__()
        self.k = k
        self.bidirectional = bidirectional

        self.lstm = nn.LSTM(
            input_dim, lstm_hidden_dim, num_layers=lstm_layers,
            batch_first=True, bidirectional=bidirectional,
            dropout=dropout if lstm_layers > 1 else 0.0
        )
        lstm_out_dim = lstm_hidden_dim * (2 if bidirectional else 1)
        self.gcn1 = GCNLayer(lstm_out_dim, gnn_hidden_dim)
        self.gcn2 = GCNLayer(gnn_hidden_dim, gnn_hidden_dim)

        self.classifier = nn.Sequential(
            nn.Linear(gnn_hidden_dim, gnn_hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(gnn_hidden_dim // 2, num_classes)
        )

    def forward(self, x):
        _, (h_n, _) = self.lstm(x)
        if self.bidirectional:
            node_feat = torch.cat([h_n[-2], h_n[-1]], dim=1)
        else:
            node_feat = h_n[-1]
        adj = build_knn_graph(node_feat, k=self.k)
        h = F.relu(self.gcn1(node_feat, adj))
        h = F.relu(self.gcn2(h, adj))
        logits = self.classifier(h)
        return logits

# ================== 7a. 混淆矩阵可视化 ==================
def plot_confusion_matrix(cm, class_names, title, save_path, normalize=False):
    """绘制混淆矩阵热力图并保存到 save_path。"""
    fig, ax = plt.subplots(figsize=(6, 5))
    if normalize:
        row_sum = cm.sum(axis=1, keepdims=True).astype(float)
        data = cm.astype(float) / np.maximum(row_sum, 1e-12)
        fmt = '.2f'
    else:
        data = cm
        fmt = 'd'

    im = ax.imshow(data, interpolation='nearest', cmap=plt.cm.Blues)
    fig.colorbar(im, ax=ax)

    ax.set(
        xticks=np.arange(cm.shape[1]),
        yticks=np.arange(cm.shape[0]),
        xticklabels=class_names,
        yticklabels=class_names,
        xlabel='预测标签 (Predicted)',
        ylabel='真实标签 (True)',
        title=title,
    )
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')

    thresh = data.max() / 2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(
                j, i, format(data[i, j], fmt),
                ha='center', va='center',
                color='white' if data[i, j] > thresh else 'black',
            )
    fig.tight_layout()
    fig.savefig(save_path, dpi=200, bbox_inches='tight')
    plt.close(fig)

# ================== 7. 评估指标函数 ==================
def evaluate_metrics(model, loader, num_classes):
    model.eval()
    preds, labels, probs = [], [], []

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            prob = F.softmax(logits, dim=1)
            pred = torch.argmax(prob, dim=1)

            preds.extend(pred.cpu().numpy())
            labels.extend(y.cpu().numpy())
            probs.extend(prob.cpu().numpy())

    preds = np.array(preds)
    labels = np.array(labels)
    probs = np.array(probs)

    acc = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average='macro', zero_division=0)
    recall = recall_score(labels, preds, average='macro', zero_division=0)
    f1 = f1_score(labels, preds, average='macro', zero_division=0)

    # ROC-AUC 和 PR-AUC（多分类用 one-hot）
    y_onehot = np.eye(num_classes)[labels]
    try:
        roc_auc = roc_auc_score(y_onehot, probs, average='macro', multi_class='ovr')
        pr_auc = average_precision_score(y_onehot, probs, average='macro')
    except ValueError:
        roc_auc = pr_auc = 0.0

    # 误报率 FPR = FP / (FP + TN)
    cm = confusion_matrix(labels, preds)
    FP = cm.sum(axis=0) - np.diag(cm)
    TN = cm.sum() - (FP + cm.sum(axis=1) - np.diag(cm) + np.diag(cm))
    FPR = (FP / (FP + TN + 1e-8)).mean()  # 平均误报率

    return acc, precision, recall, f1, roc_auc, pr_auc, FPR, cm

# ================== 8. Dataset & DataLoader ==================
class TimeSeriesDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# ================== 9. 主实验函数 ==================
def run_experiment(seed, X_data, y_data):
    print(f"\n{'='*60}")
    print(f"开始实验 - 随机种子: {seed}")
    print(f"{'='*60}")
    
    # 设置随机种子
    def set_seed(s):
        random.seed(s)
        np.random.seed(s)
        torch.manual_seed(s)
        torch.cuda.manual_seed_all(s)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    set_seed(seed)
    
    # 数据预处理和划分
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_data)
    X_seq = X_scaled.reshape(-1, SEQ_LEN, INPUT_DIM)
    
    # 8:1:1 划分 (训练:验证:测试 = 8:1:1)
    X_temp, X_test, y_temp, y_test = train_test_split(
        X_seq, y_data, test_size=0.1, random_state=seed, stratify=y_data
    )
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=0.1111, random_state=seed, stratify=y_temp
    )
    
    print(f"训练集: {X_train.shape}, {y_train.shape}")
    print(f"验证集: {X_val.shape}, {y_val.shape}")
    print(f"测试集: {X_test.shape}, {y_test.shape}")
    
    # 创建 DataLoader
    train_loader = DataLoader(TimeSeriesDataset(X_train, y_train), 
                             batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(TimeSeriesDataset(X_val, y_val), 
                           batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(TimeSeriesDataset(X_test, y_test), 
                            batch_size=BATCH_SIZE, shuffle=False)
    
    # 初始化模型
    model = LSTM_GNN_Classifier(INPUT_DIM, 128, 2, 128, NUM_CLASSES, k=K_NEIGHBORS).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    
    # 用于保存最佳模型和结果
    best_val_f1 = 0.0
    best_model_state = None
    best_val_metrics = None
    best_epoch = 0
    
    # 训练循环
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        train_loss = total_loss / len(train_loader)
        
        # 验证集评估
        acc, precision, recall, f1, roc_auc, pr_auc, FPR, _ = evaluate_metrics(model, val_loader, NUM_CLASSES)
        
        # 保存最佳模型（基于F1分数）
        if f1 > best_val_f1:
            best_val_f1 = f1
            best_model_state = copy.deepcopy(model.state_dict())
            best_val_metrics = {
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'val_acc': acc,
                'val_precision': precision,
                'val_recall': recall,
                'val_f1': f1,
                'val_roc_auc': roc_auc,
                'val_pr_auc': pr_auc,
                'val_fpr': FPR
            }
            best_epoch = epoch + 1
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch [{epoch+1}/{EPOCHS}] "
                  f"Train Loss: {train_loss:.4f} "
                  f"Val F1: {f1:.4f} (Best: {best_val_f1:.4f} @ Epoch {best_epoch})")
    
    # 加载最佳模型
    model.load_state_dict(best_model_state)
    
    # 测试集评估
    test_acc, test_precision, test_recall, test_f1, test_roc_auc, test_pr_auc, test_fpr, test_cm = evaluate_metrics(
        model, test_loader, NUM_CLASSES
    )

    # 保存当前种子的测试集混淆矩阵图
    class_names = [str(c) for c in range(NUM_CLASSES)]
    cm_raw_path = os.path.join(OUTPUT_DIR, f'cm_seed{seed}.png')
    cm_norm_path = os.path.join(OUTPUT_DIR, f'cm_seed{seed}_normalized.png')
    plot_confusion_matrix(test_cm, class_names,
                          f'测试集混淆矩阵 (seed={seed})', cm_raw_path)
    plot_confusion_matrix(test_cm, class_names,
                          f'测试集混淆矩阵-归一化 (seed={seed})', cm_norm_path, normalize=True)

    print(f"\n测试集混淆矩阵 (seed={seed}):")
    print(test_cm)
    print(f"已保存: {cm_raw_path}")
    print(f"已保存: {cm_norm_path}")

    test_metrics = {
        'test_acc': test_acc,
        'test_precision': test_precision,
        'test_recall': test_recall,
        'test_f1': test_f1,
        'test_roc_auc': test_roc_auc,
        'test_pr_auc': test_pr_auc,
        'test_fpr': test_fpr,
        'test_cm': test_cm,
    }
    
    print(f"\n{'='*60}")
    print(f"实验完成 - 随机种子: {seed}")
    print(f"最佳验证集结果 @ Epoch {best_epoch}:")
    print(f"  F1: {best_val_metrics['val_f1']:.4f}, "
          f"Accuracy: {best_val_metrics['val_acc']:.4f}, "
          f"Precision: {best_val_metrics['val_precision']:.4f}, "
          f"Recall: {best_val_metrics['val_recall']:.4f}")
    print(f"测试集结果:")
    print(f"  F1: {test_f1:.4f}, "
          f"Accuracy: {test_acc:.4f}, "
          f"Precision: {test_precision:.4f}, "
          f"Recall: {test_recall:.4f}")
    print(f"{'='*60}")
    
    return best_val_metrics, test_metrics

# ================== 10. 运行多个实验 ==================
seeds = [42, 123, 456, 789, 999]  # 5个不同的随机种子
all_val_results = []
all_test_results = []

for seed in seeds:
    val_metrics, test_metrics = run_experiment(seed, X, y)
    all_val_results.append(val_metrics)
    all_test_results.append(test_metrics)

# ================== 11. 打印汇总结果 ==================
print("\n" + "="*80)
print("实验汇总结果")
print("="*80)

# 打印每个种子的验证集最佳结果
print("\n各随机种子验证集最佳结果:")
print("-"*100)
print(f"{'种子':<10} {'Epoch':<10} {'Train Loss':<12} {'Accuracy':<10} {'Precision':<10} "
      f"{'Recall':<10} {'F1':<10} {'ROC-AUC':<10} {'PR-AUC':<10} {'FPR':<10}")
print("-"*100)

for i, seed in enumerate(seeds):
    metrics = all_val_results[i]
    print(f"{seed:<10} {metrics['epoch']:<10} {metrics['train_loss']:<12.4f} "
          f"{metrics['val_acc']:<10.4f} {metrics['val_precision']:<10.4f} "
          f"{metrics['val_recall']:<10.4f} {metrics['val_f1']:<10.4f} "
          f"{metrics['val_roc_auc']:<10.4f} {metrics['val_pr_auc']:<10.4f} "
          f"{metrics['val_fpr']:<10.4f}")

# 打印每个种子的测试集结果
print("\n\n各随机种子测试集结果:")
print("-"*100)
print(f"{'种子':<10} {'Accuracy':<10} {'Precision':<10} {'Recall':<10} "
      f"{'F1':<10} {'ROC-AUC':<10} {'PR-AUC':<10} {'FPR':<10}")
print("-"*100)

for i, seed in enumerate(seeds):
    metrics = all_test_results[i]
    print(f"{seed:<10} {metrics['test_acc']:<10.4f} {metrics['test_precision']:<10.4f} "
          f"{metrics['test_recall']:<10.4f} {metrics['test_f1']:<10.4f} "
          f"{metrics['test_roc_auc']:<10.4f} {metrics['test_pr_auc']:<10.4f} "
          f"{metrics['test_fpr']:<10.4f}")

# 计算平均结果
print("\n\n平均结果:")
print("-"*100)

# 验证集平均结果
val_avg = {
    'train_loss': np.mean([r['train_loss'] for r in all_val_results]),
    'val_acc': np.mean([r['val_acc'] for r in all_val_results]),
    'val_precision': np.mean([r['val_precision'] for r in all_val_results]),
    'val_recall': np.mean([r['val_recall'] for r in all_val_results]),
    'val_f1': np.mean([r['val_f1'] for r in all_val_results]),
    'val_roc_auc': np.mean([r['val_roc_auc'] for r in all_val_results]),
    'val_pr_auc': np.mean([r['val_pr_auc'] for r in all_val_results]),
    'val_fpr': np.mean([r['val_fpr'] for r in all_val_results])
}

# 测试集平均结果
test_avg = {
    'test_acc': np.mean([r['test_acc'] for r in all_test_results]),
    'test_precision': np.mean([r['test_precision'] for r in all_test_results]),
    'test_recall': np.mean([r['test_recall'] for r in all_test_results]),
    'test_f1': np.mean([r['test_f1'] for r in all_test_results]),
    'test_roc_auc': np.mean([r['test_roc_auc'] for r in all_test_results]),
    'test_pr_auc': np.mean([r['test_pr_auc'] for r in all_test_results]),
    'test_fpr': np.mean([r['test_fpr'] for r in all_test_results])
}

# 计算标准差
val_std = {
    'val_acc': np.std([r['val_acc'] for r in all_val_results]),
    'val_f1': np.std([r['val_f1'] for r in all_val_results]),
}

test_std = {
    'test_acc': np.std([r['test_acc'] for r in all_test_results]),
    'test_f1': np.std([r['test_f1'] for r in all_test_results]),
}

print("验证集平均结果:")
print(f"  Accuracy: {val_avg['val_acc']:.4f} ± {val_std['val_acc']:.4f}")
print(f"  Precision: {val_avg['val_precision']:.4f}")
print(f"  Recall: {val_avg['val_recall']:.4f}")
print(f"  F1: {val_avg['val_f1']:.4f} ± {val_std['val_f1']:.4f}")
print(f"  ROC-AUC: {val_avg['val_roc_auc']:.4f}")
print(f"  PR-AUC: {val_avg['val_pr_auc']:.4f}")
print(f"  FPR: {val_avg['val_fpr']:.4f}")

print("\n测试集平均结果:")
print(f"  Accuracy: {test_avg['test_acc']:.4f} ± {test_std['test_acc']:.4f}")
print(f"  Precision: {test_avg['test_precision']:.4f}")
print(f"  Recall: {test_avg['test_recall']:.4f}")
print(f"  F1: {test_avg['test_f1']:.4f} ± {test_std['test_f1']:.4f}")
print(f"  ROC-AUC: {test_avg['test_roc_auc']:.4f}")
print(f"  PR-AUC: {test_avg['test_pr_auc']:.4f}")
print(f"  FPR: {test_avg['test_fpr']:.4f}")

# ================== 12. 汇总所有种子的测试集混淆矩阵 ==================
total_cm = np.sum([r['test_cm'] for r in all_test_results], axis=0)
class_names = [str(c) for c in range(NUM_CLASSES)]

agg_raw_path = os.path.join(OUTPUT_DIR, 'cm_aggregated.png')
agg_norm_path = os.path.join(OUTPUT_DIR, 'cm_aggregated_normalized.png')
plot_confusion_matrix(total_cm, class_names,
                      '测试集混淆矩阵-累加 (所有种子)', agg_raw_path)
plot_confusion_matrix(total_cm, class_names,
                      '测试集混淆矩阵-累加归一化 (所有种子)', agg_norm_path, normalize=True)

print("\n累加所有种子的测试集混淆矩阵:")
print(total_cm)
print(f"已保存: {agg_raw_path}")
print(f"已保存: {agg_norm_path}")

print("\n" + "="*80)
print("所有实验完成!")
print(f"混淆矩阵图保存目录: {OUTPUT_DIR}")
print("="*80)