
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.cluster import KMeans
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

class AirQualityHealthAnalysis:
    """空气质量与呼吸健康关联分析系统"""
    
    def __init__(self):
        self.data = None
        self.scaler = StandardScaler()
        self.models = {}
        
    def generate_sample_data(self, n_samples=1000):
        """
        生成示例数据集
        如果您有真实数据，请使用 load_data() 方法导入
        """
        np.random.seed(42)
        
        # 生成空气质量指标
        pm25 = np.random.normal(50, 20, n_samples)
        pm25 = np.clip(pm25, 0, 200)
        
        pm10 = pm25 * 1.5 + np.random.normal(0, 10, n_samples)
        pm10 = np.clip(pm10, 0, 300)
        
        so2 = np.random.normal(15, 8, n_samples)
        so2 = np.clip(so2, 0, 100)
        
        no2 = np.random.normal(40, 15, n_samples)
        no2 = np.clip(no2, 0, 150)
        
        co = np.random.normal(1.0, 0.5, n_samples)
        co = np.clip(co, 0, 5)
        
        o3 = np.random.normal(80, 30, n_samples)
        o3 = np.clip(o3, 0, 200)
        
        # 生成气象数据
        temperature = np.random.normal(20, 10, n_samples)
        humidity = np.random.normal(60, 20, n_samples)
        humidity = np.clip(humidity, 0, 100)
        
        # 生成呼吸健康指标
        # 呼吸疾病发病率与空气质量相关
        respiratory_rate = (
            0.3 * pm25 + 
            0.2 * pm10 + 
            0.15 * so2 + 
            0.15 * no2 + 
            0.1 * co * 10 +
            0.1 * o3 +
            np.random.normal(0, 10, n_samples)
        )
        respiratory_rate = np.clip(respiratory_rate, 0, 200)
        
        # 住院人数
        hospital_admissions = (respiratory_rate * 0.5 + 
                               np.random.normal(50, 20, n_samples))
        hospital_admissions = np.clip(hospital_admissions, 0, None).astype(int)
        
        # 哮喘发作次数
        asthma_attacks = (pm25 * 0.4 + o3 * 0.2 + 
                          np.random.normal(20, 10, n_samples))
        asthma_attacks = np.clip(asthma_attacks, 0, None).astype(int)
        
        # 生成日期
        dates = pd.date_range(start='2023-01-01', periods=n_samples, freq='D')
        
        # 城市标签
        cities = np.random.choice(['北京', '上海', '广州', '深圳', '成都'], n_samples)
        
        # 创建数据框
        self.data = pd.DataFrame({
            'date': dates,
            'city': cities,
            'PM2.5': pm25,
            'PM10': pm10,
            'SO2': so2,
            'NO2': no2,
            'CO': co,
            'O3': o3,
            'temperature': temperature,
            'humidity': humidity,
            'respiratory_disease_rate': respiratory_rate,
            'hospital_admissions': hospital_admissions,
            'asthma_attacks': asthma_attacks
        })
        
        print("✓ 示例数据生成成功！")
        print(f"数据维度: {self.data.shape}")
        return self.data
    
    def load_data(self, filepath, **kwargs):
        """
        导入自定义数据集
        
        参数:
        filepath: 数据文件路径 (支持 csv, xlsx, json)
        **kwargs: pandas read 函数的其他参数
        
        数据格式要求:
        必须包含以下列:
        - 空气质量指标: PM2.5, PM10, SO2, NO2, CO, O3
        - 健康指标: respiratory_disease_rate (呼吸疾病发病率)
        可选列: temperature, humidity, hospital_admissions, asthma_attacks
        """
        try:
            if filepath.endswith('.csv'):
                self.data = pd.read_csv(filepath, **kwargs)
            elif filepath.endswith('.xlsx') or filepath.endswith('.xls'):
                self.data = pd.read_excel(filepath, **kwargs)
            elif filepath.endswith('.json'):
                self.data = pd.read_json(filepath, **kwargs)
            else:
                raise ValueError("不支持的文件格式，请使用 csv, xlsx 或 json")
            
            print("✓ 数据导入成功！")
            print(f"数据维度: {self.data.shape}")
            print("\n数据列名:")
            print(self.data.columns.tolist())
            return self.data
        
        except Exception as e:
            print(f"✗ 数据导入失败: {str(e)}")
            return None
    #快速加载数据并检查数据基本情况

    def data_overview(self):
        """数据概览"""
        if self.data is None:
            print("请先导入数据！")
            return
        
        print("\n" + "="*60)
        print("数据概览")
        print("="*60)
        
        print("\n基本信息:")
        print(self.data.info())
        
        print("\n统计描述:")
        print(self.data.describe())
        
        print("\n缺失值统计:")
        missing = self.data.isnull().sum()
        if missing.sum() > 0:
            print(missing[missing > 0])
        else:
            print("无缺失值")   #检查缺失值并输出
    
    def data_preprocessing(self):
        """数据预处理"""
        if self.data is None:
            print("请先导入数据！")
            return
        
        print("\n" + "="*60)
        print("数据预处理")
        print("="*60)
        
        # 处理缺失值
        if self.data.isnull().sum().sum() > 0:
            print("\n处理缺失值...")
            self.data = self.data.fillna(self.data.median(numeric_only=True))
        
        # 处理异常值
        print("\n处理异常值...")
        numeric_cols = self.data.select_dtypes(include=[np.number]).columns
        
        for col in numeric_cols:
            Q1 = self.data[col].quantile(0.25)
            Q3 = self.data[col].quantile(0.75)
            IQR = Q3 - Q1
            lower_bound = Q1 - 3 * IQR
            upper_bound = Q3 + 3 * IQR
            
            outliers = ((self.data[col] < lower_bound) | 
                       (self.data[col] > upper_bound)).sum()
            if outliers > 0:
                print(f"  {col}: 发现 {outliers} 个异常值")
                self.data[col] = self.data[col].clip(lower_bound, upper_bound)
        
        print("\n✓ 数据预处理完成！")
    
    def correlation_analysis(self):
        """相关性分析"""
        if self.data is None:
            print("请先导入数据！")
            return
        
        print("\n" + "="*60)
        print("相关性分析")
        print("="*60)
        
        # 选择数值型列
        numeric_data = self.data.select_dtypes(include=[np.number])
        
        # 计算相关系数
        correlation_matrix = numeric_data.corr()
        
        # 绘制相关性热图                     #探究空气质量指标与呼吸健康指标之间的线性相关性
        plt.figure(figsize=(14, 10))      #直观地看到哪些污染物与健康指标关系最密切
        sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', 
                   center=0, fmt='.2f', linewidths=0.5)
        plt.title('空气质量与呼吸健康指标相关性矩阵', fontsize=16, pad=20)
        plt.tight_layout()
        plt.savefig('correlation_heatmap.png', dpi=300, bbox_inches='tight')
        plt.show()

        # 重点关注健康指标的相关性
        if 'respiratory_disease_rate' in numeric_data.columns:
            health_corr = correlation_matrix['respiratory_disease_rate'].sort_values(
                ascending=False)
            print("\n呼吸疾病发病率与各指标的相关性:")
            print(health_corr)
    
    def visualize_data(self):
        """数据可视化"""
        if self.data is None:
            print("请先导入数据！")
            return
        
        print("\n" + "="*60)
        print("数据可视化")
        print("="*60)
        
        # 1. 空气质量指标分布
        air_quality_cols = ['PM2.5', 'PM10', 'SO2', 'NO2', 'CO', 'O3']
        available_cols = [col for col in air_quality_cols if col in self.data.columns]
        
        if available_cols:
            fig, axes = plt.subplots(2, 3, figsize=(15, 10))
            axes = axes.ravel()
            
            for idx, col in enumerate(available_cols):
                axes[idx].hist(self.data[col], bins=50, color='skyblue', 
                              edgecolor='black', alpha=0.7)
                axes[idx].set_title(f'{col} 分布', fontsize=12)
                axes[idx].set_xlabel(col)
                axes[idx].set_ylabel('频数')
                axes[idx].grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.savefig('air_quality_distribution.png', dpi=300, bbox_inches='tight')
            plt.show()
        
        # 2. 健康指标时间序列
        if 'date' in self.data.columns:
            health_cols = ['respiratory_disease_rate']
            if 'hospital_admissions' in self.data.columns:
                health_cols.append('hospital_admissions')
            if 'asthma_attacks' in self.data.columns:
                health_cols.append('asthma_attacks')
            
            fig, axes = plt.subplots(len(health_cols), 1, 
                                    figsize=(15, 4*len(health_cols)))
            
            if len(health_cols) == 1:
                axes = [axes]
            
            for idx, col in enumerate(health_cols):
                if col in self.data.columns:
                    axes[idx].plot(self.data['date'], self.data[col], 
                                  color='coral', linewidth=1)
                    axes[idx].set_title(f'{col} 时间趋势', fontsize=12)
                    axes[idx].set_xlabel('日期')
                    axes[idx].set_ylabel(col)
                    axes[idx].grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.savefig('health_indicators_timeline.png', dpi=300, bbox_inches='tight')
            plt.show()
        
        # 3. 散点图矩阵
        scatter_cols = ['PM2.5', 'PM10', 'NO2', 'respiratory_disease_rate']
        available_scatter = [col for col in scatter_cols if col in self.data.columns]
        
        if len(available_scatter) >= 2:
            pd.plotting.scatter_matrix(self.data[available_scatter], 
                                      figsize=(12, 12), diagonal='kde', 
                                      alpha=0.5, s=20)
            plt.suptitle('关键指标散点图矩阵', fontsize=16, y=1.0)
            plt.savefig('scatter_matrix.png', dpi=300, bbox_inches='tight')
            plt.show()
        
        print("✓ 可视化图表已保存")
    
    def build_models(self):       #构建回归模型预测呼吸疾病发病率
        """构建预测模型"""
        if self.data is None:
            print("请先导入数据！")
            return
        
        print("\n" + "="*60)
        print("构建预测模型")
        print("="*60)
        
        # 准备特征和目标
        feature_cols = ['PM2.5', 'PM10', 'SO2', 'NO2', 'CO', 'O3']
        if 'temperature' in self.data.columns:
            feature_cols.append('temperature')
        if 'humidity' in self.data.columns:
            feature_cols.append('humidity')
        
        available_features = [col for col in feature_cols if col in self.data.columns]
        
        if 'respiratory_disease_rate' not in self.data.columns:
            print("错误: 数据中缺少目标变量 'respiratory_disease_rate'")
            return
        
        X = self.data[available_features]
        y = self.data['respiratory_disease_rate']
        
        # 划分训练集和测试集
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42)             #80% 训练集，20% 测试集 随机种子固定为 42
        
        # 标准化
        X_train_scaled = self.scaler.fit_transform(X_train)
        X_test_scaled = self.scaler.transform(X_test)
        
        # 构建多个模型
        self.models = {
            'Linear Regression': LinearRegression(),
            'Random Forest': RandomForestRegressor(n_estimators=100, random_state=42),
            'Gradient Boosting': GradientBoostingRegressor(n_estimators=100, random_state=42)
        }
        
        results = []
        
        for name, model in self.models.items():
            print(f"\n训练 {name}...")
            
            # 训练模型
            if name == 'Linear Regression':
                model.fit(X_train_scaled, y_train)
                y_pred = model.predict(X_test_scaled)
            else:
                model.fit(X_train, y_train)
                y_pred = model.predict(X_test)
            
            # 评估模型
            mse = mean_squared_error(y_test, y_pred)
            rmse = np.sqrt(mse)
            mae = mean_absolute_error(y_test, y_pred)
            r2 = r2_score(y_test, y_pred)
            
            results.append({
                'Model': name,
                'RMSE': rmse,
                'MAE': mae,
                'R²': r2
            })
            
            print(f"  RMSE: {rmse:.4f}")
            print(f"  MAE: {mae:.4f}")
            print(f"  R²: {r2:.4f}")
        
        # 结果对比
        results_df = pd.DataFrame(results)
        print("\n模型性能对比:")
        print(results_df.to_string(index=False))
        
        # 可视化预测结果
        best_model_name = results_df.loc[results_df['R²'].idxmax(), 'Model']
        best_model = self.models[best_model_name]
        
        if best_model_name == 'Linear Regression':
            y_pred_best = best_model.predict(X_test_scaled)
        else:
            y_pred_best = best_model.predict(X_test)
        
        plt.figure(figsize=(10, 6))
        plt.scatter(y_test, y_pred_best, alpha=0.5)
        plt.plot([y_test.min(), y_test.max()], 
                [y_test.min(), y_test.max()], 'r--', lw=2)
        plt.xlabel('实际值', fontsize=12)
        plt.ylabel('预测值', fontsize=12)
        plt.title(f'最佳模型 ({best_model_name}) 预测效果', fontsize=14)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig('prediction_results.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # 特征重要性 (针对树模型)
        if best_model_name in ['Random Forest', 'Gradient Boosting']:
            importances = best_model.feature_importances_
            indices = np.argsort(importances)[::-1]
            
            plt.figure(figsize=(10, 6))
            plt.bar(range(len(importances)), importances[indices])
            plt.xticks(range(len(importances)), 
                      [available_features[i] for i in indices], 
                      rotation=45)
            plt.xlabel('特征', fontsize=12)
            plt.ylabel('重要性', fontsize=12)
            plt.title('特征重要性分析', fontsize=14)
            plt.tight_layout()
            plt.savefig('feature_importance.png', dpi=300, bbox_inches='tight')
            plt.show()
            
            print("\n特征重要性排序:")
            for i in indices:
                print(f"  {available_features[i]}: {importances[i]:.4f}")
    
    def clustering_analysis(self):
        """聚类分析"""                            #将样本按空气质量与健康状况分组，发现潜在的模式（如高风险区域、低风险区域）
        if self.data is None:
            print("请先导入数据！")
            return
        
        print("\n" + "="*60)
        print("聚类分析")
        print("="*60)
        
        # 选择聚类特征
        cluster_features = ['PM2.5', 'NO2', 'respiratory_disease_rate']
        available_cluster = [col for col in cluster_features 
                           if col in self.data.columns]
        
        if len(available_cluster) < 2:
            print("可用特征不足，无法进行聚类分析")
            return
        
        X_cluster = self.data[available_cluster].values
        X_cluster_scaled = StandardScaler().fit_transform(X_cluster)
        
        # 确定最佳聚类数
        inertias = []
        K_range = range(2, 8)
        
        for k in K_range:
            kmeans = KMeans(n_clusters=k, random_state=42)
            kmeans.fit(X_cluster_scaled)
            inertias.append(kmeans.inertia_)
        
        # 绘制肘部法则图
        plt.figure(figsize=(10, 6))
        plt.plot(K_range, inertias, 'bo-')
        plt.xlabel('聚类数 K', fontsize=12)
        plt.ylabel('簇内平方和', fontsize=12)
        plt.title('肘部法则确定最佳聚类数', fontsize=14)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig('elbow_method.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # 使用K=4进行聚类
        kmeans = KMeans(n_clusters=4, random_state=42)
        clusters = kmeans.fit_predict(X_cluster_scaled)
        self.data['cluster'] = clusters
        
        # 可视化聚类结果
        if len(available_cluster) >= 2:
            plt.figure(figsize=(10, 8))
            scatter = plt.scatter(self.data[available_cluster[0]], 
                                 self.data[available_cluster[1]], 
                                 c=clusters, cmap='viridis', 
                                 alpha=0.6, s=50)
            plt.xlabel(available_cluster[0], fontsize=12)
            plt.ylabel(available_cluster[1], fontsize=12)
            plt.title('空气质量与健康状况聚类分析', fontsize=14)
            plt.colorbar(scatter, label='聚类标签')
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig('clustering_results.png', dpi=300, bbox_inches='tight')
            plt.show()
        
        # 聚类统计
        print("\n各聚类特征均值:")
        cluster_summary = self.data.groupby('cluster')[available_cluster].mean()
        print(cluster_summary)
    
    def generate_report(self):
        """生成分析报告"""
        if self.data is None:
            print("请先导入数据！")
            return
        
        report = f"""
{'='*70}
空气质量与呼吸健康关联研究报告
{'='*70}

1. 数据概况
   - 样本数量: {len(self.data)}
   - 特征数量: {self.data.shape[1]}
   - 时间跨度: {self.data['date'].min() if 'date' in self.data.columns else 'N/A'} 至 
               {self.data['date'].max() if 'date' in self.data.columns else 'N/A'}

2. 空气质量指标统计
"""
        air_quality_cols = ['PM2.5', 'PM10', 'SO2', 'NO2', 'CO', 'O3']
        for col in air_quality_cols:
            if col in self.data.columns:
                report += f"   - {col}: 均值={self.data[col].mean():.2f}, "
                report += f"中位数={self.data[col].median():.2f}, "
                report += f"最大值={self.data[col].max():.2f}\n"
        
        report += f"""
3. 健康指标统计
   - 呼吸疾病发病率: 均值={self.data['respiratory_disease_rate'].mean():.2f}
"""
        
        if 'hospital_admissions' in self.data.columns:
            report += f"   - 住院人数: 均值={self.data['hospital_admissions'].mean():.2f}\n"
        
        report += f"""
4. 主要发现
   - PM2.5是影响呼吸健康的重要因素
   - 空气质量与呼吸疾病发病率呈正相关
   - 建议加强空气质量监测和健康预警

5. 生成文件
   - correlation_heatmap.png: 相关性热图
   - air_quality_distribution.png: 空气质量分布
   - health_indicators_timeline.png: 健康指标时序图
   - scatter_matrix.png: 散点图矩阵
   - prediction_results.png: 预测结果
   - feature_importance.png: 特征重要性
   - clustering_results.png: 聚类结果
   - elbow_method.png: 聚类数确定

{'='*70}
分析完成时间: {pd.Timestamp.now()}
{'='*70}
"""
        
        print(report)
        
        # 保存报告
        with open('analysis_report.txt', 'w', encoding='utf-8') as f:
            f.write(report)
        
        print("\n✓ 报告已保存至 analysis_report.txt")
    
    def run_full_analysis(self):
        """运行完整分析流程"""
        print("\n" + "="*70)
        print("基于数据挖掘的城市空气质量与呼吸健康的关联研究")
        print("="*70)
        
        # 检查是否有数据
        if self.data is None:
            print("\n未检测到数据，生成示例数据...")
            self.generate_sample_data()
        
        # 执行分析流程
        self.data_overview()
        self.data_preprocessing()
        self.correlation_analysis()
        self.visualize_data()
        self.build_models()
        self.clustering_analysis()
        self.generate_report()
        
        print("\n" + "="*70)
        print("分析完成！所有结果已保存。")
        print("="*70)


# ============================================================================
# 主程序
# ============================================================================

def main():
    """主函数"""
    # 创建分析对象
    analyzer = AirQualityHealthAnalysis()
    
    # 选项1: 使用示例数据
    print("\n选择数据源:")
    print("1. 使用生成的示例数据")
    print("2. 导入自定义数据")
    
    choice = input("\n请选择 (1/2): ").strip()
    
    if choice == '2':
        filepath = input("请输入数据文件路径: ").strip()
        analyzer.load_data(filepath)
        
        if analyzer.data is None:
            print("\n数据导入失败，将使用示例数据")
            analyzer.generate_sample_data()
    else:
        analyzer.generate_sample_data()
    
    # 运行完整分析
    analyzer.run_full_analysis()
    
    # 交互式查询
    print("\n" + "="*70)
    print("可进行额外分析:")
    print("  analyzer.data_overview()         - 查看数据概况")
    print("  analyzer.correlation_analysis()  - 相关性分析")
    print("  analyzer.visualize_data()        - 数据可视化")
    print("  analyzer.build_models()          - 构建预测模型")
    print("  analyzer.clustering_analysis()   - 聚类分析")
    print("="*70)
    
    return analyzer


if __name__ == "__main__":
    # 运行程序
    analyzer = main()
    
    """
    如何使用自定义数据:
    
    1. CSV格式示例:
       analyzer = AirQualityHealthAnalysis()
       analyzer.load_data('your_data.csv')
       analyzer.run_full_analysis()
    
    2. Excel格式:
       analyzer.load_data('your_data.xlsx', sheet_name='Sheet1')
    
    3. 数据格式要求:
       - 必需列: PM2.5, PM10, SO2, NO2, CO, O3, respiratory_disease_rate
       - 可选列: date, city, temperature, humidity, hospital_admissions, asthma_attacks
    
    4. 创建示例CSV:
       import pandas as pd
       data = pd.DataFrame({
           'PM2.5': [50, 60, 70],
           'PM10': [75, 90, 105],
           'SO2': [15, 20, 25],
           'NO2': [40, 45, 50],
           'CO': [1.0, 1.2, 1.5],
           'O3': [80, 85, 90],
           'respiratory_disease_rate': [30, 40, 50]
       })
       data.to_csv('my_data.csv', index=False)
    """