kronos-trader/trader.py
2025-10-29 14:11:27 +08:00

420 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
BTCUSD 15m 多次采样预测 → 区间聚合可视化
Author: 你(鹅)
"""
import os
import sys
import math
import json
import time
import random
import argparse
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from dotenv import load_dotenv
from pybit.unified_trading import HTTP
# ==== Kronos ====
from model import Kronos, KronosTokenizer, KronosPredictor
# ========== Matplotlib 字体 ==========
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
plt.rcParams["axes.unicode_minus"] = False
# ========== 默认参数 ==========
NUM_SAMPLES = 20 # 多次采样次数(建议 10~50
QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上)
LOOKBACK = 360 # 历史窗口
PRED_LEN = 96 # 未来预测点数
INTERVAL_MIN = 15 # K 线周期(分钟)
DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0")
TEMPERATURE = 1.0
TOP_P = 0.9
OUTPUT_DIR = "figures"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ========== 实用函数 ==========
def set_global_seed(seed: int):
"""尽可能固定随机性(若底层采样逻辑使用 torch/np/random"""
try:
import torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
except Exception:
pass
random.seed(seed)
np.random.seed(seed)
def fetch_kline(
session: HTTP, symbol="BTCUSD", interval="15", limit=1000, category="inverse"
):
print("正在获取K线数据...")
resp = session.get_kline(
category=category, symbol=symbol, interval=interval, limit=limit
)
if resp.get("retCode", -1) != 0:
raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}")
lst = resp["result"]["list"]
df = pd.DataFrame(
lst,
columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"],
)
df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms")
for c in ["open", "high", "low", "close", "volume", "turnover"]:
df[c] = df[c].astype(float)
df = df.sort_values("timestamps").reset_index(drop=True)
df["amount"] = df["turnover"]
print(f"获取到 {len(df)} 根K线数据")
return df
def prepare_io_windows(
df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN
):
end_idx = len(df)
start_idx = max(0, end_idx - lookback)
x_df = df.loc[
start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"]
].reset_index(drop=True)
x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True)
last_ts = df.loc[end_idx - 1, "timestamps"]
future_timestamps = pd.date_range(
start=last_ts + pd.Timedelta(minutes=interval_min),
periods=pred_len,
freq=f"{interval_min}min",
)
y_timestamp = pd.Series(future_timestamps)
y_timestamp.index = range(len(y_timestamp))
print(f"数据总量: {len(df)} 根K线")
print(f"使用最新的 {lookback} 根K线索引 {start_idx}{end_idx-1}")
print(f"最后一根历史K线时间: {last_ts}")
print(f"预测未来 {pred_len} 根K线")
return x_df, x_timestamp, y_timestamp, start_idx, end_idx
def load_kronos(device=DEVICE):
print("正在加载 Kronos 模型...")
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
print("模型加载完成!\n")
return predictor
def run_one_prediction(
predictor,
x_df,
x_timestamp,
y_timestamp,
T=TEMPERATURE,
top_p=TOP_P,
seed=None,
verbose=False,
):
if seed is not None:
set_global_seed(seed)
return predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=len(y_timestamp),
T=T,
top_p=top_p,
verbose=verbose,
)
def run_multi_predictions(
predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42,
temperature=TEMPERATURE, top_p=TOP_P
):
preds = []
print(f"正在进行多次预测:{num_samples}T={temperature}, top_p={top_p}...")
for i in tqdm(range(num_samples), desc="预测进度", unit="", ncols=100):
seed = base_seed + i # 每次不同 seed
pred_df = run_one_prediction(
predictor,
x_df,
x_timestamp,
y_timestamp,
T=temperature,
top_p=top_p,
seed=seed,
verbose=False,
)
# 兼容性处理:确保只取需要的列
cols_present = [
c
for c in ["open", "high", "low", "close", "volume", "amount"]
if c in pred_df.columns
]
pred_df = pred_df[cols_present].copy()
pred_df.reset_index(drop=True, inplace=True)
preds.append(pred_df)
print("多次预测完成。\n")
return preds
def aggregate_quantiles(pred_list, quantiles=QUANTILES):
"""
将多次预测列表聚合成分位数 DataFrame
输出列命名:<col>_q10, <col>_q50, <col>_q90按 quantiles 中的值)
"""
# 先把每次预测拼成 3Dtime x feature x samples
keys = pred_list[0].columns.tolist()
T_len = len(pred_list[0])
S = len(pred_list)
data = {k: np.zeros((T_len, S), dtype=float) for k in keys}
for j, pdf in enumerate(pred_list):
for k in keys:
data[k][:, j] = pdf[k].values
out = {}
for k in keys:
for q in quantiles:
qv = np.quantile(data[k], q, axis=1)
out[f"{k}_q{int(q*100):02d}"] = qv
agg_df = pd.DataFrame(out)
return agg_df
def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"):
"""
上下两个子图共享X轴
- 上方高度3历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50)
- 下方高度1历史成交量柱 + 预测成交量中位柱仅q50不显示区间
"""
import matplotlib.gridspec as gridspec
# 智能推断K线宽度柱宽
try:
if len(historical_df) >= 2:
hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2])
else:
hist_step = pd.Timedelta(minutes=10)
if len(y_timestamp) >= 2:
pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0])
else:
pred_step = pd.Timedelta(minutes=10)
bar_width_hist = hist_step * 0.8
bar_width_pred = pred_step * 0.8
except Exception:
bar_width_hist = pd.Timedelta(minutes=10)
bar_width_pred = pd.Timedelta(minutes=10)
# 取出预测分位
close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None
close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None
close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None
vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None
# 图形与网格
fig = plt.figure(figsize=(18, 10))
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3
# ===== 上:收盘价 =====
ax_price = fig.add_subplot(gs[0])
ax_price.plot(
historical_df["timestamps"],
historical_df["close"],
label="历史收盘价",
linewidth=1.8,
)
# 预测收盘价区间与中位线
if close_q10 is not None and close_q90 is not None:
ax_price.fill_between(
y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)"
)
if close_q50 is not None:
ax_price.plot(
y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)"
)
# 预测起点与阴影
if len(y_timestamp) > 0:
ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点")
ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08)
# 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价
# 合并历史和预测时间范围与价格数据
all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True)
hist_close = historical_df["close"]
pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp))
all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True)
if len(all_timestamps) > 0:
start_time = all_timestamps.min()
end_time = all_timestamps.max()
# 生成所有16:00时间点UTC+8
current_date = start_time.normalize() # 当天零点
while current_date <= end_time:
target_time = current_date + pd.Timedelta(hours=16) # 16:00
if start_time <= target_time <= end_time:
# 画虚线
ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5)
# 找到最接近16:00的时间点的收盘价
time_diffs = (all_timestamps - target_time).abs()
closest_idx = time_diffs.idxmin()
closest_time = all_timestamps.iloc[closest_idx]
closest_price = all_close.iloc[closest_idx]
# 如果时间差不超过1小时则标注价格
if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price):
ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7)
ax_price.text(
closest_time, closest_price,
f' ${closest_price:.1f}',
fontsize=9,
color='blue',
verticalalignment='bottom',
horizontalalignment='left',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7)
)
current_date += pd.Timedelta(days=1)
ax_price.set_ylabel("价格 (USD)", fontsize=11)
ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold")
ax_price.grid(True, alpha=0.3)
ax_price.legend(loc="best", fontsize=9)
# ===== 下:成交量(仅预测中位量能柱)=====
ax_vol = fig.add_subplot(gs[1], sharex=ax_price)
# 历史量能柱
ax_vol.bar(
historical_df["timestamps"],
historical_df["volume"],
width=bar_width_hist,
alpha=0.35,
label="历史成交量",
)
# 预测中位量能柱(不画区间)
if vol_q50 is not None:
ax_vol.bar(
y_timestamp.values,
vol_q50,
width=bar_width_pred,
alpha=0.6,
label="预测成交量中位(q50)",
)
# 预测起点线
if len(y_timestamp) > 0:
ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1)
ax_vol.set_ylabel("成交量", fontsize=11)
ax_vol.set_xlabel("时间", fontsize=11)
ax_vol.grid(True, alpha=0.25)
ax_vol.legend(loc="best", fontsize=9)
# 避免X轴标签重叠
plt.setp(ax_price.get_xticklabels(), visible=False)
plt.tight_layout()
return fig
def main():
parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化")
parser.add_argument("--symbol", default="BTCUSD")
parser.add_argument("--category", default="inverse")
parser.add_argument("--interval", default="15")
parser.add_argument("--limit", type=int, default=1000)
parser.add_argument("--lookback", type=int, default=LOOKBACK)
parser.add_argument("--pred_len", type=int, default=PRED_LEN)
parser.add_argument("--samples", type=int, default=NUM_SAMPLES)
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
parser.add_argument("--top_p", type=float, default=TOP_P)
parser.add_argument("--quantiles", default="0.1,0.5,0.9")
parser.add_argument("--device", default=DEVICE)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# 载入env & HTTP
load_dotenv()
api_key = os.getenv("BYBIT_API_KEY")
api_secret = os.getenv("BYBIT_API_SECRET")
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
# Kronos
predictor = load_kronos(device=args.device)
# 拉数据
df = fetch_kline(
session,
symbol=args.symbol,
interval=args.interval,
limit=args.limit,
category=args.category,
)
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
df,
lookback=args.lookback,
pred_len=args.pred_len,
interval_min=int(args.interval),
)
# 历史用于画图的窗口最近200根可按需调整
plot_start = max(0, end_idx - 200)
historical_df = df.loc[plot_start : end_idx - 1].copy()
# 采样参数(直接使用局部变量,不需要修改全局变量)
temperature = args.temperature
top_p = args.top_p
qs = [float(x) for x in args.quantiles.split(",") if x.strip()]
# 多次预测(传入局部变量)
preds = run_multi_predictions(
predictor, x_df, x_ts, y_ts, num_samples=args.samples, base_seed=args.seed,
temperature=temperature, top_p=top_p
)
# 聚合分位
agg_df = aggregate_quantiles(preds, quantiles=qs)
agg_df["timestamps"] = y_ts.values
# 输出 CSV聚合
out_csv = os.path.join(
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv"
)
agg_df.to_csv(out_csv, index=False)
print(f"预测分位数据已保存到: {out_csv}")
# 作图
fig = plot_results(
historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟"
)
out_png = os.path.join(
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png"
)
fig.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"预测区间图表已保存到: {out_png}")
plt.show()
if __name__ == "__main__":
try:
main()
except Exception as e:
print("运行出错:", repr(e))
sys.exit(1)