diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..05b9164
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,526 @@
+# -*- coding: utf-8 -*-
+"""
+BTCUSD 15m 多次采样预测 → 区间聚合可视化
+Author: 你(鹅)
+"""
+
+import os
+import sys
+import math
+import json
+import time
+import random
+import argparse
+from datetime import datetime
+
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+from tqdm import tqdm
+
+from dotenv import load_dotenv
+from pybit.unified_trading import HTTP
+
+# ==== Kronos ====
+from model import Kronos, KronosTokenizer, KronosPredictor
+
+# ========== Matplotlib 字体 ==========
+plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
+plt.rcParams["axes.unicode_minus"] = False
+
+# ========== 默认参数 ==========
+NUM_SAMPLES = 20 # 多次采样次数(建议 10~50)
+QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上)
+LOOKBACK = 360 # 历史窗口
+PRED_LEN = 24 # 未来预测点数
+INTERVAL_MIN = 60 # K 线周期(分钟)
+DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0")
+TEMPERATURE = 1.0
+TOP_P = 0.9
+DRAW_WINDOW_LEN = 240 # 历史用于画图的窗口
+
+OUTPUT_DIR = "figures"
+os.makedirs(OUTPUT_DIR, exist_ok=True)
+
+
+# ========== 实用函数 ==========
+def set_global_seed(seed: int):
+ """尽可能固定随机性(若底层采样逻辑使用 torch/np/random)"""
+ try:
+ import torch
+
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+ except Exception:
+ pass
+ random.seed(seed)
+ np.random.seed(seed)
+
+
+def fetch_kline(
+ session: HTTP, symbol="BTCUSD", interval="15", limit=500, category="inverse"
+):
+ print("正在获取K线数据...")
+ resp = session.get_kline(
+ category=category, symbol=symbol, interval=interval, limit=limit
+ )
+ if resp.get("retCode", -1) != 0:
+ raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}")
+ lst = resp["result"]["list"]
+
+ df = pd.DataFrame(
+ lst,
+ columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"],
+ )
+ df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms")
+ for c in ["open", "high", "low", "close", "volume", "turnover"]:
+ df[c] = df[c].astype(float)
+ df = df.sort_values("timestamps").reset_index(drop=True)
+ df["amount"] = df["turnover"]
+ # 去掉最新一根(通常为未完成的当前K线)
+ if len(df) > 0:
+ df = df.iloc[:-1].reset_index(drop=True)
+ print(f"获取到 {len(df)} 根K线数据")
+ return df
+
+
+def get_recent_kline(
+ symbol: str = "BTCUSD",
+ category: str = "inverse",
+ interval: str = "60",
+ limit: int = 500,
+):
+ """便捷封装:基于环境变量创建会话并获取最近K线数据。
+
+ 返回数据按时间升序,且会移除最后一根未完成K线(与 fetch_kline 保持一致)。
+ """
+ load_dotenv()
+ api_key = os.getenv("BYBIT_API_KEY")
+ api_secret = os.getenv("BYBIT_API_SECRET")
+ session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
+ return fetch_kline(
+ session=session,
+ symbol=symbol,
+ interval=interval,
+ limit=limit,
+ category=category,
+ )
+
+
+def prepare_io_windows(
+ df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN
+):
+ end_idx = len(df)
+ start_idx = max(0, end_idx - lookback)
+ x_df = df.loc[
+ start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"]
+ ].reset_index(drop=True)
+ x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True)
+
+ last_ts = df.loc[end_idx - 1, "timestamps"]
+ future_timestamps = pd.date_range(
+ start=last_ts + pd.Timedelta(minutes=interval_min),
+ periods=pred_len,
+ freq=f"{interval_min}min",
+ )
+ y_timestamp = pd.Series(future_timestamps)
+ y_timestamp.index = range(len(y_timestamp))
+
+ print(f"数据总量: {len(df)} 根K线")
+ print(f"使用最新的 {lookback} 根K线(索引 {start_idx} 到 {end_idx-1})")
+ print(f"最后一根历史K线时间: {last_ts}")
+ print(f"预测未来 {pred_len} 根K线")
+ return x_df, x_timestamp, y_timestamp, start_idx, end_idx
+
+
+def load_kronos(device=DEVICE):
+ print("正在加载 Kronos 模型...")
+ tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
+ model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
+ predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
+ print("模型加载完成!\n")
+ return predictor
+
+
+def run_one_prediction(
+ predictor,
+ x_df,
+ x_timestamp,
+ y_timestamp,
+ T=TEMPERATURE,
+ top_p=TOP_P,
+ seed=None,
+ verbose=False,
+):
+ if seed is not None:
+ set_global_seed(seed)
+ return predictor.predict(
+ df=x_df,
+ x_timestamp=x_timestamp,
+ y_timestamp=y_timestamp,
+ pred_len=len(y_timestamp),
+ T=T,
+ top_p=top_p,
+ verbose=verbose,
+ )
+
+
+def run_multi_predictions(
+ predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42,
+ temperature=TEMPERATURE, top_p=TOP_P
+):
+ preds = []
+ print(f"正在进行多次预测:{num_samples} 次(T={temperature}, top_p={top_p})...")
+ for i in tqdm(range(num_samples), desc="预测进度", unit="次", ncols=100):
+ seed = base_seed + i # 每次不同 seed
+ pred_df = run_one_prediction(
+ predictor,
+ x_df,
+ x_timestamp,
+ y_timestamp,
+ T=temperature,
+ top_p=top_p,
+ seed=seed,
+ verbose=False,
+ )
+ # 兼容性处理:确保只取需要的列
+ cols_present = [
+ c
+ for c in ["open", "high", "low", "close", "volume", "amount"]
+ if c in pred_df.columns
+ ]
+ pred_df = pred_df[cols_present].copy()
+ pred_df.reset_index(drop=True, inplace=True)
+ preds.append(pred_df)
+ print("多次预测完成。\n")
+ return preds
+
+
+def aggregate_quantiles(pred_list, quantiles=QUANTILES):
+ """
+ 将多次预测列表聚合成分位数 DataFrame:
+ 输出列命名:
_q10, _q50, _q90(按 quantiles 中的值)
+ """
+ # 先把每次预测拼成 3D:time x feature x samples
+ keys = pred_list[0].columns.tolist()
+ T_len = len(pred_list[0])
+ S = len(pred_list)
+ data = {k: np.zeros((T_len, S), dtype=float) for k in keys}
+ for j, pdf in enumerate(pred_list):
+ for k in keys:
+ data[k][:, j] = pdf[k].values
+
+ out = {}
+ for k in keys:
+ # 极值(用于风险评估)
+ out[f"{k}_min"] = np.min(data[k], axis=1)
+ out[f"{k}_max"] = np.max(data[k], axis=1)
+ for q in quantiles:
+ qv = np.quantile(data[k], q, axis=1)
+ out[f"{k}_q{int(q*100):02d}"] = qv
+ agg_df = pd.DataFrame(out)
+ return agg_df
+
+
+def kronos_predict_df(
+ symbol: str = "BTCUSD",
+ category: str = "inverse",
+ interval: str = "60",
+ limit: int = 500,
+ lookback: int = LOOKBACK,
+ pred_len: int = PRED_LEN,
+ samples: int = NUM_SAMPLES,
+ temperature: float = TEMPERATURE,
+ top_p: float = TOP_P,
+ quantiles: list[float] | None = None,
+ device: str = DEVICE,
+ seed: int = 42,
+ verbose: bool = False,
+):
+ """
+ 作为模块使用的API:返回聚合后的预测DataFrame(仅未来部分)。
+
+ 返回DataFrame包含:
+ - timestamps: 未来各点时间戳
+ - _qXX: 各分位数(如 close_q50)
+ - _min/_max: 样本极值(用于风控/盈亏比)
+ - base_close: 当前基准价(最后一根历史K线的收盘价,重复列)
+ """
+ # 环境 & HTTP会话
+ load_dotenv()
+ api_key = os.getenv("BYBIT_API_KEY")
+ api_secret = os.getenv("BYBIT_API_SECRET")
+ session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
+
+ # 加载模型(懒加载一次)
+ predictor = load_kronos(device=device)
+
+ # 拉取K线并准备窗口
+ df = fetch_kline(
+ session,
+ symbol=symbol,
+ interval=interval,
+ limit=limit,
+ category=category,
+ )
+ x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
+ df,
+ lookback=lookback,
+ pred_len=pred_len,
+ interval_min=int(interval),
+ )
+
+ # 多次预测
+ preds = run_multi_predictions(
+ predictor,
+ x_df,
+ x_ts,
+ y_ts,
+ num_samples=samples,
+ base_seed=seed,
+ temperature=temperature,
+ top_p=top_p,
+ )
+
+ # 聚合
+ qs = quantiles if quantiles is not None else QUANTILES
+ agg_df = aggregate_quantiles(preds, quantiles=qs)
+ agg_df["timestamps"] = y_ts.values
+
+ # 当前价(基准)
+ base_close = float(df.loc[end_idx - 1, "close"]) if end_idx > 0 else np.nan
+ agg_df["base_close"] = base_close
+
+ # 仅返回DataFrame(调用方自行决定是否绘图/保存)
+ if verbose:
+ print(
+ f"完成预测:symbol={symbol} interval={interval}m samples={samples} base_close={base_close}"
+ )
+ return agg_df
+
+
+def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"):
+ """
+ 上下两个子图(共享X轴):
+ - 上方(高度3):历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50)
+ - 下方(高度1):历史成交量柱 + 预测成交量中位柱(仅q50,不显示区间)
+ """
+ import matplotlib.gridspec as gridspec
+
+ # 智能推断K线宽度(柱宽)
+ try:
+ if len(historical_df) >= 2:
+ hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2])
+ else:
+ hist_step = pd.Timedelta(minutes=10)
+ if len(y_timestamp) >= 2:
+ pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0])
+ else:
+ pred_step = pd.Timedelta(minutes=10)
+ bar_width_hist = hist_step * 0.8
+ bar_width_pred = pred_step * 0.8
+ except Exception:
+ bar_width_hist = pd.Timedelta(minutes=10)
+ bar_width_pred = pd.Timedelta(minutes=10)
+
+ # 取出预测分位
+ close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None
+ close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None
+ close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None
+
+ vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None
+
+ # 图形与网格
+ fig = plt.figure(figsize=(18, 10))
+ gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3
+
+ # ===== 上:收盘价 =====
+ ax_price = fig.add_subplot(gs[0])
+
+ ax_price.plot(
+ historical_df["timestamps"],
+ historical_df["close"],
+ label="历史收盘价",
+ linewidth=1.8,
+ )
+
+ # 预测收盘价区间与中位线
+ if close_q10 is not None and close_q90 is not None:
+ ax_price.fill_between(
+ y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)"
+ )
+ if close_q50 is not None:
+ ax_price.plot(
+ y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)"
+ )
+
+ # 预测起点与阴影
+ if len(y_timestamp) > 0:
+ ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点")
+ ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08)
+
+ # 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价
+ # 合并历史和预测时间范围与价格数据
+ all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True)
+ hist_close = historical_df["close"]
+ pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp))
+ all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True)
+
+ if len(all_timestamps) > 0:
+ start_time = all_timestamps.min()
+ end_time = all_timestamps.max()
+
+ # 生成所有16:00时间点(UTC+8)
+ current_date = start_time.normalize() # 当天零点
+ while current_date <= end_time:
+ target_time = current_date + pd.Timedelta(hours=16) # 16:00
+ if start_time <= target_time <= end_time:
+ # 画虚线
+ ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5)
+
+ # 找到最接近16:00的时间点的收盘价
+ time_diffs = (all_timestamps - target_time).abs()
+ closest_idx = time_diffs.idxmin()
+ closest_time = all_timestamps.iloc[closest_idx]
+ closest_price = all_close.iloc[closest_idx]
+
+ # 如果时间差不超过1小时,则标注价格
+ if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price):
+ ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7)
+ ax_price.text(
+ closest_time, closest_price,
+ f' ${closest_price:.1f}',
+ fontsize=9,
+ color='blue',
+ verticalalignment='bottom',
+ horizontalalignment='left',
+ bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7)
+ )
+ current_date += pd.Timedelta(days=1)
+
+ ax_price.set_ylabel("价格 (USD)", fontsize=11)
+ ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold")
+ ax_price.grid(True, alpha=0.3)
+ ax_price.legend(loc="best", fontsize=9)
+
+ # ===== 下:成交量(仅预测中位量能柱)=====
+ ax_vol = fig.add_subplot(gs[1], sharex=ax_price)
+
+ # 历史量能柱
+ ax_vol.bar(
+ historical_df["timestamps"],
+ historical_df["volume"],
+ width=bar_width_hist,
+ alpha=0.35,
+ label="历史成交量",
+ )
+
+ # 预测中位量能柱(不画区间)
+ if vol_q50 is not None:
+ ax_vol.bar(
+ y_timestamp.values,
+ vol_q50,
+ width=bar_width_pred,
+ alpha=0.6,
+ label="预测成交量中位(q50)",
+ )
+
+ # 预测起点线
+ if len(y_timestamp) > 0:
+ ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1)
+
+ ax_vol.set_ylabel("成交量", fontsize=11)
+ ax_vol.set_xlabel("时间", fontsize=11)
+ ax_vol.grid(True, alpha=0.25)
+ ax_vol.legend(loc="best", fontsize=9)
+
+ # 避免X轴标签重叠
+ plt.setp(ax_price.get_xticklabels(), visible=False)
+ plt.tight_layout()
+ return fig
+
+def main():
+ parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化")
+ parser.add_argument("--symbol", default="BTCUSD")
+ parser.add_argument("--category", default="inverse")
+ parser.add_argument("--interval", default="60")
+ parser.add_argument("--limit", type=int, default=500)
+ parser.add_argument("--lookback", type=int, default=LOOKBACK)
+ parser.add_argument("--pred_len", type=int, default=PRED_LEN)
+ parser.add_argument("--samples", type=int, default=NUM_SAMPLES)
+ parser.add_argument("--temperature", type=float, default=TEMPERATURE)
+ parser.add_argument("--top_p", type=float, default=TOP_P)
+ parser.add_argument("--quantiles", default="0.1,0.5,0.9")
+ parser.add_argument("--device", default=DEVICE)
+ parser.add_argument("--seed", type=int, default=42)
+ args = parser.parse_args()
+
+ # 载入env & HTTP
+ load_dotenv()
+ api_key = os.getenv("BYBIT_API_KEY")
+ api_secret = os.getenv("BYBIT_API_SECRET")
+ session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
+ # 为命令行模式执行完整流程
+ predictor = load_kronos(device=args.device)
+ df = fetch_kline(
+ session,
+ symbol=args.symbol,
+ interval=args.interval,
+ limit=args.limit,
+ category=args.category,
+ )
+ x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
+ df,
+ lookback=args.lookback,
+ pred_len=args.pred_len,
+ interval_min=int(args.interval),
+ )
+
+ plot_start = max(0, end_idx - DRAW_WINDOW_LEN)
+ historical_df = df.loc[plot_start : end_idx - 1].copy()
+
+ temperature = args.temperature
+ top_p = args.top_p
+ qs = [float(x) for x in args.quantiles.split(",") if x.strip()]
+
+ preds = run_multi_predictions(
+ predictor,
+ x_df,
+ x_ts,
+ y_ts,
+ num_samples=args.samples,
+ base_seed=args.seed,
+ temperature=temperature,
+ top_p=top_p,
+ )
+
+ agg_df = aggregate_quantiles(preds, quantiles=qs)
+ agg_df["timestamps"] = y_ts.values
+
+ # 输出 CSV(聚合)
+ out_csv = os.path.join(
+ OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv"
+ )
+ agg_df.to_csv(out_csv, index=False)
+ print(f"预测分位数据已保存到: {out_csv}")
+
+ # 作图
+ fig = plot_results(
+ historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟"
+ )
+ out_png = os.path.join(
+ OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png"
+ )
+ fig.savefig(out_png, dpi=150, bbox_inches="tight")
+ print(f"预测区间图表已保存到: {out_png}")
+
+ plt.show()
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except Exception as e:
+ print("运行出错:", repr(e))
+ sys.exit(1)
diff --git a/qq.py b/qq.py
new file mode 100644
index 0000000..5679283
--- /dev/null
+++ b/qq.py
@@ -0,0 +1,132 @@
+# -*- coding: utf-8 -*-
+"""
+轻量的 QQ Bot 客户端(用于发送私聊文本和媒体)
+
+使用环境变量:
+- QQ_BOT_URL: bot 接收地址(例如 http://localhost:30000/send_private_msg)
+- QQ_BOT_TARGET_ID: 默认目标 QQ id(可被函数参数覆盖)
+
+示例:
+ from qq import send_msg, send_media_msg
+ send_msg('hello')
+ # 发送图片时将自动转换为 "base64://..." 格式
+ send_media_msg('figures/btcusd_interval_pred_60m.png', type='image')
+"""
+from __future__ import annotations
+
+import os
+import json
+from typing import Optional
+import base64
+
+try:
+ import requests
+except Exception: # pragma: no cover - runtime dependency
+ requests = None
+
+
+def _bot_url() -> str:
+ return os.getenv("QQ_BOT_URL", "http://localhost:30000/send_private_msg")
+
+
+def _default_target() -> Optional[str]:
+ return os.getenv("QQ_BOT_TARGET_ID")
+
+
+def _file_to_base64_uri(path: str) -> str:
+ """将本地文件读取为 base64 uri 字符串(前缀 base64://)。"""
+ with open(path, "rb") as f:
+ data = f.read()
+ b64 = base64.b64encode(data).decode("ascii")
+ return f"base64://{b64}"
+
+
+def send_msg(msg: str, target_id: Optional[str] = None, timeout: float = 8.0):
+ """发送文本消息到 QQ 机器人。返回 requests.Response 或 None(失败)。"""
+ if target_id is None:
+ target_id = _default_target()
+ if target_id is None:
+ raise ValueError("target_id 未提供,且环境变量 QQ_BOT_TARGET_ID 未设置")
+
+ payload = {
+ "user_id": str(target_id),
+ "message": [
+ {"type": "text", "data": {"text": str(msg)}}
+ ],
+ }
+
+ url = _bot_url()
+ print(f"[QQ] 发送文本到 {target_id}: {msg}")
+ if requests is None:
+ print("requests 未安装,无法发送消息")
+ return None
+
+ try:
+ resp = requests.post(url, json=payload, timeout=timeout)
+ try:
+ # 非必要:打印返回的简短信息
+ print(f"[QQ] 返回: {resp.status_code} {resp.text[:200]}")
+ except Exception:
+ pass
+ return resp
+ except Exception as e:
+ print(f"[QQ] 发送文本消息出错: {e}")
+ return None
+
+
+def send_media_msg(file_path: str, target_id: Optional[str] = None, type: str = "image", timeout: float = 15.0):
+ """发送媒体消息(image 或 video)。
+
+ - 当 type == "image" 时:读取文件并以 "base64://..." 形式发送。
+ - 当 type == "video" 时:仍使用 "file://" 本地文件引用(保持原有行为)。
+ """
+ if type not in ("image", "video"):
+ raise ValueError("type 必须为 'image' 或 'video'")
+
+ if target_id is None:
+ target_id = _default_target()
+ if target_id is None:
+ raise ValueError("target_id 未提供,且环境变量 QQ_BOT_TARGET_ID 未设置")
+
+ # 根据类型构造 file 字段
+ if type == "image":
+ try:
+ file_field = _file_to_base64_uri(file_path)
+ except Exception as e:
+ print(f"[QQ] 读取图片失败: {e}")
+ return None
+ else: # video
+ file_field = f"file://{file_path}"
+
+ payload = {
+ "user_id": str(target_id),
+ "message": [
+ {"type": type, "data": {"file": file_field}}
+ ],
+ }
+
+ url = _bot_url()
+ print(f"[QQ] 发送媒体 {type} 到 {target_id}: {file_path}")
+ if requests is None:
+ print("requests 未安装,无法发送媒体消息")
+ return None
+
+ try:
+ resp = requests.post(url, json=payload, timeout=timeout)
+ try:
+ print(f"[QQ] 返回: {resp.status_code} {resp.text[:200]}")
+ except Exception:
+ pass
+ return resp
+ except Exception as e:
+ print(f"[QQ] 发送媒体消息出错: {e}")
+ return None
+
+
+def send_prediction_result(summary: str, png_path: Optional[str] = None, target_id: Optional[str] = None):
+ """便捷函数:先发送 summary 文本,再发送 png(若提供)。"""
+ r1 = send_msg(summary, target_id=target_id)
+ r2 = None
+ if png_path:
+ r2 = send_media_msg(png_path, target_id=target_id, type="image")
+ return r1, r2
diff --git a/requirements.txt b/requirements.txt
index d94a8d7..352e7a1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,12 @@
numpy
pandas
torch
-
+requests
+dotenv
einops==0.8.1
huggingface_hub==0.33.1
matplotlib==3.9.3
pandas==2.2.2
tqdm==4.67.1
safetensors==0.6.2
+pybit
diff --git a/trader.py b/trader.py
index 52d289c..360a61e 100644
--- a/trader.py
+++ b/trader.py
@@ -1,419 +1,408 @@
# -*- coding: utf-8 -*-
"""
-BTCUSD 15m 多次采样预测 → 区间聚合可视化
-Author: 你(鹅)
+基于 Kronos 多次采样预测,评估多/空两侧的盈亏比,并给出最佳机会与可视化。
+
+盈亏比 = 潜在盈利 / 最大亏损
+
+多头:
+- 潜在盈利:以中位线 close_q50 为准,选取未来涨幅最高点(相对当前价 base_close)的涨幅。
+- 最大亏损:从现在到该目标点区间内,使用多次采样的最差点(close_min)相对当前价的最大跌幅。
+
+空头:
+- 潜在盈利:以中位线 close_q50 为准,选取未来跌幅最大点(相对当前价 base_close)的跌幅。
+- 最大亏损:从现在到该目标点区间内,使用多次采样的最差点(close_max)相对当前价的最大涨幅。
"""
import os
-import sys
-import math
-import json
-import time
-import random
import argparse
-from datetime import datetime
+from typing import Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
-from tqdm import tqdm
from dotenv import load_dotenv
-from pybit.unified_trading import HTTP
+from predict import kronos_predict_df, get_recent_kline
-# ==== Kronos ====
-from model import Kronos, KronosTokenizer, KronosPredictor
-# ========== Matplotlib 字体 ==========
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
plt.rcParams["axes.unicode_minus"] = False
-# ========== 默认参数 ==========
-NUM_SAMPLES = 20 # 多次采样次数(建议 10~50)
-QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上)
-LOOKBACK = 360 # 历史窗口
-PRED_LEN = 96 # 未来预测点数
-INTERVAL_MIN = 15 # K 线周期(分钟)
-DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0")
-TEMPERATURE = 1.0
-TOP_P = 0.9
-OUTPUT_DIR = "figures"
-os.makedirs(OUTPUT_DIR, exist_ok=True)
-
-
-# ========== 实用函数 ==========
-def set_global_seed(seed: int):
- """尽可能固定随机性(若底层采样逻辑使用 torch/np/random)"""
+def _safe_get(series: pd.Series, default=np.nan):
try:
- import torch
-
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(seed)
+ return float(series.iloc[0])
except Exception:
- pass
- random.seed(seed)
- np.random.seed(seed)
+ return default
-def fetch_kline(
- session: HTTP, symbol="BTCUSD", interval="15", limit=1000, category="inverse"
-):
- print("正在获取K线数据...")
- resp = session.get_kline(
- category=category, symbol=symbol, interval=interval, limit=limit
+def compute_long_metrics(agg_df: pd.DataFrame) -> dict:
+ base = float(agg_df["base_close"].iloc[0])
+ close_med = agg_df.get("close_q50", agg_df.get("close_q10"))
+ if close_med is None:
+ raise ValueError("预测结果中缺少 close_q50(或 close_q10 兜底)列")
+
+ # 相对收益(百分比)
+ ret = (close_med - base) / base
+ j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0
+ profit = max(float(ret.iloc[j_target]), 0.0)
+
+ # 区间内的“最差的90%分位”(使用下分位线 q10,而非样本极小值)
+ q10 = agg_df.get("close_q10", close_med)
+ worst_price = float(q10.iloc[: j_target + 1].min())
+ loss = max((base - worst_price) / base, 1e-9)
+
+ return {
+ "direction": "LONG",
+ "base": base,
+ "profit": profit,
+ "loss": loss,
+ "rr": profit / loss if loss > 0 else np.inf,
+ "target_idx": j_target,
+ "target_price": float(close_med.iloc[j_target]),
+ "worst_idx": int(np.nanargmin(q10.iloc[: j_target + 1].values)),
+ "worst_price": worst_price,
+ }
+
+
+def compute_short_metrics(agg_df: pd.DataFrame) -> dict:
+ base = float(agg_df["base_close"].iloc[0])
+ close_med = agg_df.get("close_q50", agg_df.get("close_q90"))
+ if close_med is None:
+ raise ValueError("预测结果中缺少 close_q50(或 close_q90 兜底)列")
+
+ # 相对收益(空头盈利为向下幅度)
+ ret = (base - close_med) / base
+ j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0
+ profit = max(float(ret.iloc[j_target]), 0.0)
+
+ # 区间内的“最差的90%分位”(使用上分位线 q90,而非样本极大值)
+ q90 = agg_df.get("close_q90", close_med)
+ worst_price = float(q90.iloc[: j_target + 1].max())
+ loss = max((worst_price - base) / base, 1e-9)
+
+ return {
+ "direction": "SHORT",
+ "base": base,
+ "profit": profit,
+ "loss": loss,
+ "rr": profit / loss if loss > 0 else np.inf,
+ "target_idx": j_target,
+ "target_price": float(close_med.iloc[j_target]),
+ "worst_idx": int(np.nanargmax(q90.iloc[: j_target + 1].values)),
+ "worst_price": worst_price,
+ }
+
+
+def plot_opportunities(
+ agg_df: pd.DataFrame,
+ long_m: dict,
+ short_m: dict,
+ symbol: str,
+ interval: str,
+ out_dir: str,
+ hist_df: pd.DataFrame | None = None,
+) -> str:
+ """生成清晰可视化:上下两个子图展示多/空机会与风险。"""
+ from matplotlib import gridspec
+ import numpy as np
+
+ ts = (
+ pd.to_datetime(agg_df["timestamps"])
+ if not np.issubdtype(agg_df["timestamps"].dtype, np.datetime64)
+ else agg_df["timestamps"]
)
- if resp.get("retCode", -1) != 0:
- raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}")
- lst = resp["result"]["list"]
+ base = long_m["base"]
+ q10 = agg_df.get("close_q10", None)
+ q50 = agg_df.get("close_q50", None)
+ q90 = agg_df.get("close_q90", None)
- df = pd.DataFrame(
- lst,
- columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"],
- )
- df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms")
- for c in ["open", "high", "low", "close", "volume", "turnover"]:
- df[c] = df[c].astype(float)
- df = df.sort_values("timestamps").reset_index(drop=True)
- df["amount"] = df["turnover"]
- print(f"获取到 {len(df)} 根K线数据")
- return df
+ # 历史24根K线(若提供)
+ hist_ts = None
+ hist_close = None
+ if hist_df is not None and len(hist_df) > 0:
+ try:
+ hist_ts = (
+ pd.to_datetime(hist_df["timestamps"]) if not np.issubdtype(hist_df["timestamps"].dtype, np.datetime64) else hist_df["timestamps"]
+ )
+ hist_close = hist_df["close"].astype(float).values
+ except Exception:
+ hist_ts, hist_close = None, None
-
-def prepare_io_windows(
- df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN
-):
- end_idx = len(df)
- start_idx = max(0, end_idx - lookback)
- x_df = df.loc[
- start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"]
- ].reset_index(drop=True)
- x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True)
-
- last_ts = df.loc[end_idx - 1, "timestamps"]
- future_timestamps = pd.date_range(
- start=last_ts + pd.Timedelta(minutes=interval_min),
- periods=pred_len,
- freq=f"{interval_min}min",
- )
- y_timestamp = pd.Series(future_timestamps)
- y_timestamp.index = range(len(y_timestamp))
-
- print(f"数据总量: {len(df)} 根K线")
- print(f"使用最新的 {lookback} 根K线(索引 {start_idx} 到 {end_idx-1})")
- print(f"最后一根历史K线时间: {last_ts}")
- print(f"预测未来 {pred_len} 根K线")
- return x_df, x_timestamp, y_timestamp, start_idx, end_idx
-
-
-def load_kronos(device=DEVICE):
- print("正在加载 Kronos 模型...")
- tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
- model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
- predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
- print("模型加载完成!\n")
- return predictor
-
-
-def run_one_prediction(
- predictor,
- x_df,
- x_timestamp,
- y_timestamp,
- T=TEMPERATURE,
- top_p=TOP_P,
- seed=None,
- verbose=False,
-):
- if seed is not None:
- set_global_seed(seed)
- return predictor.predict(
- df=x_df,
- x_timestamp=x_timestamp,
- y_timestamp=y_timestamp,
- pred_len=len(y_timestamp),
- T=T,
- top_p=top_p,
- verbose=verbose,
- )
-
-
-def run_multi_predictions(
- predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42,
- temperature=TEMPERATURE, top_p=TOP_P
-):
- preds = []
- print(f"正在进行多次预测:{num_samples} 次(T={temperature}, top_p={top_p})...")
- for i in tqdm(range(num_samples), desc="预测进度", unit="次", ncols=100):
- seed = base_seed + i # 每次不同 seed
- pred_df = run_one_prediction(
- predictor,
- x_df,
- x_timestamp,
- y_timestamp,
- T=temperature,
- top_p=top_p,
- seed=seed,
- verbose=False,
- )
- # 兼容性处理:确保只取需要的列
- cols_present = [
- c
- for c in ["open", "high", "low", "close", "volume", "amount"]
- if c in pred_df.columns
- ]
- pred_df = pred_df[cols_present].copy()
- pred_df.reset_index(drop=True, inplace=True)
- preds.append(pred_df)
- print("多次预测完成。\n")
- return preds
-
-
-def aggregate_quantiles(pred_list, quantiles=QUANTILES):
- """
- 将多次预测列表聚合成分位数 DataFrame:
- 输出列命名:_q10, _q50, _q90(按 quantiles 中的值)
- """
- # 先把每次预测拼成 3D:time x feature x samples
- keys = pred_list[0].columns.tolist()
- T_len = len(pred_list[0])
- S = len(pred_list)
- data = {k: np.zeros((T_len, S), dtype=float) for k in keys}
- for j, pdf in enumerate(pred_list):
- for k in keys:
- data[k][:, j] = pdf[k].values
-
- out = {}
- for k in keys:
- for q in quantiles:
- qv = np.quantile(data[k], q, axis=1)
- out[f"{k}_q{int(q*100):02d}"] = qv
- agg_df = pd.DataFrame(out)
- return agg_df
-
-
-def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"):
- """
- 上下两个子图(共享X轴):
- - 上方(高度3):历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50)
- - 下方(高度1):历史成交量柱 + 预测成交量中位柱(仅q50,不显示区间)
- """
- import matplotlib.gridspec as gridspec
-
- # 智能推断K线宽度(柱宽)
- try:
- if len(historical_df) >= 2:
- hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2])
- else:
- hist_step = pd.Timedelta(minutes=10)
- if len(y_timestamp) >= 2:
- pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0])
- else:
- pred_step = pd.Timedelta(minutes=10)
- bar_width_hist = hist_step * 0.8
- bar_width_pred = pred_step * 0.8
- except Exception:
- bar_width_hist = pd.Timedelta(minutes=10)
- bar_width_pred = pd.Timedelta(minutes=10)
-
- # 取出预测分位
- close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None
- close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None
- close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None
-
- vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None
-
- # 图形与网格
fig = plt.figure(figsize=(18, 10))
- gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3
+ gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1], hspace=0.15)
- # ===== 上:收盘价 =====
- ax_price = fig.add_subplot(gs[0])
+ def _draw_panel(
+ ax, title, target_idx, target_price, worst_idx, worst_price, direction
+ ):
+ # 历史收盘价(预测前24根)
+ if hist_ts is not None and hist_close is not None:
+ ax.plot(
+ hist_ts.values,
+ hist_close,
+ color="#6c757d",
+ lw=1.6,
+ label="历史收盘(近24根)",
+ )
+ # 预测起点标注
+ try:
+ ax.axvline(ts.iloc[0], color="gray", ls=":", lw=1.0, label="预测起点")
+ except Exception:
+ pass
- ax_price.plot(
- historical_df["timestamps"],
- historical_df["close"],
- label="历史收盘价",
- linewidth=1.8,
+ # 区间带
+ if q10 is not None and q90 is not None:
+ ax.fill_between(
+ ts.values,
+ q10.values,
+ q90.values,
+ color="#7dbcea",
+ alpha=0.22,
+ label="q10~q90",
+ )
+ # 中位线
+ if q50 is not None:
+ ax.plot(ts.values, q50.values, color="#00539c", lw=2.0, label="中位线 q50")
+
+ # 基准线
+ ax.axhline(base, color="gray", ls=":", lw=1.2, label=f"当前价 {base:.1f}")
+
+ # 目标点
+ ax.plot(
+ ts.iloc[target_idx],
+ target_price,
+ marker="o",
+ color="#2e7d32",
+ ms=8,
+ label="目标点",
+ )
+ ax.annotate(
+ f"目标 {target_price:.1f}",
+ (ts.iloc[target_idx], target_price),
+ textcoords="offset points",
+ xytext=(8, 8),
+ bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#2e7d32", alpha=0.8),
+ color="#2e7d32",
+ )
+
+ # 最差点
+ ax.plot(
+ ts.iloc[worst_idx],
+ worst_price,
+ marker="v",
+ color="#c62828",
+ ms=8,
+ label="最差点",
+ )
+ ax.annotate(
+ f"最差 {worst_price:.1f}",
+ (ts.iloc[worst_idx], worst_price),
+ textcoords="offset points",
+ xytext=(8, -14),
+ bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#c62828", alpha=0.8),
+ color="#c62828",
+ )
+
+ # 连接箭头(方向)
+ color = "#2e7d32" if direction == "LONG" else "#c62828"
+ ax.annotate(
+ "",
+ xy=(ts.iloc[target_idx], target_price),
+ xytext=(ts.iloc[0], base),
+ arrowprops=dict(arrowstyle="->", color=color, lw=2, alpha=0.9),
+ )
+
+ # 盈亏比矩形:用填充矩形表示潜在盈利(绿色)与最大亏损(红色)
+ x_range = ts.iloc[: target_idx + 1]
+ if direction == "LONG":
+ # 盈利(若目标价在基准价之上)
+ if target_price > base:
+ y_low = np.full_like(x_range, base, dtype=float)
+ y_high = np.full_like(x_range, target_price, dtype=float)
+ ax.fill_between(
+ x_range.values,
+ y_low,
+ y_high,
+ color="#2e7d32",
+ alpha=0.18,
+ label="潜在盈利",
+ )
+ # 亏损(q10 代表“最差的90%分位”)
+ if q10 is not None and np.isfinite(worst_price) and worst_price < base:
+ y_low = np.full_like(x_range, worst_price, dtype=float)
+ y_high = np.full_like(x_range, base, dtype=float)
+ ax.fill_between(
+ x_range.values,
+ y_low,
+ y_high,
+ color="#c62828",
+ alpha=0.18,
+ label="最大亏损(q10)",
+ )
+ else:
+ # 盈利(空头目标价在基准价之下)
+ if target_price < base:
+ y_low = np.full_like(x_range, target_price, dtype=float)
+ y_high = np.full_like(x_range, base, dtype=float)
+ ax.fill_between(
+ x_range.values,
+ y_low,
+ y_high,
+ color="#2e7d32",
+ alpha=0.18,
+ label="潜在盈利",
+ )
+ # 亏损(q90 代表“最差的90%分位”)
+ if q90 is not None and np.isfinite(worst_price) and worst_price > base:
+ y_low = np.full_like(x_range, base, dtype=float)
+ y_high = np.full_like(x_range, worst_price, dtype=float)
+ ax.fill_between(
+ x_range.values,
+ y_low,
+ y_high,
+ color="#c62828",
+ alpha=0.18,
+ label="最大亏损(q90)",
+ )
+
+ ax.set_title(title)
+ ax.grid(True, alpha=0.25)
+ ax.legend(loc="best", fontsize=9)
+
+ ax1 = fig.add_subplot(gs[0])
+ rr_long_pct = long_m["rr"] * 100.0
+ profit_long_pct = long_m["profit"] * 100.0
+ loss_long_pct = long_m["loss"] * 100.0
+ title_long = f"多头机会:盈亏比 RR={long_m['rr']:.2f}(+{profit_long_pct:.2f}% / -{loss_long_pct:.2f}%)"
+ _draw_panel(
+ ax1,
+ title_long,
+ long_m["target_idx"],
+ long_m["target_price"],
+ long_m["worst_idx"],
+ long_m["worst_price"],
+ "LONG",
)
- # 预测收盘价区间与中位线
- if close_q10 is not None and close_q90 is not None:
- ax_price.fill_between(
- y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)"
- )
- if close_q50 is not None:
- ax_price.plot(
- y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)"
- )
-
- # 预测起点与阴影
- if len(y_timestamp) > 0:
- ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点")
- ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08)
-
- # 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价
- # 合并历史和预测时间范围与价格数据
- all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True)
- hist_close = historical_df["close"]
- pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp))
- all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True)
-
- if len(all_timestamps) > 0:
- start_time = all_timestamps.min()
- end_time = all_timestamps.max()
-
- # 生成所有16:00时间点(UTC+8)
- current_date = start_time.normalize() # 当天零点
- while current_date <= end_time:
- target_time = current_date + pd.Timedelta(hours=16) # 16:00
- if start_time <= target_time <= end_time:
- # 画虚线
- ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5)
-
- # 找到最接近16:00的时间点的收盘价
- time_diffs = (all_timestamps - target_time).abs()
- closest_idx = time_diffs.idxmin()
- closest_time = all_timestamps.iloc[closest_idx]
- closest_price = all_close.iloc[closest_idx]
-
- # 如果时间差不超过1小时,则标注价格
- if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price):
- ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7)
- ax_price.text(
- closest_time, closest_price,
- f' ${closest_price:.1f}',
- fontsize=9,
- color='blue',
- verticalalignment='bottom',
- horizontalalignment='left',
- bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7)
- )
- current_date += pd.Timedelta(days=1)
-
- ax_price.set_ylabel("价格 (USD)", fontsize=11)
- ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold")
- ax_price.grid(True, alpha=0.3)
- ax_price.legend(loc="best", fontsize=9)
-
- # ===== 下:成交量(仅预测中位量能柱)=====
- ax_vol = fig.add_subplot(gs[1], sharex=ax_price)
-
- # 历史量能柱
- ax_vol.bar(
- historical_df["timestamps"],
- historical_df["volume"],
- width=bar_width_hist,
- alpha=0.35,
- label="历史成交量",
+ ax2 = fig.add_subplot(gs[1], sharex=ax1)
+ rr_short_pct = short_m["rr"] * 100.0
+ profit_short_pct = short_m["profit"] * 100.0
+ loss_short_pct = short_m["loss"] * 100.0
+ title_short = f"空头机会:盈亏比 RR={short_m['rr']:.2f}(+{profit_short_pct:.2f}% / -{loss_short_pct:.2f}%)"
+ _draw_panel(
+ ax2,
+ title_short,
+ short_m["target_idx"],
+ short_m["target_price"],
+ short_m["worst_idx"],
+ short_m["worst_price"],
+ "SHORT",
)
- # 预测中位量能柱(不画区间)
- if vol_q50 is not None:
- ax_vol.bar(
- y_timestamp.values,
- vol_q50,
- width=bar_width_pred,
- alpha=0.6,
- label="预测成交量中位(q50)",
- )
+ for lbl in ax1.get_xticklabels():
+ lbl.set_visible(False)
+ ax2.set_xlabel("时间")
+ ax1.set_ylabel("价格 (USD)")
+ ax2.set_ylabel("价格 (USD)")
+ fig.suptitle(
+ f"{symbol} {interval}分钟线 24小时 预测机会评估", fontsize=16, fontweight="bold"
+ )
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
- # 预测起点线
- if len(y_timestamp) > 0:
- ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1)
+ os.makedirs(out_dir, exist_ok=True)
+ out_png = os.path.join(
+ out_dir, f"trade_opportunity_{symbol.lower()}_{interval}m.png"
+ )
+ fig.savefig(out_png, dpi=150, bbox_inches="tight")
+ return out_png
- ax_vol.set_ylabel("成交量", fontsize=11)
- ax_vol.set_xlabel("时间", fontsize=11)
- ax_vol.grid(True, alpha=0.25)
- ax_vol.legend(loc="best", fontsize=9)
-
- # 避免X轴标签重叠
- plt.setp(ax_price.get_xticklabels(), visible=False)
- plt.tight_layout()
- return fig
def main():
- parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化")
- parser.add_argument("--symbol", default="BTCUSD")
+ # 加载环境变量
+ load_dotenv()
+
+ parser = argparse.ArgumentParser(description="计算并可视化高盈亏比交易机会")
+ parser.add_argument("--symbol", default="BTCUSDT")
parser.add_argument("--category", default="inverse")
- parser.add_argument("--interval", default="15")
- parser.add_argument("--limit", type=int, default=1000)
- parser.add_argument("--lookback", type=int, default=LOOKBACK)
- parser.add_argument("--pred_len", type=int, default=PRED_LEN)
- parser.add_argument("--samples", type=int, default=NUM_SAMPLES)
- parser.add_argument("--temperature", type=float, default=TEMPERATURE)
- parser.add_argument("--top_p", type=float, default=TOP_P)
- parser.add_argument("--quantiles", default="0.1,0.5,0.9")
- parser.add_argument("--device", default=DEVICE)
+ parser.add_argument("--interval", default="60")
+ parser.add_argument("--limit", type=int, default=500)
+ parser.add_argument("--lookback", type=int, default=360)
+ parser.add_argument("--pred_len", type=int, default=24)
+ parser.add_argument("--samples", type=int, default=30)
+ parser.add_argument("--temperature", type=float, default=1.0)
+ parser.add_argument("--top_p", type=float, default=0.9)
+ parser.add_argument("--device", default=os.getenv("KRONOS_DEVICE", "cuda:0"))
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
- # 载入env & HTTP
- load_dotenv()
- api_key = os.getenv("BYBIT_API_KEY")
- api_secret = os.getenv("BYBIT_API_SECRET")
- session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
- # Kronos
- predictor = load_kronos(device=args.device)
-
- # 拉数据
- df = fetch_kline(
- session,
+ # 获取预测聚合数据(仅未来)
+ agg_df = kronos_predict_df(
symbol=args.symbol,
+ category=args.category,
interval=args.interval,
limit=args.limit,
- category=args.category,
- )
- x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
- df,
lookback=args.lookback,
pred_len=args.pred_len,
- interval_min=int(args.interval),
+ samples=args.samples,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ quantiles=[0.1, 0.5, 0.9],
+ device=args.device,
+ seed=args.seed,
+ verbose=True,
)
- # 历史用于画图的窗口(最近200根,可按需调整)
- plot_start = max(0, end_idx - 200)
- historical_df = df.loc[plot_start : end_idx - 1].copy()
+ # 计算多 / 空的机会
+ long_m = compute_long_metrics(agg_df)
+ short_m = compute_short_metrics(agg_df)
- # 采样参数(直接使用局部变量,不需要修改全局变量)
- temperature = args.temperature
- top_p = args.top_p
- qs = [float(x) for x in args.quantiles.split(",") if x.strip()]
+ # 选择更优机会
+ best = long_m if long_m["rr"] >= short_m["rr"] else short_m
- # 多次预测(传入局部变量)
- preds = run_multi_predictions(
- predictor, x_df, x_ts, y_ts, num_samples=args.samples, base_seed=args.seed,
- temperature=temperature, top_p=top_p
+ # 生成汇总字符串(增强可读性与可辨识度)
+ dir_emoji = "📈 多头" if best["direction"] == "LONG" else "📉 空头"
+ long_arrow = "⬆️"
+ short_arrow = "⬇️"
+ summary_lines = [
+ "== 24小时 每日交易机会评估 ==\n",
+ f"📌 标的: {args.symbol} | 周期: {args.interval}m | 当前价: {best['base']:.2f}\n",
+ f"📈 多头 | RR={long_m['rr']:.2f} | 盈利 +{long_m['profit'] * 100.0:.2f}% | 亏损 -{long_m['loss'] * 100.0:.2f}% | 🎯 目标 {long_m['target_price']:.2f} | 🛡️ 最差 {long_m['worst_price']:.2f}\n",
+ f"📉 空头 | RR={short_m['rr']:.2f} | 盈利 +{short_m['profit'] * 100.0:.2f}% | 亏损 -{short_m['loss'] * 100.0:.2f}% | 🎯 目标 {short_m['target_price']:.2f} | 🛡️ 最差 {short_m['worst_price']:.2f}\n",
+ f"✅ 最佳: {dir_emoji} | 盈亏比 RR={best['rr']:.2f}",
+ ]
+ if best["rr"] < 2.0:
+ summary_lines.append("⚠️ 交易机会不佳,建议观望。")
+ summary = "\n".join(summary_lines) + "\n"
+
+ # 打印汇总
+ print(summary)
+
+ # 获取预测前24根K线用于可视化上下文
+ try:
+ full_hist_df = get_recent_kline(
+ symbol=args.symbol,
+ category=args.category,
+ interval=args.interval,
+ limit=args.limit,
+ )
+ hist24 = full_hist_df.tail(24).reset_index(drop=True)
+ except Exception:
+ hist24 = None
+
+ # 可视化
+ out_png = plot_opportunities(
+ agg_df, long_m, short_m, args.symbol, args.interval, out_dir="figures", hist_df=hist24
)
-
- # 聚合分位
- agg_df = aggregate_quantiles(preds, quantiles=qs)
- agg_df["timestamps"] = y_ts.values
-
- # 输出 CSV(聚合)
- out_csv = os.path.join(
- OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv"
- )
- agg_df.to_csv(out_csv, index=False)
- print(f"预测分位数据已保存到: {out_csv}")
-
- # 作图
- fig = plot_results(
- historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟"
- )
- out_png = os.path.join(
- OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png"
- )
- fig.savefig(out_png, dpi=150, bbox_inches="tight")
- print(f"预测区间图表已保存到: {out_png}")
-
- plt.show()
+ print(f"可视化已保存: {out_png}")
+
+ # 发送交易机会评估结果到 QQ
+ try:
+ from qq import send_prediction_result
+ target = os.getenv("QQ_BOT_TARGET_ID")
+ send_prediction_result(summary, png_path=out_png, target_id=target)
+ except Exception as e:
+ print(f"发送 QQ 消息失败: {repr(e)}")
if __name__ == "__main__":
- try:
- main()
- except Exception as e:
- print("运行出错:", repr(e))
- sys.exit(1)
+ main()