409 lines
14 KiB
Python
409 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
基于 Kronos 多次采样预测,评估多/空两侧的盈亏比,并给出最佳机会与可视化。
|
||
|
||
盈亏比 = 潜在盈利 / 最大亏损
|
||
|
||
多头:
|
||
- 潜在盈利:以中位线 close_q50 为准,选取未来涨幅最高点(相对当前价 base_close)的涨幅。
|
||
- 最大亏损:从现在到该目标点区间内,使用多次采样的最差点(close_min)相对当前价的最大跌幅。
|
||
|
||
空头:
|
||
- 潜在盈利:以中位线 close_q50 为准,选取未来跌幅最大点(相对当前价 base_close)的跌幅。
|
||
- 最大亏损:从现在到该目标点区间内,使用多次采样的最差点(close_max)相对当前价的最大涨幅。
|
||
"""
|
||
|
||
import os
|
||
import argparse
|
||
from typing import Tuple
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
|
||
from dotenv import load_dotenv
|
||
from predict import kronos_predict_df, get_recent_kline
|
||
|
||
|
||
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
|
||
plt.rcParams["axes.unicode_minus"] = False
|
||
|
||
|
||
def _safe_get(series: pd.Series, default=np.nan):
|
||
try:
|
||
return float(series.iloc[0])
|
||
except Exception:
|
||
return default
|
||
|
||
|
||
def compute_long_metrics(agg_df: pd.DataFrame) -> dict:
|
||
base = float(agg_df["base_close"].iloc[0])
|
||
close_med = agg_df.get("close_q50", agg_df.get("close_q10"))
|
||
if close_med is None:
|
||
raise ValueError("预测结果中缺少 close_q50(或 close_q10 兜底)列")
|
||
|
||
# 相对收益(百分比)
|
||
ret = (close_med - base) / base
|
||
j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0
|
||
profit = max(float(ret.iloc[j_target]), 0.0)
|
||
|
||
# 区间内的“最差的90%分位”(使用下分位线 q10,而非样本极小值)
|
||
q10 = agg_df.get("close_q10", close_med)
|
||
worst_price = float(q10.iloc[: j_target + 1].min())
|
||
loss = max((base - worst_price) / base, 1e-9)
|
||
|
||
return {
|
||
"direction": "LONG",
|
||
"base": base,
|
||
"profit": profit,
|
||
"loss": loss,
|
||
"rr": profit / loss if loss > 0 else np.inf,
|
||
"target_idx": j_target,
|
||
"target_price": float(close_med.iloc[j_target]),
|
||
"worst_idx": int(np.nanargmin(q10.iloc[: j_target + 1].values)),
|
||
"worst_price": worst_price,
|
||
}
|
||
|
||
|
||
def compute_short_metrics(agg_df: pd.DataFrame) -> dict:
|
||
base = float(agg_df["base_close"].iloc[0])
|
||
close_med = agg_df.get("close_q50", agg_df.get("close_q90"))
|
||
if close_med is None:
|
||
raise ValueError("预测结果中缺少 close_q50(或 close_q90 兜底)列")
|
||
|
||
# 相对收益(空头盈利为向下幅度)
|
||
ret = (base - close_med) / base
|
||
j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0
|
||
profit = max(float(ret.iloc[j_target]), 0.0)
|
||
|
||
# 区间内的“最差的90%分位”(使用上分位线 q90,而非样本极大值)
|
||
q90 = agg_df.get("close_q90", close_med)
|
||
worst_price = float(q90.iloc[: j_target + 1].max())
|
||
loss = max((worst_price - base) / base, 1e-9)
|
||
|
||
return {
|
||
"direction": "SHORT",
|
||
"base": base,
|
||
"profit": profit,
|
||
"loss": loss,
|
||
"rr": profit / loss if loss > 0 else np.inf,
|
||
"target_idx": j_target,
|
||
"target_price": float(close_med.iloc[j_target]),
|
||
"worst_idx": int(np.nanargmax(q90.iloc[: j_target + 1].values)),
|
||
"worst_price": worst_price,
|
||
}
|
||
|
||
|
||
def plot_opportunities(
|
||
agg_df: pd.DataFrame,
|
||
long_m: dict,
|
||
short_m: dict,
|
||
symbol: str,
|
||
interval: str,
|
||
out_dir: str,
|
||
hist_df: pd.DataFrame | None = None,
|
||
) -> str:
|
||
"""生成清晰可视化:上下两个子图展示多/空机会与风险。"""
|
||
from matplotlib import gridspec
|
||
import numpy as np
|
||
|
||
ts = (
|
||
pd.to_datetime(agg_df["timestamps"])
|
||
if not np.issubdtype(agg_df["timestamps"].dtype, np.datetime64)
|
||
else agg_df["timestamps"]
|
||
)
|
||
base = long_m["base"]
|
||
q10 = agg_df.get("close_q10", None)
|
||
q50 = agg_df.get("close_q50", None)
|
||
q90 = agg_df.get("close_q90", None)
|
||
|
||
# 历史24根K线(若提供)
|
||
hist_ts = None
|
||
hist_close = None
|
||
if hist_df is not None and len(hist_df) > 0:
|
||
try:
|
||
hist_ts = (
|
||
pd.to_datetime(hist_df["timestamps"]) if not np.issubdtype(hist_df["timestamps"].dtype, np.datetime64) else hist_df["timestamps"]
|
||
)
|
||
hist_close = hist_df["close"].astype(float).values
|
||
except Exception:
|
||
hist_ts, hist_close = None, None
|
||
|
||
fig = plt.figure(figsize=(18, 10))
|
||
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1], hspace=0.15)
|
||
|
||
def _draw_panel(
|
||
ax, title, target_idx, target_price, worst_idx, worst_price, direction
|
||
):
|
||
# 历史收盘价(预测前24根)
|
||
if hist_ts is not None and hist_close is not None:
|
||
ax.plot(
|
||
hist_ts.values,
|
||
hist_close,
|
||
color="#6c757d",
|
||
lw=1.6,
|
||
label="历史收盘(近24根)",
|
||
)
|
||
# 预测起点标注
|
||
try:
|
||
ax.axvline(ts.iloc[0], color="gray", ls=":", lw=1.0, label="预测起点")
|
||
except Exception:
|
||
pass
|
||
|
||
# 区间带
|
||
if q10 is not None and q90 is not None:
|
||
ax.fill_between(
|
||
ts.values,
|
||
q10.values,
|
||
q90.values,
|
||
color="#7dbcea",
|
||
alpha=0.22,
|
||
label="q10~q90",
|
||
)
|
||
# 中位线
|
||
if q50 is not None:
|
||
ax.plot(ts.values, q50.values, color="#00539c", lw=2.0, label="中位线 q50")
|
||
|
||
# 基准线
|
||
ax.axhline(base, color="gray", ls=":", lw=1.2, label=f"当前价 {base:.1f}")
|
||
|
||
# 目标点
|
||
ax.plot(
|
||
ts.iloc[target_idx],
|
||
target_price,
|
||
marker="o",
|
||
color="#2e7d32",
|
||
ms=8,
|
||
label="目标点",
|
||
)
|
||
ax.annotate(
|
||
f"目标 {target_price:.1f}",
|
||
(ts.iloc[target_idx], target_price),
|
||
textcoords="offset points",
|
||
xytext=(8, 8),
|
||
bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#2e7d32", alpha=0.8),
|
||
color="#2e7d32",
|
||
)
|
||
|
||
# 最差点
|
||
ax.plot(
|
||
ts.iloc[worst_idx],
|
||
worst_price,
|
||
marker="v",
|
||
color="#c62828",
|
||
ms=8,
|
||
label="最差点",
|
||
)
|
||
ax.annotate(
|
||
f"最差 {worst_price:.1f}",
|
||
(ts.iloc[worst_idx], worst_price),
|
||
textcoords="offset points",
|
||
xytext=(8, -14),
|
||
bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#c62828", alpha=0.8),
|
||
color="#c62828",
|
||
)
|
||
|
||
# 连接箭头(方向)
|
||
color = "#2e7d32" if direction == "LONG" else "#c62828"
|
||
ax.annotate(
|
||
"",
|
||
xy=(ts.iloc[target_idx], target_price),
|
||
xytext=(ts.iloc[0], base),
|
||
arrowprops=dict(arrowstyle="->", color=color, lw=2, alpha=0.9),
|
||
)
|
||
|
||
# 盈亏比矩形:用填充矩形表示潜在盈利(绿色)与最大亏损(红色)
|
||
x_range = ts.iloc[: target_idx + 1]
|
||
if direction == "LONG":
|
||
# 盈利(若目标价在基准价之上)
|
||
if target_price > base:
|
||
y_low = np.full_like(x_range, base, dtype=float)
|
||
y_high = np.full_like(x_range, target_price, dtype=float)
|
||
ax.fill_between(
|
||
x_range.values,
|
||
y_low,
|
||
y_high,
|
||
color="#2e7d32",
|
||
alpha=0.18,
|
||
label="潜在盈利",
|
||
)
|
||
# 亏损(q10 代表“最差的90%分位”)
|
||
if q10 is not None and np.isfinite(worst_price) and worst_price < base:
|
||
y_low = np.full_like(x_range, worst_price, dtype=float)
|
||
y_high = np.full_like(x_range, base, dtype=float)
|
||
ax.fill_between(
|
||
x_range.values,
|
||
y_low,
|
||
y_high,
|
||
color="#c62828",
|
||
alpha=0.18,
|
||
label="最大亏损(q10)",
|
||
)
|
||
else:
|
||
# 盈利(空头目标价在基准价之下)
|
||
if target_price < base:
|
||
y_low = np.full_like(x_range, target_price, dtype=float)
|
||
y_high = np.full_like(x_range, base, dtype=float)
|
||
ax.fill_between(
|
||
x_range.values,
|
||
y_low,
|
||
y_high,
|
||
color="#2e7d32",
|
||
alpha=0.18,
|
||
label="潜在盈利",
|
||
)
|
||
# 亏损(q90 代表“最差的90%分位”)
|
||
if q90 is not None and np.isfinite(worst_price) and worst_price > base:
|
||
y_low = np.full_like(x_range, base, dtype=float)
|
||
y_high = np.full_like(x_range, worst_price, dtype=float)
|
||
ax.fill_between(
|
||
x_range.values,
|
||
y_low,
|
||
y_high,
|
||
color="#c62828",
|
||
alpha=0.18,
|
||
label="最大亏损(q90)",
|
||
)
|
||
|
||
ax.set_title(title)
|
||
ax.grid(True, alpha=0.25)
|
||
ax.legend(loc="best", fontsize=9)
|
||
|
||
ax1 = fig.add_subplot(gs[0])
|
||
rr_long_pct = long_m["rr"] * 100.0
|
||
profit_long_pct = long_m["profit"] * 100.0
|
||
loss_long_pct = long_m["loss"] * 100.0
|
||
title_long = f"多头机会:盈亏比 RR={long_m['rr']:.2f}(+{profit_long_pct:.2f}% / -{loss_long_pct:.2f}%)"
|
||
_draw_panel(
|
||
ax1,
|
||
title_long,
|
||
long_m["target_idx"],
|
||
long_m["target_price"],
|
||
long_m["worst_idx"],
|
||
long_m["worst_price"],
|
||
"LONG",
|
||
)
|
||
|
||
ax2 = fig.add_subplot(gs[1], sharex=ax1)
|
||
rr_short_pct = short_m["rr"] * 100.0
|
||
profit_short_pct = short_m["profit"] * 100.0
|
||
loss_short_pct = short_m["loss"] * 100.0
|
||
title_short = f"空头机会:盈亏比 RR={short_m['rr']:.2f}(+{profit_short_pct:.2f}% / -{loss_short_pct:.2f}%)"
|
||
_draw_panel(
|
||
ax2,
|
||
title_short,
|
||
short_m["target_idx"],
|
||
short_m["target_price"],
|
||
short_m["worst_idx"],
|
||
short_m["worst_price"],
|
||
"SHORT",
|
||
)
|
||
|
||
for lbl in ax1.get_xticklabels():
|
||
lbl.set_visible(False)
|
||
ax2.set_xlabel("时间")
|
||
ax1.set_ylabel("价格 (USD)")
|
||
ax2.set_ylabel("价格 (USD)")
|
||
fig.suptitle(
|
||
f"{symbol} {interval}分钟线 24小时 预测机会评估", fontsize=16, fontweight="bold"
|
||
)
|
||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||
|
||
os.makedirs(out_dir, exist_ok=True)
|
||
out_png = os.path.join(
|
||
out_dir, f"trade_opportunity_{symbol.lower()}_{interval}m.png"
|
||
)
|
||
fig.savefig(out_png, dpi=150, bbox_inches="tight")
|
||
return out_png
|
||
|
||
|
||
def main():
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
parser = argparse.ArgumentParser(description="计算并可视化高盈亏比交易机会")
|
||
parser.add_argument("--symbol", default="BTCUSDT")
|
||
parser.add_argument("--category", default="inverse")
|
||
parser.add_argument("--interval", default="60")
|
||
parser.add_argument("--limit", type=int, default=500)
|
||
parser.add_argument("--lookback", type=int, default=360)
|
||
parser.add_argument("--pred_len", type=int, default=24)
|
||
parser.add_argument("--samples", type=int, default=30)
|
||
parser.add_argument("--temperature", type=float, default=1.0)
|
||
parser.add_argument("--top_p", type=float, default=0.9)
|
||
parser.add_argument("--device", default=os.getenv("KRONOS_DEVICE", "cuda:0"))
|
||
parser.add_argument("--seed", type=int, default=42)
|
||
args = parser.parse_args()
|
||
|
||
# 获取预测聚合数据(仅未来)
|
||
agg_df = kronos_predict_df(
|
||
symbol=args.symbol,
|
||
category=args.category,
|
||
interval=args.interval,
|
||
limit=args.limit,
|
||
lookback=args.lookback,
|
||
pred_len=args.pred_len,
|
||
samples=args.samples,
|
||
temperature=args.temperature,
|
||
top_p=args.top_p,
|
||
quantiles=[0.1, 0.5, 0.9],
|
||
device=args.device,
|
||
seed=args.seed,
|
||
verbose=True,
|
||
)
|
||
|
||
# 计算多 / 空的机会
|
||
long_m = compute_long_metrics(agg_df)
|
||
short_m = compute_short_metrics(agg_df)
|
||
|
||
# 选择更优机会
|
||
best = long_m if long_m["rr"] >= short_m["rr"] else short_m
|
||
|
||
# 生成汇总字符串(增强可读性与可辨识度)
|
||
dir_emoji = "📈 多头" if best["direction"] == "LONG" else "📉 空头"
|
||
long_arrow = "⬆️"
|
||
short_arrow = "⬇️"
|
||
summary_lines = [
|
||
"== 24小时 每日交易机会评估 ==\n",
|
||
f"📌 标的: {args.symbol} | 周期: {args.interval}m | 当前价: {best['base']:.2f}\n",
|
||
f"📈 多头 | RR={long_m['rr']:.2f} | 盈利 +{long_m['profit'] * 100.0:.2f}% | 亏损 -{long_m['loss'] * 100.0:.2f}% | 🎯 目标 {long_m['target_price']:.2f} | 🛡️ 最差 {long_m['worst_price']:.2f}\n",
|
||
f"📉 空头 | RR={short_m['rr']:.2f} | 盈利 +{short_m['profit'] * 100.0:.2f}% | 亏损 -{short_m['loss'] * 100.0:.2f}% | 🎯 目标 {short_m['target_price']:.2f} | 🛡️ 最差 {short_m['worst_price']:.2f}\n",
|
||
f"✅ 最佳: {dir_emoji} | 盈亏比 RR={best['rr']:.2f}",
|
||
]
|
||
if best["rr"] < 2.0:
|
||
summary_lines.append("⚠️ 交易机会不佳,建议观望。")
|
||
summary = "\n".join(summary_lines) + "\n"
|
||
|
||
# 打印汇总
|
||
print(summary)
|
||
|
||
# 获取预测前24根K线用于可视化上下文
|
||
try:
|
||
full_hist_df = get_recent_kline(
|
||
symbol=args.symbol,
|
||
category=args.category,
|
||
interval=args.interval,
|
||
limit=args.limit,
|
||
)
|
||
hist24 = full_hist_df.tail(24).reset_index(drop=True)
|
||
except Exception:
|
||
hist24 = None
|
||
|
||
# 可视化
|
||
out_png = plot_opportunities(
|
||
agg_df, long_m, short_m, args.symbol, args.interval, out_dir="figures", hist_df=hist24
|
||
)
|
||
print(f"可视化已保存: {out_png}")
|
||
|
||
# 发送交易机会评估结果到 QQ
|
||
try:
|
||
from qq import send_prediction_result
|
||
target = os.getenv("QQ_BOT_TARGET_ID")
|
||
send_prediction_result(summary, png_path=out_png, target_id=target)
|
||
except Exception as e:
|
||
print(f"发送 QQ 消息失败: {repr(e)}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|