生生不息

生生不息

AI使用二元分段法检测变点

2025-06-23
AI使用二元分段法检测变点

提示词

一个DataFrame中,index是日期,有一列是股票的收盘价,请给出使用ruptures中的Binseg模型进行变点检测的示例代码。

答案

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ruptures as rpt
from matplotlib import dates as mdates

plt.style.use('seaborn-v0_8')
#指定字体,防止中文出现乱码,windows系统指定为‘SimHei’
plt.rcParams['font.sans-serif'] = ['SimHei']
#这行代码让中文的负号“-”可以正常显示
plt.rcParams["axes.unicode_minus"]=False

# 1. 创建示例数据(含异常值,替换为你的实际数据)
df = pd.read_excel('../数据/2-1.xlsx')
df = df[['日期', '收盘']]
df.columns = ['date','Close']
df['date'] = pd.to_datetime(df['date'])
df.set_index('date', inplace = True)

# 2. 使用BinSeg模型检测变点
signal = df["Close"].values  # 提取收盘价数组
model = "l2"  # 检测均值变化[1,6](@ref)
algo = rpt.detection.Binseg(model="l2", min_size=20)
algo.fit(signal)
change_points = algo.predict(pen=3)  # pen: 惩罚系数[1,4](@ref)

# 3. 可视化结果
plt.figure(figsize=(14, 7))
plt.plot(df.index, df["Close"], label="收盘价", lw=1.5, color="steelblue")

# 标记变点(排除最后一个索引)
for cp in change_points[:-1]:
    plt.axvline(
        x=df.index[cp], 
        color="red", 
        linestyle="--", 
        alpha=0.8,
        label="变点" if cp == change_points[0] else None  # 仅第一个添加图例
    )

# 分段着色展示
segment_colors = ["#FFE9E9", "#E9F7FF", "#F0F0FF"]
prev_idx = 0
for i, cp in enumerate(change_points[:-1]):
    plt.axvspan(
        df.index[prev_idx], 
        df.index[cp],
        facecolor=segment_colors[i % 3], 
        alpha=0.2
    )
    prev_idx = cp

# 图表美化
plt.title("股票收盘价变点检测 (BinSeg算法)", fontsize=15, pad=15)
plt.ylabel("价格", fontsize=12)
plt.grid(alpha=0.2, linestyle=":")
plt.legend(loc="upper left", frameon=True)
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))  # 日期格式
plt.gca().xaxis.set_major_locator(mdates.MonthLocator(interval=1))
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# 4. 输出变点信息
print(f"检测到变点数量: {len(change_points)-1}")
print("变点日期及对应收盘价:")
change_info = pd.DataFrame({
    "变点日期": df.index[change_points[:-1]],
    "收盘价": df["Close"].iloc[change_points[:-1]].values,
    "分段长度": np.diff([0] + change_points[:-1])  # 各段数据点数量
})
print(change_info)

数据导入以及防止画图中文乱码的代码是我加进去的,其余都是ai写的。

图像: