import numpy as np
import pandas as pd
import rasterio
import xgboost as xgb
import shap
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt
import os
from scipy.stats import spearmanr
import sys
import io
import locale
import warnings

# ========== 修复溢出警告 + 基础环境配置 ==========
warnings.filterwarnings('ignore', category=RuntimeWarning)

# 设置标准输出编码为UTF-8
try:
    if sys.stdout.encoding != 'UTF-8':
        sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
    if sys.stderr.encoding != 'UTF-8':
        sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
except:
    pass

os.environ["PYTHONIOENCODING"] = "utf-8"
os.environ["PYTHONUTF8"] = "1"

# 设置本地化
try:
    locale.setlocale(locale.LC_ALL, 'C.UTF-8')
except:
    try:
        locale.setlocale(locale.LC_ALL, '')
    except:
        print("警告: 无法设置UTF-8本地化")

# ========== 核心字体配置（中文+英文区分 + 5号字体） ==========
# 5号字体对应10.5磅，全局统一字体大小
plt.rcParams["font.size"] = 10.5
# 中文用黑体，英文（专业词汇）用Times New Roman
plt.rcParams["font.sans-serif"] = ["SimHei", "Times New Roman", "DejaVu Sans", "Arial"]
plt.rcParams["font.serif"] = ["Times New Roman", "SimHei"]
# 解决负号显示问题
plt.rcParams["axes.unicode_minus"] = False

print(f"XGBoost版本: {xgb.__version__}")
print(f"SHAP版本: {shap.__version__}")


# 读取并裁剪TIFF文件（修复溢出问题 + 增强异常值处理）
def read_and_crop_tiff(data_dir, filename, target_shape=None):
    file_path = os.path.join(data_dir, filename)
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"文件 {file_path} 不存在，请检查路径")

    with rasterio.open(file_path) as src:
        data = src.read(1).astype(np.float64)
        no_data_value = src.nodata if src.nodata is not None else -9999
        meta = src.meta.copy()
        original_shape = data.shape

        data[data == no_data_value] = np.nan
        float32_min = np.finfo(np.float32).min
        float32_max = np.finfo(np.float32).max
        data = np.clip(data, float32_min, float32_max)
        data = np.where(np.isinf(data), np.nan, data)
        data = data.astype(np.float32)

        if target_shape is not None:
            target_rows, target_cols = target_shape
            cropped_data = np.full(target_shape, np.nan, dtype=np.float32)
            crop_rows = min(original_shape[0], target_rows)
            crop_cols = min(original_shape[1], target_cols)
            cropped_data[:crop_rows, :crop_cols] = data[:crop_rows, :crop_cols]
            data = cropped_data
            meta['height'], meta['width'] = target_rows, target_cols
            meta['transform'] = src.transform * src.transform.translation(0, original_shape[0] - crop_rows)

        return data, meta, no_data_value


