From 0771ca73a08181dd9f466dfa0717527a9f13c62c Mon Sep 17 00:00:00 2001 From: feie9456 Date: Thu, 30 Oct 2025 11:59:03 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E7=9B=88=E4=BA=8F=E6=AF=94?= =?UTF-8?q?=E8=AE=A1=E7=AE=97=E5=92=8C=E6=B6=88=E6=81=AF=E6=8E=A8=E9=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- predict.py | 526 ++++++++++++++++++++++++++++++++++ qq.py | 132 +++++++++ requirements.txt | 4 +- trader.py | 721 +++++++++++++++++++++++------------------------ 4 files changed, 1016 insertions(+), 367 deletions(-) create mode 100644 predict.py create mode 100644 qq.py diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..05b9164 --- /dev/null +++ b/predict.py @@ -0,0 +1,526 @@ +# -*- coding: utf-8 -*- +""" +BTCUSD 15m 多次采样预测 → 区间聚合可视化 +Author: 你(鹅) +""" + +import os +import sys +import math +import json +import time +import random +import argparse +from datetime import datetime + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from tqdm import tqdm + +from dotenv import load_dotenv +from pybit.unified_trading import HTTP + +# ==== Kronos ==== +from model import Kronos, KronosTokenizer, KronosPredictor + +# ========== Matplotlib 字体 ========== +plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"] +plt.rcParams["axes.unicode_minus"] = False + +# ========== 默认参数 ========== +NUM_SAMPLES = 20 # 多次采样次数(建议 10~50) +QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上) +LOOKBACK = 360 # 历史窗口 +PRED_LEN = 24 # 未来预测点数 +INTERVAL_MIN = 60 # K 线周期(分钟) +DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0") +TEMPERATURE = 1.0 +TOP_P = 0.9 +DRAW_WINDOW_LEN = 240 # 历史用于画图的窗口 + +OUTPUT_DIR = "figures" +os.makedirs(OUTPUT_DIR, exist_ok=True) + + +# ========== 实用函数 ========== +def set_global_seed(seed: int): + """尽可能固定随机性(若底层采样逻辑使用 torch/np/random)""" + try: + import torch + + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + except Exception: + pass + random.seed(seed) + np.random.seed(seed) + + +def fetch_kline( + session: HTTP, symbol="BTCUSD", interval="15", limit=500, category="inverse" +): + print("正在获取K线数据...") + resp = session.get_kline( + category=category, symbol=symbol, interval=interval, limit=limit + ) + if resp.get("retCode", -1) != 0: + raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}") + lst = resp["result"]["list"] + + df = pd.DataFrame( + lst, + columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"], + ) + df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms") + for c in ["open", "high", "low", "close", "volume", "turnover"]: + df[c] = df[c].astype(float) + df = df.sort_values("timestamps").reset_index(drop=True) + df["amount"] = df["turnover"] + # 去掉最新一根(通常为未完成的当前K线) + if len(df) > 0: + df = df.iloc[:-1].reset_index(drop=True) + print(f"获取到 {len(df)} 根K线数据") + return df + + +def get_recent_kline( + symbol: str = "BTCUSD", + category: str = "inverse", + interval: str = "60", + limit: int = 500, +): + """便捷封装:基于环境变量创建会话并获取最近K线数据。 + + 返回数据按时间升序,且会移除最后一根未完成K线(与 fetch_kline 保持一致)。 + """ + load_dotenv() + api_key = os.getenv("BYBIT_API_KEY") + api_secret = os.getenv("BYBIT_API_SECRET") + session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret) + return fetch_kline( + session=session, + symbol=symbol, + interval=interval, + limit=limit, + category=category, + ) + + +def prepare_io_windows( + df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN +): + end_idx = len(df) + start_idx = max(0, end_idx - lookback) + x_df = df.loc[ + start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"] + ].reset_index(drop=True) + x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True) + + last_ts = df.loc[end_idx - 1, "timestamps"] + future_timestamps = pd.date_range( + start=last_ts + pd.Timedelta(minutes=interval_min), + periods=pred_len, + freq=f"{interval_min}min", + ) + y_timestamp = pd.Series(future_timestamps) + y_timestamp.index = range(len(y_timestamp)) + + print(f"数据总量: {len(df)} 根K线") + print(f"使用最新的 {lookback} 根K线(索引 {start_idx} 到 {end_idx-1})") + print(f"最后一根历史K线时间: {last_ts}") + print(f"预测未来 {pred_len} 根K线") + return x_df, x_timestamp, y_timestamp, start_idx, end_idx + + +def load_kronos(device=DEVICE): + print("正在加载 Kronos 模型...") + tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") + model = Kronos.from_pretrained("NeoQuasar/Kronos-base") + predictor = KronosPredictor(model, tokenizer, device=device, max_context=512) + print("模型加载完成!\n") + return predictor + + +def run_one_prediction( + predictor, + x_df, + x_timestamp, + y_timestamp, + T=TEMPERATURE, + top_p=TOP_P, + seed=None, + verbose=False, +): + if seed is not None: + set_global_seed(seed) + return predictor.predict( + df=x_df, + x_timestamp=x_timestamp, + y_timestamp=y_timestamp, + pred_len=len(y_timestamp), + T=T, + top_p=top_p, + verbose=verbose, + ) + + +def run_multi_predictions( + predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42, + temperature=TEMPERATURE, top_p=TOP_P +): + preds = [] + print(f"正在进行多次预测:{num_samples} 次(T={temperature}, top_p={top_p})...") + for i in tqdm(range(num_samples), desc="预测进度", unit="次", ncols=100): + seed = base_seed + i # 每次不同 seed + pred_df = run_one_prediction( + predictor, + x_df, + x_timestamp, + y_timestamp, + T=temperature, + top_p=top_p, + seed=seed, + verbose=False, + ) + # 兼容性处理:确保只取需要的列 + cols_present = [ + c + for c in ["open", "high", "low", "close", "volume", "amount"] + if c in pred_df.columns + ] + pred_df = pred_df[cols_present].copy() + pred_df.reset_index(drop=True, inplace=True) + preds.append(pred_df) + print("多次预测完成。\n") + return preds + + +def aggregate_quantiles(pred_list, quantiles=QUANTILES): + """ + 将多次预测列表聚合成分位数 DataFrame: + 输出列命名:_q10, _q50, _q90(按 quantiles 中的值) + """ + # 先把每次预测拼成 3D:time x feature x samples + keys = pred_list[0].columns.tolist() + T_len = len(pred_list[0]) + S = len(pred_list) + data = {k: np.zeros((T_len, S), dtype=float) for k in keys} + for j, pdf in enumerate(pred_list): + for k in keys: + data[k][:, j] = pdf[k].values + + out = {} + for k in keys: + # 极值(用于风险评估) + out[f"{k}_min"] = np.min(data[k], axis=1) + out[f"{k}_max"] = np.max(data[k], axis=1) + for q in quantiles: + qv = np.quantile(data[k], q, axis=1) + out[f"{k}_q{int(q*100):02d}"] = qv + agg_df = pd.DataFrame(out) + return agg_df + + +def kronos_predict_df( + symbol: str = "BTCUSD", + category: str = "inverse", + interval: str = "60", + limit: int = 500, + lookback: int = LOOKBACK, + pred_len: int = PRED_LEN, + samples: int = NUM_SAMPLES, + temperature: float = TEMPERATURE, + top_p: float = TOP_P, + quantiles: list[float] | None = None, + device: str = DEVICE, + seed: int = 42, + verbose: bool = False, +): + """ + 作为模块使用的API:返回聚合后的预测DataFrame(仅未来部分)。 + + 返回DataFrame包含: + - timestamps: 未来各点时间戳 + - _qXX: 各分位数(如 close_q50) + - _min/_max: 样本极值(用于风控/盈亏比) + - base_close: 当前基准价(最后一根历史K线的收盘价,重复列) + """ + # 环境 & HTTP会话 + load_dotenv() + api_key = os.getenv("BYBIT_API_KEY") + api_secret = os.getenv("BYBIT_API_SECRET") + session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret) + + # 加载模型(懒加载一次) + predictor = load_kronos(device=device) + + # 拉取K线并准备窗口 + df = fetch_kline( + session, + symbol=symbol, + interval=interval, + limit=limit, + category=category, + ) + x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows( + df, + lookback=lookback, + pred_len=pred_len, + interval_min=int(interval), + ) + + # 多次预测 + preds = run_multi_predictions( + predictor, + x_df, + x_ts, + y_ts, + num_samples=samples, + base_seed=seed, + temperature=temperature, + top_p=top_p, + ) + + # 聚合 + qs = quantiles if quantiles is not None else QUANTILES + agg_df = aggregate_quantiles(preds, quantiles=qs) + agg_df["timestamps"] = y_ts.values + + # 当前价(基准) + base_close = float(df.loc[end_idx - 1, "close"]) if end_idx > 0 else np.nan + agg_df["base_close"] = base_close + + # 仅返回DataFrame(调用方自行决定是否绘图/保存) + if verbose: + print( + f"完成预测:symbol={symbol} interval={interval}m samples={samples} base_close={base_close}" + ) + return agg_df + + +def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"): + """ + 上下两个子图(共享X轴): + - 上方(高度3):历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50) + - 下方(高度1):历史成交量柱 + 预测成交量中位柱(仅q50,不显示区间) + """ + import matplotlib.gridspec as gridspec + + # 智能推断K线宽度(柱宽) + try: + if len(historical_df) >= 2: + hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2]) + else: + hist_step = pd.Timedelta(minutes=10) + if len(y_timestamp) >= 2: + pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0]) + else: + pred_step = pd.Timedelta(minutes=10) + bar_width_hist = hist_step * 0.8 + bar_width_pred = pred_step * 0.8 + except Exception: + bar_width_hist = pd.Timedelta(minutes=10) + bar_width_pred = pd.Timedelta(minutes=10) + + # 取出预测分位 + close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None + close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None + close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None + + vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None + + # 图形与网格 + fig = plt.figure(figsize=(18, 10)) + gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3 + + # ===== 上:收盘价 ===== + ax_price = fig.add_subplot(gs[0]) + + ax_price.plot( + historical_df["timestamps"], + historical_df["close"], + label="历史收盘价", + linewidth=1.8, + ) + + # 预测收盘价区间与中位线 + if close_q10 is not None and close_q90 is not None: + ax_price.fill_between( + y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)" + ) + if close_q50 is not None: + ax_price.plot( + y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)" + ) + + # 预测起点与阴影 + if len(y_timestamp) > 0: + ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点") + ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08) + + # 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价 + # 合并历史和预测时间范围与价格数据 + all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True) + hist_close = historical_df["close"] + pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp)) + all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True) + + if len(all_timestamps) > 0: + start_time = all_timestamps.min() + end_time = all_timestamps.max() + + # 生成所有16:00时间点(UTC+8) + current_date = start_time.normalize() # 当天零点 + while current_date <= end_time: + target_time = current_date + pd.Timedelta(hours=16) # 16:00 + if start_time <= target_time <= end_time: + # 画虚线 + ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5) + + # 找到最接近16:00的时间点的收盘价 + time_diffs = (all_timestamps - target_time).abs() + closest_idx = time_diffs.idxmin() + closest_time = all_timestamps.iloc[closest_idx] + closest_price = all_close.iloc[closest_idx] + + # 如果时间差不超过1小时,则标注价格 + if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price): + ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7) + ax_price.text( + closest_time, closest_price, + f' ${closest_price:.1f}', + fontsize=9, + color='blue', + verticalalignment='bottom', + horizontalalignment='left', + bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7) + ) + current_date += pd.Timedelta(days=1) + + ax_price.set_ylabel("价格 (USD)", fontsize=11) + ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold") + ax_price.grid(True, alpha=0.3) + ax_price.legend(loc="best", fontsize=9) + + # ===== 下:成交量(仅预测中位量能柱)===== + ax_vol = fig.add_subplot(gs[1], sharex=ax_price) + + # 历史量能柱 + ax_vol.bar( + historical_df["timestamps"], + historical_df["volume"], + width=bar_width_hist, + alpha=0.35, + label="历史成交量", + ) + + # 预测中位量能柱(不画区间) + if vol_q50 is not None: + ax_vol.bar( + y_timestamp.values, + vol_q50, + width=bar_width_pred, + alpha=0.6, + label="预测成交量中位(q50)", + ) + + # 预测起点线 + if len(y_timestamp) > 0: + ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1) + + ax_vol.set_ylabel("成交量", fontsize=11) + ax_vol.set_xlabel("时间", fontsize=11) + ax_vol.grid(True, alpha=0.25) + ax_vol.legend(loc="best", fontsize=9) + + # 避免X轴标签重叠 + plt.setp(ax_price.get_xticklabels(), visible=False) + plt.tight_layout() + return fig + +def main(): + parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化") + parser.add_argument("--symbol", default="BTCUSD") + parser.add_argument("--category", default="inverse") + parser.add_argument("--interval", default="60") + parser.add_argument("--limit", type=int, default=500) + parser.add_argument("--lookback", type=int, default=LOOKBACK) + parser.add_argument("--pred_len", type=int, default=PRED_LEN) + parser.add_argument("--samples", type=int, default=NUM_SAMPLES) + parser.add_argument("--temperature", type=float, default=TEMPERATURE) + parser.add_argument("--top_p", type=float, default=TOP_P) + parser.add_argument("--quantiles", default="0.1,0.5,0.9") + parser.add_argument("--device", default=DEVICE) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + # 载入env & HTTP + load_dotenv() + api_key = os.getenv("BYBIT_API_KEY") + api_secret = os.getenv("BYBIT_API_SECRET") + session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret) + # 为命令行模式执行完整流程 + predictor = load_kronos(device=args.device) + df = fetch_kline( + session, + symbol=args.symbol, + interval=args.interval, + limit=args.limit, + category=args.category, + ) + x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows( + df, + lookback=args.lookback, + pred_len=args.pred_len, + interval_min=int(args.interval), + ) + + plot_start = max(0, end_idx - DRAW_WINDOW_LEN) + historical_df = df.loc[plot_start : end_idx - 1].copy() + + temperature = args.temperature + top_p = args.top_p + qs = [float(x) for x in args.quantiles.split(",") if x.strip()] + + preds = run_multi_predictions( + predictor, + x_df, + x_ts, + y_ts, + num_samples=args.samples, + base_seed=args.seed, + temperature=temperature, + top_p=top_p, + ) + + agg_df = aggregate_quantiles(preds, quantiles=qs) + agg_df["timestamps"] = y_ts.values + + # 输出 CSV(聚合) + out_csv = os.path.join( + OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv" + ) + agg_df.to_csv(out_csv, index=False) + print(f"预测分位数据已保存到: {out_csv}") + + # 作图 + fig = plot_results( + historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟" + ) + out_png = os.path.join( + OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png" + ) + fig.savefig(out_png, dpi=150, bbox_inches="tight") + print(f"预测区间图表已保存到: {out_png}") + + plt.show() + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print("运行出错:", repr(e)) + sys.exit(1) diff --git a/qq.py b/qq.py new file mode 100644 index 0000000..5679283 --- /dev/null +++ b/qq.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +""" +轻量的 QQ Bot 客户端(用于发送私聊文本和媒体) + +使用环境变量: +- QQ_BOT_URL: bot 接收地址(例如 http://localhost:30000/send_private_msg) +- QQ_BOT_TARGET_ID: 默认目标 QQ id(可被函数参数覆盖) + +示例: + from qq import send_msg, send_media_msg + send_msg('hello') + # 发送图片时将自动转换为 "base64://..." 格式 + send_media_msg('figures/btcusd_interval_pred_60m.png', type='image') +""" +from __future__ import annotations + +import os +import json +from typing import Optional +import base64 + +try: + import requests +except Exception: # pragma: no cover - runtime dependency + requests = None + + +def _bot_url() -> str: + return os.getenv("QQ_BOT_URL", "http://localhost:30000/send_private_msg") + + +def _default_target() -> Optional[str]: + return os.getenv("QQ_BOT_TARGET_ID") + + +def _file_to_base64_uri(path: str) -> str: + """将本地文件读取为 base64 uri 字符串(前缀 base64://)。""" + with open(path, "rb") as f: + data = f.read() + b64 = base64.b64encode(data).decode("ascii") + return f"base64://{b64}" + + +def send_msg(msg: str, target_id: Optional[str] = None, timeout: float = 8.0): + """发送文本消息到 QQ 机器人。返回 requests.Response 或 None(失败)。""" + if target_id is None: + target_id = _default_target() + if target_id is None: + raise ValueError("target_id 未提供,且环境变量 QQ_BOT_TARGET_ID 未设置") + + payload = { + "user_id": str(target_id), + "message": [ + {"type": "text", "data": {"text": str(msg)}} + ], + } + + url = _bot_url() + print(f"[QQ] 发送文本到 {target_id}: {msg}") + if requests is None: + print("requests 未安装,无法发送消息") + return None + + try: + resp = requests.post(url, json=payload, timeout=timeout) + try: + # 非必要:打印返回的简短信息 + print(f"[QQ] 返回: {resp.status_code} {resp.text[:200]}") + except Exception: + pass + return resp + except Exception as e: + print(f"[QQ] 发送文本消息出错: {e}") + return None + + +def send_media_msg(file_path: str, target_id: Optional[str] = None, type: str = "image", timeout: float = 15.0): + """发送媒体消息(image 或 video)。 + + - 当 type == "image" 时:读取文件并以 "base64://..." 形式发送。 + - 当 type == "video" 时:仍使用 "file://" 本地文件引用(保持原有行为)。 + """ + if type not in ("image", "video"): + raise ValueError("type 必须为 'image' 或 'video'") + + if target_id is None: + target_id = _default_target() + if target_id is None: + raise ValueError("target_id 未提供,且环境变量 QQ_BOT_TARGET_ID 未设置") + + # 根据类型构造 file 字段 + if type == "image": + try: + file_field = _file_to_base64_uri(file_path) + except Exception as e: + print(f"[QQ] 读取图片失败: {e}") + return None + else: # video + file_field = f"file://{file_path}" + + payload = { + "user_id": str(target_id), + "message": [ + {"type": type, "data": {"file": file_field}} + ], + } + + url = _bot_url() + print(f"[QQ] 发送媒体 {type} 到 {target_id}: {file_path}") + if requests is None: + print("requests 未安装,无法发送媒体消息") + return None + + try: + resp = requests.post(url, json=payload, timeout=timeout) + try: + print(f"[QQ] 返回: {resp.status_code} {resp.text[:200]}") + except Exception: + pass + return resp + except Exception as e: + print(f"[QQ] 发送媒体消息出错: {e}") + return None + + +def send_prediction_result(summary: str, png_path: Optional[str] = None, target_id: Optional[str] = None): + """便捷函数:先发送 summary 文本,再发送 png(若提供)。""" + r1 = send_msg(summary, target_id=target_id) + r2 = None + if png_path: + r2 = send_media_msg(png_path, target_id=target_id, type="image") + return r1, r2 diff --git a/requirements.txt b/requirements.txt index d94a8d7..352e7a1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,12 @@ numpy pandas torch - +requests +dotenv einops==0.8.1 huggingface_hub==0.33.1 matplotlib==3.9.3 pandas==2.2.2 tqdm==4.67.1 safetensors==0.6.2 +pybit diff --git a/trader.py b/trader.py index 52d289c..360a61e 100644 --- a/trader.py +++ b/trader.py @@ -1,419 +1,408 @@ # -*- coding: utf-8 -*- """ -BTCUSD 15m 多次采样预测 → 区间聚合可视化 -Author: 你(鹅) +基于 Kronos 多次采样预测,评估多/空两侧的盈亏比,并给出最佳机会与可视化。 + +盈亏比 = 潜在盈利 / 最大亏损 + +多头: +- 潜在盈利:以中位线 close_q50 为准,选取未来涨幅最高点(相对当前价 base_close)的涨幅。 +- 最大亏损:从现在到该目标点区间内,使用多次采样的最差点(close_min)相对当前价的最大跌幅。 + +空头: +- 潜在盈利:以中位线 close_q50 为准,选取未来跌幅最大点(相对当前价 base_close)的跌幅。 +- 最大亏损:从现在到该目标点区间内,使用多次采样的最差点(close_max)相对当前价的最大涨幅。 """ import os -import sys -import math -import json -import time -import random import argparse -from datetime import datetime +from typing import Tuple import numpy as np import pandas as pd import matplotlib.pyplot as plt -from tqdm import tqdm from dotenv import load_dotenv -from pybit.unified_trading import HTTP +from predict import kronos_predict_df, get_recent_kline -# ==== Kronos ==== -from model import Kronos, KronosTokenizer, KronosPredictor -# ========== Matplotlib 字体 ========== plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"] plt.rcParams["axes.unicode_minus"] = False -# ========== 默认参数 ========== -NUM_SAMPLES = 20 # 多次采样次数(建议 10~50) -QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上) -LOOKBACK = 360 # 历史窗口 -PRED_LEN = 96 # 未来预测点数 -INTERVAL_MIN = 15 # K 线周期(分钟) -DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0") -TEMPERATURE = 1.0 -TOP_P = 0.9 -OUTPUT_DIR = "figures" -os.makedirs(OUTPUT_DIR, exist_ok=True) - - -# ========== 实用函数 ========== -def set_global_seed(seed: int): - """尽可能固定随机性(若底层采样逻辑使用 torch/np/random)""" +def _safe_get(series: pd.Series, default=np.nan): try: - import torch - - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + return float(series.iloc[0]) except Exception: - pass - random.seed(seed) - np.random.seed(seed) + return default -def fetch_kline( - session: HTTP, symbol="BTCUSD", interval="15", limit=1000, category="inverse" -): - print("正在获取K线数据...") - resp = session.get_kline( - category=category, symbol=symbol, interval=interval, limit=limit +def compute_long_metrics(agg_df: pd.DataFrame) -> dict: + base = float(agg_df["base_close"].iloc[0]) + close_med = agg_df.get("close_q50", agg_df.get("close_q10")) + if close_med is None: + raise ValueError("预测结果中缺少 close_q50(或 close_q10 兜底)列") + + # 相对收益(百分比) + ret = (close_med - base) / base + j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0 + profit = max(float(ret.iloc[j_target]), 0.0) + + # 区间内的“最差的90%分位”(使用下分位线 q10,而非样本极小值) + q10 = agg_df.get("close_q10", close_med) + worst_price = float(q10.iloc[: j_target + 1].min()) + loss = max((base - worst_price) / base, 1e-9) + + return { + "direction": "LONG", + "base": base, + "profit": profit, + "loss": loss, + "rr": profit / loss if loss > 0 else np.inf, + "target_idx": j_target, + "target_price": float(close_med.iloc[j_target]), + "worst_idx": int(np.nanargmin(q10.iloc[: j_target + 1].values)), + "worst_price": worst_price, + } + + +def compute_short_metrics(agg_df: pd.DataFrame) -> dict: + base = float(agg_df["base_close"].iloc[0]) + close_med = agg_df.get("close_q50", agg_df.get("close_q90")) + if close_med is None: + raise ValueError("预测结果中缺少 close_q50(或 close_q90 兜底)列") + + # 相对收益(空头盈利为向下幅度) + ret = (base - close_med) / base + j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0 + profit = max(float(ret.iloc[j_target]), 0.0) + + # 区间内的“最差的90%分位”(使用上分位线 q90,而非样本极大值) + q90 = agg_df.get("close_q90", close_med) + worst_price = float(q90.iloc[: j_target + 1].max()) + loss = max((worst_price - base) / base, 1e-9) + + return { + "direction": "SHORT", + "base": base, + "profit": profit, + "loss": loss, + "rr": profit / loss if loss > 0 else np.inf, + "target_idx": j_target, + "target_price": float(close_med.iloc[j_target]), + "worst_idx": int(np.nanargmax(q90.iloc[: j_target + 1].values)), + "worst_price": worst_price, + } + + +def plot_opportunities( + agg_df: pd.DataFrame, + long_m: dict, + short_m: dict, + symbol: str, + interval: str, + out_dir: str, + hist_df: pd.DataFrame | None = None, +) -> str: + """生成清晰可视化:上下两个子图展示多/空机会与风险。""" + from matplotlib import gridspec + import numpy as np + + ts = ( + pd.to_datetime(agg_df["timestamps"]) + if not np.issubdtype(agg_df["timestamps"].dtype, np.datetime64) + else agg_df["timestamps"] ) - if resp.get("retCode", -1) != 0: - raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}") - lst = resp["result"]["list"] + base = long_m["base"] + q10 = agg_df.get("close_q10", None) + q50 = agg_df.get("close_q50", None) + q90 = agg_df.get("close_q90", None) - df = pd.DataFrame( - lst, - columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"], - ) - df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms") - for c in ["open", "high", "low", "close", "volume", "turnover"]: - df[c] = df[c].astype(float) - df = df.sort_values("timestamps").reset_index(drop=True) - df["amount"] = df["turnover"] - print(f"获取到 {len(df)} 根K线数据") - return df + # 历史24根K线(若提供) + hist_ts = None + hist_close = None + if hist_df is not None and len(hist_df) > 0: + try: + hist_ts = ( + pd.to_datetime(hist_df["timestamps"]) if not np.issubdtype(hist_df["timestamps"].dtype, np.datetime64) else hist_df["timestamps"] + ) + hist_close = hist_df["close"].astype(float).values + except Exception: + hist_ts, hist_close = None, None - -def prepare_io_windows( - df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN -): - end_idx = len(df) - start_idx = max(0, end_idx - lookback) - x_df = df.loc[ - start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"] - ].reset_index(drop=True) - x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True) - - last_ts = df.loc[end_idx - 1, "timestamps"] - future_timestamps = pd.date_range( - start=last_ts + pd.Timedelta(minutes=interval_min), - periods=pred_len, - freq=f"{interval_min}min", - ) - y_timestamp = pd.Series(future_timestamps) - y_timestamp.index = range(len(y_timestamp)) - - print(f"数据总量: {len(df)} 根K线") - print(f"使用最新的 {lookback} 根K线(索引 {start_idx} 到 {end_idx-1})") - print(f"最后一根历史K线时间: {last_ts}") - print(f"预测未来 {pred_len} 根K线") - return x_df, x_timestamp, y_timestamp, start_idx, end_idx - - -def load_kronos(device=DEVICE): - print("正在加载 Kronos 模型...") - tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") - model = Kronos.from_pretrained("NeoQuasar/Kronos-base") - predictor = KronosPredictor(model, tokenizer, device=device, max_context=512) - print("模型加载完成!\n") - return predictor - - -def run_one_prediction( - predictor, - x_df, - x_timestamp, - y_timestamp, - T=TEMPERATURE, - top_p=TOP_P, - seed=None, - verbose=False, -): - if seed is not None: - set_global_seed(seed) - return predictor.predict( - df=x_df, - x_timestamp=x_timestamp, - y_timestamp=y_timestamp, - pred_len=len(y_timestamp), - T=T, - top_p=top_p, - verbose=verbose, - ) - - -def run_multi_predictions( - predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42, - temperature=TEMPERATURE, top_p=TOP_P -): - preds = [] - print(f"正在进行多次预测:{num_samples} 次(T={temperature}, top_p={top_p})...") - for i in tqdm(range(num_samples), desc="预测进度", unit="次", ncols=100): - seed = base_seed + i # 每次不同 seed - pred_df = run_one_prediction( - predictor, - x_df, - x_timestamp, - y_timestamp, - T=temperature, - top_p=top_p, - seed=seed, - verbose=False, - ) - # 兼容性处理:确保只取需要的列 - cols_present = [ - c - for c in ["open", "high", "low", "close", "volume", "amount"] - if c in pred_df.columns - ] - pred_df = pred_df[cols_present].copy() - pred_df.reset_index(drop=True, inplace=True) - preds.append(pred_df) - print("多次预测完成。\n") - return preds - - -def aggregate_quantiles(pred_list, quantiles=QUANTILES): - """ - 将多次预测列表聚合成分位数 DataFrame: - 输出列命名:_q10, _q50, _q90(按 quantiles 中的值) - """ - # 先把每次预测拼成 3D:time x feature x samples - keys = pred_list[0].columns.tolist() - T_len = len(pred_list[0]) - S = len(pred_list) - data = {k: np.zeros((T_len, S), dtype=float) for k in keys} - for j, pdf in enumerate(pred_list): - for k in keys: - data[k][:, j] = pdf[k].values - - out = {} - for k in keys: - for q in quantiles: - qv = np.quantile(data[k], q, axis=1) - out[f"{k}_q{int(q*100):02d}"] = qv - agg_df = pd.DataFrame(out) - return agg_df - - -def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"): - """ - 上下两个子图(共享X轴): - - 上方(高度3):历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50) - - 下方(高度1):历史成交量柱 + 预测成交量中位柱(仅q50,不显示区间) - """ - import matplotlib.gridspec as gridspec - - # 智能推断K线宽度(柱宽) - try: - if len(historical_df) >= 2: - hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2]) - else: - hist_step = pd.Timedelta(minutes=10) - if len(y_timestamp) >= 2: - pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0]) - else: - pred_step = pd.Timedelta(minutes=10) - bar_width_hist = hist_step * 0.8 - bar_width_pred = pred_step * 0.8 - except Exception: - bar_width_hist = pd.Timedelta(minutes=10) - bar_width_pred = pd.Timedelta(minutes=10) - - # 取出预测分位 - close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None - close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None - close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None - - vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None - - # 图形与网格 fig = plt.figure(figsize=(18, 10)) - gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3 + gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1], hspace=0.15) - # ===== 上:收盘价 ===== - ax_price = fig.add_subplot(gs[0]) + def _draw_panel( + ax, title, target_idx, target_price, worst_idx, worst_price, direction + ): + # 历史收盘价(预测前24根) + if hist_ts is not None and hist_close is not None: + ax.plot( + hist_ts.values, + hist_close, + color="#6c757d", + lw=1.6, + label="历史收盘(近24根)", + ) + # 预测起点标注 + try: + ax.axvline(ts.iloc[0], color="gray", ls=":", lw=1.0, label="预测起点") + except Exception: + pass - ax_price.plot( - historical_df["timestamps"], - historical_df["close"], - label="历史收盘价", - linewidth=1.8, + # 区间带 + if q10 is not None and q90 is not None: + ax.fill_between( + ts.values, + q10.values, + q90.values, + color="#7dbcea", + alpha=0.22, + label="q10~q90", + ) + # 中位线 + if q50 is not None: + ax.plot(ts.values, q50.values, color="#00539c", lw=2.0, label="中位线 q50") + + # 基准线 + ax.axhline(base, color="gray", ls=":", lw=1.2, label=f"当前价 {base:.1f}") + + # 目标点 + ax.plot( + ts.iloc[target_idx], + target_price, + marker="o", + color="#2e7d32", + ms=8, + label="目标点", + ) + ax.annotate( + f"目标 {target_price:.1f}", + (ts.iloc[target_idx], target_price), + textcoords="offset points", + xytext=(8, 8), + bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#2e7d32", alpha=0.8), + color="#2e7d32", + ) + + # 最差点 + ax.plot( + ts.iloc[worst_idx], + worst_price, + marker="v", + color="#c62828", + ms=8, + label="最差点", + ) + ax.annotate( + f"最差 {worst_price:.1f}", + (ts.iloc[worst_idx], worst_price), + textcoords="offset points", + xytext=(8, -14), + bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#c62828", alpha=0.8), + color="#c62828", + ) + + # 连接箭头(方向) + color = "#2e7d32" if direction == "LONG" else "#c62828" + ax.annotate( + "", + xy=(ts.iloc[target_idx], target_price), + xytext=(ts.iloc[0], base), + arrowprops=dict(arrowstyle="->", color=color, lw=2, alpha=0.9), + ) + + # 盈亏比矩形:用填充矩形表示潜在盈利(绿色)与最大亏损(红色) + x_range = ts.iloc[: target_idx + 1] + if direction == "LONG": + # 盈利(若目标价在基准价之上) + if target_price > base: + y_low = np.full_like(x_range, base, dtype=float) + y_high = np.full_like(x_range, target_price, dtype=float) + ax.fill_between( + x_range.values, + y_low, + y_high, + color="#2e7d32", + alpha=0.18, + label="潜在盈利", + ) + # 亏损(q10 代表“最差的90%分位”) + if q10 is not None and np.isfinite(worst_price) and worst_price < base: + y_low = np.full_like(x_range, worst_price, dtype=float) + y_high = np.full_like(x_range, base, dtype=float) + ax.fill_between( + x_range.values, + y_low, + y_high, + color="#c62828", + alpha=0.18, + label="最大亏损(q10)", + ) + else: + # 盈利(空头目标价在基准价之下) + if target_price < base: + y_low = np.full_like(x_range, target_price, dtype=float) + y_high = np.full_like(x_range, base, dtype=float) + ax.fill_between( + x_range.values, + y_low, + y_high, + color="#2e7d32", + alpha=0.18, + label="潜在盈利", + ) + # 亏损(q90 代表“最差的90%分位”) + if q90 is not None and np.isfinite(worst_price) and worst_price > base: + y_low = np.full_like(x_range, base, dtype=float) + y_high = np.full_like(x_range, worst_price, dtype=float) + ax.fill_between( + x_range.values, + y_low, + y_high, + color="#c62828", + alpha=0.18, + label="最大亏损(q90)", + ) + + ax.set_title(title) + ax.grid(True, alpha=0.25) + ax.legend(loc="best", fontsize=9) + + ax1 = fig.add_subplot(gs[0]) + rr_long_pct = long_m["rr"] * 100.0 + profit_long_pct = long_m["profit"] * 100.0 + loss_long_pct = long_m["loss"] * 100.0 + title_long = f"多头机会:盈亏比 RR={long_m['rr']:.2f}(+{profit_long_pct:.2f}% / -{loss_long_pct:.2f}%)" + _draw_panel( + ax1, + title_long, + long_m["target_idx"], + long_m["target_price"], + long_m["worst_idx"], + long_m["worst_price"], + "LONG", ) - # 预测收盘价区间与中位线 - if close_q10 is not None and close_q90 is not None: - ax_price.fill_between( - y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)" - ) - if close_q50 is not None: - ax_price.plot( - y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)" - ) - - # 预测起点与阴影 - if len(y_timestamp) > 0: - ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点") - ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08) - - # 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价 - # 合并历史和预测时间范围与价格数据 - all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True) - hist_close = historical_df["close"] - pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp)) - all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True) - - if len(all_timestamps) > 0: - start_time = all_timestamps.min() - end_time = all_timestamps.max() - - # 生成所有16:00时间点(UTC+8) - current_date = start_time.normalize() # 当天零点 - while current_date <= end_time: - target_time = current_date + pd.Timedelta(hours=16) # 16:00 - if start_time <= target_time <= end_time: - # 画虚线 - ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5) - - # 找到最接近16:00的时间点的收盘价 - time_diffs = (all_timestamps - target_time).abs() - closest_idx = time_diffs.idxmin() - closest_time = all_timestamps.iloc[closest_idx] - closest_price = all_close.iloc[closest_idx] - - # 如果时间差不超过1小时,则标注价格 - if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price): - ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7) - ax_price.text( - closest_time, closest_price, - f' ${closest_price:.1f}', - fontsize=9, - color='blue', - verticalalignment='bottom', - horizontalalignment='left', - bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7) - ) - current_date += pd.Timedelta(days=1) - - ax_price.set_ylabel("价格 (USD)", fontsize=11) - ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold") - ax_price.grid(True, alpha=0.3) - ax_price.legend(loc="best", fontsize=9) - - # ===== 下:成交量(仅预测中位量能柱)===== - ax_vol = fig.add_subplot(gs[1], sharex=ax_price) - - # 历史量能柱 - ax_vol.bar( - historical_df["timestamps"], - historical_df["volume"], - width=bar_width_hist, - alpha=0.35, - label="历史成交量", + ax2 = fig.add_subplot(gs[1], sharex=ax1) + rr_short_pct = short_m["rr"] * 100.0 + profit_short_pct = short_m["profit"] * 100.0 + loss_short_pct = short_m["loss"] * 100.0 + title_short = f"空头机会:盈亏比 RR={short_m['rr']:.2f}(+{profit_short_pct:.2f}% / -{loss_short_pct:.2f}%)" + _draw_panel( + ax2, + title_short, + short_m["target_idx"], + short_m["target_price"], + short_m["worst_idx"], + short_m["worst_price"], + "SHORT", ) - # 预测中位量能柱(不画区间) - if vol_q50 is not None: - ax_vol.bar( - y_timestamp.values, - vol_q50, - width=bar_width_pred, - alpha=0.6, - label="预测成交量中位(q50)", - ) + for lbl in ax1.get_xticklabels(): + lbl.set_visible(False) + ax2.set_xlabel("时间") + ax1.set_ylabel("价格 (USD)") + ax2.set_ylabel("价格 (USD)") + fig.suptitle( + f"{symbol} {interval}分钟线 24小时 预测机会评估", fontsize=16, fontweight="bold" + ) + plt.tight_layout(rect=[0, 0, 1, 0.96]) - # 预测起点线 - if len(y_timestamp) > 0: - ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1) + os.makedirs(out_dir, exist_ok=True) + out_png = os.path.join( + out_dir, f"trade_opportunity_{symbol.lower()}_{interval}m.png" + ) + fig.savefig(out_png, dpi=150, bbox_inches="tight") + return out_png - ax_vol.set_ylabel("成交量", fontsize=11) - ax_vol.set_xlabel("时间", fontsize=11) - ax_vol.grid(True, alpha=0.25) - ax_vol.legend(loc="best", fontsize=9) - - # 避免X轴标签重叠 - plt.setp(ax_price.get_xticklabels(), visible=False) - plt.tight_layout() - return fig def main(): - parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化") - parser.add_argument("--symbol", default="BTCUSD") + # 加载环境变量 + load_dotenv() + + parser = argparse.ArgumentParser(description="计算并可视化高盈亏比交易机会") + parser.add_argument("--symbol", default="BTCUSDT") parser.add_argument("--category", default="inverse") - parser.add_argument("--interval", default="15") - parser.add_argument("--limit", type=int, default=1000) - parser.add_argument("--lookback", type=int, default=LOOKBACK) - parser.add_argument("--pred_len", type=int, default=PRED_LEN) - parser.add_argument("--samples", type=int, default=NUM_SAMPLES) - parser.add_argument("--temperature", type=float, default=TEMPERATURE) - parser.add_argument("--top_p", type=float, default=TOP_P) - parser.add_argument("--quantiles", default="0.1,0.5,0.9") - parser.add_argument("--device", default=DEVICE) + parser.add_argument("--interval", default="60") + parser.add_argument("--limit", type=int, default=500) + parser.add_argument("--lookback", type=int, default=360) + parser.add_argument("--pred_len", type=int, default=24) + parser.add_argument("--samples", type=int, default=30) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--device", default=os.getenv("KRONOS_DEVICE", "cuda:0")) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() - # 载入env & HTTP - load_dotenv() - api_key = os.getenv("BYBIT_API_KEY") - api_secret = os.getenv("BYBIT_API_SECRET") - session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret) - # Kronos - predictor = load_kronos(device=args.device) - - # 拉数据 - df = fetch_kline( - session, + # 获取预测聚合数据(仅未来) + agg_df = kronos_predict_df( symbol=args.symbol, + category=args.category, interval=args.interval, limit=args.limit, - category=args.category, - ) - x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows( - df, lookback=args.lookback, pred_len=args.pred_len, - interval_min=int(args.interval), + samples=args.samples, + temperature=args.temperature, + top_p=args.top_p, + quantiles=[0.1, 0.5, 0.9], + device=args.device, + seed=args.seed, + verbose=True, ) - # 历史用于画图的窗口(最近200根,可按需调整) - plot_start = max(0, end_idx - 200) - historical_df = df.loc[plot_start : end_idx - 1].copy() + # 计算多 / 空的机会 + long_m = compute_long_metrics(agg_df) + short_m = compute_short_metrics(agg_df) - # 采样参数(直接使用局部变量,不需要修改全局变量) - temperature = args.temperature - top_p = args.top_p - qs = [float(x) for x in args.quantiles.split(",") if x.strip()] + # 选择更优机会 + best = long_m if long_m["rr"] >= short_m["rr"] else short_m - # 多次预测(传入局部变量) - preds = run_multi_predictions( - predictor, x_df, x_ts, y_ts, num_samples=args.samples, base_seed=args.seed, - temperature=temperature, top_p=top_p + # 生成汇总字符串(增强可读性与可辨识度) + dir_emoji = "📈 多头" if best["direction"] == "LONG" else "📉 空头" + long_arrow = "⬆️" + short_arrow = "⬇️" + summary_lines = [ + "== 24小时 每日交易机会评估 ==\n", + f"📌 标的: {args.symbol} | 周期: {args.interval}m | 当前价: {best['base']:.2f}\n", + f"📈 多头 | RR={long_m['rr']:.2f} | 盈利 +{long_m['profit'] * 100.0:.2f}% | 亏损 -{long_m['loss'] * 100.0:.2f}% | 🎯 目标 {long_m['target_price']:.2f} | 🛡️ 最差 {long_m['worst_price']:.2f}\n", + f"📉 空头 | RR={short_m['rr']:.2f} | 盈利 +{short_m['profit'] * 100.0:.2f}% | 亏损 -{short_m['loss'] * 100.0:.2f}% | 🎯 目标 {short_m['target_price']:.2f} | 🛡️ 最差 {short_m['worst_price']:.2f}\n", + f"✅ 最佳: {dir_emoji} | 盈亏比 RR={best['rr']:.2f}", + ] + if best["rr"] < 2.0: + summary_lines.append("⚠️ 交易机会不佳,建议观望。") + summary = "\n".join(summary_lines) + "\n" + + # 打印汇总 + print(summary) + + # 获取预测前24根K线用于可视化上下文 + try: + full_hist_df = get_recent_kline( + symbol=args.symbol, + category=args.category, + interval=args.interval, + limit=args.limit, + ) + hist24 = full_hist_df.tail(24).reset_index(drop=True) + except Exception: + hist24 = None + + # 可视化 + out_png = plot_opportunities( + agg_df, long_m, short_m, args.symbol, args.interval, out_dir="figures", hist_df=hist24 ) - - # 聚合分位 - agg_df = aggregate_quantiles(preds, quantiles=qs) - agg_df["timestamps"] = y_ts.values - - # 输出 CSV(聚合) - out_csv = os.path.join( - OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv" - ) - agg_df.to_csv(out_csv, index=False) - print(f"预测分位数据已保存到: {out_csv}") - - # 作图 - fig = plot_results( - historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟" - ) - out_png = os.path.join( - OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png" - ) - fig.savefig(out_png, dpi=150, bbox_inches="tight") - print(f"预测区间图表已保存到: {out_png}") - - plt.show() + print(f"可视化已保存: {out_png}") + + # 发送交易机会评估结果到 QQ + try: + from qq import send_prediction_result + target = os.getenv("QQ_BOT_TARGET_ID") + send_prediction_result(summary, png_path=out_png, target_id=target) + except Exception as e: + print(f"发送 QQ 消息失败: {repr(e)}") if __name__ == "__main__": - try: - main() - except Exception as e: - print("运行出错:", repr(e)) - sys.exit(1) + main()