新增盈亏比计算和消息推送

This commit is contained in:
feie9456 2025-10-30 11:59:03 +08:00
parent 2b5bcc55f8
commit 0771ca73a0
4 changed files with 1016 additions and 367 deletions

526
predict.py Normal file
View File

@ -0,0 +1,526 @@
# -*- coding: utf-8 -*-
"""
BTCUSD 15m 多次采样预测 区间聚合可视化
Author:
"""
import os
import sys
import math
import json
import time
import random
import argparse
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from dotenv import load_dotenv
from pybit.unified_trading import HTTP
# ==== Kronos ====
from model import Kronos, KronosTokenizer, KronosPredictor
# ========== Matplotlib 字体 ==========
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
plt.rcParams["axes.unicode_minus"] = False
# ========== 默认参数 ==========
NUM_SAMPLES = 20 # 多次采样次数(建议 10~50
QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上)
LOOKBACK = 360 # 历史窗口
PRED_LEN = 24 # 未来预测点数
INTERVAL_MIN = 60 # K 线周期(分钟)
DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0")
TEMPERATURE = 1.0
TOP_P = 0.9
DRAW_WINDOW_LEN = 240 # 历史用于画图的窗口
OUTPUT_DIR = "figures"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ========== 实用函数 ==========
def set_global_seed(seed: int):
"""尽可能固定随机性(若底层采样逻辑使用 torch/np/random"""
try:
import torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
except Exception:
pass
random.seed(seed)
np.random.seed(seed)
def fetch_kline(
session: HTTP, symbol="BTCUSD", interval="15", limit=500, category="inverse"
):
print("正在获取K线数据...")
resp = session.get_kline(
category=category, symbol=symbol, interval=interval, limit=limit
)
if resp.get("retCode", -1) != 0:
raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}")
lst = resp["result"]["list"]
df = pd.DataFrame(
lst,
columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"],
)
df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms")
for c in ["open", "high", "low", "close", "volume", "turnover"]:
df[c] = df[c].astype(float)
df = df.sort_values("timestamps").reset_index(drop=True)
df["amount"] = df["turnover"]
# 去掉最新一根通常为未完成的当前K线
if len(df) > 0:
df = df.iloc[:-1].reset_index(drop=True)
print(f"获取到 {len(df)} 根K线数据")
return df
def get_recent_kline(
symbol: str = "BTCUSD",
category: str = "inverse",
interval: str = "60",
limit: int = 500,
):
"""便捷封装基于环境变量创建会话并获取最近K线数据。
返回数据按时间升序且会移除最后一根未完成K线 fetch_kline 保持一致
"""
load_dotenv()
api_key = os.getenv("BYBIT_API_KEY")
api_secret = os.getenv("BYBIT_API_SECRET")
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
return fetch_kline(
session=session,
symbol=symbol,
interval=interval,
limit=limit,
category=category,
)
def prepare_io_windows(
df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN
):
end_idx = len(df)
start_idx = max(0, end_idx - lookback)
x_df = df.loc[
start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"]
].reset_index(drop=True)
x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True)
last_ts = df.loc[end_idx - 1, "timestamps"]
future_timestamps = pd.date_range(
start=last_ts + pd.Timedelta(minutes=interval_min),
periods=pred_len,
freq=f"{interval_min}min",
)
y_timestamp = pd.Series(future_timestamps)
y_timestamp.index = range(len(y_timestamp))
print(f"数据总量: {len(df)} 根K线")
print(f"使用最新的 {lookback} 根K线索引 {start_idx}{end_idx-1}")
print(f"最后一根历史K线时间: {last_ts}")
print(f"预测未来 {pred_len} 根K线")
return x_df, x_timestamp, y_timestamp, start_idx, end_idx
def load_kronos(device=DEVICE):
print("正在加载 Kronos 模型...")
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
print("模型加载完成!\n")
return predictor
def run_one_prediction(
predictor,
x_df,
x_timestamp,
y_timestamp,
T=TEMPERATURE,
top_p=TOP_P,
seed=None,
verbose=False,
):
if seed is not None:
set_global_seed(seed)
return predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=len(y_timestamp),
T=T,
top_p=top_p,
verbose=verbose,
)
def run_multi_predictions(
predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42,
temperature=TEMPERATURE, top_p=TOP_P
):
preds = []
print(f"正在进行多次预测:{num_samples}T={temperature}, top_p={top_p}...")
for i in tqdm(range(num_samples), desc="预测进度", unit="", ncols=100):
seed = base_seed + i # 每次不同 seed
pred_df = run_one_prediction(
predictor,
x_df,
x_timestamp,
y_timestamp,
T=temperature,
top_p=top_p,
seed=seed,
verbose=False,
)
# 兼容性处理:确保只取需要的列
cols_present = [
c
for c in ["open", "high", "low", "close", "volume", "amount"]
if c in pred_df.columns
]
pred_df = pred_df[cols_present].copy()
pred_df.reset_index(drop=True, inplace=True)
preds.append(pred_df)
print("多次预测完成。\n")
return preds
def aggregate_quantiles(pred_list, quantiles=QUANTILES):
"""
将多次预测列表聚合成分位数 DataFrame
输出列命名<col>_q10, <col>_q50, <col>_q90 quantiles 中的值
"""
# 先把每次预测拼成 3Dtime x feature x samples
keys = pred_list[0].columns.tolist()
T_len = len(pred_list[0])
S = len(pred_list)
data = {k: np.zeros((T_len, S), dtype=float) for k in keys}
for j, pdf in enumerate(pred_list):
for k in keys:
data[k][:, j] = pdf[k].values
out = {}
for k in keys:
# 极值(用于风险评估)
out[f"{k}_min"] = np.min(data[k], axis=1)
out[f"{k}_max"] = np.max(data[k], axis=1)
for q in quantiles:
qv = np.quantile(data[k], q, axis=1)
out[f"{k}_q{int(q*100):02d}"] = qv
agg_df = pd.DataFrame(out)
return agg_df
def kronos_predict_df(
symbol: str = "BTCUSD",
category: str = "inverse",
interval: str = "60",
limit: int = 500,
lookback: int = LOOKBACK,
pred_len: int = PRED_LEN,
samples: int = NUM_SAMPLES,
temperature: float = TEMPERATURE,
top_p: float = TOP_P,
quantiles: list[float] | None = None,
device: str = DEVICE,
seed: int = 42,
verbose: bool = False,
):
"""
作为模块使用的API返回聚合后的预测DataFrame仅未来部分
返回DataFrame包含
- timestamps: 未来各点时间戳
- <col>_qXX: 各分位数 close_q50
- <col>_min/<col>_max: 样本极值用于风控/盈亏比
- base_close: 当前基准价最后一根历史K线的收盘价重复列
"""
# 环境 & HTTP会话
load_dotenv()
api_key = os.getenv("BYBIT_API_KEY")
api_secret = os.getenv("BYBIT_API_SECRET")
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
# 加载模型(懒加载一次)
predictor = load_kronos(device=device)
# 拉取K线并准备窗口
df = fetch_kline(
session,
symbol=symbol,
interval=interval,
limit=limit,
category=category,
)
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
df,
lookback=lookback,
pred_len=pred_len,
interval_min=int(interval),
)
# 多次预测
preds = run_multi_predictions(
predictor,
x_df,
x_ts,
y_ts,
num_samples=samples,
base_seed=seed,
temperature=temperature,
top_p=top_p,
)
# 聚合
qs = quantiles if quantiles is not None else QUANTILES
agg_df = aggregate_quantiles(preds, quantiles=qs)
agg_df["timestamps"] = y_ts.values
# 当前价(基准)
base_close = float(df.loc[end_idx - 1, "close"]) if end_idx > 0 else np.nan
agg_df["base_close"] = base_close
# 仅返回DataFrame调用方自行决定是否绘图/保存)
if verbose:
print(
f"完成预测symbol={symbol} interval={interval}m samples={samples} base_close={base_close}"
)
return agg_df
def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"):
"""
上下两个子图共享X轴
- 上方高度3历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50)
- 下方高度1历史成交量柱 + 预测成交量中位柱仅q50不显示区间
"""
import matplotlib.gridspec as gridspec
# 智能推断K线宽度柱宽
try:
if len(historical_df) >= 2:
hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2])
else:
hist_step = pd.Timedelta(minutes=10)
if len(y_timestamp) >= 2:
pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0])
else:
pred_step = pd.Timedelta(minutes=10)
bar_width_hist = hist_step * 0.8
bar_width_pred = pred_step * 0.8
except Exception:
bar_width_hist = pd.Timedelta(minutes=10)
bar_width_pred = pd.Timedelta(minutes=10)
# 取出预测分位
close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None
close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None
close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None
vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None
# 图形与网格
fig = plt.figure(figsize=(18, 10))
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3
# ===== 上:收盘价 =====
ax_price = fig.add_subplot(gs[0])
ax_price.plot(
historical_df["timestamps"],
historical_df["close"],
label="历史收盘价",
linewidth=1.8,
)
# 预测收盘价区间与中位线
if close_q10 is not None and close_q90 is not None:
ax_price.fill_between(
y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)"
)
if close_q50 is not None:
ax_price.plot(
y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)"
)
# 预测起点与阴影
if len(y_timestamp) > 0:
ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点")
ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08)
# 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价
# 合并历史和预测时间范围与价格数据
all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True)
hist_close = historical_df["close"]
pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp))
all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True)
if len(all_timestamps) > 0:
start_time = all_timestamps.min()
end_time = all_timestamps.max()
# 生成所有16:00时间点UTC+8
current_date = start_time.normalize() # 当天零点
while current_date <= end_time:
target_time = current_date + pd.Timedelta(hours=16) # 16:00
if start_time <= target_time <= end_time:
# 画虚线
ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5)
# 找到最接近16:00的时间点的收盘价
time_diffs = (all_timestamps - target_time).abs()
closest_idx = time_diffs.idxmin()
closest_time = all_timestamps.iloc[closest_idx]
closest_price = all_close.iloc[closest_idx]
# 如果时间差不超过1小时则标注价格
if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price):
ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7)
ax_price.text(
closest_time, closest_price,
f' ${closest_price:.1f}',
fontsize=9,
color='blue',
verticalalignment='bottom',
horizontalalignment='left',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7)
)
current_date += pd.Timedelta(days=1)
ax_price.set_ylabel("价格 (USD)", fontsize=11)
ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold")
ax_price.grid(True, alpha=0.3)
ax_price.legend(loc="best", fontsize=9)
# ===== 下:成交量(仅预测中位量能柱)=====
ax_vol = fig.add_subplot(gs[1], sharex=ax_price)
# 历史量能柱
ax_vol.bar(
historical_df["timestamps"],
historical_df["volume"],
width=bar_width_hist,
alpha=0.35,
label="历史成交量",
)
# 预测中位量能柱(不画区间)
if vol_q50 is not None:
ax_vol.bar(
y_timestamp.values,
vol_q50,
width=bar_width_pred,
alpha=0.6,
label="预测成交量中位(q50)",
)
# 预测起点线
if len(y_timestamp) > 0:
ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1)
ax_vol.set_ylabel("成交量", fontsize=11)
ax_vol.set_xlabel("时间", fontsize=11)
ax_vol.grid(True, alpha=0.25)
ax_vol.legend(loc="best", fontsize=9)
# 避免X轴标签重叠
plt.setp(ax_price.get_xticklabels(), visible=False)
plt.tight_layout()
return fig
def main():
parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化")
parser.add_argument("--symbol", default="BTCUSD")
parser.add_argument("--category", default="inverse")
parser.add_argument("--interval", default="60")
parser.add_argument("--limit", type=int, default=500)
parser.add_argument("--lookback", type=int, default=LOOKBACK)
parser.add_argument("--pred_len", type=int, default=PRED_LEN)
parser.add_argument("--samples", type=int, default=NUM_SAMPLES)
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
parser.add_argument("--top_p", type=float, default=TOP_P)
parser.add_argument("--quantiles", default="0.1,0.5,0.9")
parser.add_argument("--device", default=DEVICE)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# 载入env & HTTP
load_dotenv()
api_key = os.getenv("BYBIT_API_KEY")
api_secret = os.getenv("BYBIT_API_SECRET")
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
# 为命令行模式执行完整流程
predictor = load_kronos(device=args.device)
df = fetch_kline(
session,
symbol=args.symbol,
interval=args.interval,
limit=args.limit,
category=args.category,
)
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
df,
lookback=args.lookback,
pred_len=args.pred_len,
interval_min=int(args.interval),
)
plot_start = max(0, end_idx - DRAW_WINDOW_LEN)
historical_df = df.loc[plot_start : end_idx - 1].copy()
temperature = args.temperature
top_p = args.top_p
qs = [float(x) for x in args.quantiles.split(",") if x.strip()]
preds = run_multi_predictions(
predictor,
x_df,
x_ts,
y_ts,
num_samples=args.samples,
base_seed=args.seed,
temperature=temperature,
top_p=top_p,
)
agg_df = aggregate_quantiles(preds, quantiles=qs)
agg_df["timestamps"] = y_ts.values
# 输出 CSV聚合
out_csv = os.path.join(
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv"
)
agg_df.to_csv(out_csv, index=False)
print(f"预测分位数据已保存到: {out_csv}")
# 作图
fig = plot_results(
historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟"
)
out_png = os.path.join(
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png"
)
fig.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"预测区间图表已保存到: {out_png}")
plt.show()
if __name__ == "__main__":
try:
main()
except Exception as e:
print("运行出错:", repr(e))
sys.exit(1)

