kronos-trader/trader.py

409 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
基于 Kronos 多次采样预测,评估多/空两侧的盈亏比,并给出最佳机会与可视化。
盈亏比 = 潜在盈利 / 最大亏损
多头:
- 潜在盈利:以中位线 close_q50 为准,选取未来涨幅最高点(相对当前价 base_close的涨幅。
- 最大亏损从现在到该目标点区间内使用多次采样的最差点close_min相对当前价的最大跌幅。
空头:
- 潜在盈利:以中位线 close_q50 为准,选取未来跌幅最大点(相对当前价 base_close的跌幅。
- 最大亏损从现在到该目标点区间内使用多次采样的最差点close_max相对当前价的最大涨幅。
"""
import os
import argparse
from typing import Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from dotenv import load_dotenv
from predict import kronos_predict_df, get_recent_kline
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
plt.rcParams["axes.unicode_minus"] = False
def _safe_get(series: pd.Series, default=np.nan):
try:
return float(series.iloc[0])
except Exception:
return default
def compute_long_metrics(agg_df: pd.DataFrame) -> dict:
base = float(agg_df["base_close"].iloc[0])
close_med = agg_df.get("close_q50", agg_df.get("close_q10"))
if close_med is None:
raise ValueError("预测结果中缺少 close_q50或 close_q10 兜底)列")
# 相对收益(百分比)
ret = (close_med - base) / base
j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0
profit = max(float(ret.iloc[j_target]), 0.0)
# 区间内的“最差的90%分位”(使用下分位线 q10而非样本极小值
q10 = agg_df.get("close_q10", close_med)
worst_price = float(q10.iloc[: j_target + 1].min())
loss = max((base - worst_price) / base, 1e-9)
return {
"direction": "LONG",
"base": base,
"profit": profit,
"loss": loss,
"rr": profit / loss if loss > 0 else np.inf,
"target_idx": j_target,
"target_price": float(close_med.iloc[j_target]),
"worst_idx": int(np.nanargmin(q10.iloc[: j_target + 1].values)),
"worst_price": worst_price,
}
def compute_short_metrics(agg_df: pd.DataFrame) -> dict:
base = float(agg_df["base_close"].iloc[0])
close_med = agg_df.get("close_q50", agg_df.get("close_q90"))
if close_med is None:
raise ValueError("预测结果中缺少 close_q50或 close_q90 兜底)列")
# 相对收益(空头盈利为向下幅度)
ret = (base - close_med) / base
j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0
profit = max(float(ret.iloc[j_target]), 0.0)
# 区间内的“最差的90%分位”(使用上分位线 q90而非样本极大值
q90 = agg_df.get("close_q90", close_med)
worst_price = float(q90.iloc[: j_target + 1].max())
loss = max((worst_price - base) / base, 1e-9)
return {
"direction": "SHORT",
"base": base,
"profit": profit,
"loss": loss,
"rr": profit / loss if loss > 0 else np.inf,
"target_idx": j_target,
"target_price": float(close_med.iloc[j_target]),
"worst_idx": int(np.nanargmax(q90.iloc[: j_target + 1].values)),
"worst_price": worst_price,
}
def plot_opportunities(
agg_df: pd.DataFrame,
long_m: dict,
short_m: dict,
symbol: str,
interval: str,
out_dir: str,
hist_df: pd.DataFrame | None = None,
) -> str:
"""生成清晰可视化:上下两个子图展示多/空机会与风险。"""
from matplotlib import gridspec
import numpy as np
ts = (
pd.to_datetime(agg_df["timestamps"])
if not np.issubdtype(agg_df["timestamps"].dtype, np.datetime64)
else agg_df["timestamps"]
)
base = long_m["base"]
q10 = agg_df.get("close_q10", None)
q50 = agg_df.get("close_q50", None)
q90 = agg_df.get("close_q90", None)
# 历史24根K线若提供
hist_ts = None
hist_close = None
if hist_df is not None and len(hist_df) > 0:
try:
hist_ts = (
pd.to_datetime(hist_df["timestamps"]) if not np.issubdtype(hist_df["timestamps"].dtype, np.datetime64) else hist_df["timestamps"]
)
hist_close = hist_df["close"].astype(float).values
except Exception:
hist_ts, hist_close = None, None
fig = plt.figure(figsize=(18, 10))
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1], hspace=0.15)
def _draw_panel(
ax, title, target_idx, target_price, worst_idx, worst_price, direction
):
# 历史收盘价预测前24根
if hist_ts is not None and hist_close is not None:
ax.plot(
hist_ts.values,
hist_close,
color="#6c757d",
lw=1.6,
label="历史收盘(近24根)",
)
# 预测起点标注
try:
ax.axvline(ts.iloc[0], color="gray", ls=":", lw=1.0, label="预测起点")
except Exception:
pass
# 区间带
if q10 is not None and q90 is not None:
ax.fill_between(
ts.values,
q10.values,
q90.values,
color="#7dbcea",
alpha=0.22,
label="q10~q90",
)
# 中位线
if q50 is not None:
ax.plot(ts.values, q50.values, color="#00539c", lw=2.0, label="中位线 q50")
# 基准线
ax.axhline(base, color="gray", ls=":", lw=1.2, label=f"当前价 {base:.1f}")
# 目标点
ax.plot(
ts.iloc[target_idx],
target_price,
marker="o",
color="#2e7d32",
ms=8,
label="目标点",
)
ax.annotate(
f"目标 {target_price:.1f}",
(ts.iloc[target_idx], target_price),
textcoords="offset points",
xytext=(8, 8),
bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#2e7d32", alpha=0.8),
color="#2e7d32",
)
# 最差点
ax.plot(
ts.iloc[worst_idx],
worst_price,
marker="v",
color="#c62828",
ms=8,
label="最差点",
)
ax.annotate(
f"最差 {worst_price:.1f}",
(ts.iloc[worst_idx], worst_price),
textcoords="offset points",
xytext=(8, -14),
bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#c62828", alpha=0.8),
color="#c62828",
)
# 连接箭头(方向)
color = "#2e7d32" if direction == "LONG" else "#c62828"
ax.annotate(
"",
xy=(ts.iloc[target_idx], target_price),
xytext=(ts.iloc[0], base),
arrowprops=dict(arrowstyle="->", color=color, lw=2, alpha=0.9),
)
# 盈亏比矩形:用填充矩形表示潜在盈利(绿色)与最大亏损(红色)
x_range = ts.iloc[: target_idx + 1]
if direction == "LONG":
# 盈利(若目标价在基准价之上)
if target_price > base:
y_low = np.full_like(x_range, base, dtype=float)
y_high = np.full_like(x_range, target_price, dtype=float)
ax.fill_between(
x_range.values,
y_low,
y_high,
color="#2e7d32",
alpha=0.18,
label="潜在盈利",
)
# 亏损q10 代表“最差的90%分位”)
if q10 is not None and np.isfinite(worst_price) and worst_price < base:
y_low = np.full_like(x_range, worst_price, dtype=float)
y_high = np.full_like(x_range, base, dtype=float)
ax.fill_between(
x_range.values,
y_low,
y_high,
color="#c62828",
alpha=0.18,
label="最大亏损(q10)",
)
else:
# 盈利(空头目标价在基准价之下)
if target_price < base:
y_low = np.full_like(x_range, target_price, dtype=float)
y_high = np.full_like(x_range, base, dtype=float)
ax.fill_between(
x_range.values,
y_low,
y_high,
color="#2e7d32",
alpha=0.18,
label="潜在盈利",
)
# 亏损q90 代表“最差的90%分位”)
if q90 is not None and np.isfinite(worst_price) and worst_price > base:
y_low = np.full_like(x_range, base, dtype=float)
y_high = np.full_like(x_range, worst_price, dtype=float)
ax.fill_between(
x_range.values,
y_low,
y_high,
color="#c62828",
alpha=0.18,
label="最大亏损(q90)",
)
ax.set_title(title)
ax.grid(True, alpha=0.25)
ax.legend(loc="best", fontsize=9)
ax1 = fig.add_subplot(gs[0])
rr_long_pct = long_m["rr"] * 100.0
profit_long_pct = long_m["profit"] * 100.0
loss_long_pct = long_m["loss"] * 100.0
title_long = f"多头机会:盈亏比 RR={long_m['rr']:.2f}+{profit_long_pct:.2f}% / -{loss_long_pct:.2f}%"
_draw_panel(
ax1,
title_long,
long_m["target_idx"],
long_m["target_price"],
long_m["worst_idx"],
long_m["worst_price"],
"LONG",
)
ax2 = fig.add_subplot(gs[1], sharex=ax1)
rr_short_pct = short_m["rr"] * 100.0
profit_short_pct = short_m["profit"] * 100.0
loss_short_pct = short_m["loss"] * 100.0
title_short = f"空头机会:盈亏比 RR={short_m['rr']:.2f}+{profit_short_pct:.2f}% / -{loss_short_pct:.2f}%"
_draw_panel(
ax2,
title_short,
short_m["target_idx"],
short_m["target_price"],
short_m["worst_idx"],
short_m["worst_price"],
"SHORT",
)
for lbl in ax1.get_xticklabels():
lbl.set_visible(False)
ax2.set_xlabel("时间")
ax1.set_ylabel("价格 (USD)")
ax2.set_ylabel("价格 (USD)")
fig.suptitle(
f"{symbol} {interval}分钟线 24小时 预测机会评估", fontsize=16, fontweight="bold"
)
plt.tight_layout(rect=[0, 0, 1, 0.96])
os.makedirs(out_dir, exist_ok=True)
out_png = os.path.join(
out_dir, f"trade_opportunity_{symbol.lower()}_{interval}m.png"
)
fig.savefig(out_png, dpi=150, bbox_inches="tight")
return out_png
def main():
# 加载环境变量
load_dotenv()
parser = argparse.ArgumentParser(description="计算并可视化高盈亏比交易机会")
parser.add_argument("--symbol", default="BTCUSDT")
parser.add_argument("--category", default="inverse")
parser.add_argument("--interval", default="60")
parser.add_argument("--limit", type=int, default=500)
parser.add_argument("--lookback", type=int, default=360)
parser.add_argument("--pred_len", type=int, default=24)
parser.add_argument("--samples", type=int, default=30)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--device", default=os.getenv("KRONOS_DEVICE", "cuda:0"))
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# 获取预测聚合数据(仅未来)
agg_df = kronos_predict_df(
symbol=args.symbol,
category=args.category,
interval=args.interval,
limit=args.limit,
lookback=args.lookback,
pred_len=args.pred_len,
samples=args.samples,
temperature=args.temperature,
top_p=args.top_p,
quantiles=[0.1, 0.5, 0.9],
device=args.device,
seed=args.seed,
verbose=True,
)
# 计算多 / 空的机会
long_m = compute_long_metrics(agg_df)
short_m = compute_short_metrics(agg_df)
# 选择更优机会
best = long_m if long_m["rr"] >= short_m["rr"] else short_m
# 生成汇总字符串(增强可读性与可辨识度)
dir_emoji = "📈 多头" if best["direction"] == "LONG" else "📉 空头"
long_arrow = "⬆️"
short_arrow = "⬇️"
summary_lines = [
"== 24小时 每日交易机会评估 ==\n",
f"📌 标的: {args.symbol} | 周期: {args.interval}m | 当前价: {best['base']:.2f}\n",
f"📈 多头 | RR={long_m['rr']:.2f} | 盈利 +{long_m['profit'] * 100.0:.2f}% | 亏损 -{long_m['loss'] * 100.0:.2f}% | 🎯 目标 {long_m['target_price']:.2f} | 🛡️ 最差 {long_m['worst_price']:.2f}\n",
f"📉 空头 | RR={short_m['rr']:.2f} | 盈利 +{short_m['profit'] * 100.0:.2f}% | 亏损 -{short_m['loss'] * 100.0:.2f}% | 🎯 目标 {short_m['target_price']:.2f} | 🛡️ 最差 {short_m['worst_price']:.2f}\n",
f"✅ 最佳: {dir_emoji} | 盈亏比 RR={best['rr']:.2f}",
]
if best["rr"] < 2.0:
summary_lines.append("⚠️ 交易机会不佳,建议观望。")
summary = "\n".join(summary_lines) + "\n"
# 打印汇总
print(summary)
# 获取预测前24根K线用于可视化上下文
try:
full_hist_df = get_recent_kline(
symbol=args.symbol,
category=args.category,
interval=args.interval,
limit=args.limit,
)
hist24 = full_hist_df.tail(24).reset_index(drop=True)
except Exception:
hist24 = None
# 可视化
out_png = plot_opportunities(
agg_df, long_m, short_m, args.symbol, args.interval, out_dir="figures", hist_df=hist24
)
print(f"可视化已保存: {out_png}")
# 发送交易机会评估结果到 QQ
try:
from qq import send_prediction_result
target = os.getenv("QQ_BOT_TARGET_ID")
send_prediction_result(summary, png_path=out_png, target_id=target)
except Exception as e:
print(f"发送 QQ 消息失败: {repr(e)}")
if __name__ == "__main__":
main()