# 准备数据
def prepare_data():
    # 替换为你的实际数据路径
    data_dir = r'C:\Users\leele\xwechat_files\wxid_7igqqddfzuwr21_cdc8\msg\attach\9e20f478899dc29eb19741386f9343c8\2025-10\Rec\0d525438b500878e\F\1\07_data'
    print(f"数据目录: {data_dir}")

    if not os.path.exists(data_dir):
        raise NotADirectoryError(f"数据目录 {data_dir} 不存在，请检查路径")

    # 生态气象类自变量（与你的文件列表完全匹配）
    eco_meteor_files = {
        "NDBI": {"filename": "NDBI.tif", "type": "numeric"},
        "NDVI": {"filename": "NDVI.tif", "type": "numeric"},
        "风速": {"filename": "风速.tif", "type": "numeric"},
        "地表反照率": {"filename": "地表反照率.tif", "type": "numeric"},
        "相对湿度": {"filename": "相对湿度.tif", "type": "numeric"}
    }

    # 社会经济类自变量（与你的文件列表完全匹配）
    socio_econ_files = {
        "GDP": {"filename": "GDP.tif", "type": "numeric"},
        "人口密度": {"filename": "人口密度.tif", "type": "numeric"},
        "夜间灯光指数": {"filename": "夜间灯光指数.tif", "type": "numeric"}
    }

    # 地形类自变量（与你的文件列表完全匹配）
    terrain_files = {
        "海拔": {"filename": "海拔.tif", "type": "numeric"},
        "坡度": {"filename": "坡度.tif", "type": "numeric"}
    }

    # 目标变量（与你的文件列表完全匹配）
    target_file = {"LST": {"filename": "LST.tif", "type": "target"}}

    # 合并所有变量
    all_files = {**eco_meteor_files, **socio_econ_files, **terrain_files, **target_file}

    # 打印当前使用的变量列表
    print("\n当前使用的变量列表：")
    print(f"生态气象变量: {list(eco_meteor_files.keys())}")
    print(f"社会经济变量: {list(socio_econ_files.keys())}")
    print(f"地形变量: {list(terrain_files.keys())}")
    print(f"目标变量（因变量）: {list(target_file.keys())}")

    # 获取所有文件的原始形状，计算全局最小尺寸
    all_shapes = []
    for key, config in all_files.items():
        file_path = os.path.join(data_dir, config["filename"])
        with rasterio.open(file_path) as src:
            all_shapes.append((config["filename"], (src.height, src.width)))
            print(f"文件 {config['filename']} 原始尺寸: {src.height} x {src.width}")

    # 计算全局最小尺寸
    min_rows = min(shape[0] for _, shape in all_shapes)
    min_cols = min(shape[1] for _, shape in all_shapes)
    target_shape = (min_rows, min_cols)
    print(f"全局统一裁剪尺寸: {target_shape[0]} x {target_shape[1]}")

    data = {}
    meta = None
    all_valid_mask = None
    var_types = {"eco_meteor": [], "socio_econ": [], "terrain": [], "categorical": []}

    # 先读取目标变量
    target_key = "LST"
    target_data, target_meta, target_nd = read_and_crop_tiff(data_dir, target_file[target_key]["filename"],
                                                             target_shape)
    print(f"目标变量 {target_key} 裁剪后尺寸: {target_data.shape}")

    # 读取所有解释变量并统一裁剪
    for key, config in all_files.items():
        if key == target_key:
            continue

        data_arr, file_meta, nd = read_and_crop_tiff(data_dir, config["filename"], target_shape)
        print(f"自变量 {key} 裁剪后尺寸: {data_arr.shape}")

        if key in eco_meteor_files:
            var_types["eco_meteor"].append(key)
        elif key in socio_econ_files:
            var_types["socio_econ"].append(key)
        elif key in terrain_files:
            var_types["terrain"].append(key)

        data[key] = data_arr

        current_mask = ~np.isnan(data_arr)
        if meta is None:
            meta = file_meta
            all_valid_mask = current_mask
        else:
            if all_valid_mask.shape != current_mask.shape:
                raise ValueError(f"变量 {key} 掩码维度 {current_mask.shape} 与全局维度 {all_valid_mask.shape} 不匹配")
            all_valid_mask = np.logical_and(all_valid_mask, current_mask)

    # 合并目标变量的有效掩码
    target_mask = ~np.isnan(target_data)
    if all_valid_mask.shape != target_mask.shape:
        raise ValueError(f"目标变量掩码维度 {target_mask.shape} 与自变量掩码维度 {all_valid_mask.shape} 不匹配")
    all_valid_mask = np.logical_and(all_valid_mask, target_mask)
    print(f"最终有效数据掩码尺寸: {all_valid_mask.shape}")

    # 检查有效数据
    valid_count = np.sum(all_valid_mask)
    if valid_count == 0:
        raise ValueError("未找到所有变量都有数据的有效样本，请检查数据或无数据值设置")
    print(f"有效样本总数: {valid_count}")

    # 提取有效样本
    valid_indices = np.where(all_valid_mask)
    samples = {key: data[key][valid_indices] for key in data.keys()}
    samples[target_key] = target_data[valid_indices]

    df = pd.DataFrame(samples)
    print(f"最终数据框形状: {df.shape}")
    return df, var_types, meta