132
qq.py Normal file
View File

@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
"""
轻量的 QQ Bot 客户端用于发送私聊文本和媒体
使用环境变量
- QQ_BOT_URL: bot 接收地址例如 http://localhost:30000/send_private_msg
- QQ_BOT_TARGET_ID: 默认目标 QQ id可被函数参数覆盖
示例
from qq import send_msg, send_media_msg
send_msg('hello')
# 发送图片时将自动转换为 "base64://..." 格式
send_media_msg('figures/btcusd_interval_pred_60m.png', type='image')
"""
from __future__ import annotations
import os
import json
from typing import Optional
import base64
try:
import requests
except Exception: # pragma: no cover - runtime dependency
requests = None
def _bot_url() -> str:
return os.getenv("QQ_BOT_URL", "http://localhost:30000/send_private_msg")
def _default_target() -> Optional[str]:
return os.getenv("QQ_BOT_TARGET_ID")
def _file_to_base64_uri(path: str) -> str:
"""将本地文件读取为 base64 uri 字符串(前缀 base64://)。"""
with open(path, "rb") as f:
data = f.read()
b64 = base64.b64encode(data).decode("ascii")
return f"base64://{b64}"
def send_msg(msg: str, target_id: Optional[str] = None, timeout: float = 8.0):
"""发送文本消息到 QQ 机器人。返回 requests.Response 或 None失败"""
if target_id is None:
target_id = _default_target()
if target_id is None:
raise ValueError("target_id 未提供,且环境变量 QQ_BOT_TARGET_ID 未设置")
payload = {
"user_id": str(target_id),
"message": [
{"type": "text", "data": {"text": str(msg)}}
],
}
url = _bot_url()
print(f"[QQ] 发送文本到 {target_id}: {msg}")
if requests is None:
print("requests 未安装,无法发送消息")
return None
try:
resp = requests.post(url, json=payload, timeout=timeout)
try:
# 非必要:打印返回的简短信息
print(f"[QQ] 返回: {resp.status_code} {resp.text[:200]}")
except Exception:
pass
return resp
except Exception as e:
print(f"[QQ] 发送文本消息出错: {e}")
return None
def send_media_msg(file_path: str, target_id: Optional[str] = None, type: str = "image", timeout: float = 15.0):
"""发送媒体消息image 或 video
- type == "image" 读取文件并以 "base64://..." 形式发送
- type == "video" 仍使用 "file://" 本地文件引用保持原有行为
"""
if type not in ("image", "video"):
raise ValueError("type 必须为 'image''video'")
if target_id is None:
target_id = _default_target()
if target_id is None:
raise ValueError("target_id 未提供,且环境变量 QQ_BOT_TARGET_ID 未设置")
# 根据类型构造 file 字段
if type == "image":
try:
file_field = _file_to_base64_uri(file_path)
except Exception as e:
print(f"[QQ] 读取图片失败: {e}")
return None
else: # video
file_field = f"file://{file_path}"
payload = {
"user_id": str(target_id),
"message": [
{"type": type, "data": {"file": file_field}}
],
}
url = _bot_url()
print(f"[QQ] 发送媒体 {type}{target_id}: {file_path}")
if requests is None:
print("requests 未安装,无法发送媒体消息")
return None
try:
resp = requests.post(url, json=payload, timeout=timeout)
try:
print(f"[QQ] 返回: {resp.status_code} {resp.text[:200]}")
except Exception:
pass
return resp
except Exception as e:
print(f"[QQ] 发送媒体消息出错: {e}")
return None
def send_prediction_result(summary: str, png_path: Optional[str] = None, target_id: Optional[str] = None):
"""便捷函数:先发送 summary 文本,再发送 png若提供"""
r1 = send_msg(summary, target_id=target_id)
r2 = None
if png_path:
r2 = send_media_msg(png_path, target_id=target_id, type="image")
return r1, r2

