# -*- coding: utf-8 -*- """ BTCUSD 15m 多次采样预测 → 区间聚合可视化 Author: 你(鹅) """ import os import sys import math import json import time import random import argparse from datetime import datetime import numpy as np import pandas as pd import matplotlib.pyplot as plt from tqdm import tqdm from dotenv import load_dotenv from pybit.unified_trading import HTTP # ==== Kronos ==== from model import Kronos, KronosTokenizer, KronosPredictor # ========== Matplotlib 字体 ========== plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"] plt.rcParams["axes.unicode_minus"] = False # ========== 默认参数 ========== NUM_SAMPLES = 20 # 多次采样次数(建议 10~50) QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上) LOOKBACK = 360 # 历史窗口 PRED_LEN = 96 # 未来预测点数 INTERVAL_MIN = 15 # K 线周期(分钟) DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0") TEMPERATURE = 1.0 TOP_P = 0.9 OUTPUT_DIR = "figures" os.makedirs(OUTPUT_DIR, exist_ok=True) # ========== 实用函数 ========== def set_global_seed(seed: int): """尽可能固定随机性(若底层采样逻辑使用 torch/np/random)""" try: import torch torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) except Exception: pass random.seed(seed) np.random.seed(seed) def fetch_kline( session: HTTP, symbol="BTCUSD", interval="15", limit=1000, category="inverse" ): print("正在获取K线数据...") resp = session.get_kline( category=category, symbol=symbol, interval=interval, limit=limit ) if resp.get("retCode", -1) != 0: raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}") lst = resp["result"]["list"] df = pd.DataFrame( lst, columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"], ) df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms") for c in ["open", "high", "low", "close", "volume", "turnover"]: df[c] = df[c].astype(float) df = df.sort_values("timestamps").reset_index(drop=True) df["amount"] = df["turnover"] print(f"获取到 {len(df)} 根K线数据") return df def prepare_io_windows( df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN ): end_idx = len(df) start_idx = max(0, end_idx - lookback) x_df = df.loc[ start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"] ].reset_index(drop=True) x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True) last_ts = df.loc[end_idx - 1, "timestamps"] future_timestamps = pd.date_range( start=last_ts + pd.Timedelta(minutes=interval_min), periods=pred_len, freq=f"{interval_min}min", ) y_timestamp = pd.Series(future_timestamps) y_timestamp.index = range(len(y_timestamp)) print(f"数据总量: {len(df)} 根K线") print(f"使用最新的 {lookback} 根K线(索引 {start_idx} 到 {end_idx-1})") print(f"最后一根历史K线时间: {last_ts}") print(f"预测未来 {pred_len} 根K线") return x_df, x_timestamp, y_timestamp, start_idx, end_idx def load_kronos(device=DEVICE): print("正在加载 Kronos 模型...") tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") model = Kronos.from_pretrained("NeoQuasar/Kronos-base") predictor = KronosPredictor(model, tokenizer, device=device, max_context=512) print("模型加载完成!\n") return predictor def run_one_prediction( predictor, x_df, x_timestamp, y_timestamp, T=TEMPERATURE, top_p=TOP_P, seed=None, verbose=False, ): if seed is not None: set_global_seed(seed) return predictor.predict( df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, pred_len=len(y_timestamp), T=T, top_p=top_p, verbose=verbose, ) def run_multi_predictions( predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42, temperature=TEMPERATURE, top_p=TOP_P ): preds = [] print(f"正在进行多次预测:{num_samples} 次(T={temperature}, top_p={top_p})...") for i in tqdm(range(num_samples), desc="预测进度", unit="次", ncols=100): seed = base_seed + i # 每次不同 seed pred_df = run_one_prediction( predictor, x_df, x_timestamp, y_timestamp, T=temperature, top_p=top_p, seed=seed, verbose=False, ) # 兼容性处理:确保只取需要的列 cols_present = [ c for c in ["open", "high", "low", "close", "volume", "amount"] if c in pred_df.columns ] pred_df = pred_df[cols_present].copy() pred_df.reset_index(drop=True, inplace=True) preds.append(pred_df) print("多次预测完成。\n") return preds def aggregate_quantiles(pred_list, quantiles=QUANTILES): """ 将多次预测列表聚合成分位数 DataFrame: 输出列命名:_q10, _q50, _q90(按 quantiles 中的值) """ # 先把每次预测拼成 3D:time x feature x samples keys = pred_list[0].columns.tolist() T_len = len(pred_list[0]) S = len(pred_list) data = {k: np.zeros((T_len, S), dtype=float) for k in keys} for j, pdf in enumerate(pred_list): for k in keys: data[k][:, j] = pdf[k].values out = {} for k in keys: for q in quantiles: qv = np.quantile(data[k], q, axis=1) out[f"{k}_q{int(q*100):02d}"] = qv agg_df = pd.DataFrame(out) return agg_df def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"): """ 上下两个子图(共享X轴): - 上方(高度3):历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50) - 下方(高度1):历史成交量柱 + 预测成交量中位柱(仅q50,不显示区间) """ import matplotlib.gridspec as gridspec # 智能推断K线宽度(柱宽) try: if len(historical_df) >= 2: hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2]) else: hist_step = pd.Timedelta(minutes=10) if len(y_timestamp) >= 2: pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0]) else: pred_step = pd.Timedelta(minutes=10) bar_width_hist = hist_step * 0.8 bar_width_pred = pred_step * 0.8 except Exception: bar_width_hist = pd.Timedelta(minutes=10) bar_width_pred = pd.Timedelta(minutes=10) # 取出预测分位 close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None # 图形与网格 fig = plt.figure(figsize=(18, 10)) gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3 # ===== 上:收盘价 ===== ax_price = fig.add_subplot(gs[0]) ax_price.plot( historical_df["timestamps"], historical_df["close"], label="历史收盘价", linewidth=1.8, ) # 预测收盘价区间与中位线 if close_q10 is not None and close_q90 is not None: ax_price.fill_between( y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)" ) if close_q50 is not None: ax_price.plot( y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)" ) # 预测起点与阴影 if len(y_timestamp) > 0: ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点") ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08) # 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价 # 合并历史和预测时间范围与价格数据 all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True) hist_close = historical_df["close"] pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp)) all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True) if len(all_timestamps) > 0: start_time = all_timestamps.min() end_time = all_timestamps.max() # 生成所有16:00时间点(UTC+8) current_date = start_time.normalize() # 当天零点 while current_date <= end_time: target_time = current_date + pd.Timedelta(hours=16) # 16:00 if start_time <= target_time <= end_time: # 画虚线 ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5) # 找到最接近16:00的时间点的收盘价 time_diffs = (all_timestamps - target_time).abs() closest_idx = time_diffs.idxmin() closest_time = all_timestamps.iloc[closest_idx] closest_price = all_close.iloc[closest_idx] # 如果时间差不超过1小时,则标注价格 if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price): ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7) ax_price.text( closest_time, closest_price, f' ${closest_price:.1f}', fontsize=9, color='blue', verticalalignment='bottom', horizontalalignment='left', bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7) ) current_date += pd.Timedelta(days=1) ax_price.set_ylabel("价格 (USD)", fontsize=11) ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold") ax_price.grid(True, alpha=0.3) ax_price.legend(loc="best", fontsize=9) # ===== 下:成交量(仅预测中位量能柱)===== ax_vol = fig.add_subplot(gs[1], sharex=ax_price) # 历史量能柱 ax_vol.bar( historical_df["timestamps"], historical_df["volume"], width=bar_width_hist, alpha=0.35, label="历史成交量", ) # 预测中位量能柱(不画区间) if vol_q50 is not None: ax_vol.bar( y_timestamp.values, vol_q50, width=bar_width_pred, alpha=0.6, label="预测成交量中位(q50)", ) # 预测起点线 if len(y_timestamp) > 0: ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1) ax_vol.set_ylabel("成交量", fontsize=11) ax_vol.set_xlabel("时间", fontsize=11) ax_vol.grid(True, alpha=0.25) ax_vol.legend(loc="best", fontsize=9) # 避免X轴标签重叠 plt.setp(ax_price.get_xticklabels(), visible=False) plt.tight_layout() return fig def main(): parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化") parser.add_argument("--symbol", default="BTCUSD") parser.add_argument("--category", default="inverse") parser.add_argument("--interval", default="15") parser.add_argument("--limit", type=int, default=1000) parser.add_argument("--lookback", type=int, default=LOOKBACK) parser.add_argument("--pred_len", type=int, default=PRED_LEN) parser.add_argument("--samples", type=int, default=NUM_SAMPLES) parser.add_argument("--temperature", type=float, default=TEMPERATURE) parser.add_argument("--top_p", type=float, default=TOP_P) parser.add_argument("--quantiles", default="0.1,0.5,0.9") parser.add_argument("--device", default=DEVICE) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() # 载入env & HTTP load_dotenv() api_key = os.getenv("BYBIT_API_KEY") api_secret = os.getenv("BYBIT_API_SECRET") session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret) # Kronos predictor = load_kronos(device=args.device) # 拉数据 df = fetch_kline( session, symbol=args.symbol, interval=args.interval, limit=args.limit, category=args.category, ) x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows( df, lookback=args.lookback, pred_len=args.pred_len, interval_min=int(args.interval), ) # 历史用于画图的窗口(最近200根,可按需调整) plot_start = max(0, end_idx - 200) historical_df = df.loc[plot_start : end_idx - 1].copy() # 采样参数(直接使用局部变量,不需要修改全局变量) temperature = args.temperature top_p = args.top_p qs = [float(x) for x in args.quantiles.split(",") if x.strip()] # 多次预测(传入局部变量) preds = run_multi_predictions( predictor, x_df, x_ts, y_ts, num_samples=args.samples, base_seed=args.seed, temperature=temperature, top_p=top_p ) # 聚合分位 agg_df = aggregate_quantiles(preds, quantiles=qs) agg_df["timestamps"] = y_ts.values # 输出 CSV(聚合) out_csv = os.path.join( OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv" ) agg_df.to_csv(out_csv, index=False) print(f"预测分位数据已保存到: {out_csv}") # 作图 fig = plot_results( historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟" ) out_png = os.path.join( OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png" ) fig.savefig(out_png, dpi=150, bbox_inches="tight") print(f"预测区间图表已保存到: {out_png}") plt.show() if __name__ == "__main__": try: main() except Exception as e: print("运行出错:", repr(e)) sys.exit(1)