新增盈亏比计算和消息推送
This commit is contained in:
parent
2b5bcc55f8
commit
0771ca73a0
526
predict.py
Normal file
526
predict.py
Normal file
@ -0,0 +1,526 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
BTCUSD 15m 多次采样预测 → 区间聚合可视化
|
||||||
|
Author: 你(鹅)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import argparse
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from pybit.unified_trading import HTTP
|
||||||
|
|
||||||
|
# ==== Kronos ====
|
||||||
|
from model import Kronos, KronosTokenizer, KronosPredictor
|
||||||
|
|
||||||
|
# ========== Matplotlib 字体 ==========
|
||||||
|
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
|
||||||
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
|
|
||||||
|
# ========== 默认参数 ==========
|
||||||
|
NUM_SAMPLES = 20 # 多次采样次数(建议 10~50)
|
||||||
|
QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上)
|
||||||
|
LOOKBACK = 360 # 历史窗口
|
||||||
|
PRED_LEN = 24 # 未来预测点数
|
||||||
|
INTERVAL_MIN = 60 # K 线周期(分钟)
|
||||||
|
DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0")
|
||||||
|
TEMPERATURE = 1.0
|
||||||
|
TOP_P = 0.9
|
||||||
|
DRAW_WINDOW_LEN = 240 # 历史用于画图的窗口
|
||||||
|
|
||||||
|
OUTPUT_DIR = "figures"
|
||||||
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 实用函数 ==========
|
||||||
|
def set_global_seed(seed: int):
|
||||||
|
"""尽可能固定随机性(若底层采样逻辑使用 torch/np/random)"""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_kline(
|
||||||
|
session: HTTP, symbol="BTCUSD", interval="15", limit=500, category="inverse"
|
||||||
|
):
|
||||||
|
print("正在获取K线数据...")
|
||||||
|
resp = session.get_kline(
|
||||||
|
category=category, symbol=symbol, interval=interval, limit=limit
|
||||||
|
)
|
||||||
|
if resp.get("retCode", -1) != 0:
|
||||||
|
raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}")
|
||||||
|
lst = resp["result"]["list"]
|
||||||
|
|
||||||
|
df = pd.DataFrame(
|
||||||
|
lst,
|
||||||
|
columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"],
|
||||||
|
)
|
||||||
|
df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms")
|
||||||
|
for c in ["open", "high", "low", "close", "volume", "turnover"]:
|
||||||
|
df[c] = df[c].astype(float)
|
||||||
|
df = df.sort_values("timestamps").reset_index(drop=True)
|
||||||
|
df["amount"] = df["turnover"]
|
||||||
|
# 去掉最新一根(通常为未完成的当前K线)
|
||||||
|
if len(df) > 0:
|
||||||
|
df = df.iloc[:-1].reset_index(drop=True)
|
||||||
|
print(f"获取到 {len(df)} 根K线数据")
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def get_recent_kline(
|
||||||
|
symbol: str = "BTCUSD",
|
||||||
|
category: str = "inverse",
|
||||||
|
interval: str = "60",
|
||||||
|
limit: int = 500,
|
||||||
|
):
|
||||||
|
"""便捷封装:基于环境变量创建会话并获取最近K线数据。
|
||||||
|
|
||||||
|
返回数据按时间升序,且会移除最后一根未完成K线(与 fetch_kline 保持一致)。
|
||||||
|
"""
|
||||||
|
load_dotenv()
|
||||||
|
api_key = os.getenv("BYBIT_API_KEY")
|
||||||
|
api_secret = os.getenv("BYBIT_API_SECRET")
|
||||||
|
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
|
||||||
|
return fetch_kline(
|
||||||
|
session=session,
|
||||||
|
symbol=symbol,
|
||||||
|
interval=interval,
|
||||||
|
limit=limit,
|
||||||
|
category=category,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_io_windows(
|
||||||
|
df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN
|
||||||
|
):
|
||||||
|
end_idx = len(df)
|
||||||
|
start_idx = max(0, end_idx - lookback)
|
||||||
|
x_df = df.loc[
|
||||||
|
start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"]
|
||||||
|
].reset_index(drop=True)
|
||||||
|
x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True)
|
||||||
|
|
||||||
|
last_ts = df.loc[end_idx - 1, "timestamps"]
|
||||||
|
future_timestamps = pd.date_range(
|
||||||
|
start=last_ts + pd.Timedelta(minutes=interval_min),
|
||||||
|
periods=pred_len,
|
||||||
|
freq=f"{interval_min}min",
|
||||||
|
)
|
||||||
|
y_timestamp = pd.Series(future_timestamps)
|
||||||
|
y_timestamp.index = range(len(y_timestamp))
|
||||||
|
|
||||||
|
print(f"数据总量: {len(df)} 根K线")
|
||||||
|
print(f"使用最新的 {lookback} 根K线(索引 {start_idx} 到 {end_idx-1})")
|
||||||
|
print(f"最后一根历史K线时间: {last_ts}")
|
||||||
|
print(f"预测未来 {pred_len} 根K线")
|
||||||
|
return x_df, x_timestamp, y_timestamp, start_idx, end_idx
|
||||||
|
|
||||||
|
|
||||||
|
def load_kronos(device=DEVICE):
|
||||||
|
print("正在加载 Kronos 模型...")
|
||||||
|
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
|
||||||
|
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
|
||||||
|
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
|
||||||
|
print("模型加载完成!\n")
|
||||||
|
return predictor
|
||||||
|
|
||||||
|
|
||||||
|
def run_one_prediction(
|
||||||
|
predictor,
|
||||||
|
x_df,
|
||||||
|
x_timestamp,
|
||||||
|
y_timestamp,
|
||||||
|
T=TEMPERATURE,
|
||||||
|
top_p=TOP_P,
|
||||||
|
seed=None,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
if seed is not None:
|
||||||
|
set_global_seed(seed)
|
||||||
|
return predictor.predict(
|
||||||
|
df=x_df,
|
||||||
|
x_timestamp=x_timestamp,
|
||||||
|
y_timestamp=y_timestamp,
|
||||||
|
pred_len=len(y_timestamp),
|
||||||
|
T=T,
|
||||||
|
top_p=top_p,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_multi_predictions(
|
||||||
|
predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42,
|
||||||
|
temperature=TEMPERATURE, top_p=TOP_P
|
||||||
|
):
|
||||||
|
preds = []
|
||||||
|
print(f"正在进行多次预测:{num_samples} 次(T={temperature}, top_p={top_p})...")
|
||||||
|
for i in tqdm(range(num_samples), desc="预测进度", unit="次", ncols=100):
|
||||||
|
seed = base_seed + i # 每次不同 seed
|
||||||
|
pred_df = run_one_prediction(
|
||||||
|
predictor,
|
||||||
|
x_df,
|
||||||
|
x_timestamp,
|
||||||
|
y_timestamp,
|
||||||
|
T=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
seed=seed,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
# 兼容性处理:确保只取需要的列
|
||||||
|
cols_present = [
|
||||||
|
c
|
||||||
|
for c in ["open", "high", "low", "close", "volume", "amount"]
|
||||||
|
if c in pred_df.columns
|
||||||
|
]
|
||||||
|
pred_df = pred_df[cols_present].copy()
|
||||||
|
pred_df.reset_index(drop=True, inplace=True)
|
||||||
|
preds.append(pred_df)
|
||||||
|
print("多次预测完成。\n")
|
||||||
|
return preds
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_quantiles(pred_list, quantiles=QUANTILES):
|
||||||
|
"""
|
||||||
|
将多次预测列表聚合成分位数 DataFrame:
|
||||||
|
输出列命名:<col>_q10, <col>_q50, <col>_q90(按 quantiles 中的值)
|
||||||
|
"""
|
||||||
|
# 先把每次预测拼成 3D:time x feature x samples
|
||||||
|
keys = pred_list[0].columns.tolist()
|
||||||
|
T_len = len(pred_list[0])
|
||||||
|
S = len(pred_list)
|
||||||
|
data = {k: np.zeros((T_len, S), dtype=float) for k in keys}
|
||||||
|
for j, pdf in enumerate(pred_list):
|
||||||
|
for k in keys:
|
||||||
|
data[k][:, j] = pdf[k].values
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
for k in keys:
|
||||||
|
# 极值(用于风险评估)
|
||||||
|
out[f"{k}_min"] = np.min(data[k], axis=1)
|
||||||
|
out[f"{k}_max"] = np.max(data[k], axis=1)
|
||||||
|
for q in quantiles:
|
||||||
|
qv = np.quantile(data[k], q, axis=1)
|
||||||
|
out[f"{k}_q{int(q*100):02d}"] = qv
|
||||||
|
agg_df = pd.DataFrame(out)
|
||||||
|
return agg_df
|
||||||
|
|
||||||
|
|
||||||
|
def kronos_predict_df(
|
||||||
|
symbol: str = "BTCUSD",
|
||||||
|
category: str = "inverse",
|
||||||
|
interval: str = "60",
|
||||||
|
limit: int = 500,
|
||||||
|
lookback: int = LOOKBACK,
|
||||||
|
pred_len: int = PRED_LEN,
|
||||||
|
samples: int = NUM_SAMPLES,
|
||||||
|
temperature: float = TEMPERATURE,
|
||||||
|
top_p: float = TOP_P,
|
||||||
|
quantiles: list[float] | None = None,
|
||||||
|
device: str = DEVICE,
|
||||||
|
seed: int = 42,
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
作为模块使用的API:返回聚合后的预测DataFrame(仅未来部分)。
|
||||||
|
|
||||||
|
返回DataFrame包含:
|
||||||
|
- timestamps: 未来各点时间戳
|
||||||
|
- <col>_qXX: 各分位数(如 close_q50)
|
||||||
|
- <col>_min/<col>_max: 样本极值(用于风控/盈亏比)
|
||||||
|
- base_close: 当前基准价(最后一根历史K线的收盘价,重复列)
|
||||||
|
"""
|
||||||
|
# 环境 & HTTP会话
|
||||||
|
load_dotenv()
|
||||||
|
api_key = os.getenv("BYBIT_API_KEY")
|
||||||
|
api_secret = os.getenv("BYBIT_API_SECRET")
|
||||||
|
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
|
||||||
|
|
||||||
|
# 加载模型(懒加载一次)
|
||||||
|
predictor = load_kronos(device=device)
|
||||||
|
|
||||||
|
# 拉取K线并准备窗口
|
||||||
|
df = fetch_kline(
|
||||||
|
session,
|
||||||
|
symbol=symbol,
|
||||||
|
interval=interval,
|
||||||
|
limit=limit,
|
||||||
|
category=category,
|
||||||
|
)
|
||||||
|
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
|
||||||
|
df,
|
||||||
|
lookback=lookback,
|
||||||
|
pred_len=pred_len,
|
||||||
|
interval_min=int(interval),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 多次预测
|
||||||
|
preds = run_multi_predictions(
|
||||||
|
predictor,
|
||||||
|
x_df,
|
||||||
|
x_ts,
|
||||||
|
y_ts,
|
||||||
|
num_samples=samples,
|
||||||
|
base_seed=seed,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 聚合
|
||||||
|
qs = quantiles if quantiles is not None else QUANTILES
|
||||||
|
agg_df = aggregate_quantiles(preds, quantiles=qs)
|
||||||
|
agg_df["timestamps"] = y_ts.values
|
||||||
|
|
||||||
|
# 当前价(基准)
|
||||||
|
base_close = float(df.loc[end_idx - 1, "close"]) if end_idx > 0 else np.nan
|
||||||
|
agg_df["base_close"] = base_close
|
||||||
|
|
||||||
|
# 仅返回DataFrame(调用方自行决定是否绘图/保存)
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"完成预测:symbol={symbol} interval={interval}m samples={samples} base_close={base_close}"
|
||||||
|
)
|
||||||
|
return agg_df
|
||||||
|
|
||||||
|
|
||||||
|
def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"):
|
||||||
|
"""
|
||||||
|
上下两个子图(共享X轴):
|
||||||
|
- 上方(高度3):历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50)
|
||||||
|
- 下方(高度1):历史成交量柱 + 预测成交量中位柱(仅q50,不显示区间)
|
||||||
|
"""
|
||||||
|
import matplotlib.gridspec as gridspec
|
||||||
|
|
||||||
|
# 智能推断K线宽度(柱宽)
|
||||||
|
try:
|
||||||
|
if len(historical_df) >= 2:
|
||||||
|
hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2])
|
||||||
|
else:
|
||||||
|
hist_step = pd.Timedelta(minutes=10)
|
||||||
|
if len(y_timestamp) >= 2:
|
||||||
|
pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0])
|
||||||
|
else:
|
||||||
|
pred_step = pd.Timedelta(minutes=10)
|
||||||
|
bar_width_hist = hist_step * 0.8
|
||||||
|
bar_width_pred = pred_step * 0.8
|
||||||
|
except Exception:
|
||||||
|
bar_width_hist = pd.Timedelta(minutes=10)
|
||||||
|
bar_width_pred = pd.Timedelta(minutes=10)
|
||||||
|
|
||||||
|
# 取出预测分位
|
||||||
|
close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None
|
||||||
|
close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None
|
||||||
|
close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None
|
||||||
|
|
||||||
|
vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None
|
||||||
|
|
||||||
|
# 图形与网格
|
||||||
|
fig = plt.figure(figsize=(18, 10))
|
||||||
|
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3
|
||||||
|
|
||||||
|
# ===== 上:收盘价 =====
|
||||||
|
ax_price = fig.add_subplot(gs[0])
|
||||||
|
|
||||||
|
ax_price.plot(
|
||||||
|
historical_df["timestamps"],
|
||||||
|
historical_df["close"],
|
||||||
|
label="历史收盘价",
|
||||||
|
linewidth=1.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测收盘价区间与中位线
|
||||||
|
if close_q10 is not None and close_q90 is not None:
|
||||||
|
ax_price.fill_between(
|
||||||
|
y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)"
|
||||||
|
)
|
||||||
|
if close_q50 is not None:
|
||||||
|
ax_price.plot(
|
||||||
|
y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测起点与阴影
|
||||||
|
if len(y_timestamp) > 0:
|
||||||
|
ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点")
|
||||||
|
ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08)
|
||||||
|
|
||||||
|
# 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价
|
||||||
|
# 合并历史和预测时间范围与价格数据
|
||||||
|
all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True)
|
||||||
|
hist_close = historical_df["close"]
|
||||||
|
pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp))
|
||||||
|
all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True)
|
||||||
|
|
||||||
|
if len(all_timestamps) > 0:
|
||||||
|
start_time = all_timestamps.min()
|
||||||
|
end_time = all_timestamps.max()
|
||||||
|
|
||||||
|
# 生成所有16:00时间点(UTC+8)
|
||||||
|
current_date = start_time.normalize() # 当天零点
|
||||||
|
while current_date <= end_time:
|
||||||
|
target_time = current_date + pd.Timedelta(hours=16) # 16:00
|
||||||
|
if start_time <= target_time <= end_time:
|
||||||
|
# 画虚线
|
||||||
|
ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5)
|
||||||
|
|
||||||
|
# 找到最接近16:00的时间点的收盘价
|
||||||
|
time_diffs = (all_timestamps - target_time).abs()
|
||||||
|
closest_idx = time_diffs.idxmin()
|
||||||
|
closest_time = all_timestamps.iloc[closest_idx]
|
||||||
|
closest_price = all_close.iloc[closest_idx]
|
||||||
|
|
||||||
|
# 如果时间差不超过1小时,则标注价格
|
||||||
|
if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price):
|
||||||
|
ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7)
|
||||||
|
ax_price.text(
|
||||||
|
closest_time, closest_price,
|
||||||
|
f' ${closest_price:.1f}',
|
||||||
|
fontsize=9,
|
||||||
|
color='blue',
|
||||||
|
verticalalignment='bottom',
|
||||||
|
horizontalalignment='left',
|
||||||
|
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7)
|
||||||
|
)
|
||||||
|
current_date += pd.Timedelta(days=1)
|
||||||
|
|
||||||
|
ax_price.set_ylabel("价格 (USD)", fontsize=11)
|
||||||
|
ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold")
|
||||||
|
ax_price.grid(True, alpha=0.3)
|
||||||
|
ax_price.legend(loc="best", fontsize=9)
|
||||||
|
|
||||||
|
# ===== 下:成交量(仅预测中位量能柱)=====
|
||||||
|
ax_vol = fig.add_subplot(gs[1], sharex=ax_price)
|
||||||
|
|
||||||
|
# 历史量能柱
|
||||||
|
ax_vol.bar(
|
||||||
|
historical_df["timestamps"],
|
||||||
|
historical_df["volume"],
|
||||||
|
width=bar_width_hist,
|
||||||
|
alpha=0.35,
|
||||||
|
label="历史成交量",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测中位量能柱(不画区间)
|
||||||
|
if vol_q50 is not None:
|
||||||
|
ax_vol.bar(
|
||||||
|
y_timestamp.values,
|
||||||
|
vol_q50,
|
||||||
|
width=bar_width_pred,
|
||||||
|
alpha=0.6,
|
||||||
|
label="预测成交量中位(q50)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测起点线
|
||||||
|
if len(y_timestamp) > 0:
|
||||||
|
ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1)
|
||||||
|
|
||||||
|
ax_vol.set_ylabel("成交量", fontsize=11)
|
||||||
|
ax_vol.set_xlabel("时间", fontsize=11)
|
||||||
|
ax_vol.grid(True, alpha=0.25)
|
||||||
|
ax_vol.legend(loc="best", fontsize=9)
|
||||||
|
|
||||||
|
# 避免X轴标签重叠
|
||||||
|
plt.setp(ax_price.get_xticklabels(), visible=False)
|
||||||
|
plt.tight_layout()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化")
|
||||||
|
parser.add_argument("--symbol", default="BTCUSD")
|
||||||
|
parser.add_argument("--category", default="inverse")
|
||||||
|
parser.add_argument("--interval", default="60")
|
||||||
|
parser.add_argument("--limit", type=int, default=500)
|
||||||
|
parser.add_argument("--lookback", type=int, default=LOOKBACK)
|
||||||
|
parser.add_argument("--pred_len", type=int, default=PRED_LEN)
|
||||||
|
parser.add_argument("--samples", type=int, default=NUM_SAMPLES)
|
||||||
|
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
|
||||||
|
parser.add_argument("--top_p", type=float, default=TOP_P)
|
||||||
|
parser.add_argument("--quantiles", default="0.1,0.5,0.9")
|
||||||
|
parser.add_argument("--device", default=DEVICE)
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 载入env & HTTP
|
||||||
|
load_dotenv()
|
||||||
|
api_key = os.getenv("BYBIT_API_KEY")
|
||||||
|
api_secret = os.getenv("BYBIT_API_SECRET")
|
||||||
|
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
|
||||||
|
# 为命令行模式执行完整流程
|
||||||
|
predictor = load_kronos(device=args.device)
|
||||||
|
df = fetch_kline(
|
||||||
|
session,
|
||||||
|
symbol=args.symbol,
|
||||||
|
interval=args.interval,
|
||||||
|
limit=args.limit,
|
||||||
|
category=args.category,
|
||||||
|
)
|
||||||
|
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
|
||||||
|
df,
|
||||||
|
lookback=args.lookback,
|
||||||
|
pred_len=args.pred_len,
|
||||||
|
interval_min=int(args.interval),
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_start = max(0, end_idx - DRAW_WINDOW_LEN)
|
||||||
|
historical_df = df.loc[plot_start : end_idx - 1].copy()
|
||||||
|
|
||||||
|
temperature = args.temperature
|
||||||
|
top_p = args.top_p
|
||||||
|
qs = [float(x) for x in args.quantiles.split(",") if x.strip()]
|
||||||
|
|
||||||
|
preds = run_multi_predictions(
|
||||||
|
predictor,
|
||||||
|
x_df,
|
||||||
|
x_ts,
|
||||||
|
y_ts,
|
||||||
|
num_samples=args.samples,
|
||||||
|
base_seed=args.seed,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
agg_df = aggregate_quantiles(preds, quantiles=qs)
|
||||||
|
agg_df["timestamps"] = y_ts.values
|
||||||
|
|
||||||
|
# 输出 CSV(聚合)
|
||||||
|
out_csv = os.path.join(
|
||||||
|
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv"
|
||||||
|
)
|
||||||
|
agg_df.to_csv(out_csv, index=False)
|
||||||
|
print(f"预测分位数据已保存到: {out_csv}")
|
||||||
|
|
||||||
|
# 作图
|
||||||
|
fig = plot_results(
|
||||||
|
historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟"
|
||||||
|
)
|
||||||
|
out_png = os.path.join(
|
||||||
|
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png"
|
||||||
|
)
|
||||||
|
fig.savefig(out_png, dpi=150, bbox_inches="tight")
|
||||||
|
print(f"预测区间图表已保存到: {out_png}")
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except Exception as e:
|
||||||
|
print("运行出错:", repr(e))
|
||||||
|
sys.exit(1)
|
||||||
132
qq.py
Normal file
132
qq.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
轻量的 QQ Bot 客户端(用于发送私聊文本和媒体)
|
||||||
|
|
||||||
|
使用环境变量:
|
||||||
|
- QQ_BOT_URL: bot 接收地址(例如 http://localhost:30000/send_private_msg)
|
||||||
|
- QQ_BOT_TARGET_ID: 默认目标 QQ id(可被函数参数覆盖)
|
||||||
|
|
||||||
|
示例:
|
||||||
|
from qq import send_msg, send_media_msg
|
||||||
|
send_msg('hello')
|
||||||
|
# 发送图片时将自动转换为 "base64://..." 格式
|
||||||
|
send_media_msg('figures/btcusd_interval_pred_60m.png', type='image')
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
import base64
|
||||||
|
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
except Exception: # pragma: no cover - runtime dependency
|
||||||
|
requests = None
|
||||||
|
|
||||||
|
|
||||||
|
def _bot_url() -> str:
|
||||||
|
return os.getenv("QQ_BOT_URL", "http://localhost:30000/send_private_msg")
|
||||||
|
|
||||||
|
|
||||||
|
def _default_target() -> Optional[str]:
|
||||||
|
return os.getenv("QQ_BOT_TARGET_ID")
|
||||||
|
|
||||||
|
|
||||||
|
def _file_to_base64_uri(path: str) -> str:
|
||||||
|
"""将本地文件读取为 base64 uri 字符串(前缀 base64://)。"""
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
data = f.read()
|
||||||
|
b64 = base64.b64encode(data).decode("ascii")
|
||||||
|
return f"base64://{b64}"
|
||||||
|
|
||||||
|
|
||||||
|
def send_msg(msg: str, target_id: Optional[str] = None, timeout: float = 8.0):
|
||||||
|
"""发送文本消息到 QQ 机器人。返回 requests.Response 或 None(失败)。"""
|
||||||
|
if target_id is None:
|
||||||
|
target_id = _default_target()
|
||||||
|
if target_id is None:
|
||||||
|
raise ValueError("target_id 未提供,且环境变量 QQ_BOT_TARGET_ID 未设置")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"user_id": str(target_id),
|
||||||
|
"message": [
|
||||||
|
{"type": "text", "data": {"text": str(msg)}}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
url = _bot_url()
|
||||||
|
print(f"[QQ] 发送文本到 {target_id}: {msg}")
|
||||||
|
if requests is None:
|
||||||
|
print("requests 未安装,无法发送消息")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = requests.post(url, json=payload, timeout=timeout)
|
||||||
|
try:
|
||||||
|
# 非必要:打印返回的简短信息
|
||||||
|
print(f"[QQ] 返回: {resp.status_code} {resp.text[:200]}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return resp
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[QQ] 发送文本消息出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def send_media_msg(file_path: str, target_id: Optional[str] = None, type: str = "image", timeout: float = 15.0):
|
||||||
|
"""发送媒体消息(image 或 video)。
|
||||||
|
|
||||||
|
- 当 type == "image" 时:读取文件并以 "base64://..." 形式发送。
|
||||||
|
- 当 type == "video" 时:仍使用 "file://" 本地文件引用(保持原有行为)。
|
||||||
|
"""
|
||||||
|
if type not in ("image", "video"):
|
||||||
|
raise ValueError("type 必须为 'image' 或 'video'")
|
||||||
|
|
||||||
|
if target_id is None:
|
||||||
|
target_id = _default_target()
|
||||||
|
if target_id is None:
|
||||||
|
raise ValueError("target_id 未提供,且环境变量 QQ_BOT_TARGET_ID 未设置")
|
||||||
|
|
||||||
|
# 根据类型构造 file 字段
|
||||||
|
if type == "image":
|
||||||
|
try:
|
||||||
|
file_field = _file_to_base64_uri(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[QQ] 读取图片失败: {e}")
|
||||||
|
return None
|
||||||
|
else: # video
|
||||||
|
file_field = f"file://{file_path}"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"user_id": str(target_id),
|
||||||
|
"message": [
|
||||||
|
{"type": type, "data": {"file": file_field}}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
url = _bot_url()
|
||||||
|
print(f"[QQ] 发送媒体 {type} 到 {target_id}: {file_path}")
|
||||||
|
if requests is None:
|
||||||
|
print("requests 未安装,无法发送媒体消息")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = requests.post(url, json=payload, timeout=timeout)
|
||||||
|
try:
|
||||||
|
print(f"[QQ] 返回: {resp.status_code} {resp.text[:200]}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return resp
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[QQ] 发送媒体消息出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def send_prediction_result(summary: str, png_path: Optional[str] = None, target_id: Optional[str] = None):
|
||||||
|
"""便捷函数:先发送 summary 文本,再发送 png(若提供)。"""
|
||||||
|
r1 = send_msg(summary, target_id=target_id)
|
||||||
|
r2 = None
|
||||||
|
if png_path:
|
||||||
|
r2 = send_media_msg(png_path, target_id=target_id, type="image")
|
||||||
|
return r1, r2
|
||||||
@ -1,10 +1,12 @@
|
|||||||
numpy
|
numpy
|
||||||
pandas
|
pandas
|
||||||
torch
|
torch
|
||||||
|
requests
|
||||||
|
dotenv
|
||||||
einops==0.8.1
|
einops==0.8.1
|
||||||
huggingface_hub==0.33.1
|
huggingface_hub==0.33.1
|
||||||
matplotlib==3.9.3
|
matplotlib==3.9.3
|
||||||
pandas==2.2.2
|
pandas==2.2.2
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
safetensors==0.6.2
|
safetensors==0.6.2
|
||||||
|
pybit
|
||||||
|
|||||||
721
trader.py
721
trader.py
@ -1,419 +1,408 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""
|
"""
|
||||||
BTCUSD 15m 多次采样预测 → 区间聚合可视化
|
基于 Kronos 多次采样预测,评估多/空两侧的盈亏比,并给出最佳机会与可视化。
|
||||||
Author: 你(鹅)
|
|
||||||
|
盈亏比 = 潜在盈利 / 最大亏损
|
||||||
|
|
||||||
|
多头:
|
||||||
|
- 潜在盈利:以中位线 close_q50 为准,选取未来涨幅最高点(相对当前价 base_close)的涨幅。
|
||||||
|
- 最大亏损:从现在到该目标点区间内,使用多次采样的最差点(close_min)相对当前价的最大跌幅。
|
||||||
|
|
||||||
|
空头:
|
||||||
|
- 潜在盈利:以中位线 close_q50 为准,选取未来跌幅最大点(相对当前价 base_close)的跌幅。
|
||||||
|
- 最大亏损:从现在到该目标点区间内,使用多次采样的最差点(close_max)相对当前价的最大涨幅。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
import argparse
|
import argparse
|
||||||
from datetime import datetime
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pybit.unified_trading import HTTP
|
from predict import kronos_predict_df, get_recent_kline
|
||||||
|
|
||||||
# ==== Kronos ====
|
|
||||||
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
||||||
|
|
||||||
# ========== Matplotlib 字体 ==========
|
|
||||||
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
|
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
|
||||||
plt.rcParams["axes.unicode_minus"] = False
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
|
|
||||||
# ========== 默认参数 ==========
|
|
||||||
NUM_SAMPLES = 20 # 多次采样次数(建议 10~50)
|
|
||||||
QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上)
|
|
||||||
LOOKBACK = 360 # 历史窗口
|
|
||||||
PRED_LEN = 96 # 未来预测点数
|
|
||||||
INTERVAL_MIN = 15 # K 线周期(分钟)
|
|
||||||
DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0")
|
|
||||||
TEMPERATURE = 1.0
|
|
||||||
TOP_P = 0.9
|
|
||||||
|
|
||||||
OUTPUT_DIR = "figures"
|
def _safe_get(series: pd.Series, default=np.nan):
|
||||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 实用函数 ==========
|
|
||||||
def set_global_seed(seed: int):
|
|
||||||
"""尽可能固定随机性(若底层采样逻辑使用 torch/np/random)"""
|
|
||||||
try:
|
try:
|
||||||
import torch
|
return float(series.iloc[0])
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
return default
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_kline(
|
def compute_long_metrics(agg_df: pd.DataFrame) -> dict:
|
||||||
session: HTTP, symbol="BTCUSD", interval="15", limit=1000, category="inverse"
|
base = float(agg_df["base_close"].iloc[0])
|
||||||
):
|
close_med = agg_df.get("close_q50", agg_df.get("close_q10"))
|
||||||
print("正在获取K线数据...")
|
if close_med is None:
|
||||||
resp = session.get_kline(
|
raise ValueError("预测结果中缺少 close_q50(或 close_q10 兜底)列")
|
||||||
category=category, symbol=symbol, interval=interval, limit=limit
|
|
||||||
|
# 相对收益(百分比)
|
||||||
|
ret = (close_med - base) / base
|
||||||
|
j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0
|
||||||
|
profit = max(float(ret.iloc[j_target]), 0.0)
|
||||||
|
|
||||||
|
# 区间内的“最差的90%分位”(使用下分位线 q10,而非样本极小值)
|
||||||
|
q10 = agg_df.get("close_q10", close_med)
|
||||||
|
worst_price = float(q10.iloc[: j_target + 1].min())
|
||||||
|
loss = max((base - worst_price) / base, 1e-9)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"direction": "LONG",
|
||||||
|
"base": base,
|
||||||
|
"profit": profit,
|
||||||
|
"loss": loss,
|
||||||
|
"rr": profit / loss if loss > 0 else np.inf,
|
||||||
|
"target_idx": j_target,
|
||||||
|
"target_price": float(close_med.iloc[j_target]),
|
||||||
|
"worst_idx": int(np.nanargmin(q10.iloc[: j_target + 1].values)),
|
||||||
|
"worst_price": worst_price,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_short_metrics(agg_df: pd.DataFrame) -> dict:
|
||||||
|
base = float(agg_df["base_close"].iloc[0])
|
||||||
|
close_med = agg_df.get("close_q50", agg_df.get("close_q90"))
|
||||||
|
if close_med is None:
|
||||||
|
raise ValueError("预测结果中缺少 close_q50(或 close_q90 兜底)列")
|
||||||
|
|
||||||
|
# 相对收益(空头盈利为向下幅度)
|
||||||
|
ret = (base - close_med) / base
|
||||||
|
j_target = int(np.nanargmax(ret.values)) if len(ret) > 0 else 0
|
||||||
|
profit = max(float(ret.iloc[j_target]), 0.0)
|
||||||
|
|
||||||
|
# 区间内的“最差的90%分位”(使用上分位线 q90,而非样本极大值)
|
||||||
|
q90 = agg_df.get("close_q90", close_med)
|
||||||
|
worst_price = float(q90.iloc[: j_target + 1].max())
|
||||||
|
loss = max((worst_price - base) / base, 1e-9)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"direction": "SHORT",
|
||||||
|
"base": base,
|
||||||
|
"profit": profit,
|
||||||
|
"loss": loss,
|
||||||
|
"rr": profit / loss if loss > 0 else np.inf,
|
||||||
|
"target_idx": j_target,
|
||||||
|
"target_price": float(close_med.iloc[j_target]),
|
||||||
|
"worst_idx": int(np.nanargmax(q90.iloc[: j_target + 1].values)),
|
||||||
|
"worst_price": worst_price,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def plot_opportunities(
|
||||||
|
agg_df: pd.DataFrame,
|
||||||
|
long_m: dict,
|
||||||
|
short_m: dict,
|
||||||
|
symbol: str,
|
||||||
|
interval: str,
|
||||||
|
out_dir: str,
|
||||||
|
hist_df: pd.DataFrame | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""生成清晰可视化:上下两个子图展示多/空机会与风险。"""
|
||||||
|
from matplotlib import gridspec
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
ts = (
|
||||||
|
pd.to_datetime(agg_df["timestamps"])
|
||||||
|
if not np.issubdtype(agg_df["timestamps"].dtype, np.datetime64)
|
||||||
|
else agg_df["timestamps"]
|
||||||
)
|
)
|
||||||
if resp.get("retCode", -1) != 0:
|
base = long_m["base"]
|
||||||
raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}")
|
q10 = agg_df.get("close_q10", None)
|
||||||
lst = resp["result"]["list"]
|
q50 = agg_df.get("close_q50", None)
|
||||||
|
q90 = agg_df.get("close_q90", None)
|
||||||
|
|
||||||
df = pd.DataFrame(
|
# 历史24根K线(若提供)
|
||||||
lst,
|
hist_ts = None
|
||||||
columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"],
|
hist_close = None
|
||||||
)
|
if hist_df is not None and len(hist_df) > 0:
|
||||||
df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms")
|
try:
|
||||||
for c in ["open", "high", "low", "close", "volume", "turnover"]:
|
hist_ts = (
|
||||||
df[c] = df[c].astype(float)
|
pd.to_datetime(hist_df["timestamps"]) if not np.issubdtype(hist_df["timestamps"].dtype, np.datetime64) else hist_df["timestamps"]
|
||||||
df = df.sort_values("timestamps").reset_index(drop=True)
|
)
|
||||||
df["amount"] = df["turnover"]
|
hist_close = hist_df["close"].astype(float).values
|
||||||
print(f"获取到 {len(df)} 根K线数据")
|
except Exception:
|
||||||
return df
|
hist_ts, hist_close = None, None
|
||||||
|
|
||||||
|
|
||||||
def prepare_io_windows(
|
|
||||||
df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN
|
|
||||||
):
|
|
||||||
end_idx = len(df)
|
|
||||||
start_idx = max(0, end_idx - lookback)
|
|
||||||
x_df = df.loc[
|
|
||||||
start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"]
|
|
||||||
].reset_index(drop=True)
|
|
||||||
x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True)
|
|
||||||
|
|
||||||
last_ts = df.loc[end_idx - 1, "timestamps"]
|
|
||||||
future_timestamps = pd.date_range(
|
|
||||||
start=last_ts + pd.Timedelta(minutes=interval_min),
|
|
||||||
periods=pred_len,
|
|
||||||
freq=f"{interval_min}min",
|
|
||||||
)
|
|
||||||
y_timestamp = pd.Series(future_timestamps)
|
|
||||||
y_timestamp.index = range(len(y_timestamp))
|
|
||||||
|
|
||||||
print(f"数据总量: {len(df)} 根K线")
|
|
||||||
print(f"使用最新的 {lookback} 根K线(索引 {start_idx} 到 {end_idx-1})")
|
|
||||||
print(f"最后一根历史K线时间: {last_ts}")
|
|
||||||
print(f"预测未来 {pred_len} 根K线")
|
|
||||||
return x_df, x_timestamp, y_timestamp, start_idx, end_idx
|
|
||||||
|
|
||||||
|
|
||||||
def load_kronos(device=DEVICE):
|
|
||||||
print("正在加载 Kronos 模型...")
|
|
||||||
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
|
|
||||||
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
|
|
||||||
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
|
|
||||||
print("模型加载完成!\n")
|
|
||||||
return predictor
|
|
||||||
|
|
||||||
|
|
||||||
def run_one_prediction(
|
|
||||||
predictor,
|
|
||||||
x_df,
|
|
||||||
x_timestamp,
|
|
||||||
y_timestamp,
|
|
||||||
T=TEMPERATURE,
|
|
||||||
top_p=TOP_P,
|
|
||||||
seed=None,
|
|
||||||
verbose=False,
|
|
||||||
):
|
|
||||||
if seed is not None:
|
|
||||||
set_global_seed(seed)
|
|
||||||
return predictor.predict(
|
|
||||||
df=x_df,
|
|
||||||
x_timestamp=x_timestamp,
|
|
||||||
y_timestamp=y_timestamp,
|
|
||||||
pred_len=len(y_timestamp),
|
|
||||||
T=T,
|
|
||||||
top_p=top_p,
|
|
||||||
verbose=verbose,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_multi_predictions(
|
|
||||||
predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42,
|
|
||||||
temperature=TEMPERATURE, top_p=TOP_P
|
|
||||||
):
|
|
||||||
preds = []
|
|
||||||
print(f"正在进行多次预测:{num_samples} 次(T={temperature}, top_p={top_p})...")
|
|
||||||
for i in tqdm(range(num_samples), desc="预测进度", unit="次", ncols=100):
|
|
||||||
seed = base_seed + i # 每次不同 seed
|
|
||||||
pred_df = run_one_prediction(
|
|
||||||
predictor,
|
|
||||||
x_df,
|
|
||||||
x_timestamp,
|
|
||||||
y_timestamp,
|
|
||||||
T=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
seed=seed,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
# 兼容性处理:确保只取需要的列
|
|
||||||
cols_present = [
|
|
||||||
c
|
|
||||||
for c in ["open", "high", "low", "close", "volume", "amount"]
|
|
||||||
if c in pred_df.columns
|
|
||||||
]
|
|
||||||
pred_df = pred_df[cols_present].copy()
|
|
||||||
pred_df.reset_index(drop=True, inplace=True)
|
|
||||||
preds.append(pred_df)
|
|
||||||
print("多次预测完成。\n")
|
|
||||||
return preds
|
|
||||||
|
|
||||||
|
|
||||||
def aggregate_quantiles(pred_list, quantiles=QUANTILES):
|
|
||||||
"""
|
|
||||||
将多次预测列表聚合成分位数 DataFrame:
|
|
||||||
输出列命名:<col>_q10, <col>_q50, <col>_q90(按 quantiles 中的值)
|
|
||||||
"""
|
|
||||||
# 先把每次预测拼成 3D:time x feature x samples
|
|
||||||
keys = pred_list[0].columns.tolist()
|
|
||||||
T_len = len(pred_list[0])
|
|
||||||
S = len(pred_list)
|
|
||||||
data = {k: np.zeros((T_len, S), dtype=float) for k in keys}
|
|
||||||
for j, pdf in enumerate(pred_list):
|
|
||||||
for k in keys:
|
|
||||||
data[k][:, j] = pdf[k].values
|
|
||||||
|
|
||||||
out = {}
|
|
||||||
for k in keys:
|
|
||||||
for q in quantiles:
|
|
||||||
qv = np.quantile(data[k], q, axis=1)
|
|
||||||
out[f"{k}_q{int(q*100):02d}"] = qv
|
|
||||||
agg_df = pd.DataFrame(out)
|
|
||||||
return agg_df
|
|
||||||
|
|
||||||
|
|
||||||
def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"):
|
|
||||||
"""
|
|
||||||
上下两个子图(共享X轴):
|
|
||||||
- 上方(高度3):历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50)
|
|
||||||
- 下方(高度1):历史成交量柱 + 预测成交量中位柱(仅q50,不显示区间)
|
|
||||||
"""
|
|
||||||
import matplotlib.gridspec as gridspec
|
|
||||||
|
|
||||||
# 智能推断K线宽度(柱宽)
|
|
||||||
try:
|
|
||||||
if len(historical_df) >= 2:
|
|
||||||
hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2])
|
|
||||||
else:
|
|
||||||
hist_step = pd.Timedelta(minutes=10)
|
|
||||||
if len(y_timestamp) >= 2:
|
|
||||||
pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0])
|
|
||||||
else:
|
|
||||||
pred_step = pd.Timedelta(minutes=10)
|
|
||||||
bar_width_hist = hist_step * 0.8
|
|
||||||
bar_width_pred = pred_step * 0.8
|
|
||||||
except Exception:
|
|
||||||
bar_width_hist = pd.Timedelta(minutes=10)
|
|
||||||
bar_width_pred = pd.Timedelta(minutes=10)
|
|
||||||
|
|
||||||
# 取出预测分位
|
|
||||||
close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None
|
|
||||||
close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None
|
|
||||||
close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None
|
|
||||||
|
|
||||||
vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None
|
|
||||||
|
|
||||||
# 图形与网格
|
|
||||||
fig = plt.figure(figsize=(18, 10))
|
fig = plt.figure(figsize=(18, 10))
|
||||||
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3
|
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1], hspace=0.15)
|
||||||
|
|
||||||
# ===== 上:收盘价 =====
|
def _draw_panel(
|
||||||
ax_price = fig.add_subplot(gs[0])
|
ax, title, target_idx, target_price, worst_idx, worst_price, direction
|
||||||
|
):
|
||||||
|
# 历史收盘价(预测前24根)
|
||||||
|
if hist_ts is not None and hist_close is not None:
|
||||||
|
ax.plot(
|
||||||
|
hist_ts.values,
|
||||||
|
hist_close,
|
||||||
|
color="#6c757d",
|
||||||
|
lw=1.6,
|
||||||
|
label="历史收盘(近24根)",
|
||||||
|
)
|
||||||
|
# 预测起点标注
|
||||||
|
try:
|
||||||
|
ax.axvline(ts.iloc[0], color="gray", ls=":", lw=1.0, label="预测起点")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
ax_price.plot(
|
# 区间带
|
||||||
historical_df["timestamps"],
|
if q10 is not None and q90 is not None:
|
||||||
historical_df["close"],
|
ax.fill_between(
|
||||||
label="历史收盘价",
|
ts.values,
|
||||||
linewidth=1.8,
|
q10.values,
|
||||||
|
q90.values,
|
||||||
|
color="#7dbcea",
|
||||||
|
alpha=0.22,
|
||||||
|
label="q10~q90",
|
||||||
|
)
|
||||||
|
# 中位线
|
||||||
|
if q50 is not None:
|
||||||
|
ax.plot(ts.values, q50.values, color="#00539c", lw=2.0, label="中位线 q50")
|
||||||
|
|
||||||
|
# 基准线
|
||||||
|
ax.axhline(base, color="gray", ls=":", lw=1.2, label=f"当前价 {base:.1f}")
|
||||||
|
|
||||||
|
# 目标点
|
||||||
|
ax.plot(
|
||||||
|
ts.iloc[target_idx],
|
||||||
|
target_price,
|
||||||
|
marker="o",
|
||||||
|
color="#2e7d32",
|
||||||
|
ms=8,
|
||||||
|
label="目标点",
|
||||||
|
)
|
||||||
|
ax.annotate(
|
||||||
|
f"目标 {target_price:.1f}",
|
||||||
|
(ts.iloc[target_idx], target_price),
|
||||||
|
textcoords="offset points",
|
||||||
|
xytext=(8, 8),
|
||||||
|
bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#2e7d32", alpha=0.8),
|
||||||
|
color="#2e7d32",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 最差点
|
||||||
|
ax.plot(
|
||||||
|
ts.iloc[worst_idx],
|
||||||
|
worst_price,
|
||||||
|
marker="v",
|
||||||
|
color="#c62828",
|
||||||
|
ms=8,
|
||||||
|
label="最差点",
|
||||||
|
)
|
||||||
|
ax.annotate(
|
||||||
|
f"最差 {worst_price:.1f}",
|
||||||
|
(ts.iloc[worst_idx], worst_price),
|
||||||
|
textcoords="offset points",
|
||||||
|
xytext=(8, -14),
|
||||||
|
bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="#c62828", alpha=0.8),
|
||||||
|
color="#c62828",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 连接箭头(方向)
|
||||||
|
color = "#2e7d32" if direction == "LONG" else "#c62828"
|
||||||
|
ax.annotate(
|
||||||
|
"",
|
||||||
|
xy=(ts.iloc[target_idx], target_price),
|
||||||
|
xytext=(ts.iloc[0], base),
|
||||||
|
arrowprops=dict(arrowstyle="->", color=color, lw=2, alpha=0.9),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 盈亏比矩形:用填充矩形表示潜在盈利(绿色)与最大亏损(红色)
|
||||||
|
x_range = ts.iloc[: target_idx + 1]
|
||||||
|
if direction == "LONG":
|
||||||
|
# 盈利(若目标价在基准价之上)
|
||||||
|
if target_price > base:
|
||||||
|
y_low = np.full_like(x_range, base, dtype=float)
|
||||||
|
y_high = np.full_like(x_range, target_price, dtype=float)
|
||||||
|
ax.fill_between(
|
||||||
|
x_range.values,
|
||||||
|
y_low,
|
||||||
|
y_high,
|
||||||
|
color="#2e7d32",
|
||||||
|
alpha=0.18,
|
||||||
|
label="潜在盈利",
|
||||||
|
)
|
||||||
|
# 亏损(q10 代表“最差的90%分位”)
|
||||||
|
if q10 is not None and np.isfinite(worst_price) and worst_price < base:
|
||||||
|
y_low = np.full_like(x_range, worst_price, dtype=float)
|
||||||
|
y_high = np.full_like(x_range, base, dtype=float)
|
||||||
|
ax.fill_between(
|
||||||
|
x_range.values,
|
||||||
|
y_low,
|
||||||
|
y_high,
|
||||||
|
color="#c62828",
|
||||||
|
alpha=0.18,
|
||||||
|
label="最大亏损(q10)",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 盈利(空头目标价在基准价之下)
|
||||||
|
if target_price < base:
|
||||||
|
y_low = np.full_like(x_range, target_price, dtype=float)
|
||||||
|
y_high = np.full_like(x_range, base, dtype=float)
|
||||||
|
ax.fill_between(
|
||||||
|
x_range.values,
|
||||||
|
y_low,
|
||||||
|
y_high,
|
||||||
|
color="#2e7d32",
|
||||||
|
alpha=0.18,
|
||||||
|
label="潜在盈利",
|
||||||
|
)
|
||||||
|
# 亏损(q90 代表“最差的90%分位”)
|
||||||
|
if q90 is not None and np.isfinite(worst_price) and worst_price > base:
|
||||||
|
y_low = np.full_like(x_range, base, dtype=float)
|
||||||
|
y_high = np.full_like(x_range, worst_price, dtype=float)
|
||||||
|
ax.fill_between(
|
||||||
|
x_range.values,
|
||||||
|
y_low,
|
||||||
|
y_high,
|
||||||
|
color="#c62828",
|
||||||
|
alpha=0.18,
|
||||||
|
label="最大亏损(q90)",
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.grid(True, alpha=0.25)
|
||||||
|
ax.legend(loc="best", fontsize=9)
|
||||||
|
|
||||||
|
ax1 = fig.add_subplot(gs[0])
|
||||||
|
rr_long_pct = long_m["rr"] * 100.0
|
||||||
|
profit_long_pct = long_m["profit"] * 100.0
|
||||||
|
loss_long_pct = long_m["loss"] * 100.0
|
||||||
|
title_long = f"多头机会:盈亏比 RR={long_m['rr']:.2f}(+{profit_long_pct:.2f}% / -{loss_long_pct:.2f}%)"
|
||||||
|
_draw_panel(
|
||||||
|
ax1,
|
||||||
|
title_long,
|
||||||
|
long_m["target_idx"],
|
||||||
|
long_m["target_price"],
|
||||||
|
long_m["worst_idx"],
|
||||||
|
long_m["worst_price"],
|
||||||
|
"LONG",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 预测收盘价区间与中位线
|
ax2 = fig.add_subplot(gs[1], sharex=ax1)
|
||||||
if close_q10 is not None and close_q90 is not None:
|
rr_short_pct = short_m["rr"] * 100.0
|
||||||
ax_price.fill_between(
|
profit_short_pct = short_m["profit"] * 100.0
|
||||||
y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)"
|
loss_short_pct = short_m["loss"] * 100.0
|
||||||
)
|
title_short = f"空头机会:盈亏比 RR={short_m['rr']:.2f}(+{profit_short_pct:.2f}% / -{loss_short_pct:.2f}%)"
|
||||||
if close_q50 is not None:
|
_draw_panel(
|
||||||
ax_price.plot(
|
ax2,
|
||||||
y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)"
|
title_short,
|
||||||
)
|
short_m["target_idx"],
|
||||||
|
short_m["target_price"],
|
||||||
# 预测起点与阴影
|
short_m["worst_idx"],
|
||||||
if len(y_timestamp) > 0:
|
short_m["worst_price"],
|
||||||
ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点")
|
"SHORT",
|
||||||
ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08)
|
|
||||||
|
|
||||||
# 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价
|
|
||||||
# 合并历史和预测时间范围与价格数据
|
|
||||||
all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True)
|
|
||||||
hist_close = historical_df["close"]
|
|
||||||
pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp))
|
|
||||||
all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True)
|
|
||||||
|
|
||||||
if len(all_timestamps) > 0:
|
|
||||||
start_time = all_timestamps.min()
|
|
||||||
end_time = all_timestamps.max()
|
|
||||||
|
|
||||||
# 生成所有16:00时间点(UTC+8)
|
|
||||||
current_date = start_time.normalize() # 当天零点
|
|
||||||
while current_date <= end_time:
|
|
||||||
target_time = current_date + pd.Timedelta(hours=16) # 16:00
|
|
||||||
if start_time <= target_time <= end_time:
|
|
||||||
# 画虚线
|
|
||||||
ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5)
|
|
||||||
|
|
||||||
# 找到最接近16:00的时间点的收盘价
|
|
||||||
time_diffs = (all_timestamps - target_time).abs()
|
|
||||||
closest_idx = time_diffs.idxmin()
|
|
||||||
closest_time = all_timestamps.iloc[closest_idx]
|
|
||||||
closest_price = all_close.iloc[closest_idx]
|
|
||||||
|
|
||||||
# 如果时间差不超过1小时,则标注价格
|
|
||||||
if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price):
|
|
||||||
ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7)
|
|
||||||
ax_price.text(
|
|
||||||
closest_time, closest_price,
|
|
||||||
f' ${closest_price:.1f}',
|
|
||||||
fontsize=9,
|
|
||||||
color='blue',
|
|
||||||
verticalalignment='bottom',
|
|
||||||
horizontalalignment='left',
|
|
||||||
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7)
|
|
||||||
)
|
|
||||||
current_date += pd.Timedelta(days=1)
|
|
||||||
|
|
||||||
ax_price.set_ylabel("价格 (USD)", fontsize=11)
|
|
||||||
ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold")
|
|
||||||
ax_price.grid(True, alpha=0.3)
|
|
||||||
ax_price.legend(loc="best", fontsize=9)
|
|
||||||
|
|
||||||
# ===== 下:成交量(仅预测中位量能柱)=====
|
|
||||||
ax_vol = fig.add_subplot(gs[1], sharex=ax_price)
|
|
||||||
|
|
||||||
# 历史量能柱
|
|
||||||
ax_vol.bar(
|
|
||||||
historical_df["timestamps"],
|
|
||||||
historical_df["volume"],
|
|
||||||
width=bar_width_hist,
|
|
||||||
alpha=0.35,
|
|
||||||
label="历史成交量",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 预测中位量能柱(不画区间)
|
for lbl in ax1.get_xticklabels():
|
||||||
if vol_q50 is not None:
|
lbl.set_visible(False)
|
||||||
ax_vol.bar(
|
ax2.set_xlabel("时间")
|
||||||
y_timestamp.values,
|
ax1.set_ylabel("价格 (USD)")
|
||||||
vol_q50,
|
ax2.set_ylabel("价格 (USD)")
|
||||||
width=bar_width_pred,
|
fig.suptitle(
|
||||||
alpha=0.6,
|
f"{symbol} {interval}分钟线 24小时 预测机会评估", fontsize=16, fontweight="bold"
|
||||||
label="预测成交量中位(q50)",
|
)
|
||||||
)
|
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||||
|
|
||||||
# 预测起点线
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
if len(y_timestamp) > 0:
|
out_png = os.path.join(
|
||||||
ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1)
|
out_dir, f"trade_opportunity_{symbol.lower()}_{interval}m.png"
|
||||||
|
)
|
||||||
|
fig.savefig(out_png, dpi=150, bbox_inches="tight")
|
||||||
|
return out_png
|
||||||
|
|
||||||
ax_vol.set_ylabel("成交量", fontsize=11)
|
|
||||||
ax_vol.set_xlabel("时间", fontsize=11)
|
|
||||||
ax_vol.grid(True, alpha=0.25)
|
|
||||||
ax_vol.legend(loc="best", fontsize=9)
|
|
||||||
|
|
||||||
# 避免X轴标签重叠
|
|
||||||
plt.setp(ax_price.get_xticklabels(), visible=False)
|
|
||||||
plt.tight_layout()
|
|
||||||
return fig
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化")
|
# 加载环境变量
|
||||||
parser.add_argument("--symbol", default="BTCUSD")
|
load_dotenv()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="计算并可视化高盈亏比交易机会")
|
||||||
|
parser.add_argument("--symbol", default="BTCUSDT")
|
||||||
parser.add_argument("--category", default="inverse")
|
parser.add_argument("--category", default="inverse")
|
||||||
parser.add_argument("--interval", default="15")
|
parser.add_argument("--interval", default="60")
|
||||||
parser.add_argument("--limit", type=int, default=1000)
|
parser.add_argument("--limit", type=int, default=500)
|
||||||
parser.add_argument("--lookback", type=int, default=LOOKBACK)
|
parser.add_argument("--lookback", type=int, default=360)
|
||||||
parser.add_argument("--pred_len", type=int, default=PRED_LEN)
|
parser.add_argument("--pred_len", type=int, default=24)
|
||||||
parser.add_argument("--samples", type=int, default=NUM_SAMPLES)
|
parser.add_argument("--samples", type=int, default=30)
|
||||||
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
parser.add_argument("--top_p", type=float, default=TOP_P)
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
parser.add_argument("--quantiles", default="0.1,0.5,0.9")
|
parser.add_argument("--device", default=os.getenv("KRONOS_DEVICE", "cuda:0"))
|
||||||
parser.add_argument("--device", default=DEVICE)
|
|
||||||
parser.add_argument("--seed", type=int, default=42)
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# 载入env & HTTP
|
# 获取预测聚合数据(仅未来)
|
||||||
load_dotenv()
|
agg_df = kronos_predict_df(
|
||||||
api_key = os.getenv("BYBIT_API_KEY")
|
|
||||||
api_secret = os.getenv("BYBIT_API_SECRET")
|
|
||||||
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
|
|
||||||
# Kronos
|
|
||||||
predictor = load_kronos(device=args.device)
|
|
||||||
|
|
||||||
# 拉数据
|
|
||||||
df = fetch_kline(
|
|
||||||
session,
|
|
||||||
symbol=args.symbol,
|
symbol=args.symbol,
|
||||||
|
category=args.category,
|
||||||
interval=args.interval,
|
interval=args.interval,
|
||||||
limit=args.limit,
|
limit=args.limit,
|
||||||
category=args.category,
|
|
||||||
)
|
|
||||||
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
|
|
||||||
df,
|
|
||||||
lookback=args.lookback,
|
lookback=args.lookback,
|
||||||
pred_len=args.pred_len,
|
pred_len=args.pred_len,
|
||||||
interval_min=int(args.interval),
|
samples=args.samples,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.top_p,
|
||||||
|
quantiles=[0.1, 0.5, 0.9],
|
||||||
|
device=args.device,
|
||||||
|
seed=args.seed,
|
||||||
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 历史用于画图的窗口(最近200根,可按需调整)
|
# 计算多 / 空的机会
|
||||||
plot_start = max(0, end_idx - 200)
|
long_m = compute_long_metrics(agg_df)
|
||||||
historical_df = df.loc[plot_start : end_idx - 1].copy()
|
short_m = compute_short_metrics(agg_df)
|
||||||
|
|
||||||
# 采样参数(直接使用局部变量,不需要修改全局变量)
|
# 选择更优机会
|
||||||
temperature = args.temperature
|
best = long_m if long_m["rr"] >= short_m["rr"] else short_m
|
||||||
top_p = args.top_p
|
|
||||||
qs = [float(x) for x in args.quantiles.split(",") if x.strip()]
|
|
||||||
|
|
||||||
# 多次预测(传入局部变量)
|
# 生成汇总字符串(增强可读性与可辨识度)
|
||||||
preds = run_multi_predictions(
|
dir_emoji = "📈 多头" if best["direction"] == "LONG" else "📉 空头"
|
||||||
predictor, x_df, x_ts, y_ts, num_samples=args.samples, base_seed=args.seed,
|
long_arrow = "⬆️"
|
||||||
temperature=temperature, top_p=top_p
|
short_arrow = "⬇️"
|
||||||
|
summary_lines = [
|
||||||
|
"== 24小时 每日交易机会评估 ==\n",
|
||||||
|
f"📌 标的: {args.symbol} | 周期: {args.interval}m | 当前价: {best['base']:.2f}\n",
|
||||||
|
f"📈 多头 | RR={long_m['rr']:.2f} | 盈利 +{long_m['profit'] * 100.0:.2f}% | 亏损 -{long_m['loss'] * 100.0:.2f}% | 🎯 目标 {long_m['target_price']:.2f} | 🛡️ 最差 {long_m['worst_price']:.2f}\n",
|
||||||
|
f"📉 空头 | RR={short_m['rr']:.2f} | 盈利 +{short_m['profit'] * 100.0:.2f}% | 亏损 -{short_m['loss'] * 100.0:.2f}% | 🎯 目标 {short_m['target_price']:.2f} | 🛡️ 最差 {short_m['worst_price']:.2f}\n",
|
||||||
|
f"✅ 最佳: {dir_emoji} | 盈亏比 RR={best['rr']:.2f}",
|
||||||
|
]
|
||||||
|
if best["rr"] < 2.0:
|
||||||
|
summary_lines.append("⚠️ 交易机会不佳,建议观望。")
|
||||||
|
summary = "\n".join(summary_lines) + "\n"
|
||||||
|
|
||||||
|
# 打印汇总
|
||||||
|
print(summary)
|
||||||
|
|
||||||
|
# 获取预测前24根K线用于可视化上下文
|
||||||
|
try:
|
||||||
|
full_hist_df = get_recent_kline(
|
||||||
|
symbol=args.symbol,
|
||||||
|
category=args.category,
|
||||||
|
interval=args.interval,
|
||||||
|
limit=args.limit,
|
||||||
|
)
|
||||||
|
hist24 = full_hist_df.tail(24).reset_index(drop=True)
|
||||||
|
except Exception:
|
||||||
|
hist24 = None
|
||||||
|
|
||||||
|
# 可视化
|
||||||
|
out_png = plot_opportunities(
|
||||||
|
agg_df, long_m, short_m, args.symbol, args.interval, out_dir="figures", hist_df=hist24
|
||||||
)
|
)
|
||||||
|
print(f"可视化已保存: {out_png}")
|
||||||
# 聚合分位
|
|
||||||
agg_df = aggregate_quantiles(preds, quantiles=qs)
|
# 发送交易机会评估结果到 QQ
|
||||||
agg_df["timestamps"] = y_ts.values
|
try:
|
||||||
|
from qq import send_prediction_result
|
||||||
# 输出 CSV(聚合)
|
target = os.getenv("QQ_BOT_TARGET_ID")
|
||||||
out_csv = os.path.join(
|
send_prediction_result(summary, png_path=out_png, target_id=target)
|
||||||
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv"
|
except Exception as e:
|
||||||
)
|
print(f"发送 QQ 消息失败: {repr(e)}")
|
||||||
agg_df.to_csv(out_csv, index=False)
|
|
||||||
print(f"预测分位数据已保存到: {out_csv}")
|
|
||||||
|
|
||||||
# 作图
|
|
||||||
fig = plot_results(
|
|
||||||
historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟"
|
|
||||||
)
|
|
||||||
out_png = os.path.join(
|
|
||||||
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png"
|
|
||||||
)
|
|
||||||
fig.savefig(out_png, dpi=150, bbox_inches="tight")
|
|
||||||
print(f"预测区间图表已保存到: {out_png}")
|
|
||||||
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
main()
|
||||||
main()
|
|
||||||
except Exception as e:
|
|
||||||
print("运行出错:", repr(e))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user