View File

@ -1,10 +1,12 @@
numpy numpy
pandas pandas
torch torch
requests
dotenv
einops==0.8.1 einops==0.8.1
huggingface_hub==0.33.1 huggingface_hub==0.33.1
matplotlib==3.9.3 matplotlib==3.9.3
pandas==2.2.2 pandas==2.2.2
tqdm==4.67.1 tqdm==4.67.1
safetensors==0.6.2 safetensors==0.6.2
pybit

719
trader.py
View File

@ -1,419 +1,408 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
BTCUSD 15m 多次采样预测 区间聚合可视化 基于 Kronos 多次采样预测评估多/空两侧的盈亏比并给出最佳机会与可视化
Author:
盈亏比 = 潜在盈利 / 最大亏损
多头
- 潜在盈利以中位线 close_q50 为准选取未来涨幅最高点相对当前价 base_close的涨幅
- 最大亏损从现在到该目标点区间内使用多次采样的最差点close_min相对当前价的最大跌幅
空头
- 潜在盈利以中位线 close_q50 为准选取未来跌幅最大点相对当前价 base_close的跌幅
- 最大亏损从现在到该目标点区间内使用多次采样的最差点close_max相对当前价的最大涨幅
""" """
import os import os
import sys
import math
import json
import time
import random
import argparse import argparse
from datetime import datetime from typing import Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from tqdm import tqdm
from dotenv import load_dotenv from dotenv import load_dotenv
from pybit.unified_trading import HTTP from predict import kronos_predict_df, get_recent_kline
# ==== Kronos ====
from model import Kronos, KronosTokenizer, KronosPredictor
# ========== Matplotlib 字体 ==========
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"] plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
plt.rcParams["axes.unicode_minus"] = False plt.rcParams["axes.unicode_minus"] = False
# ========== 默认参数 ==========
NUM_SAMPLES = 20 # 多次采样次数(建议 10~50
QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上)
LOOKBACK = 360 # 历史窗口
PRED_LEN = 96 # 未来预测点数
INTERVAL_MIN = 15 # K 线周期(分钟)
DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0")
TEMPERATURE = 1.0
TOP_P = 0.9
OUTPUT_DIR = "figures" def _safe_get(series: pd.Series, default=np.nan):
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ========== 实用函数 ==========
def set_global_seed(seed: int):
"""尽可能固定随机性(若底层采样逻辑使用 torch/np/random"""
try: try:
import torch return float(series.iloc[0])
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
except Exception: except Exception:
pass return default
random.seed(seed)
np.random.seed(seed)
def fetch_kline( def compute_long_metrics(agg_df: pd.DataFrame) -> dict:
session: HTTP, symbol="BTCUSD", interval="15", limit=1000, category="inverse" base = float(agg_df["base_close"].iloc[0])
): close_med = agg_df.get("close_q50", agg_df.get("close_q10"))
print("正在获取K线数据...") if close_med is None:
resp = session.get_kline( raise ValueError("预测结果中缺少 close_q50或 close_q10 兜底)列")
category=category, symbol=symbol, interval=interval, limit=limit
# 相对收益(百分比)
ret = (close_med - base) / base
j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0
profit = max(float(ret.iloc[j_target]), 0.0)
# 区间内的“最差的90%分位”(使用下分位线 q10而非样本极小值
q10 = agg_df.get("close_q10", close_med)
worst_price = float(q10.iloc[: j_target + 1].min())
loss = max((base - worst_price) / base, 1e-9)
return {
"direction": "LONG",
"base": base,
"profit": profit,
"loss": loss,
"rr": profit / loss if loss > 0 else np.inf,
"target_idx": j_target,
"target_price": float(close_med.iloc[j_target]),
"worst_idx": int(np.nanargmin(q10.iloc[: j_target + 1].values)),
"worst_price": worst_price,
}
def compute_short_metrics(agg_df: pd.DataFrame) -> dict:
base = float(agg_df["base_close"].iloc[0])
close_med = agg_df.get("close_q50", agg_df.get("close_q90"))
if close_med is None:
raise ValueError("预测结果中缺少 close_q50或 close_q90 兜底)列")
# 相对收益(空头盈利为向下幅度)
ret = (base - close_med) / base
j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0
profit = max(float(ret.iloc[j_target]), 0.0)
# 区间内的“最差的90%分位”(使用上分位线 q90而非样本极大值
q90 = agg_df.get("close_q90", close_med)
worst_price = float(q90.iloc[: j_target + 1].max())
loss = max((worst_price - base) / base, 1e-9)
return {
"direction": "SHORT",
"base": base,
"profit": profit,
"loss": loss,
"rr": profit / loss if loss > 0 else np.inf,
"target_idx": j_target,
"target_price": float(close_med.iloc[j_target]),
"worst_idx": int(np.nanargmax(q90.iloc[: j_target + 1].values)),
"worst_price": worst_price,
}
def plot_opportunities(
agg_df: pd.DataFrame,
long_m: dict,
short_m: dict,
symbol: str,
interval: str,
out_dir: str,
hist_df: pd.DataFrame | None = None,
) -> str:
"""生成清晰可视化:上下两个子图展示多/空机会与风险。"""
from matplotlib import gridspec
import numpy as np
ts = (
pd.to_datetime(agg_df["timestamps"])
if not np.issubdtype(agg_df["timestamps"].dtype, np.datetime64)
else agg_df["timestamps"]
) )
if resp.get("retCode", -1) != 0: base = long_m["base"]
raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}") q10 = agg_df.get("close_q10", None)
lst = resp["result"]["list"] q50 = agg_df.get("close_q50", None)
q90 = agg_df.get("close_q90", None)
df = pd.DataFrame( # 历史24根K线若提供
lst, hist_ts = None
columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"], hist_close = None
) if hist_df is not None and len(hist_df) > 0:
df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms") try:
for c in ["open", "high", "low", "close", "volume", "turnover"]: hist_ts = (
df[c] = df[c].astype(float) pd.to_datetime(hist_df["timestamps"]) if not np.issubdtype(hist_df["timestamps"].dtype, np.datetime64) else hist_df["timestamps"]
df = df.sort_values("timestamps").reset_index(drop=True) )
df["amount"] = df["turnover"] hist_close = hist_df["close"].astype(float).values
print(f"获取到 {len(df)} 根K线数据") except Exception:
return df hist_ts, hist_close = None, None
def prepare_io_windows(
df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN
):
end_idx = len(df)
start_idx = max(0, end_idx - lookback)
x_df = df.loc[
start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"]
].reset_index(drop=True)
x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True)
last_ts = df.loc[end_idx - 1, "timestamps"]
future_timestamps = pd.date_range(
start=last_ts + pd.Timedelta(minutes=interval_min),
periods=pred_len,
freq=f"{interval_min}min",
)
y_timestamp = pd.Series(future_timestamps)
y_timestamp.index = range(len(y_timestamp))
print(f"数据总量: {len(df)} 根K线")
print(f"使用最新的 {lookback} 根K线索引 {start_idx}{end_idx-1}")
print(f"最后一根历史K线时间: {last_ts}")
print(f"预测未来 {pred_len} 根K线")
return x_df, x_timestamp, y_timestamp, start_idx, end_idx
def load_kronos(device=DEVICE):
print("正在加载 Kronos 模型...")
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
print("模型加载完成!\n")
return predictor
def run_one_prediction(
predictor,
x_df,
x_timestamp,
y_timestamp,
T=TEMPERATURE,
top_p=TOP_P,
seed=None,
verbose=False,
):
if seed is not None:
set_global_seed(seed)
return predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=len(y_timestamp),
T=T,
top_p=top_p,
verbose=verbose,
)
def run_multi_predictions(
predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42,
temperature=TEMPERATURE, top_p=TOP_P
):
preds = []
print(f"正在进行多次预测:{num_samples}T={temperature}, top_p={top_p}...")
for i in tqdm(range(num_samples), desc="预测进度", unit="", ncols=100):
seed = base_seed + i # 每次不同 seed
pred_df = run_one_prediction(
predictor,
x_df,
x_timestamp,
y_timestamp,
T=temperature,
top_p=top_p,
seed=seed,
verbose=False,
)
# 兼容性处理:确保只取需要的列
cols_present = [
c
for c in ["open", "high", "low", "close", "volume", "amount"]
if c in pred_df.columns
]
pred_df = pred_df[cols_present].copy()
pred_df.reset_index(drop=True, inplace=True)
preds.append(pred_df)
print("多次预测完成。\n")
return preds
def aggregate_quantiles(pred_list, quantiles=QUANTILES):
"""
将多次预测列表聚合成分位数 DataFrame
输出列命名<col>_q10, <col>_q50, <col>_q90 quantiles 中的值
"""
# 先把每次预测拼成 3Dtime x feature x samples
keys = pred_list[0].columns.tolist()
T_len = len(pred_list[0])
S = len(pred_list)
data = {k: np.zeros((T_len, S), dtype=float) for k in keys}
for j, pdf in enumerate(pred_list):
for k in keys:
data[k][:, j] = pdf[k].values
out = {}
for k in keys:
for q in quantiles:
qv = np.quantile(data[k], q, axis=1)
out[f"{k}_q{int(q*100):02d}"] = qv
agg_df = pd.DataFrame(out)
return agg_df
def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"):
"""
上下两个子图共享X轴
- 上方高度3历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50)
- 下方高度1历史成交量柱 + 预测成交量中位柱仅q50不显示区间
"""
import matplotlib.gridspec as gridspec
# 智能推断K线宽度柱宽
try:
if len(historical_df) >= 2:
hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2])
else:
hist_step = pd.Timedelta(minutes=10)
if len(y_timestamp) >= 2:
pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0])
else:
pred_step = pd.Timedelta(minutes=10)
bar_width_hist = hist_step * 0.8
bar_width_pred = pred_step * 0.8
except Exception:
bar_width_hist = pd.Timedelta(minutes=10)
bar_width_pred = pd.Timedelta(minutes=10)
# 取出预测分位
close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None
close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None
close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None
vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None
# 图形与网格
fig = plt.figure(figsize=(18, 10)) fig = plt.figure(figsize=(18, 10))
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3 gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1], hspace=0.15)
# ===== 上:收盘价 ===== def _draw_panel(
ax_price = fig.add_subplot(gs[0]) ax, title, target_idx, target_price, worst_idx, worst_price, direction
):
# 历史收盘价预测前24根
if hist_ts is not None and hist_close is not None:
ax.plot(
hist_ts.values,
hist_close,
color="#6c757d",
lw=1.6,
label="历史收盘(近24根)",
)
# 预测起点标注
try:
ax.axvline(ts.iloc[0], color="gray", ls=":", lw=1.0, label="预测起点")
except Exception:
pass
ax_price.plot( # 区间带
historical_df["timestamps"], if q10 is not None and q90 is not None:
historical_df["close"], ax.fill_between(
label="历史收盘价", ts.values,
linewidth=1.8, q10.values,
q90.values,
color="#7dbcea",
alpha=0.22,
label="q10~q90",
)
# 中位线
if q50 is not None:
ax.plot(ts.values, q50.values, color="#00539c", lw=2.0, label="中位线 q50")
# 基准线
ax.axhline(base, color="gray", ls=":", lw=1.2, label=f"当前价 {base:.1f}")
# 目标点
ax.plot(
ts.iloc[target_idx],
target_price,
marker="o",
color="#2e7d32",
ms=8,
label="目标点",
)
ax.annotate(
f"目标 {target_price:.1f}",
(ts.iloc[target_idx], target_price),
textcoords="offset points",
xytext=(8, 8),
bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#2e7d32", alpha=0.8),
color="#2e7d32",
)
# 最差点
ax.plot(
ts.iloc[worst_idx],
worst_price,
marker="v",
color="#c62828",
ms=8,
label="最差点",
)
ax.annotate(
f"最差 {worst_price:.1f}",
(ts.iloc[worst_idx], worst_price),
textcoords="offset points",
xytext=(8, -14),
bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#c62828", alpha=0.8),
color="#c62828",
)
# 连接箭头(方向)
color = "#2e7d32" if direction == "LONG" else "#c62828"
ax.annotate(
"",
xy=(ts.iloc[target_idx], target_price),
xytext=(ts.iloc[0], base),
arrowprops=dict(arrowstyle="->", color=color, lw=2, alpha=0.9),
)
# 盈亏比矩形:用填充矩形表示潜在盈利(绿色)与最大亏损(红色)
x_range = ts.iloc[: target_idx + 1]
if direction == "LONG":
# 盈利(若目标价在基准价之上)
if target_price > base:
y_low = np.full_like(x_range, base, dtype=float)
y_high = np.full_like(x_range, target_price, dtype=float)
ax.fill_between(
x_range.values,
y_low,
y_high,
color="#2e7d32",
alpha=0.18,
label="潜在盈利",
)
# 亏损q10 代表“最差的90%分位”)
if q10 is not None and np.isfinite(worst_price) and worst_price < base:
y_low = np.full_like(x_range, worst_price, dtype=float)
y_high = np.full_like(x_range, base, dtype=float)
ax.fill_between(
x_range.values,
y_low,
y_high,
color="#c62828",
alpha=0.18,
label="最大亏损(q10)",
)
else:
# 盈利(空头目标价在基准价之下)
if target_price < base:
y_low = np.full_like(x_range, target_price, dtype=float)
y_high = np.full_like(x_range, base, dtype=float)
ax.fill_between(
x_range.values,
y_low,
y_high,
color="#2e7d32",
alpha=0.18,
label="潜在盈利",
)
# 亏损q90 代表“最差的90%分位”)
if q90 is not None and np.isfinite(worst_price) and worst_price > base:
y_low = np.full_like(x_range, base, dtype=float)
y_high = np.full_like(x_range, worst_price, dtype=float)
ax.fill_between(
x_range.values,
y_low,
y_high,
color="#c62828",
alpha=0.18,
label="最大亏损(q90)",
)
ax.set_title(title)
ax.grid(True, alpha=0.25)
ax.legend(loc="best", fontsize=9)
ax1 = fig.add_subplot(gs[0])
rr_long_pct = long_m["rr"] * 100.0
profit_long_pct = long_m["profit"] * 100.0
loss_long_pct = long_m["loss"] * 100.0
title_long = f"多头机会:盈亏比 RR={long_m['rr']:.2f}+{profit_long_pct:.2f}% / -{loss_long_pct:.2f}%"
_draw_panel(
ax1,
title_long,
long_m["target_idx"],
long_m["target_price"],
long_m["worst_idx"],
long_m["worst_price"],
"LONG",
) )
# 预测收盘价区间与中位线 ax2 = fig.add_subplot(gs[1], sharex=ax1)
if close_q10 is not None and close_q90 is not None: rr_short_pct = short_m["rr"] * 100.0
ax_price.fill_between( profit_short_pct = short_m["profit"] * 100.0
y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)" loss_short_pct = short_m["loss"] * 100.0
) title_short = f"空头机会:盈亏比 RR={short_m['rr']:.2f}+{profit_short_pct:.2f}% / -{loss_short_pct:.2f}%"
if close_q50 is not None: _draw_panel(
ax_price.plot( ax2,
y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)" title_short,
) short_m["target_idx"],
short_m["target_price"],
# 预测起点与阴影 short_m["worst_idx"],
if len(y_timestamp) > 0: short_m["worst_price"],
ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点") "SHORT",
ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08)
# 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价
# 合并历史和预测时间范围与价格数据
all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True)
hist_close = historical_df["close"]
pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp))
all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True)
if len(all_timestamps) > 0:
start_time = all_timestamps.min()
end_time = all_timestamps.max()
# 生成所有16:00时间点UTC+8
current_date = start_time.normalize() # 当天零点
while current_date <= end_time:
target_time = current_date + pd.Timedelta(hours=16) # 16:00
if start_time <= target_time <= end_time:
# 画虚线
ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5)
# 找到最接近16:00的时间点的收盘价
time_diffs = (all_timestamps - target_time).abs()
closest_idx = time_diffs.idxmin()
closest_time = all_timestamps.iloc[closest_idx]
closest_price = all_close.iloc[closest_idx]
# 如果时间差不超过1小时则标注价格
if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price):
ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7)
ax_price.text(
closest_time, closest_price,
f' ${closest_price:.1f}',
fontsize=9,
color='blue',
verticalalignment='bottom',
horizontalalignment='left',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7)
)
current_date += pd.Timedelta(days=1)
ax_price.set_ylabel("价格 (USD)", fontsize=11)
ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold")
ax_price.grid(True, alpha=0.3)
ax_price.legend(loc="best", fontsize=9)
# ===== 下:成交量(仅预测中位量能柱)=====
ax_vol = fig.add_subplot(gs[1], sharex=ax_price)
# 历史量能柱
ax_vol.bar(
historical_df["timestamps"],
historical_df["volume"],
width=bar_width_hist,
alpha=0.35,
label="历史成交量",
) )
# 预测中位量能柱(不画区间) for lbl in ax1.get_xticklabels():
if vol_q50 is not None: lbl.set_visible(False)
ax_vol.bar( ax2.set_xlabel("时间")
y_timestamp.values, ax1.set_ylabel("价格 (USD)")
vol_q50, ax2.set_ylabel("价格 (USD)")
width=bar_width_pred, fig.suptitle(
alpha=0.6, f"{symbol} {interval}分钟线 24小时 预测机会评估", fontsize=16, fontweight="bold"
label="预测成交量中位(q50)", )
) plt.tight_layout(rect=[0, 0, 1, 0.96])
# 预测起点线 os.makedirs(out_dir, exist_ok=True)
if len(y_timestamp) > 0: out_png = os.path.join(
ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1) out_dir, f"trade_opportunity_{symbol.lower()}_{interval}m.png"
)
fig.savefig(out_png, dpi=150, bbox_inches="tight")
return out_png
ax_vol.set_ylabel("成交量", fontsize=11)
ax_vol.set_xlabel("时间", fontsize=11)
ax_vol.grid(True, alpha=0.25)
ax_vol.legend(loc="best", fontsize=9)
# 避免X轴标签重叠
plt.setp(ax_price.get_xticklabels(), visible=False)
plt.tight_layout()
return fig
def main(): def main():
parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化") # 加载环境变量
parser.add_argument("--symbol", default="BTCUSD") load_dotenv()
parser = argparse.ArgumentParser(description="计算并可视化高盈亏比交易机会")
parser.add_argument("--symbol", default="BTCUSDT")
parser.add_argument("--category", default="inverse") parser.add_argument("--category", default="inverse")
parser.add_argument("--interval", default="15") parser.add_argument("--interval", default="60")
parser.add_argument("--limit", type=int, default=1000) parser.add_argument("--limit", type=int, default=500)
parser.add_argument("--lookback", type=int, default=LOOKBACK) parser.add_argument("--lookback", type=int, default=360)
parser.add_argument("--pred_len", type=int, default=PRED_LEN) parser.add_argument("--pred_len", type=int, default=24)
parser.add_argument("--samples", type=int, default=NUM_SAMPLES) parser.add_argument("--samples", type=int, default=30)
parser.add_argument("--temperature", type=float, default=TEMPERATURE) parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=TOP_P) parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--quantiles", default="0.1,0.5,0.9") parser.add_argument("--device", default=os.getenv("KRONOS_DEVICE", "cuda:0"))
parser.add_argument("--device", default=DEVICE)
parser.add_argument("--seed", type=int, default=42) parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args() args = parser.parse_args()
# 载入env & HTTP # 获取预测聚合数据(仅未来)
load_dotenv() agg_df = kronos_predict_df(
api_key = os.getenv("BYBIT_API_KEY")
api_secret = os.getenv("BYBIT_API_SECRET")
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
# Kronos
predictor = load_kronos(device=args.device)
# 拉数据
df = fetch_kline(
session,
symbol=args.symbol, symbol=args.symbol,
category=args.category,
interval=args.interval, interval=args.interval,
limit=args.limit, limit=args.limit,
category=args.category,
)
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
df,
lookback=args.lookback, lookback=args.lookback,
pred_len=args.pred_len, pred_len=args.pred_len,
interval_min=int(args.interval), samples=args.samples,
temperature=args.temperature,
top_p=args.top_p,
quantiles=[0.1, 0.5, 0.9],
device=args.device,
seed=args.seed,
verbose=True,
) )
# 历史用于画图的窗口最近200根可按需调整 # 计算多 / 空的机会
plot_start = max(0, end_idx - 200) long_m = compute_long_metrics(agg_df)
historical_df = df.loc[plot_start : end_idx - 1].copy() short_m = compute_short_metrics(agg_df)
# 采样参数(直接使用局部变量,不需要修改全局变量) # 选择更优机会
temperature = args.temperature best = long_m if long_m["rr"] >= short_m["rr"] else short_m
top_p = args.top_p
qs = [float(x) for x in args.quantiles.split(",") if x.strip()]
# 多次预测(传入局部变量) # 生成汇总字符串(增强可读性与可辨识度)
preds = run_multi_predictions( dir_emoji = "📈 多头" if best["direction"] == "LONG" else "📉 空头"
predictor, x_df, x_ts, y_ts, num_samples=args.samples, base_seed=args.seed, long_arrow = "⬆️"
temperature=temperature, top_p=top_p short_arrow = "⬇️"
summary_lines = [
"== 24小时 每日交易机会评估 ==\n",
f"📌 标的: {args.symbol} | 周期: {args.interval}m | 当前价: {best['base']:.2f}\n",
f"📈 多头 | RR={long_m['rr']:.2f} | 盈利 +{long_m['profit'] * 100.0:.2f}% | 亏损 -{long_m['loss'] * 100.0:.2f}% | 🎯 目标 {long_m['target_price']:.2f} | 🛡️ 最差 {long_m['worst_price']:.2f}\n",
f"📉 空头 | RR={short_m['rr']:.2f} | 盈利 +{short_m['profit'] * 100.0:.2f}% | 亏损 -{short_m['loss'] * 100.0:.2f}% | 🎯 目标 {short_m['target_price']:.2f} | 🛡️ 最差 {short_m['worst_price']:.2f}\n",
f"✅ 最佳: {dir_emoji} | 盈亏比 RR={best['rr']:.2f}",
]
if best["rr"] < 2.0:
summary_lines.append("⚠️ 交易机会不佳,建议观望。")
summary = "\n".join(summary_lines) + "\n"
# 打印汇总
print(summary)
# 获取预测前24根K线用于可视化上下文
try:
full_hist_df = get_recent_kline(
symbol=args.symbol,
category=args.category,
interval=args.interval,
limit=args.limit,
)
hist24 = full_hist_df.tail(24).reset_index(drop=True)
except Exception:
hist24 = None
# 可视化
out_png = plot_opportunities(
agg_df, long_m, short_m, args.symbol, args.interval, out_dir="figures", hist_df=hist24
) )
print(f"可视化已保存: {out_png}")
# 聚合分位 # 发送交易机会评估结果到 QQ
agg_df = aggregate_quantiles(preds, quantiles=qs) try:
agg_df["timestamps"] = y_ts.values from qq import send_prediction_result
target = os.getenv("QQ_BOT_TARGET_ID")
# 输出 CSV聚合 send_prediction_result(summary, png_path=out_png, target_id=target)
out_csv = os.path.join( except Exception as e:
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv" print(f"发送 QQ 消息失败: {repr(e)}")
)
agg_df.to_csv(out_csv, index=False)
print(f"预测分位数据已保存到: {out_csv}")
# 作图
fig = plot_results(
historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟"
)
out_png = os.path.join(
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png"
)
fig.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"预测区间图表已保存到: {out_png}")
plt.show()
if __name__ == "__main__": if __name__ == "__main__":
try: main()
main()
except Exception as e:
print("运行出错:", repr(e))
sys.exit(1)