# Spearman相关性分析与可视化
def calculate_spearman(df, var_types, target_col, output_dir):
    # 生态气象变量相关性
    eco_meteor_corr = {}
    for col in var_types["eco_meteor"]:
        corr, p_val = spearmanr(df[col], df[target_col], nan_policy='omit')
        eco_meteor_corr[col] = {"相关系数": corr, "P值": p_val}
    eco_meteor_corr_df = pd.DataFrame.from_dict(eco_meteor_corr, orient='index').sort_values("相关系数",
                                                                                             ascending=False)
    eco_meteor_corr_df.to_csv(os.path.join(output_dir, '生态气象变量相关性.csv'), encoding='utf-8-sig')

    # 社会经济变量相关性
    socio_econ_corr = {}
    for col in var_types["socio_econ"]:
        corr, p_val = spearmanr(df[col], df[target_col], nan_policy='omit')
        socio_econ_corr[col] = {"相关系数": corr, "P值": p_val}
    socio_econ_corr_df = pd.DataFrame.from_dict(socio_econ_corr, orient='index').sort_values("相关系数",
                                                                                             ascending=False)
    socio_econ_corr_df.to_csv(os.path.join(output_dir, '社会经济变量相关性.csv'), encoding='utf-8-sig')

    # 地形变量相关性
    terrain_corr = {}
    for col in var_types["terrain"]:
        corr, p_val = spearmanr(df[col], df[target_col], nan_policy='omit')
        terrain_corr[col] = {"相关系数": corr, "P值": p_val}
    terrain_corr_df = pd.DataFrame.from_dict(terrain_corr, orient='index').sort_values("相关系数", ascending=False)
    terrain_corr_df.to_csv(os.path.join(output_dir, '地形变量相关性.csv'), encoding='utf-8-sig')

    # 可视化生态气象相关性（中文标题+5号字体）
    plt.figure(figsize=(12, 8))
    eco_meteor_corr_df["相关系数"].plot(kind='bar')
    plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
    plt.title('生态气象变量与LST的Spearman相关性', fontsize=10.5)
    plt.xlabel('变量', fontsize=10.5)
    plt.ylabel('Spearman相关系数', fontsize=10.5)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, '生态气象变量相关性.png'), dpi=300)
    plt.close()

    # 可视化社会经济相关性（中文标题+5号字体）
    plt.figure(figsize=(12, 8))
    socio_econ_corr_df["相关系数"].plot(kind='bar')
    plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
    plt.title('社会经济变量与LST的Spearman相关性', fontsize=10.5)
    plt.xlabel('变量', fontsize=10.5)
    plt.ylabel('Spearman相关系数', fontsize=10.5)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, '社会经济变量相关性.png'), dpi=300)
    plt.close()

    # 可视化地形相关性（中文标题+5号字体）
    plt.figure(figsize=(12, 8))
    terrain_corr_df["相关系数"].plot(kind='bar')
    plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
    plt.title('地形变量与LST的Spearman相关性', fontsize=10.5)
    plt.xlabel('变量', fontsize=10.5)
    plt.ylabel('Spearman相关系数', fontsize=10.5)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, '地形变量相关性.png'), dpi=300)
    plt.close()

    return eco_meteor_corr_df, socio_econ_corr_df, terrain_corr_df


# 特征选择
def select_features(eco_meteor_corr_df, socio_econ_corr_df, terrain_corr_df, top_n=3):
    eco_count = len(eco_meteor_corr_df)
    socio_count = len(socio_econ_corr_df)
    terrain_count = len(terrain_corr_df)

    eco_top_n = min(top_n, eco_count)
    socio_top_n = min(top_n, socio_count)
    terrain_top_n = min(top_n, terrain_count)

    print(f"\n特征选择容错处理：")
    print(f"生态气象变量共{eco_count}个，选取前{eco_top_n}个")
    print(f"社会经济变量共{socio_count}个，选取前{socio_top_n}个")
    print(f"地形变量共{terrain_count}个，选取前{terrain_top_n}个")

    eco_meteor_selected = eco_meteor_corr_df.head(eco_top_n).index.tolist()
    socio_econ_selected = socio_econ_corr_df.head(socio_top_n).index.tolist()
    terrain_selected = terrain_corr_df.head(terrain_top_n).index.tolist()

    print(f"选中的生态气象变量: {eco_meteor_selected}")
    print(f"选中的社会经济变量: {socio_econ_selected}")
    print(f"选中的地形变量: {terrain_selected}")

    return eco_meteor_selected, socio_econ_selected, terrain_selected


