527 lines
17 KiB
Python
527 lines
17 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
BTCUSD 15m 多次采样预测 → 区间聚合可视化
|
||
Author: 你(鹅)
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import math
|
||
import json
|
||
import time
|
||
import random
|
||
import argparse
|
||
from datetime import datetime
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
from tqdm import tqdm
|
||
|
||
from dotenv import load_dotenv
|
||
from pybit.unified_trading import HTTP
|
||
|
||
# ==== Kronos ====
|
||
from model import Kronos, KronosTokenizer, KronosPredictor
|
||
|
||
# ========== Matplotlib 字体 ==========
|
||
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
|
||
plt.rcParams["axes.unicode_minus"] = False
|
||
|
||
# ========== 默认参数 ==========
|
||
NUM_SAMPLES = 20 # 多次采样次数(建议 10~50)
|
||
QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上)
|
||
LOOKBACK = 360 # 历史窗口
|
||
PRED_LEN = 24 # 未来预测点数
|
||
INTERVAL_MIN = 60 # K 线周期(分钟)
|
||
DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0")
|
||
TEMPERATURE = 1.0
|
||
TOP_P = 0.9
|
||
DRAW_WINDOW_LEN = 240 # 历史用于画图的窗口
|
||
|
||
OUTPUT_DIR = "figures"
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
|
||
# ========== 实用函数 ==========
|
||
def set_global_seed(seed: int):
|
||
"""尽可能固定随机性(若底层采样逻辑使用 torch/np/random)"""
|
||
try:
|
||
import torch
|
||
|
||
torch.manual_seed(seed)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed_all(seed)
|
||
except Exception:
|
||
pass
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
|
||
|
||
def fetch_kline(
|
||
session: HTTP, symbol="BTCUSD", interval="15", limit=500, category="inverse"
|
||
):
|
||
print("正在获取K线数据...")
|
||
resp = session.get_kline(
|
||
category=category, symbol=symbol, interval=interval, limit=limit
|
||
)
|
||
if resp.get("retCode", -1) != 0:
|
||
raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}")
|
||
lst = resp["result"]["list"]
|
||
|
||
df = pd.DataFrame(
|
||
lst,
|
||
columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"],
|
||
)
|
||
df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms")
|
||
for c in ["open", "high", "low", "close", "volume", "turnover"]:
|
||
df[c] = df[c].astype(float)
|
||
df = df.sort_values("timestamps").reset_index(drop=True)
|
||
df["amount"] = df["turnover"]
|
||
# 去掉最新一根(通常为未完成的当前K线)
|
||
if len(df) > 0:
|
||
df = df.iloc[:-1].reset_index(drop=True)
|
||
print(f"获取到 {len(df)} 根K线数据")
|
||
return df
|
||
|
||
|
||
def get_recent_kline(
|
||
symbol: str = "BTCUSD",
|
||
category: str = "inverse",
|
||
interval: str = "60",
|
||
limit: int = 500,
|
||
):
|
||
"""便捷封装:基于环境变量创建会话并获取最近K线数据。
|
||
|
||
返回数据按时间升序,且会移除最后一根未完成K线(与 fetch_kline 保持一致)。
|
||
"""
|
||
load_dotenv()
|
||
api_key = os.getenv("BYBIT_API_KEY")
|
||
api_secret = os.getenv("BYBIT_API_SECRET")
|
||
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
|
||
return fetch_kline(
|
||
session=session,
|
||
symbol=symbol,
|
||
interval=interval,
|
||
limit=limit,
|
||
category=category,
|
||
)
|
||
|
||
|
||
def prepare_io_windows(
|
||
df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN
|
||
):
|
||
end_idx = len(df)
|
||
start_idx = max(0, end_idx - lookback)
|
||
x_df = df.loc[
|
||
start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"]
|
||
].reset_index(drop=True)
|
||
x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True)
|
||
|
||
last_ts = df.loc[end_idx - 1, "timestamps"]
|
||
future_timestamps = pd.date_range(
|
||
start=last_ts + pd.Timedelta(minutes=interval_min),
|
||
periods=pred_len,
|
||
freq=f"{interval_min}min",
|
||
)
|
||
y_timestamp = pd.Series(future_timestamps)
|
||
y_timestamp.index = range(len(y_timestamp))
|
||
|
||
print(f"数据总量: {len(df)} 根K线")
|
||
print(f"使用最新的 {lookback} 根K线(索引 {start_idx} 到 {end_idx-1})")
|
||
print(f"最后一根历史K线时间: {last_ts}")
|
||
print(f"预测未来 {pred_len} 根K线")
|
||
return x_df, x_timestamp, y_timestamp, start_idx, end_idx
|
||
|
||
|
||
def load_kronos(device=DEVICE):
|
||
print("正在加载 Kronos 模型...")
|
||
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
|
||
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
|
||
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
|
||
print("模型加载完成!\n")
|
||
return predictor
|
||
|
||
|
||
def run_one_prediction(
|
||
predictor,
|
||
x_df,
|
||
x_timestamp,
|
||
y_timestamp,
|
||
T=TEMPERATURE,
|
||
top_p=TOP_P,
|
||
seed=None,
|
||
verbose=False,
|
||
):
|
||
if seed is not None:
|
||
set_global_seed(seed)
|
||
return predictor.predict(
|
||
df=x_df,
|
||
x_timestamp=x_timestamp,
|
||
y_timestamp=y_timestamp,
|
||
pred_len=len(y_timestamp),
|
||
T=T,
|
||
top_p=top_p,
|
||
verbose=verbose,
|
||
)
|
||
|
||
|
||
def run_multi_predictions(
|
||
predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42,
|
||
temperature=TEMPERATURE, top_p=TOP_P
|
||
):
|
||
preds = []
|
||
print(f"正在进行多次预测:{num_samples} 次(T={temperature}, top_p={top_p})...")
|
||
for i in tqdm(range(num_samples), desc="预测进度", unit="次", ncols=100):
|
||
seed = base_seed + i # 每次不同 seed
|
||
pred_df = run_one_prediction(
|
||
predictor,
|
||
x_df,
|
||
x_timestamp,
|
||
y_timestamp,
|
||
T=temperature,
|
||
top_p=top_p,
|
||
seed=seed,
|
||
verbose=False,
|
||
)
|
||
# 兼容性处理:确保只取需要的列
|
||
cols_present = [
|
||
c
|
||
for c in ["open", "high", "low", "close", "volume", "amount"]
|
||
if c in pred_df.columns
|
||
]
|
||
pred_df = pred_df[cols_present].copy()
|
||
pred_df.reset_index(drop=True, inplace=True)
|
||
preds.append(pred_df)
|
||
print("多次预测完成。\n")
|
||
return preds
|
||
|
||
|
||
def aggregate_quantiles(pred_list, quantiles=QUANTILES):
|
||
"""
|
||
将多次预测列表聚合成分位数 DataFrame:
|
||
输出列命名:<col>_q10, <col>_q50, <col>_q90(按 quantiles 中的值)
|
||
"""
|
||
# 先把每次预测拼成 3D:time x feature x samples
|
||
keys = pred_list[0].columns.tolist()
|
||
T_len = len(pred_list[0])
|
||
S = len(pred_list)
|
||
data = {k: np.zeros((T_len, S), dtype=float) for k in keys}
|
||
for j, pdf in enumerate(pred_list):
|
||
for k in keys:
|
||
data[k][:, j] = pdf[k].values
|
||
|
||
out = {}
|
||
for k in keys:
|
||
# 极值(用于风险评估)
|
||
out[f"{k}_min"] = np.min(data[k], axis=1)
|
||
out[f"{k}_max"] = np.max(data[k], axis=1)
|
||
for q in quantiles:
|
||
qv = np.quantile(data[k], q, axis=1)
|
||
out[f"{k}_q{int(q*100):02d}"] = qv
|
||
agg_df = pd.DataFrame(out)
|
||
return agg_df
|
||
|
||
|
||
def kronos_predict_df(
|
||
symbol: str = "BTCUSD",
|
||
category: str = "inverse",
|
||
interval: str = "60",
|
||
limit: int = 500,
|
||
lookback: int = LOOKBACK,
|
||
pred_len: int = PRED_LEN,
|
||
samples: int = NUM_SAMPLES,
|
||
temperature: float = TEMPERATURE,
|
||
top_p: float = TOP_P,
|
||
quantiles: list[float] | None = None,
|
||
device: str = DEVICE,
|
||
seed: int = 42,
|
||
verbose: bool = False,
|
||
):
|
||
"""
|
||
作为模块使用的API:返回聚合后的预测DataFrame(仅未来部分)。
|
||
|
||
返回DataFrame包含:
|
||
- timestamps: 未来各点时间戳
|
||
- <col>_qXX: 各分位数(如 close_q50)
|
||
- <col>_min/<col>_max: 样本极值(用于风控/盈亏比)
|
||
- base_close: 当前基准价(最后一根历史K线的收盘价,重复列)
|
||
"""
|
||
# 环境 & HTTP会话
|
||
load_dotenv()
|
||
api_key = os.getenv("BYBIT_API_KEY")
|
||
api_secret = os.getenv("BYBIT_API_SECRET")
|
||
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
|
||
|
||
# 加载模型(懒加载一次)
|
||
predictor = load_kronos(device=device)
|
||
|
||
# 拉取K线并准备窗口
|
||
df = fetch_kline(
|
||
session,
|
||
symbol=symbol,
|
||
interval=interval,
|
||
limit=limit,
|
||
category=category,
|
||
)
|
||
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
|
||
df,
|
||
lookback=lookback,
|
||
pred_len=pred_len,
|
||
interval_min=int(interval),
|
||
)
|
||
|
||
# 多次预测
|
||
preds = run_multi_predictions(
|
||
predictor,
|
||
x_df,
|
||
x_ts,
|
||
y_ts,
|
||
num_samples=samples,
|
||
base_seed=seed,
|
||
temperature=temperature,
|
||
top_p=top_p,
|
||
)
|
||
|
||
# 聚合
|
||
qs = quantiles if quantiles is not None else QUANTILES
|
||
agg_df = aggregate_quantiles(preds, quantiles=qs)
|
||
agg_df["timestamps"] = y_ts.values
|
||
|
||
# 当前价(基准)
|
||
base_close = float(df.loc[end_idx - 1, "close"]) if end_idx > 0 else np.nan
|
||
agg_df["base_close"] = base_close
|
||
|
||
# 仅返回DataFrame(调用方自行决定是否绘图/保存)
|
||
if verbose:
|
||
print(
|
||
f"完成预测:symbol={symbol} interval={interval}m samples={samples} base_close={base_close}"
|
||
)
|
||
return agg_df
|
||
|
||
|
||
def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"):
|
||
"""
|
||
上下两个子图(共享X轴):
|
||
- 上方(高度3):历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50)
|
||
- 下方(高度1):历史成交量柱 + 预测成交量中位柱(仅q50,不显示区间)
|
||
"""
|
||
import matplotlib.gridspec as gridspec
|
||
|
||
# 智能推断K线宽度(柱宽)
|
||
try:
|
||
if len(historical_df) >= 2:
|
||
hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2])
|
||
else:
|
||
hist_step = pd.Timedelta(minutes=10)
|
||
if len(y_timestamp) >= 2:
|
||
pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0])
|
||
else:
|
||
pred_step = pd.Timedelta(minutes=10)
|
||
bar_width_hist = hist_step * 0.8
|
||
bar_width_pred = pred_step * 0.8
|
||
except Exception:
|
||
bar_width_hist = pd.Timedelta(minutes=10)
|
||
bar_width_pred = pd.Timedelta(minutes=10)
|
||
|
||
# 取出预测分位
|
||
close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None
|
||
close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None
|
||
close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None
|
||
|
||
vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None
|
||
|
||
# 图形与网格
|
||
fig = plt.figure(figsize=(18, 10))
|
||
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3
|
||
|
||
# ===== 上:收盘价 =====
|
||
ax_price = fig.add_subplot(gs[0])
|
||
|
||
ax_price.plot(
|
||
historical_df["timestamps"],
|
||
historical_df["close"],
|
||
label="历史收盘价",
|
||
linewidth=1.8,
|
||
)
|
||
|
||
# 预测收盘价区间与中位线
|
||
if close_q10 is not None and close_q90 is not None:
|
||
ax_price.fill_between(
|
||
y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)"
|
||
)
|
||
if close_q50 is not None:
|
||
ax_price.plot(
|
||
y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)"
|
||
)
|
||
|
||
# 预测起点与阴影
|
||
if len(y_timestamp) > 0:
|
||
ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点")
|
||
ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08)
|
||
|
||
# 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价
|
||
# 合并历史和预测时间范围与价格数据
|
||
all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True)
|
||
hist_close = historical_df["close"]
|
||
pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp))
|
||
all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True)
|
||
|
||
if len(all_timestamps) > 0:
|
||
start_time = all_timestamps.min()
|
||
end_time = all_timestamps.max()
|
||
|
||
# 生成所有16:00时间点(UTC+8)
|
||
current_date = start_time.normalize() # 当天零点
|
||
while current_date <= end_time:
|
||
target_time = current_date + pd.Timedelta(hours=16) # 16:00
|
||
if start_time <= target_time <= end_time:
|
||
# 画虚线
|
||
ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5)
|
||
|
||
# 找到最接近16:00的时间点的收盘价
|
||
time_diffs = (all_timestamps - target_time).abs()
|
||
closest_idx = time_diffs.idxmin()
|
||
closest_time = all_timestamps.iloc[closest_idx]
|
||
closest_price = all_close.iloc[closest_idx]
|
||
|
||
# 如果时间差不超过1小时,则标注价格
|
||
if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price):
|
||
ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7)
|
||
ax_price.text(
|
||
closest_time, closest_price,
|
||
f' ${closest_price:.1f}',
|
||
fontsize=9,
|
||
color='blue',
|
||
verticalalignment='bottom',
|
||
horizontalalignment='left',
|
||
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7)
|
||
)
|
||
current_date += pd.Timedelta(days=1)
|
||
|
||
ax_price.set_ylabel("价格 (USD)", fontsize=11)
|
||
ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold")
|
||
ax_price.grid(True, alpha=0.3)
|
||
ax_price.legend(loc="best", fontsize=9)
|
||
|
||
# ===== 下:成交量(仅预测中位量能柱)=====
|
||
ax_vol = fig.add_subplot(gs[1], sharex=ax_price)
|
||
|
||
# 历史量能柱
|
||
ax_vol.bar(
|
||
historical_df["timestamps"],
|
||
historical_df["volume"],
|
||
width=bar_width_hist,
|
||
alpha=0.35,
|
||
label="历史成交量",
|
||
)
|
||
|
||
# 预测中位量能柱(不画区间)
|
||
if vol_q50 is not None:
|
||
ax_vol.bar(
|
||
y_timestamp.values,
|
||
vol_q50,
|
||
width=bar_width_pred,
|
||
alpha=0.6,
|
||
label="预测成交量中位(q50)",
|
||
)
|
||
|
||
# 预测起点线
|
||
if len(y_timestamp) > 0:
|
||
ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1)
|
||
|
||
ax_vol.set_ylabel("成交量", fontsize=11)
|
||
ax_vol.set_xlabel("时间", fontsize=11)
|
||
ax_vol.grid(True, alpha=0.25)
|
||
ax_vol.legend(loc="best", fontsize=9)
|
||
|
||
# 避免X轴标签重叠
|
||
plt.setp(ax_price.get_xticklabels(), visible=False)
|
||
plt.tight_layout()
|
||
return fig
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化")
|
||
parser.add_argument("--symbol", default="BTCUSD")
|
||
parser.add_argument("--category", default="inverse")
|
||
parser.add_argument("--interval", default="60")
|
||
parser.add_argument("--limit", type=int, default=500)
|
||
parser.add_argument("--lookback", type=int, default=LOOKBACK)
|
||
parser.add_argument("--pred_len", type=int, default=PRED_LEN)
|
||
parser.add_argument("--samples", type=int, default=NUM_SAMPLES)
|
||
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
|
||
parser.add_argument("--top_p", type=float, default=TOP_P)
|
||
parser.add_argument("--quantiles", default="0.1,0.5,0.9")
|
||
parser.add_argument("--device", default=DEVICE)
|
||
parser.add_argument("--seed", type=int, default=42)
|
||
args = parser.parse_args()
|
||
|
||
# 载入env & HTTP
|
||
load_dotenv()
|
||
api_key = os.getenv("BYBIT_API_KEY")
|
||
api_secret = os.getenv("BYBIT_API_SECRET")
|
||
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
|
||
# 为命令行模式执行完整流程
|
||
predictor = load_kronos(device=args.device)
|
||
df = fetch_kline(
|
||
session,
|
||
symbol=args.symbol,
|
||
interval=args.interval,
|
||
limit=args.limit,
|
||
category=args.category,
|
||
)
|
||
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
|
||
df,
|
||
lookback=args.lookback,
|
||
pred_len=args.pred_len,
|
||
interval_min=int(args.interval),
|
||
)
|
||
|
||
plot_start = max(0, end_idx - DRAW_WINDOW_LEN)
|
||
historical_df = df.loc[plot_start : end_idx - 1].copy()
|
||
|
||
temperature = args.temperature
|
||
top_p = args.top_p
|
||
qs = [float(x) for x in args.quantiles.split(",") if x.strip()]
|
||
|
||
preds = run_multi_predictions(
|
||
predictor,
|
||
x_df,
|
||
x_ts,
|
||
y_ts,
|
||
num_samples=args.samples,
|
||
base_seed=args.seed,
|
||
temperature=temperature,
|
||
top_p=top_p,
|
||
)
|
||
|
||
agg_df = aggregate_quantiles(preds, quantiles=qs)
|
||
agg_df["timestamps"] = y_ts.values
|
||
|
||
# 输出 CSV(聚合)
|
||
out_csv = os.path.join(
|
||
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv"
|
||
)
|
||
agg_df.to_csv(out_csv, index=False)
|
||
print(f"预测分位数据已保存到: {out_csv}")
|
||
|
||
# 作图
|
||
fig = plot_results(
|
||
historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟"
|
||
)
|
||
out_png = os.path.join(
|
||
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png"
|
||
)
|
||
fig.savefig(out_png, dpi=150, bbox_inches="tight")
|
||
print(f"预测区间图表已保存到: {out_png}")
|
||
|
||
plt.show()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
main()
|
||
except Exception as e:
|
||
print("运行出错:", repr(e))
|
||
sys.exit(1)
|