kronos-trader/predict.py

527 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
BTCUSD 15m 多次采样预测 → 区间聚合可视化
Author: 你(鹅)
"""
import os
import sys
import math
import json
import time
import random
import argparse
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from dotenv import load_dotenv
from pybit.unified_trading import HTTP
# ==== Kronos ====
from model import Kronos, KronosTokenizer, KronosPredictor
# ========== Matplotlib 字体 ==========
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
plt.rcParams["axes.unicode_minus"] = False
# ========== 默认参数 ==========
NUM_SAMPLES = 20 # 多次采样次数(建议 10~50
QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上)
LOOKBACK = 360 # 历史窗口
PRED_LEN = 24 # 未来预测点数
INTERVAL_MIN = 60 # K 线周期(分钟)
DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0")
TEMPERATURE = 1.0
TOP_P = 0.9
DRAW_WINDOW_LEN = 240 # 历史用于画图的窗口
OUTPUT_DIR = "figures"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ========== 实用函数 ==========
def set_global_seed(seed: int):
"""尽可能固定随机性(若底层采样逻辑使用 torch/np/random"""
try:
import torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
except Exception:
pass
random.seed(seed)
np.random.seed(seed)
def fetch_kline(
session: HTTP, symbol="BTCUSD", interval="15", limit=500, category="inverse"
):
print("正在获取K线数据...")
resp = session.get_kline(
category=category, symbol=symbol, interval=interval, limit=limit
)
if resp.get("retCode", -1) != 0:
raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}")
lst = resp["result"]["list"]
df = pd.DataFrame(
lst,
columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"],
)
df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms")
for c in ["open", "high", "low", "close", "volume", "turnover"]:
df[c] = df[c].astype(float)
df = df.sort_values("timestamps").reset_index(drop=True)
df["amount"] = df["turnover"]
# 去掉最新一根通常为未完成的当前K线
if len(df) > 0:
df = df.iloc[:-1].reset_index(drop=True)
print(f"获取到 {len(df)} 根K线数据")
return df
def get_recent_kline(
symbol: str = "BTCUSD",
category: str = "inverse",
interval: str = "60",
limit: int = 500,
):
"""便捷封装基于环境变量创建会话并获取最近K线数据。
返回数据按时间升序且会移除最后一根未完成K线与 fetch_kline 保持一致)。
"""
load_dotenv()
api_key = os.getenv("BYBIT_API_KEY")
api_secret = os.getenv("BYBIT_API_SECRET")
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
return fetch_kline(
session=session,
symbol=symbol,
interval=interval,
limit=limit,
category=category,
)
def prepare_io_windows(
df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN
):
end_idx = len(df)
start_idx = max(0, end_idx - lookback)
x_df = df.loc[
start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"]
].reset_index(drop=True)
x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True)
last_ts = df.loc[end_idx - 1, "timestamps"]
future_timestamps = pd.date_range(
start=last_ts + pd.Timedelta(minutes=interval_min),
periods=pred_len,
freq=f"{interval_min}min",
)
y_timestamp = pd.Series(future_timestamps)
y_timestamp.index = range(len(y_timestamp))
print(f"数据总量: {len(df)} 根K线")
print(f"使用最新的 {lookback} 根K线索引 {start_idx}{end_idx-1}")
print(f"最后一根历史K线时间: {last_ts}")
print(f"预测未来 {pred_len} 根K线")
return x_df, x_timestamp, y_timestamp, start_idx, end_idx
def load_kronos(device=DEVICE):
print("正在加载 Kronos 模型...")
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
print("模型加载完成!\n")
return predictor
def run_one_prediction(
predictor,
x_df,
x_timestamp,
y_timestamp,
T=TEMPERATURE,
top_p=TOP_P,
seed=None,
verbose=False,
):
if seed is not None:
set_global_seed(seed)
return predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=len(y_timestamp),
T=T,
top_p=top_p,
verbose=verbose,
)
def run_multi_predictions(
predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42,
temperature=TEMPERATURE, top_p=TOP_P
):
preds = []
print(f"正在进行多次预测:{num_samples}T={temperature}, top_p={top_p}...")
for i in tqdm(range(num_samples), desc="预测进度", unit="", ncols=100):
seed = base_seed + i # 每次不同 seed
pred_df = run_one_prediction(
predictor,
x_df,
x_timestamp,
y_timestamp,
T=temperature,
top_p=top_p,
seed=seed,
verbose=False,
)
# 兼容性处理:确保只取需要的列
cols_present = [
c
for c in ["open", "high", "low", "close", "volume", "amount"]
if c in pred_df.columns
]
pred_df = pred_df[cols_present].copy()
pred_df.reset_index(drop=True, inplace=True)
preds.append(pred_df)
print("多次预测完成。\n")
return preds
def aggregate_quantiles(pred_list, quantiles=QUANTILES):
"""
将多次预测列表聚合成分位数 DataFrame
输出列命名:<col>_q10, <col>_q50, <col>_q90按 quantiles 中的值)
"""
# 先把每次预测拼成 3Dtime x feature x samples
keys = pred_list[0].columns.tolist()
T_len = len(pred_list[0])
S = len(pred_list)
data = {k: np.zeros((T_len, S), dtype=float) for k in keys}
for j, pdf in enumerate(pred_list):
for k in keys:
data[k][:, j] = pdf[k].values
out = {}
for k in keys:
# 极值(用于风险评估)
out[f"{k}_min"] = np.min(data[k], axis=1)
out[f"{k}_max"] = np.max(data[k], axis=1)
for q in quantiles:
qv = np.quantile(data[k], q, axis=1)
out[f"{k}_q{int(q*100):02d}"] = qv
agg_df = pd.DataFrame(out)
return agg_df
def kronos_predict_df(
symbol: str = "BTCUSD",
category: str = "inverse",
interval: str = "60",
limit: int = 500,
lookback: int = LOOKBACK,
pred_len: int = PRED_LEN,
samples: int = NUM_SAMPLES,
temperature: float = TEMPERATURE,
top_p: float = TOP_P,
quantiles: list[float] | None = None,
device: str = DEVICE,
seed: int = 42,
verbose: bool = False,
):
"""
作为模块使用的API返回聚合后的预测DataFrame仅未来部分
返回DataFrame包含
- timestamps: 未来各点时间戳
- <col>_qXX: 各分位数(如 close_q50
- <col>_min/<col>_max: 样本极值(用于风控/盈亏比)
- base_close: 当前基准价最后一根历史K线的收盘价重复列
"""
# 环境 & HTTP会话
load_dotenv()
api_key = os.getenv("BYBIT_API_KEY")
api_secret = os.getenv("BYBIT_API_SECRET")
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
# 加载模型(懒加载一次)
predictor = load_kronos(device=device)
# 拉取K线并准备窗口
df = fetch_kline(
session,
symbol=symbol,
interval=interval,
limit=limit,
category=category,
)
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
df,
lookback=lookback,
pred_len=pred_len,
interval_min=int(interval),
)
# 多次预测
preds = run_multi_predictions(
predictor,
x_df,
x_ts,
y_ts,
num_samples=samples,
base_seed=seed,
temperature=temperature,
top_p=top_p,
)
# 聚合
qs = quantiles if quantiles is not None else QUANTILES
agg_df = aggregate_quantiles(preds, quantiles=qs)
agg_df["timestamps"] = y_ts.values
# 当前价(基准)
base_close = float(df.loc[end_idx - 1, "close"]) if end_idx > 0 else np.nan
agg_df["base_close"] = base_close
# 仅返回DataFrame调用方自行决定是否绘图/保存)
if verbose:
print(
f"完成预测symbol={symbol} interval={interval}m samples={samples} base_close={base_close}"
)
return agg_df
def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"):
"""
上下两个子图共享X轴
- 上方高度3历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50)
- 下方高度1历史成交量柱 + 预测成交量中位柱仅q50不显示区间
"""
import matplotlib.gridspec as gridspec
# 智能推断K线宽度柱宽
try:
if len(historical_df) >= 2:
hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2])
else:
hist_step = pd.Timedelta(minutes=10)
if len(y_timestamp) >= 2:
pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0])
else:
pred_step = pd.Timedelta(minutes=10)
bar_width_hist = hist_step * 0.8
bar_width_pred = pred_step * 0.8
except Exception:
bar_width_hist = pd.Timedelta(minutes=10)
bar_width_pred = pd.Timedelta(minutes=10)
# 取出预测分位
close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None
close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None
close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None
vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None
# 图形与网格
fig = plt.figure(figsize=(18, 10))
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3
# ===== 上:收盘价 =====
ax_price = fig.add_subplot(gs[0])
ax_price.plot(
historical_df["timestamps"],
historical_df["close"],
label="历史收盘价",
linewidth=1.8,
)
# 预测收盘价区间与中位线
if close_q10 is not None and close_q90 is not None:
ax_price.fill_between(
y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)"
)
if close_q50 is not None:
ax_price.plot(
y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)"
)
# 预测起点与阴影
if len(y_timestamp) > 0:
ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点")
ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08)
# 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价
# 合并历史和预测时间范围与价格数据
all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True)
hist_close = historical_df["close"]
pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp))
all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True)
if len(all_timestamps) > 0:
start_time = all_timestamps.min()
end_time = all_timestamps.max()
# 生成所有16:00时间点UTC+8
current_date = start_time.normalize() # 当天零点
while current_date <= end_time:
target_time = current_date + pd.Timedelta(hours=16) # 16:00
if start_time <= target_time <= end_time:
# 画虚线
ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5)
# 找到最接近16:00的时间点的收盘价
time_diffs = (all_timestamps - target_time).abs()
closest_idx = time_diffs.idxmin()
closest_time = all_timestamps.iloc[closest_idx]
closest_price = all_close.iloc[closest_idx]
# 如果时间差不超过1小时则标注价格
if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price):
ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7)
ax_price.text(
closest_time, closest_price,
f' ${closest_price:.1f}',
fontsize=9,
color='blue',
verticalalignment='bottom',
horizontalalignment='left',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7)
)
current_date += pd.Timedelta(days=1)
ax_price.set_ylabel("价格 (USD)", fontsize=11)
ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold")
ax_price.grid(True, alpha=0.3)
ax_price.legend(loc="best", fontsize=9)
# ===== 下:成交量(仅预测中位量能柱)=====
ax_vol = fig.add_subplot(gs[1], sharex=ax_price)
# 历史量能柱
ax_vol.bar(
historical_df["timestamps"],
historical_df["volume"],
width=bar_width_hist,
alpha=0.35,
label="历史成交量",
)
# 预测中位量能柱(不画区间)
if vol_q50 is not None:
ax_vol.bar(
y_timestamp.values,
vol_q50,
width=bar_width_pred,
alpha=0.6,
label="预测成交量中位(q50)",
)
# 预测起点线
if len(y_timestamp) > 0:
ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1)
ax_vol.set_ylabel("成交量", fontsize=11)
ax_vol.set_xlabel("时间", fontsize=11)
ax_vol.grid(True, alpha=0.25)
ax_vol.legend(loc="best", fontsize=9)
# 避免X轴标签重叠
plt.setp(ax_price.get_xticklabels(), visible=False)
plt.tight_layout()
return fig
def main():
parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化")
parser.add_argument("--symbol", default="BTCUSD")
parser.add_argument("--category", default="inverse")
parser.add_argument("--interval", default="60")
parser.add_argument("--limit", type=int, default=500)
parser.add_argument("--lookback", type=int, default=LOOKBACK)
parser.add_argument("--pred_len", type=int, default=PRED_LEN)
parser.add_argument("--samples", type=int, default=NUM_SAMPLES)
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
parser.add_argument("--top_p", type=float, default=TOP_P)
parser.add_argument("--quantiles", default="0.1,0.5,0.9")
parser.add_argument("--device", default=DEVICE)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# 载入env & HTTP
load_dotenv()
api_key = os.getenv("BYBIT_API_KEY")
api_secret = os.getenv("BYBIT_API_SECRET")
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
# 为命令行模式执行完整流程
predictor = load_kronos(device=args.device)
df = fetch_kline(
session,
symbol=args.symbol,
interval=args.interval,
limit=args.limit,
category=args.category,
)
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
df,
lookback=args.lookback,
pred_len=args.pred_len,
interval_min=int(args.interval),
)
plot_start = max(0, end_idx - DRAW_WINDOW_LEN)
historical_df = df.loc[plot_start : end_idx - 1].copy()
temperature = args.temperature
top_p = args.top_p
qs = [float(x) for x in args.quantiles.split(",") if x.strip()]
preds = run_multi_predictions(
predictor,
x_df,
x_ts,
y_ts,
num_samples=args.samples,
base_seed=args.seed,
temperature=temperature,
top_p=top_p,
)
agg_df = aggregate_quantiles(preds, quantiles=qs)
agg_df["timestamps"] = y_ts.values
# 输出 CSV聚合
out_csv = os.path.join(
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv"
)
agg_df.to_csv(out_csv, index=False)
print(f"预测分位数据已保存到: {out_csv}")
# 作图
fig = plot_results(
historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟"
)
out_png = os.path.join(
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png"
)
fig.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"预测区间图表已保存到: {out_png}")
plt.show()
if __name__ == "__main__":
try:
main()
except Exception as e:
print("运行出错:", repr(e))
sys.exit(1)