# 创建预处理管道
def create_preprocessing_pipeline(numeric_features):
    numeric_transformer = Pipeline(steps=[("scaler", StandardScaler())])
    preprocessor = ColumnTransformer(
        transformers=[("num", numeric_transformer, numeric_features)]
    )
    return preprocessor


# 绘制特征重要性组合图（中文+5号字体）
def plot_feature_importance_combined(model, feature_names, output_dir):
    feat_imp = pd.DataFrame({
        'Variable': feature_names,
        'Importance': model.feature_importances_
    }).sort_values(by='Importance', ascending=True)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 10))

    # 左侧：特征重要性条形图（中文+5号字体）
    bars = ax1.barh(feat_imp['Variable'], feat_imp['Importance'], color='#d62728')
    ax1.set_xlabel('特征重要性值', fontsize=10.5)
    ax1.set_ylabel('变量', fontsize=10.5)
    ax1.set_title('XGBoost模型特征重要性值', fontsize=10.5)
    for bar in bars:
        width = bar.get_width()
        ax1.text(width + 0.001, bar.get_y() + bar.get_height() / 2, f'{width:.3f}', va='center', fontsize=10.5)

    # 右侧：核心特征占比饼图（中文+5号字体）
    top_feats = feat_imp.nlargest(min(5, len(feat_imp)), 'Importance')
    top_feats['Percentage'] = (top_feats['Importance'] / top_feats['Importance'].sum()) * 100
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
    wedges, texts, autotexts = ax2.pie(
        top_feats['Percentage'],
        labels=top_feats['Variable'],
        colors=colors[:len(top_feats)],
        autopct='%1.1f%%',
        startangle=90,
        textprops={'fontsize': 10.5}  # 饼图文字5号字体
    )
    ax2.set_title('核心特征占比', fontsize=10.5)
    ax2.legend(wedges, top_feats['Variable'], title="变量", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1),
               fontsize=10.5)

    plt.suptitle('XGBoost模型特征重要性分析', fontsize=10.5)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'feature_importance_combined.png'), dpi=300, bbox_inches='tight')
    plt.close()


# 绘制SHAP Summary图（去掉“备选”二字 + 中文+英文专业词汇 + 5号字体）
def plot_shap_summary(shap_values, X_test_sampled, feature_names, output_dir):
    try:
        if hasattr(X_test_sampled, 'toarray'):
            X_test_array = X_test_sampled.toarray()
        else:
            X_test_array = np.array(X_test_sampled)

        min_rows = min(shap_values.shape[0], X_test_array.shape[0])
        min_cols = min(shap_values.shape[1], X_test_array.shape[1])
        shap_values_aligned = shap_values[:min_rows, :min_cols]
        X_test_aligned = X_test_array[:min_rows, :min_cols]
        feature_names_aligned = feature_names[:min_cols]

        plt.figure(figsize=(10, 12))
        shap.summary_plot(
            shap_values_aligned,
            X_test_aligned,
            feature_names=feature_names_aligned,
            plot_type="dot",
            color=X_test_aligned,
            cmap='RdBu_r',
            show=False,
            plot_size=None,
            color_bar=True
        )

        # 标题（中文+SHAP英文）+ 5号字体
        # plt.title('SHAP特征重要性汇总图', fontsize=10.5, pad=20)
        # X轴标签（中文+SHAP英文）+ 5号字体
        plt.xlabel('SHAP值（对模型输出的影响）', fontsize=10.5)

        # 颜色条配置（中文+5号字体）
        cbar_ax = plt.gcf().axes[-1]
        cbar_ax.set_ylabel('特征值', fontsize=10.5)
        cbar_ax.set_ticks([])
        # 颜色条上下文字（中文“高/低”）+ 5号字体
        cbar_ax.text(1.2, 0.95, '高', ha='center', va='center', fontsize=10.5)
        cbar_ax.text(1.2, 0.05, '低', ha='center', va='center', fontsize=10.5)

        plt.tight_layout(rect=[0, 0, 0.9, 1])
        plt.savefig(
            os.path.join(output_dir, 'shap_feature_importance_summary.png'),
            dpi=300,
            bbox_inches='tight'
        )
        plt.close()
        print("✅ SHAP特征重要性汇总图已导出")


    except Exception as e:

        print(f"❌ 绘制SHAP特征重要性汇总图失败: {e}")

        plt.figure(figsize=(10, 12))

        mean_abs_shap = np.mean(np.abs(shap_values), axis=0)

        sorted_idx = np.argsort(mean_abs_shap)[::-1]

        shap.summary_plot(

            shap_values[:, sorted_idx],

            X_test_sampled[:, sorted_idx],

            feature_names=[feature_names[i] for i in sorted_idx],

            plot_type="dot",

            show=False,

            cmap='RdBu_r'

        )

        # 关键修改：去掉标题中的“（备选）”
        # plt.title('SHAP特征重要性汇总图', fontsize=10.5)
        plt.xlabel('SHAP值（对模型输出的影响）', fontsize=10.5)
        # ========== 新增：修改兜底图的颜色条文本为中文 ==========
        # 获取颜色条轴并修改文本
        axes = plt.gcf().axes
        if len(axes) > 1:
            cbar_ax = axes[-1]
            # 将Feature value改为特征值
            cbar_ax.set_ylabel('特征值', fontsize=10.5)
            # 清除原有刻度标签
            cbar_ax.set_yticks([])
            # 添加中文高/低标签
            cbar_ax.text(2.5, 0.95, '高', ha='center', va='center', fontsize=10.5)
            cbar_ax.text(2.5, 0.05, '低', ha='center', va='center', fontsize=10.5)
        plt.tight_layout(rect=[0, 0, 0.9, 1])
        plt.savefig(os.path.join(output_dir, 'shap_summary_plot_fallback.png'), dpi=300)
        plt.close()
        print("✅ SHAP特征重要性汇总图（异常兜底）已导出")


# 创建XGBoost模型
def create_xgboost_model():
    model = xgb.XGBRegressor(
        objective="reg:squarederror",
        random_state=42,
        n_estimators=100,
        max_depth=3,
        learning_rate=0.1
    )
    return model


# 稳健的SHAP分析函数
def robust_shap_analysis(model, X_train_prep, X_test_prep, feature_names, sample_size=100):
    try:
        if len(X_test_prep) > sample_size:
            X_test_sampled = X_test_prep[:sample_size]
        else:
            X_test_sampled = X_test_prep

        print("使用KernelExplainer进行SHAP分析...")
        background_size = min(50, len(X_train_prep))
        background_data = X_train_prep[:background_size]

        if hasattr(background_data, 'toarray'):
            background_data = background_data.toarray()
        if hasattr(X_test_sampled, 'toarray'):
            X_test_sampled = X_test_sampled.toarray()

        explainer = shap.KernelExplainer(model.predict, background_data)
        shap_values = explainer.shap_values(X_test_sampled, nsamples=50)

        print("SHAP分析成功完成")
        return shap_values, X_test_sampled

    except Exception as e:
        print(f"SHAP分析失败: {e}")
        print("使用模型内置特征重要性作为备选...")
        feature_importance = model.feature_importances_
        shap_values = np.tile(feature_importance, (len(X_test_prep), 1))
        return shap_values, X_test_prep


# 模型训练与可解释性分析
def train_and_explain(df, var_types, selected_eco_meteor, selected_socio_econ, selected_terrain, output_dir):
    target_col = "LST"
    feature_names = selected_eco_meteor + selected_socio_econ + selected_terrain
    print(f"\n最终特征列表: {feature_names}")
    print(f"特征总数: {len(feature_names)}")

    X = df[feature_names]
    y = df[target_col]

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    preprocessor = create_preprocessing_pipeline(feature_names)
    X_train_prep = preprocessor.fit_transform(X_train)
    X_test_prep = preprocessor.transform(X_test)

    model = create_xgboost_model()
    model.fit(X_train_prep, y_train)

    y_pred_train = model.predict(X_train_prep)
    y_pred_test = model.predict(X_test_prep)
    train_r2 = r2_score(y_train, y_pred_train)
    test_r2 = r2_score(y_test, y_pred_test)
    print(f"训练集R² (LST): {train_r2:.4f}")
    print(f"测试集R² (LST): {test_r2:.4f}")

    with open(os.path.join(output_dir, '模型评估结果.txt'), 'w', encoding='utf-8') as f:
        f.write(f"目标变量: {target_col}\n")
        f.write(f"训练集R²: {train_r2:.4f}\n")
        f.write(f"测试集R²: {test_r2:.4f}\n")
        f.write(f"特征列表: {feature_names}\n")
        f.write(f"特征总数: {len(feature_names)}\n")
        f.write(f"\n当前使用的社会经济变量为：{selected_socio_econ}\n")

    plot_feature_importance_combined(model, feature_names, output_dir)

    print("执行SHAP可解释性分析...")
    shap_values, X_test_sampled = robust_shap_analysis(model, X_train_prep, X_test_prep, feature_names)

    print("绘制SHAP特征重要性汇总图...")
    plot_shap_summary(shap_values, X_test_sampled, feature_names, output_dir)

    shap_importance = np.mean(np.abs(shap_values), axis=0)
    shap_df = pd.DataFrame({
        '特征': feature_names,
        'SHAP重要性': shap_importance
    }).sort_values('SHAP重要性', ascending=False)
    shap_df.to_csv(os.path.join(output_dir, 'shap_importance.csv'), encoding='utf-8-sig', index=False)

    with open(os.path.join(output_dir, 'SHAP分析报告.txt'), 'w', encoding='utf-8') as f:
        f.write("XGBoost + SHAP可解释性分析报告\n")
        f.write("=" * 60 + "\n\n")
        f.write(f"目标变量（因变量）：LST\n\n")
        f.write(f"当前使用的社会经济变量为：{selected_socio_econ}\n\n")
        f.write("模型性能:\n")
        f.write(f"- 训练集R²: {train_r2:.4f}\n")
        f.write(f"- 测试集R²: {test_r2:.4f}\n\n")

        f.write("SHAP特征重要性排名:\n")
        f.write("-" * 40 + "\n")
        for idx, row in shap_df.iterrows():
            f.write(f"{idx+1}. {row['特征']}: {row['SHAP重要性']:.6f}\n")

        f.write(f"\n总特征数量: {len(feature_names)}\n")
        f.write(f"SHAP平均重要性: {np.mean(shap_importance):.6f}\n")

        if feature_names:
            f.write(f"\n关键发现:\n")
            f.write("-" * 40 + "\n")
            f.write(f"1. 最重要的特征 (SHAP): {shap_df.iloc[0]['特征']}\n")
            if selected_socio_econ:
                socio_feats = [f for f in shap_df['特征'] if f in selected_socio_econ]
                if socio_feats:
                    f.write(f"2. 最重要的社会经济因素 (SHAP): {socio_feats[0]}\n")
            f.write(f"3. 模型解释力: {test_r2 * 100:.1f}%\n")

    print("分析完成!")
    return model, preprocessor, feature_names


# 主函数
def main():
    output_dir = "LST分析结果"
    os.makedirs(output_dir, exist_ok=True)

    try:
        df, var_types, meta = prepare_data()
        target_col = "LST"
        print(f"\n数据加载完成，有效样本数: {len(df)}，目标变量（因变量）: {target_col}")

        print("\n计算Spearman相关性...")
        eco_meteor_corr, socio_econ_corr, terrain_corr = calculate_spearman(df, var_types, target_col, output_dir)

        print("\n选择高相关性指标...")
        selected_eco_meteor, selected_socio_econ, selected_terrain = select_features(
            eco_meteor_corr, socio_econ_corr, terrain_corr, top_n=2
        )

        print("\n训练模型并执行可解释性分析...")
        model, preprocessor, feature_names = train_and_explain(
            df, var_types, selected_eco_meteor, selected_socio_econ, selected_terrain, output_dir
        )

        print(f"\n✅ 所有分析完成，结果保存至 {output_dir} 目录")

    except Exception as e:
        print(f"\n❌ 程序执行过程中发生错误: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()

