feat: Translate all Chinese content to English in webui module

- Translate all Chinese comments and strings in webui/app.py
- Translate all Chinese comments and strings in webui/run.py
- Translate all Chinese comments and strings in webui/start.sh
- Translate all Chinese content in webui/README.md
- Translate all Chinese content in webui/templates/index.html
- Add prediction results directory for analysis
- Complete internationalization of webui module
This commit is contained in:
Charles 2025-08-27 15:36:47 +08:00
parent 328dc4a3b0
commit 14863929a0
31 changed files with 58661 additions and 544 deletions

View File

@ -1,135 +1,135 @@
# Kronos Web UI # Kronos Web UI
Kronos金融预测模型的Web用户界面提供直观的图形化操作界面。 Web user interface for Kronos financial prediction model, providing intuitive graphical operation interface.
## ✨ 功能特性 ## ✨ Features
- **多格式数据支持**: 支持CSV、Feather等格式的金融数据 - **Multi-format data support**: Supports CSV, Feather and other financial data formats
- **智能时间窗口**: 固定400+120数据点的时间窗口滑条选择 - **Smart time window**: Fixed 400+120 data point time window slider selection
- **真实模型预测**: 集成真实的Kronos模型支持多种模型大小 - **Real model prediction**: Integrated real Kronos model, supports multiple model sizes
- **预测质量控制**: 可调节温度、核采样、样本数量等参数 - **Prediction quality control**: Adjustable temperature, nucleus sampling, sample count and other parameters
- **多设备支持**: 支持CPU、CUDA、MPS等计算设备 - **Multi-device support**: Supports CPU, CUDA, MPS and other computing devices
- **对比分析**: 预测结果与实际数据的详细对比 - **Comparison analysis**: Detailed comparison between prediction results and actual data
- **K线图显示**: 专业的金融K线图表展示 - **K-line chart display**: Professional financial K-line chart display
## 🚀 快速开始 ## 🚀 Quick Start
### 方法1: 使用Python脚本启动 ### Method 1: Start with Python script
```bash ```bash
cd webui cd webui
python run.py python run.py
``` ```
### 方法2: 使用Shell脚本启动 ### Method 2: Start with Shell script
```bash ```bash
cd webui cd webui
chmod +x start.sh chmod +x start.sh
./start.sh ./start.sh
``` ```
### 方法3: 直接启动Flask应用 ### Method 3: Start Flask application directly
```bash ```bash
cd webui cd webui
python app.py python app.py
``` ```
启动成功后,访问 http://localhost:7070 After successful startup, visit http://localhost:7070
## 📋 使用步骤 ## 📋 Usage Steps
1. **加载数据**: 选择data目录中的金融数据文件 1. **Load data**: Select financial data file from data directory
2. **加载模型**: 选择Kronos模型和计算设备 2. **Load model**: Select Kronos model and computing device
3. **设置参数**: 调整预测质量参数 3. **Set parameters**: Adjust prediction quality parameters
4. **选择时间窗口**: 使用滑条选择400+120数据点的时间范围 4. **Select time window**: Use slider to select 400+120 data point time range
5. **开始预测**: 点击预测按钮生成结果 5. **Start prediction**: Click prediction button to generate results
6. **查看结果**: 在图表和表格中查看预测结果 6. **View results**: View prediction results in charts and tables
## 🔧 预测质量参数 ## 🔧 Prediction Quality Parameters
### 温度 (T) ### Temperature (T)
- **范围**: 0.1 - 2.0 - **Range**: 0.1 - 2.0
- **作用**: 控制预测的随机性 - **Effect**: Controls prediction randomness
- **建议**: 1.2-1.5 获得更好的预测质量 - **Recommendation**: 1.2-1.5 for better prediction quality
### 核采样 (top_p) ### Nucleus Sampling (top_p)
- **范围**: 0.1 - 1.0 - **Range**: 0.1 - 1.0
- **作用**: 控制预测的多样性 - **Effect**: Controls prediction diversity
- **建议**: 0.95-1.0 考虑更多可能性 - **Recommendation**: 0.95-1.0 to consider more possibilities
### 样本数量 ### Sample Count
- **范围**: 1 - 5 - **Range**: 1 - 5
- **作用**: 生成多个预测样本 - **Effect**: Generate multiple prediction samples
- **建议**: 2-3 个样本提高质量 - **Recommendation**: 2-3 samples to improve quality
## 📊 支持的数据格式 ## 📊 Supported Data Formats
### 必需列 ### Required Columns
- `open`: 开盘价 - `open`: Opening price
- `high`: 最高价 - `high`: Highest price
- `low`: 最低价 - `low`: Lowest price
- `close`: 收盘价 - `close`: Closing price
### 可选列 ### Optional Columns
- `volume`: 成交量 - `volume`: Trading volume
- `amount`: 成交额(不用于预测) - `amount`: Trading amount (not used for prediction)
- `timestamps`/`timestamp`/`date`: 时间戳 - `timestamps`/`timestamp`/`date`: Timestamp
## 🤖 模型支持 ## 🤖 Model Support
- **Kronos-mini**: 4.1M参数,轻量级快速预测 - **Kronos-mini**: 4.1M parameters, lightweight fast prediction
- **Kronos-small**: 24.7M参数,平衡性能和速度 - **Kronos-small**: 24.7M parameters, balanced performance and speed
- **Kronos-base**: 102.3M参数,高质量预测 - **Kronos-base**: 102.3M parameters, high quality prediction
## 🖥️ GPU加速支持 ## 🖥️ GPU Acceleration Support
- **CPU**: 通用计算,兼容性最好 - **CPU**: General computing, best compatibility
- **CUDA**: NVIDIA GPU加速,性能最佳 - **CUDA**: NVIDIA GPU acceleration, best performance
- **MPS**: Apple Silicon GPU加速Mac用户推荐 - **MPS**: Apple Silicon GPU acceleration, recommended for Mac users
## ⚠️ 注意事项 ## ⚠️ Notes
- `amount`列不会被用于预测,仅用于显示 - `amount` column is not used for prediction, only for display
- 时间窗口固定为400+120=520个数据点 - Time window is fixed at 400+120=520 data points
- 确保数据文件包含足够的历史数据 - Ensure data file contains sufficient historical data
- 首次加载模型可能需要下载,请耐心等待 - First model loading may require download, please be patient
## 🔍 对比分析 ## 🔍 Comparison Analysis
系统会自动提供预测结果与实际数据的对比分析,包括: The system automatically provides comparison analysis between prediction results and actual data, including:
- 价格差异统计 - Price difference statistics
- 误差分析 - Error analysis
- 预测质量评估 - Prediction quality assessment
## 🛠️ 技术架构 ## 🛠️ Technical Architecture
- **后端**: Flask + Python - **Backend**: Flask + Python
- **前端**: HTML + CSS + JavaScript - **Frontend**: HTML + CSS + JavaScript
- **图表**: Plotly.js - **Charts**: Plotly.js
- **数据处理**: Pandas + NumPy - **Data processing**: Pandas + NumPy
- **模型**: Hugging Face Transformers - **Model**: Hugging Face Transformers
## 📝 故障排除 ## 📝 Troubleshooting
### 常见问题 ### Common Issues
1. **端口占用**: 修改app.py中的端口号 1. **Port occupied**: Modify port number in app.py
2. **依赖缺失**: 运行 `pip install -r requirements.txt` 2. **Missing dependencies**: Run `pip install -r requirements.txt`
3. **模型加载失败**: 检查网络连接和模型ID 3. **Model loading failed**: Check network connection and model ID
4. **数据格式错误**: 确保数据列名和格式正确 4. **Data format error**: Ensure data column names and format are correct
### 日志查看 ### Log Viewing
启动时会在控制台显示详细的运行信息,包括模型状态和错误信息。 Detailed runtime information will be displayed in the console at startup, including model status and error messages.
## 📄 许可证 ## 📄 License
本项目遵循原Kronos项目的许可证条款。 This project follows the license terms of the original Kronos project.
## 🤝 贡献 ## 🤝 Contributing
欢迎提交Issue和Pull Request来改进这个Web UI Welcome to submit Issues and Pull Requests to improve this Web UI!
## 📞 支持 ## 📞 Support
如有问题,请查看: If you have questions, please check:
1. 项目文档 1. Project documentation
2. GitHub Issues 2. GitHub Issues
3. 控制台错误信息 3. Console error messages

View File

@ -11,7 +11,7 @@ import warnings
import datetime import datetime
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
# 添加项目根目录到路径 # Add project root directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try: try:
@ -19,17 +19,17 @@ try:
MODEL_AVAILABLE = True MODEL_AVAILABLE = True
except ImportError: except ImportError:
MODEL_AVAILABLE = False MODEL_AVAILABLE = False
print("警告: Kronos模型无法导入将使用模拟数据进行演示") print("Warning: Kronos model cannot be imported, will use simulated data for demonstration")
app = Flask(__name__) app = Flask(__name__)
CORS(app) CORS(app)
# 全局变量存储模型 # Global variables to store models
tokenizer = None tokenizer = None
model = None model = None
predictor = None predictor = None
# 可用的模型配置 # Available model configurations
AVAILABLE_MODELS = { AVAILABLE_MODELS = {
'kronos-mini': { 'kronos-mini': {
'name': 'Kronos-mini', 'name': 'Kronos-mini',
@ -37,7 +37,7 @@ AVAILABLE_MODELS = {
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k', 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k',
'context_length': 2048, 'context_length': 2048,
'params': '4.1M', 'params': '4.1M',
'description': '轻量级模型,适合快速预测' 'description': 'Lightweight model, suitable for fast prediction'
}, },
'kronos-small': { 'kronos-small': {
'name': 'Kronos-small', 'name': 'Kronos-small',
@ -45,7 +45,7 @@ AVAILABLE_MODELS = {
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
'context_length': 512, 'context_length': 512,
'params': '24.7M', 'params': '24.7M',
'description': '小型模型,平衡性能和速度' 'description': 'Small model, balanced performance and speed'
}, },
'kronos-base': { 'kronos-base': {
'name': 'Kronos-base', 'name': 'Kronos-base',
@ -53,12 +53,12 @@ AVAILABLE_MODELS = {
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
'context_length': 512, 'context_length': 512,
'params': '102.3M', 'params': '102.3M',
'description': '基础模型,提供更好的预测质量' 'description': 'Base model, provides better prediction quality'
} }
} }
def load_data_files(): def load_data_files():
"""扫描data目录并返回可用的数据文件""" """Scan data directory and return available data files"""
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data') data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
data_files = [] data_files = []
@ -76,119 +76,65 @@ def load_data_files():
return data_files return data_files
def load_data_file(file_path): def load_data_file(file_path):
"""加载数据文件""" """Load data file"""
try: try:
if file_path.endswith('.csv'): if file_path.endswith('.csv'):
df = pd.read_csv(file_path) df = pd.read_csv(file_path)
elif file_path.endswith('.feather'): elif file_path.endswith('.feather'):
df = pd.read_feather(file_path) df = pd.read_feather(file_path)
else: else:
return None, "不支持的文件格式" return None, "Unsupported file format"
# 检查必要的列 # Check required columns
required_cols = ['open', 'high', 'low', 'close'] required_cols = ['open', 'high', 'low', 'close']
if not all(col in df.columns for col in required_cols): if not all(col in df.columns for col in required_cols):
return None, f"缺少必要的列: {required_cols}" return None, f"Missing required columns: {required_cols}"
# 处理时间戳列 # Process timestamp column
if 'timestamps' in df.columns: if 'timestamps' in df.columns:
df['timestamps'] = pd.to_datetime(df['timestamps']) df['timestamps'] = pd.to_datetime(df['timestamps'])
elif 'timestamp' in df.columns: elif 'timestamp' in df.columns:
df['timestamps'] = pd.to_datetime(df['timestamp']) df['timestamps'] = pd.to_datetime(df['timestamp'])
elif 'date' in df.columns: elif 'date' in df.columns:
# 如果列名是'date',将其重命名为'timestamps' # If column name is 'date', rename it to 'timestamps'
df['timestamps'] = pd.to_datetime(df['date']) df['timestamps'] = pd.to_datetime(df['date'])
else: else:
# 如果没有时间戳列,创建一个 # If no timestamp column exists, create one
df['timestamps'] = pd.date_range(start='2024-01-01', periods=len(df), freq='1H') df['timestamps'] = pd.date_range(start='2024-01-01', periods=len(df), freq='1H')
# 确保数值列是数值类型 # Ensure numeric columns are numeric type
for col in ['open', 'high', 'low', 'close']: for col in ['open', 'high', 'low', 'close']:
df[col] = pd.to_numeric(df[col], errors='coerce') df[col] = pd.to_numeric(df[col], errors='coerce')
# 处理volume列可选 # Process volume column (optional)
if 'volume' in df.columns: if 'volume' in df.columns:
df['volume'] = pd.to_numeric(df['volume'], errors='coerce') df['volume'] = pd.to_numeric(df['volume'], errors='coerce')
# 处理amount列可选但不用于预测 # Process amount column (optional, but not used for prediction)
if 'amount' in df.columns: if 'amount' in df.columns:
df['amount'] = pd.to_numeric(df['amount'], errors='coerce') df['amount'] = pd.to_numeric(df['amount'], errors='coerce')
# 删除包含NaN的行 # Remove rows containing NaN values
df = df.dropna() df = df.dropna()
return df, None return df, None
except Exception as e: except Exception as e:
return None, f"加载文件失败: {str(e)}" return None, f"Failed to load file: {str(e)}"
def convert_price_to_returns(df, price_cols=['open', 'high', 'low', 'close']):
"""将绝对价格转换为相对变化率解决Kronos模型的连续性问题"""
print(f"[DEBUG] convert_price_to_returns: 开始转换,输入数据形状: {df.shape}")
print(f"[DEBUG] convert_price_to_returns: 输入数据前3行:")
print(df[price_cols].head(3))
returns_df = df.copy()
# 计算相对变化率 (pct_change)
for col in price_cols:
if col in returns_df.columns:
print(f"[DEBUG] convert_price_to_returns: 转换列 {col}")
returns_df[col] = returns_df[col].pct_change()
# 第一行会是NaN用0填充
returns_df = returns_df.fillna(0)
print(f"[DEBUG] convert_price_to_returns: 转换完成输出数据前3行:")
print(returns_df[price_cols].head(3))
return returns_df
def convert_returns_to_price(returns_df, initial_prices, price_cols=['open', 'high', 'low', 'close']):
"""将相对变化率转换回绝对价格"""
print(f"[DEBUG] convert_returns_to_price: 开始转换,输入数据形状: {returns_df.shape}")
print(f"[DEBUG] convert_returns_to_price: 初始价格: {initial_prices}")
print(f"[DEBUG] convert_returns_to_price: 输入相对变化率前3行:")
print(returns_df[price_cols].head(3))
price_df = returns_df.copy()
for col in price_cols:
if col in price_df.columns:
print(f"[DEBUG] convert_returns_to_price: 转换列 {col}")
# 正确的转换逻辑:从初始价格开始,逐步累积计算绝对价格
# 第一个点price = initial_price * (1 + return_1)
# 第二个点price = price_1 * (1 + return_2)
# 第三个点price = price_2 * (1 + return_3)
# 以此类推...
# 使用cumprod来累积计算
price_df[col] = initial_prices[col] * (1 + returns_df[col]).cumprod()
print(f"[DEBUG] convert_returns_to_price: 列 {col} 转换完成前3个值:")
print(f" 初始价格: {initial_prices[col]}")
print(f" 相对变化率: {returns_df[col].head(3).tolist()}")
print(f" 累积因子: {(1 + returns_df[col]).cumprod().head(3).tolist()}")
print(f" 最终价格: {price_df[col].head(3).tolist()}")
print(f"[DEBUG] convert_returns_to_price: 转换完成输出绝对价格前3行:")
print(price_df[price_cols].head(3))
return price_df
def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params): def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params):
"""保存预测结果到文件""" """Save prediction results to file"""
try: try:
# 创建预测结果目录 # Create prediction results directory
results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results') results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results')
os.makedirs(results_dir, exist_ok=True) os.makedirs(results_dir, exist_ok=True)
# 生成文件名 # Generate filename
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'prediction_{timestamp}.json' filename = f'prediction_{timestamp}.json'
filepath = os.path.join(results_dir, filename) filepath = os.path.join(results_dir, filename)
# 准备保存的数据 # Prepare data for saving
save_data = { save_data = {
'timestamp': datetime.datetime.now().isoformat(), 'timestamp': datetime.datetime.now().isoformat(),
'file_path': file_path, 'file_path': file_path,
@ -215,14 +161,14 @@ def save_prediction_results(file_path, prediction_type, prediction_results, actu
'analysis': {} 'analysis': {}
} }
# 如果有实际数据,进行对比分析 # If actual data exists, perform comparison analysis
if actual_data and len(actual_data) > 0: if actual_data and len(actual_data) > 0:
# 计算连续性分析 # Calculate continuity analysis
if len(prediction_results) > 0 and len(actual_data) > 0: if len(prediction_results) > 0 and len(actual_data) > 0:
last_pred = prediction_results[0] # 第一个预测点 last_pred = prediction_results[0] # First prediction point
first_actual = actual_data[0] # 第一个实际点 first_actual = actual_data[0] # First actual point
save_data['analysis']['continuity'] = { save_data['analysis']['continuity'] = {
'last_prediction': { 'last_prediction': {
'open': last_pred['open'], 'open': last_pred['open'],
'high': last_pred['high'], 'high': last_pred['high'],
@ -249,51 +195,51 @@ def save_prediction_results(file_path, prediction_type, prediction_results, actu
} }
} }
# 保存到文件 # Save to file
with open(filepath, 'w', encoding='utf-8') as f: with open(filepath, 'w', encoding='utf-8') as f:
json.dump(save_data, f, indent=2, ensure_ascii=False) json.dump(save_data, f, indent=2, ensure_ascii=False)
print(f"预测结果已保存到: {filepath}") print(f"Prediction results saved to: {filepath}")
return filepath return filepath
except Exception as e: except Exception as e:
print(f"保存预测结果失败: {e}") print(f"Failed to save prediction results: {e}")
return None return None
def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0): def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
"""创建预测图表""" """Create prediction chart"""
# 使用指定的历史数据起始位置而不是总是从df的开头开始 # Use specified historical data start position, not always from the beginning of df
if historical_start_idx + lookback + pred_len <= len(df): if historical_start_idx + lookback + pred_len <= len(df):
# 显示指定位置开始的lookback个历史点 + pred_len个预测点 # Display lookback historical points + pred_len prediction points starting from specified position
historical_df = df.iloc[historical_start_idx:historical_start_idx+lookback] historical_df = df.iloc[historical_start_idx:historical_start_idx+lookback]
prediction_range = range(historical_start_idx+lookback, historical_start_idx+lookback+pred_len) prediction_range = range(historical_start_idx+lookback, historical_start_idx+lookback+pred_len)
else: else:
# 如果数据不够,调整到可用的最大范围 # If data is insufficient, adjust to maximum available range
available_lookback = min(lookback, len(df) - historical_start_idx) available_lookback = min(lookback, len(df) - historical_start_idx)
available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback)) available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
historical_df = df.iloc[historical_start_idx:historical_start_idx+available_lookback] historical_df = df.iloc[historical_start_idx:historical_start_idx+available_lookback]
prediction_range = range(historical_start_idx+available_lookback, historical_start_idx+available_lookback+available_pred_len) prediction_range = range(historical_start_idx+available_lookback, historical_start_idx+available_lookback+available_pred_len)
# 创建图表 # Create chart
fig = go.Figure() fig = go.Figure()
# 添加历史数据K线图 # Add historical data (candlestick chart)
fig.add_trace(go.Candlestick( fig.add_trace(go.Candlestick(
x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index, x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index,
open=historical_df['open'], open=historical_df['open'],
high=historical_df['high'], high=historical_df['high'],
low=historical_df['low'], low=historical_df['low'],
close=historical_df['close'], close=historical_df['close'],
name='历史数据 (400个数据点)', name='Historical Data (400 data points)',
increasing_line_color='#26A69A', increasing_line_color='#26A69A',
decreasing_line_color='#EF5350' decreasing_line_color='#EF5350'
)) ))
# 添加预测数据K线图 # Add prediction data (candlestick chart)
if pred_df is not None and len(pred_df) > 0: if pred_df is not None and len(pred_df) > 0:
# 计算预测数据的时间戳 - 确保与历史数据连续 # Calculate prediction data timestamps - ensure continuity with historical data
if 'timestamps' in df.columns and len(historical_df) > 0: if 'timestamps' in df.columns and len(historical_df) > 0:
# 从历史数据的最后一个时间点开始,按相同的时间间隔创建预测时间戳 # Start from the last timestamp of historical data, create prediction timestamps with the same time interval
last_timestamp = historical_df['timestamps'].iloc[-1] last_timestamp = historical_df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1) time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
@ -303,7 +249,7 @@ def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, his
freq=time_diff freq=time_diff
) )
else: else:
# 如果没有时间戳,使用索引 # If no timestamps, use index
pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df)) pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
fig.add_trace(go.Candlestick( fig.add_trace(go.Candlestick(
@ -312,20 +258,20 @@ def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, his
high=pred_df['high'], high=pred_df['high'],
low=pred_df['low'], low=pred_df['low'],
close=pred_df['close'], close=pred_df['close'],
name='预测数据 (120个数据点)', name='Prediction Data (120 data points)',
increasing_line_color='#66BB6A', increasing_line_color='#66BB6A',
decreasing_line_color='#FF7043' decreasing_line_color='#FF7043'
)) ))
# 添加实际数据用于对比(如果存在) # Add actual data for comparison (if exists)
if actual_df is not None and len(actual_df) > 0: if actual_df is not None and len(actual_df) > 0:
# 实际数据应该与预测数据在同一个时间段 # Actual data should be in the same time period as prediction data
if 'timestamps' in df.columns: if 'timestamps' in df.columns:
# 实际数据应该使用与预测数据相同的时间戳,确保时间对齐 # Actual data should use the same timestamps as prediction data to ensure time alignment
if 'pred_timestamps' in locals(): if 'pred_timestamps' in locals():
actual_timestamps = pred_timestamps actual_timestamps = pred_timestamps
else: else:
# 如果没有预测时间戳,从历史数据最后一个时间点开始计算 # If no prediction timestamps, calculate from the last timestamp of historical data
if len(historical_df) > 0: if len(historical_df) > 0:
last_timestamp = historical_df['timestamps'].iloc[-1] last_timestamp = historical_df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1) time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
@ -345,24 +291,24 @@ def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, his
high=actual_df['high'], high=actual_df['high'],
low=actual_df['low'], low=actual_df['low'],
close=actual_df['close'], close=actual_df['close'],
name='实际数据 (120个数据点)', name='Actual Data (120 data points)',
increasing_line_color='#FF9800', increasing_line_color='#FF9800',
decreasing_line_color='#F44336' decreasing_line_color='#F44336'
)) ))
# 更新布局 # Update layout
fig.update_layout( fig.update_layout(
title='Kronos 金融预测结果 - 400个历史点 + 120个预测点 vs 120个实际点', title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',
xaxis_title='时间', xaxis_title='Time',
yaxis_title='价格', yaxis_title='Price',
template='plotly_white', template='plotly_white',
height=600, height=600,
showlegend=True showlegend=True
) )
# 确保x轴时间连续 # Ensure x-axis time continuity
if 'timestamps' in historical_df.columns: if 'timestamps' in historical_df.columns:
# 获取所有时间戳并排序 # Get all timestamps and sort them
all_timestamps = [] all_timestamps = []
if len(historical_df) > 0: if len(historical_df) > 0:
all_timestamps.extend(historical_df['timestamps']) all_timestamps.extend(historical_df['timestamps'])
@ -383,56 +329,56 @@ def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, his
@app.route('/') @app.route('/')
def index(): def index():
"""主页""" """Home page"""
return render_template('index.html') return render_template('index.html')
@app.route('/api/data-files') @app.route('/api/data-files')
def get_data_files(): def get_data_files():
"""获取可用的数据文件列表""" """Get available data file list"""
data_files = load_data_files() data_files = load_data_files()
return jsonify(data_files) return jsonify(data_files)
@app.route('/api/load-data', methods=['POST']) @app.route('/api/load-data', methods=['POST'])
def load_data(): def load_data():
"""加载数据文件""" """Load data file"""
try: try:
data = request.get_json() data = request.get_json()
file_path = data.get('file_path') file_path = data.get('file_path')
if not file_path: if not file_path:
return jsonify({'error': '文件路径不能为空'}), 400 return jsonify({'error': 'File path cannot be empty'}), 400
df, error = load_data_file(file_path) df, error = load_data_file(file_path)
if error: if error:
return jsonify({'error': error}), 400 return jsonify({'error': error}), 400
# 检测数据的时间频率 # Detect data time frequency
def detect_timeframe(df): def detect_timeframe(df):
if len(df) < 2: if len(df) < 2:
return "未知" return "Unknown"
time_diffs = [] time_diffs = []
for i in range(1, min(10, len(df))): # 检查前10个时间差 for i in range(1, min(10, len(df))): # Check first 10 time differences
diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i-1] diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i-1]
time_diffs.append(diff) time_diffs.append(diff)
if not time_diffs: if not time_diffs:
return "未知" return "Unknown"
# 计算平均时间差 # Calculate average time difference
avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs) avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs)
# 转换为可读格式 # Convert to readable format
if avg_diff < pd.Timedelta(minutes=1): if avg_diff < pd.Timedelta(minutes=1):
return f"{avg_diff.total_seconds():.0f}" return f"{avg_diff.total_seconds():.0f} seconds"
elif avg_diff < pd.Timedelta(hours=1): elif avg_diff < pd.Timedelta(hours=1):
return f"{avg_diff.total_seconds() / 60:.0f}分钟" return f"{avg_diff.total_seconds() / 60:.0f} minutes"
elif avg_diff < pd.Timedelta(days=1): elif avg_diff < pd.Timedelta(days=1):
return f"{avg_diff.total_seconds() / 3600:.0f}小时" return f"{avg_diff.total_seconds() / 3600:.0f} hours"
else: else:
return f"{avg_diff.days}" return f"{avg_diff.days} days"
# 返回数据信息 # Return data information
data_info = { data_info = {
'rows': len(df), 'rows': len(df),
'columns': list(df.columns), 'columns': list(df.columns),
@ -449,111 +395,89 @@ def load_data():
return jsonify({ return jsonify({
'success': True, 'success': True,
'data_info': data_info, 'data_info': data_info,
'message': f'成功加载数据,共 {len(df)}' 'message': f'Successfully loaded data, total {len(df)} rows'
}) })
except Exception as e: except Exception as e:
return jsonify({'error': f'加载数据失败: {str(e)}'}), 500 return jsonify({'error': f'Failed to load data: {str(e)}'}), 500
@app.route('/api/predict', methods=['POST']) @app.route('/api/predict', methods=['POST'])
def predict(): def predict():
"""进行预测""" """Perform prediction"""
try: try:
data = request.get_json() data = request.get_json()
file_path = data.get('file_path') file_path = data.get('file_path')
lookback = int(data.get('lookback', 400)) lookback = int(data.get('lookback', 400))
pred_len = int(data.get('pred_len', 120)) pred_len = int(data.get('pred_len', 120))
# 获取预测质量参数 # Get prediction quality parameters
temperature = float(data.get('temperature', 1.0)) temperature = float(data.get('temperature', 1.0))
top_p = float(data.get('top_p', 0.9)) top_p = float(data.get('top_p', 0.9))
sample_count = int(data.get('sample_count', 1)) sample_count = int(data.get('sample_count', 1))
if not file_path: if not file_path:
return jsonify({'error': '文件路径不能为空'}), 400 return jsonify({'error': 'File path cannot be empty'}), 400
# 加载数据 # Load data
df, error = load_data_file(file_path) df, error = load_data_file(file_path)
if error: if error:
return jsonify({'error': error}), 400 return jsonify({'error': error}), 400
if len(df) < lookback: if len(df) < lookback:
return jsonify({'error': f'数据长度不足,需要至少 {lookback} 行数据'}), 400 return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
# 进行预测 # Perform prediction
if MODEL_AVAILABLE and predictor is not None: if MODEL_AVAILABLE and predictor is not None:
try: try:
# 使用真实的Kronos模型 # Use real Kronos model
# 只使用必要的列OHLCV不包含amount # Only use necessary columns: OHLCV, excluding amount
required_cols = ['open', 'high', 'low', 'close'] required_cols = ['open', 'high', 'low', 'close']
if 'volume' in df.columns: if 'volume' in df.columns:
required_cols.append('volume') required_cols.append('volume')
# 处理时间段选择 # Process time period selection
start_date = data.get('start_date') start_date = data.get('start_date')
if start_date: if start_date:
# 自定义时间段 - 修复逻辑:使用选择的窗口内的数据 # Custom time period - fix logic: use data within selected window
start_dt = pd.to_datetime(start_date) start_dt = pd.to_datetime(start_date)
# 找到开始时间之后的数据 # Find data after start time
mask = df['timestamps'] >= start_dt mask = df['timestamps'] >= start_dt
time_range_df = df[mask] time_range_df = df[mask]
# 确保有足够的数据:lookback + pred_len # Ensure sufficient data: lookback + pred_len
if len(time_range_df) < lookback + pred_len: if len(time_range_df) < lookback + pred_len:
return jsonify({'error': f'从开始时间 {start_dt.strftime("%Y-%m-%d %H:%M")} 开始的数据不足,需要至少 {lookback + pred_len} 个数据点,当前只有 {len(time_range_df)}'}), 400 return jsonify({'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
# 使用选择的窗口内的前lookback个数据点进行预测 # Use first lookback data points within selected window for prediction
x_df = time_range_df.iloc[:lookback][required_cols] x_df = time_range_df.iloc[:lookback][required_cols]
x_timestamp = time_range_df.iloc[:lookback]['timestamps'] x_timestamp = time_range_df.iloc[:lookback]['timestamps']
# 使用选择的窗口内的后pred_len个数据点作为实际值 # Use last pred_len data points within selected window as actual values
y_timestamp = time_range_df.iloc[lookback:lookback+pred_len]['timestamps'] y_timestamp = time_range_df.iloc[lookback:lookback+pred_len]['timestamps']
# 计算实际的时间段长度 # Calculate actual time period length
start_timestamp = time_range_df['timestamps'].iloc[0] start_timestamp = time_range_df['timestamps'].iloc[0]
end_timestamp = time_range_df['timestamps'].iloc[lookback+pred_len-1] end_timestamp = time_range_df['timestamps'].iloc[lookback+pred_len-1]
time_span = end_timestamp - start_timestamp time_span = end_timestamp - start_timestamp
prediction_type = f"Kronos模型预测 (选择的窗口内:前{lookback}个数据点预测,后{pred_len}个数据点对比,时间跨度: {time_span})" prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, last {pred_len} data points for comparison, time span: {time_span})"
else: else:
# 使用最新数据 # Use latest data
x_df = df.iloc[:lookback][required_cols] x_df = df.iloc[:lookback][required_cols]
x_timestamp = df.iloc[:lookback]['timestamps'] x_timestamp = df.iloc[:lookback]['timestamps']
y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps'] y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']
prediction_type = "Kronos模型预测 (最新数据)" prediction_type = "Kronos model prediction (latest data)"
print(f"[DEBUG] 步骤1: 检查并转换时间戳格式...") # Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model
# 确保时间戳是Series格式不是DatetimeIndex避免Kronos模型的.dt属性错误
if isinstance(x_timestamp, pd.DatetimeIndex): if isinstance(x_timestamp, pd.DatetimeIndex):
print(f"[DEBUG] 将x_timestamp从DatetimeIndex转换为Series")
x_timestamp = pd.Series(x_timestamp, name='timestamps') x_timestamp = pd.Series(x_timestamp, name='timestamps')
if isinstance(y_timestamp, pd.DatetimeIndex): if isinstance(y_timestamp, pd.DatetimeIndex):
print(f"[DEBUG] 将y_timestamp从DatetimeIndex转换为Series")
y_timestamp = pd.Series(y_timestamp, name='timestamps') y_timestamp = pd.Series(y_timestamp, name='timestamps')
print(f"[DEBUG] 步骤2: 准备价格转换,保存初始价格...") pred_df = predictor.predict(
# 解决Kronos模型连续性问题将绝对价格转换为相对变化率 df=x_df,
original_x_df = x_df.copy()
initial_prices = {
'open': x_df['open'].iloc[0],
'high': x_df['high'].iloc[0],
'low': x_df['low'].iloc[0],
'close': x_df['close'].iloc[0]
}
print(f"[DEBUG] 初始价格: {initial_prices}")
print(f"[DEBUG] 步骤3: 将绝对价格转换为相对变化率...")
# 转换为相对变化率
x_df_returns = convert_price_to_returns(x_df, ['open', 'high', 'low', 'close'])
print(f"[DEBUG] 转换完成输入数据前5行相对变化率:")
print(x_df_returns[['open', 'high', 'low', 'close']].head())
print(f"[DEBUG] 步骤4: 使用转换后的数据进行Kronos模型预测...")
# 使用转换后的数据进行预测
pred_df_returns = predictor.predict(
df=x_df_returns,
x_timestamp=x_timestamp, x_timestamp=x_timestamp,
y_timestamp=y_timestamp, y_timestamp=y_timestamp,
pred_len=pred_len, pred_len=pred_len,
@ -561,49 +485,28 @@ def predict():
top_p=top_p, top_p=top_p,
sample_count=sample_count sample_count=sample_count
) )
print(f"[DEBUG] 预测完成预测结果前5行相对变化率:")
print(pred_df_returns[['open', 'high', 'low', 'close']].head())
print(f"[DEBUG] 步骤5: 将预测结果转换回绝对价格...")
# 将预测结果转换回绝对价格
pred_df = convert_returns_to_price(pred_df_returns, initial_prices, ['open', 'high', 'low', 'close'])
print(f"[DEBUG] 转换完成预测结果前5行绝对价格:")
print(pred_df[['open', 'high', 'low', 'close']].head())
print(f"[DEBUG] 步骤6: 处理volume和amount列...")
# 保持volume列不变如果存在
if 'volume' in pred_df.columns:
pred_df['volume'] = pred_df_returns['volume']
print(f"[DEBUG] 已复制volume列")
if 'amount' in pred_df.columns:
pred_df['amount'] = pred_df_returns['amount']
print(f"[DEBUG] 已复制amount列")
print(f"[DEBUG] 步骤7: 检查最终预测结果...")
print(f"[DEBUG] 最终预测结果形状: {pred_df.shape}")
print(f"[DEBUG] 最终预测结果列: {list(pred_df.columns)}")
except Exception as e: except Exception as e:
return jsonify({'error': f'Kronos模型预测失败: {str(e)}'}), 500 return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500
else: else:
return jsonify({'error': 'Kronos模型未加载,请先加载模型'}), 400 return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
# 准备实际数据用于对比(如果存在) # Prepare actual data for comparison (if exists)
actual_data = [] actual_data = []
actual_df = None actual_df = None
if start_date: # 自定义时间段 if start_date: # Custom time period
# 修复逻辑:使用选择的窗口内的数据 # Fix logic: use data within selected window
# 预测使用的是选择的窗口内的前400个数据点 # Prediction uses first 400 data points within selected window
# 实际数据应该是选择的窗口内的后120个数据点 # Actual data should be last 120 data points within selected window
start_dt = pd.to_datetime(start_date) start_dt = pd.to_datetime(start_date)
# 找到从start_date开始的数据 # Find data starting from start_date
mask = df['timestamps'] >= start_dt mask = df['timestamps'] >= start_dt
time_range_df = df[mask] time_range_df = df[mask]
if len(time_range_df) >= lookback + pred_len: if len(time_range_df) >= lookback + pred_len:
# 获取选择的窗口内的后120个数据点作为实际值 # Get last 120 data points within selected window as actual values
actual_df = time_range_df.iloc[lookback:lookback+pred_len] actual_df = time_range_df.iloc[lookback:lookback+pred_len]
for i, (_, row) in enumerate(actual_df.iterrows()): for i, (_, row) in enumerate(actual_df.iterrows()):
@ -616,9 +519,9 @@ def predict():
'volume': float(row['volume']) if 'volume' in row else 0, 'volume': float(row['volume']) if 'volume' in row else 0,
'amount': float(row['amount']) if 'amount' in row else 0 'amount': float(row['amount']) if 'amount' in row else 0
}) })
else: # 最新数据 else: # Latest data
# 预测使用的是前400个数据点 # Prediction uses first 400 data points
# 实际数据应该是400个数据点之后的120个数据点 # Actual data should be 120 data points after first 400 data points
if len(df) >= lookback + pred_len: if len(df) >= lookback + pred_len:
actual_df = df.iloc[lookback:lookback+pred_len] actual_df = df.iloc[lookback:lookback+pred_len]
for i, (_, row) in enumerate(actual_df.iterrows()): for i, (_, row) in enumerate(actual_df.iterrows()):
@ -632,28 +535,28 @@ def predict():
'amount': float(row['amount']) if 'amount' in row else 0 'amount': float(row['amount']) if 'amount' in row else 0
}) })
# 创建图表 - 传递历史数据的起始位置 # Create chart - pass historical data start position
if start_date: if start_date:
# 自定义时间段找到历史数据在原始df中的起始位置 # Custom time period: find starting position of historical data in original df
start_dt = pd.to_datetime(start_date) start_dt = pd.to_datetime(start_date)
mask = df['timestamps'] >= start_dt mask = df['timestamps'] >= start_dt
historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0 historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
else: else:
# 最新数据:从开头开始 # Latest data: start from beginning
historical_start_idx = 0 historical_start_idx = 0
chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx) chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
# 准备预测结果数据 - 修复时间戳计算逻辑 # Prepare prediction result data - fix timestamp calculation logic
if 'timestamps' in df.columns: if 'timestamps' in df.columns:
if start_date: if start_date:
# 自定义时间段:使用选择的窗口数据计算时间戳 # Custom time period: use selected window data to calculate timestamps
start_dt = pd.to_datetime(start_date) start_dt = pd.to_datetime(start_date)
mask = df['timestamps'] >= start_dt mask = df['timestamps'] >= start_dt
time_range_df = df[mask] time_range_df = df[mask]
if len(time_range_df) >= lookback: if len(time_range_df) >= lookback:
# 从选择的窗口的最后一个时间点开始计算预测时间戳 # Calculate prediction timestamps starting from last time point of selected window
last_timestamp = time_range_df['timestamps'].iloc[lookback-1] last_timestamp = time_range_df['timestamps'].iloc[lookback-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
future_timestamps = pd.date_range( future_timestamps = pd.date_range(
@ -664,7 +567,7 @@ def predict():
else: else:
future_timestamps = [] future_timestamps = []
else: else:
# 最新数据:从整个数据文件的最后时间点开始计算 # Latest data: calculate from last time point of entire data file
last_timestamp = df['timestamps'].iloc[-1] last_timestamp = df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
future_timestamps = pd.date_range( future_timestamps = pd.date_range(
@ -687,7 +590,7 @@ def predict():
'amount': float(row['amount']) if 'amount' in row else 0 'amount': float(row['amount']) if 'amount' in row else 0
}) })
# 保存预测结果到文件 # Save prediction results to file
try: try:
save_prediction_results( save_prediction_results(
file_path=file_path, file_path=file_path,
@ -705,7 +608,7 @@ def predict():
} }
) )
except Exception as e: except Exception as e:
print(f"保存预测结果失败: {e}") print(f"Failed to save prediction results: {e}")
return jsonify({ return jsonify({
'success': True, 'success': True,
@ -714,40 +617,40 @@ def predict():
'prediction_results': prediction_results, 'prediction_results': prediction_results,
'actual_data': actual_data, 'actual_data': actual_data,
'has_comparison': len(actual_data) > 0, 'has_comparison': len(actual_data) > 0,
'message': f'预测完成,生成了 {pred_len} 个预测点' + (f',包含 {len(actual_data)} 个实际数据点用于对比' if len(actual_data) > 0 else '') 'message': f'Prediction completed, generated {pred_len} prediction points' + (f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
}) })
except Exception as e: except Exception as e:
return jsonify({'error': f'预测失败: {str(e)}'}), 500 return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
@app.route('/api/load-model', methods=['POST']) @app.route('/api/load-model', methods=['POST'])
def load_model(): def load_model():
"""加载Kronos模型""" """Load Kronos model"""
global tokenizer, model, predictor global tokenizer, model, predictor
try: try:
if not MODEL_AVAILABLE: if not MODEL_AVAILABLE:
return jsonify({'error': 'Kronos模型库不可用'}), 400 return jsonify({'error': 'Kronos model library not available'}), 400
data = request.get_json() data = request.get_json()
model_key = data.get('model_key', 'kronos-small') model_key = data.get('model_key', 'kronos-small')
device = data.get('device', 'cpu') device = data.get('device', 'cpu')
if model_key not in AVAILABLE_MODELS: if model_key not in AVAILABLE_MODELS:
return jsonify({'error': f'不支持的模型: {model_key}'}), 400 return jsonify({'error': f'Unsupported model: {model_key}'}), 400
model_config = AVAILABLE_MODELS[model_key] model_config = AVAILABLE_MODELS[model_key]
# 加载tokenizer和模型 # Load tokenizer and model
tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id']) tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id'])
model = Kronos.from_pretrained(model_config['model_id']) model = Kronos.from_pretrained(model_config['model_id'])
# 创建predictor # Create predictor
predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length']) predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length'])
return jsonify({ return jsonify({
'success': True, 'success': True,
'message': f'模型加载成功: {model_config["name"]} ({model_config["params"]}) on {device}', 'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',
'model_info': { 'model_info': {
'name': model_config['name'], 'name': model_config['name'],
'params': model_config['params'], 'params': model_config['params'],
@ -757,11 +660,11 @@ def load_model():
}) })
except Exception as e: except Exception as e:
return jsonify({'error': f'模型加载失败: {str(e)}'}), 500 return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
@app.route('/api/available-models') @app.route('/api/available-models')
def get_available_models(): def get_available_models():
"""获取可用的模型列表""" """Get available model list"""
return jsonify({ return jsonify({
'models': AVAILABLE_MODELS, 'models': AVAILABLE_MODELS,
'model_available': MODEL_AVAILABLE 'model_available': MODEL_AVAILABLE
@ -769,13 +672,13 @@ def get_available_models():
@app.route('/api/model-status') @app.route('/api/model-status')
def get_model_status(): def get_model_status():
"""获取模型状态""" """Get model status"""
if MODEL_AVAILABLE: if MODEL_AVAILABLE:
if predictor is not None: if predictor is not None:
return jsonify({ return jsonify({
'available': True, 'available': True,
'loaded': True, 'loaded': True,
'message': 'Kronos模型已加载并可用', 'message': 'Kronos model loaded and available',
'current_model': { 'current_model': {
'name': predictor.model.__class__.__name__, 'name': predictor.model.__class__.__name__,
'device': str(next(predictor.model.parameters()).device) 'device': str(next(predictor.model.parameters()).device)
@ -785,21 +688,21 @@ def get_model_status():
return jsonify({ return jsonify({
'available': True, 'available': True,
'loaded': False, 'loaded': False,
'message': 'Kronos模型可用但未加载' 'message': 'Kronos model available but not loaded'
}) })
else: else:
return jsonify({ return jsonify({
'available': False, 'available': False,
'loaded': False, 'loaded': False,
'message': 'Kronos模型库不可用,请安装相关依赖' 'message': 'Kronos model library not available, please install related dependencies'
}) })
if __name__ == '__main__': if __name__ == '__main__':
print("启动Kronos Web UI...") print("Starting Kronos Web UI...")
print(f"模型可用性: {MODEL_AVAILABLE}") print(f"Model availability: {MODEL_AVAILABLE}")
if MODEL_AVAILABLE: if MODEL_AVAILABLE:
print("提示: 可以通过 /api/load-model 接口加载Kronos模型") print("Tip: You can load Kronos model through /api/load-model endpoint")
else: else:
print("提示: 将使用模拟数据进行演示") print("Tip: Will use simulated data for demonstration")
app.run(debug=True, host='0.0.0.0', port=7070) app.run(debug=True, host='0.0.0.0', port=7070)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Kronos Web UI 启动脚本 Kronos Web UI startup script
""" """
import os import os
@ -10,80 +10,80 @@ import webbrowser
import time import time
def check_dependencies(): def check_dependencies():
"""检查依赖是否安装""" """Check if dependencies are installed"""
try: try:
import flask import flask
import flask_cors import flask_cors
import pandas import pandas
import numpy import numpy
import plotly import plotly
print("所有依赖已安装") print("All dependencies installed")
return True return True
except ImportError as e: except ImportError as e:
print(f"缺少依赖: {e}") print(f"Missing dependency: {e}")
print("请运行: pip install -r requirements.txt") print("Please run: pip install -r requirements.txt")
return False return False
def install_dependencies(): def install_dependencies():
"""安装依赖""" """Install dependencies"""
print("正在安装依赖...") print("Installing dependencies...")
try: try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"]) subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
print("依赖安装完成") print("Dependencies installation completed")
return True return True
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
print("依赖安装失败") print("Dependencies installation failed")
return False return False
def main(): def main():
"""主函数""" """Main function"""
print("🚀 启动 Kronos Web UI...") print("🚀 Starting Kronos Web UI...")
print("=" * 50) print("=" * 50)
# 检查依赖 # Check dependencies
if not check_dependencies(): if not check_dependencies():
print("\n是否自动安装依赖? (y/n): ", end="") print("\nAuto-install dependencies? (y/n): ", end="")
if input().lower() == 'y': if input().lower() == 'y':
if not install_dependencies(): if not install_dependencies():
return return
else: else:
print("请手动安装依赖后重试") print("Please manually install dependencies and retry")
return return
# 检查模型可用性 # Check model availability
try: try:
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model import Kronos, KronosTokenizer, KronosPredictor from model import Kronos, KronosTokenizer, KronosPredictor
print("✅ Kronos模型库可用") print("✅ Kronos model library available")
model_available = True model_available = True
except ImportError: except ImportError:
print("⚠️ Kronos模型库不可用,将使用模拟预测") print("⚠️ Kronos model library not available, will use simulated prediction")
model_available = False model_available = False
# 启动Flask应用 # Start Flask application
print("\n🌐 启动Web服务器...") print("\n🌐 Starting Web server...")
# 设置环境变量 # Set environment variables
os.environ['FLASK_APP'] = 'app.py' os.environ['FLASK_APP'] = 'app.py'
os.environ['FLASK_ENV'] = 'development' os.environ['FLASK_ENV'] = 'development'
# 启动服务器 # Start server
try: try:
from app import app from app import app
print("✅ Web服务器启动成功!") print("✅ Web server started successfully!")
print(f"🌐 访问地址: http://localhost:7070") print(f"🌐 Access URL: http://localhost:7070")
print("💡 提示: 按 Ctrl+C 停止服务器") print("💡 Tip: Press Ctrl+C to stop server")
# 自动打开浏览器 # Auto-open browser
time.sleep(2) time.sleep(2)
webbrowser.open('http://localhost:7070') webbrowser.open('http://localhost:7070')
# 启动Flask应用 # Start Flask application
app.run(debug=True, host='0.0.0.0', port=7070) app.run(debug=True, host='0.0.0.0', port=7070)
except Exception as e: except Exception as e:
print(f"启动失败: {e}") print(f"Startup failed: {e}")
print("请检查端口7070是否被占用") print("Please check if port 7070 is occupied")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,40 +1,40 @@
#!/bin/bash #!/bin/bash
# Kronos Web UI 启动脚本 # Kronos Web UI startup script
echo "🚀 启动 Kronos Web UI..." echo "🚀 Starting Kronos Web UI..."
echo "================================" echo "================================"
# 检查Python是否安装 # Check if Python is installed
if ! command -v python3 &> /dev/null; then if ! command -v python3 &> /dev/null; then
echo "❌ Python3 未安装请先安装Python3" echo "❌ Python3 not installed, please install Python3 first"
exit 1 exit 1
fi fi
# 检查是否在正确的目录 # Check if in correct directory
if [ ! -f "app.py" ]; then if [ ! -f "app.py" ]; then
echo "❌ 请在webui目录下运行此脚本" echo "❌ Please run this script in the webui directory"
exit 1 exit 1
fi fi
# 检查依赖 # Check dependencies
echo "📦 检查依赖..." echo "📦 Checking dependencies..."
if ! python3 -c "import flask, flask_cors, pandas, numpy, plotly" &> /dev/null; then if ! python3 -c "import flask, flask_cors, pandas, numpy, plotly" &> /dev/null; then
echo "⚠️ 缺少依赖,正在安装..." echo "⚠️ Missing dependencies, installing..."
pip3 install -r requirements.txt pip3 install -r requirements.txt
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "❌ 依赖安装失败" echo "❌ Dependencies installation failed"
exit 1 exit 1
fi fi
echo "✅ 依赖安装完成" echo "✅ Dependencies installation completed"
else else
echo "✅ 所有依赖已安装" echo "✅ All dependencies installed"
fi fi
# 启动应用 # Start application
echo "🌐 启动Web服务器..." echo "🌐 Starting Web server..."
echo "访问地址: http://localhost:7070" echo "Access URL: http://localhost:7070"
echo "按 Ctrl+C 停止服务器" echo "Press Ctrl+C to stop server"
echo "" echo ""
python3 app.py python3 app.py

View File

@ -3,7 +3,7 @@
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Kronos 金融预测 Web UI</title> <title>Kronos Financial Prediction Web UI</title>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script> <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/axios/dist/axios.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/axios/dist/axios.min.js"></script>
<style> <style>
@ -93,7 +93,7 @@
border-color: #667eea; border-color: #667eea;
} }
/* 预测质量参数样式 */ /* Prediction quality parameter styles */
.form-group input[type="range"] { .form-group input[type="range"] {
width: 70%; width: 70%;
margin-right: 10px; margin-right: 10px;
@ -227,7 +227,7 @@
color: #2d3748; color: #2d3748;
} }
/* 时间窗口选择器样式 */ /* Time window selector styles */
.time-window-container { .time-window-container {
background: #f7fafc; background: #f7fafc;
border: 1px solid #e2e8f0; border: 1px solid #e2e8f0;
@ -301,7 +301,7 @@
margin-top: 5px; margin-top: 5px;
} }
/* 对比分析样式 */ /* Comparison analysis styles */
.comparison-section { .comparison-section {
background: #f7fafc; background: #f7fafc;
border: 1px solid #e2e8f0; border: 1px solid #e2e8f0;
@ -441,79 +441,79 @@
<body> <body>
<div class="container"> <div class="container">
<div class="header"> <div class="header">
<h1>🚀 Kronos 金融预测 Web UI</h1> <h1>🚀 Kronos Financial Prediction Web UI</h1>
<p>基于AI的金融K线数据预测分析平台</p> <p>AI-based financial K-line data prediction analysis platform</p>
</div> </div>
<div class="main-content"> <div class="main-content">
<div class="control-panel"> <div class="control-panel">
<h2>🎯 控制面板</h2> <h2>🎯 Control Panel</h2>
<!-- 模型选择 --> <!-- Model Selection -->
<div class="form-group"> <div class="form-group">
<label for="model-select">选择模型:</label> <label for="model-select">Select Model:</label>
<select id="model-select"> <select id="model-select">
<option value="">请先加载可用模型</option> <option value="">Please load available models first</option>
</select> </select>
<small class="form-text">选择要使用的Kronos模型</small> <small class="form-text">Select the Kronos model to use</small>
</div> </div>
<!-- 设备选择 --> <!-- Device Selection -->
<div class="form-group"> <div class="form-group">
<label for="device-select">选择设备:</label> <label for="device-select">Select Device:</label>
<select id="device-select"> <select id="device-select">
<option value="cpu">CPU</option> <option value="cpu">CPU</option>
<option value="cuda">CUDA (NVIDIA GPU)</option> <option value="cuda">CUDA (NVIDIA GPU)</option>
<option value="mps">MPS (Apple Silicon)</option> <option value="mps">MPS (Apple Silicon)</option>
</select> </select>
<small class="form-text">选择模型运行的设备</small> <small class="form-text">Select the device to run the model on</small>
</div> </div>
<!-- 模型状态 --> <!-- Model Status -->
<div id="model-status" class="status info" style="display: none;"> <div id="model-status" class="status info" style="display: none;">
模型状态信息 Model status information
</div> </div>
<!-- 加载模型按钮 --> <!-- Load Model Button -->
<button id="load-model-btn" class="btn btn-secondary"> <button id="load-model-btn" class="btn btn-secondary">
🔄 加载模型 🔄 Load Model
</button> </button>
<hr style="margin: 20px 0; border: 1px solid #e2e8f0;"> <hr style="margin: 20px 0; border: 1px solid #e2e8f0;">
<!-- 数据文件选择 --> <!-- Data File Selection -->
<div class="form-group"> <div class="form-group">
<label for="data-file-select">选择数据文件:</label> <label for="data-file-select">Select Data File:</label>
<select id="data-file-select"> <select id="data-file-select">
<option value="">请先加载数据文件列表</option> <option value="">Please load data file list first</option>
</select> </select>
<small class="form-text">从data目录选择K线数据文件</small> <small class="form-text">Select K-line data file from data directory</small>
</div> </div>
<button id="load-data-btn" class="btn btn-secondary"> <button id="load-data-btn" class="btn btn-secondary">
📁 加载数据 📁 Load Data
</button> </button>
<!-- 数据信息显示 --> <!-- Data Information Display -->
<div id="data-info" class="data-info" style="display: none;"> <div id="data-info" class="data-info" style="display: none;">
<h3>📊 数据信息</h3> <h3>📊 Data Information</h3>
<p><strong>行数:</strong> <span id="data-rows">-</span></p> <p><strong>Rows:</strong> <span id="data-rows">-</span></p>
<p><strong>列数:</strong> <span id="data-cols">-</span></p> <p><strong>Columns:</strong> <span id="data-cols">-</span></p>
<p><strong>时间范围:</strong> <span id="data-time-range">-</span></p> <p><strong>Time Range:</strong> <span id="data-time-range">-</span></p>
<p><strong>价格范围:</strong> <span id="data-price-range">-</span></p> <p><strong>Price Range:</strong> <span id="data-price-range">-</span></p>
<p><strong>时间频率:</strong> <span id="data-timeframe">-</span></p> <p><strong>Time Frequency:</strong> <span id="data-timeframe">-</span></p>
<p><strong>预测列:</strong> <span id="data-prediction-cols">-</span></p> <p><strong>Prediction Columns:</strong> <span id="data-prediction-cols">-</span></p>
</div> </div>
<hr style="margin: 20px 0; border: 1px solid #e2e8f0;"> <hr style="margin: 20px 0; border: 1px solid #e2e8f0;">
<!-- 时间窗口选择器 --> <!-- Time Window Selector -->
<div class="time-window-container"> <div class="time-window-container">
<h3>时间窗口选择</h3> <h3>Time Window Selection</h3>
<div class="time-window-info"> <div class="time-window-info">
<span id="window-start">开始: --</span> <span id="window-start">Start: --</span>
<span id="window-end">结束: --</span> <span id="window-end">End: --</span>
<span id="window-size">窗口大小: 400+120=520个数据点</span> <span id="window-size">Window Size: 400+120=520 data points</span>
</div> </div>
<div class="time-window-slider"> <div class="time-window-slider">
@ -523,104 +523,104 @@
<div class="slider-handle end-handle" id="end-handle"></div> <div class="slider-handle end-handle" id="end-handle"></div>
</div> </div>
<div class="slider-labels"> <div class="slider-labels">
<span id="min-label">最早</span> <span id="min-label">Earliest</span>
<span id="max-label">最新</span> <span id="max-label">Latest</span>
</div> </div>
</div> </div>
<small class="form-text">拖动滑条选择520个数据点的时间窗口位置绿色区域表示固定的400+120数据点范围</small> <small class="form-text">Drag slider to select time window position for 520 data points, green area represents fixed 400+120 data point range</small>
</div> </div>
<!-- 预测参数 --> <!-- Prediction Parameters -->
<div class="form-group"> <div class="form-group">
<label for="lookback">回看窗口大小:</label> <label for="lookback">Lookback Window Size:</label>
<input type="number" id="lookback" value="400" readonly> <input type="number" id="lookback" value="400" readonly>
<small class="form-text">固定为400个数据点</small> <small class="form-text">Fixed at 400 data points</small>
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="pred-len">预测长度:</label> <label for="pred-len">Prediction Length:</label>
<input type="number" id="pred-len" value="120" readonly> <input type="number" id="pred-len" value="120" readonly>
<small class="form-text">固定为120个数据点</small> <small class="form-text">Fixed at 120 data points</small>
</div> </div>
<!-- 预测质量参数 --> <!-- Prediction Quality Parameters -->
<div class="form-group"> <div class="form-group">
<label for="temperature">预测温度 (T):</label> <label for="temperature">Prediction Temperature (T):</label>
<input type="range" id="temperature" value="1.0" min="0.1" max="2.0" step="0.1"> <input type="range" id="temperature" value="1.0" min="0.1" max="2.0" step="0.1">
<span id="temperature-value">1.0</span> <span id="temperature-value">1.0</span>
<small class="form-text">控制预测的随机性,值越高预测越多样化,值越低预测越保守</small> <small class="form-text">Controls prediction randomness, higher values make predictions more diverse, lower values make predictions more conservative</small>
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="top-p">核采样参数 (top_p):</label> <label for="top-p">Nucleus Sampling Parameter (top_p):</label>
<input type="range" id="top-p" value="0.9" min="0.1" max="1.0" step="0.1"> <input type="range" id="top-p" value="0.9" min="0.1" max="1.0" step="0.1">
<span id="top-p-value">0.9</span> <span id="top-p-value">0.9</span>
<small class="form-text">控制预测的多样性,值越高考虑的概率分布越广</small> <small class="form-text">Controls prediction diversity, higher values consider broader probability distributions</small>
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="sample-count">样本数量:</label> <label for="sample-count">Sample Count:</label>
<input type="number" id="sample-count" value="1" min="1" max="5" step="1"> <input type="number" id="sample-count" value="1" min="1" max="5" step="1">
<small class="form-text">生成多个预测样本以提高质量建议1-3个</small> <small class="form-text">Generate multiple prediction samples to improve quality (recommended 1-3)</small>
</div> </div>
<button id="predict-btn" class="btn btn-success" disabled> <button id="predict-btn" class="btn btn-success" disabled>
🔮 开始预测 🔮 Start Prediction
</button> </button>
<!-- 加载状态 --> <!-- Loading Status -->
<div id="loading" class="loading"> <div id="loading" class="loading">
<div class="spinner"></div> <div class="spinner"></div>
<p>正在处理,请稍候...</p> <p>Processing, please wait...</p>
</div> </div>
</div> </div>
<div class="chart-container"> <div class="chart-container">
<h2>📈 预测结果图表</h2> <h2>📈 Prediction Results Chart</h2>
<div id="chart"></div> <div id="chart"></div>
<!-- 对比分析 --> <!-- Comparison Analysis -->
<div id="comparison-section" class="comparison-section" style="display: none;"> <div id="comparison-section" class="comparison-section" style="display: none;">
<h3>📊 预测 vs 实际数据对比</h3> <h3>📊 Prediction vs Actual Data Comparison</h3>
<div id="comparison-info" class="comparison-info"> <div id="comparison-info" class="comparison-info">
<p><strong>预测类型:</strong> <span id="prediction-type">-</span></p> <p><strong>Prediction Type:</strong> <span id="prediction-type">-</span></p>
<p><strong>对比数据:</strong> <span id="comparison-data">-</span></p> <p><strong>Comparison Data:</strong> <span id="comparison-data">-</span></p>
</div> </div>
<div class="error-stats"> <div class="error-stats">
<div class="error-stat"> <div class="error-stat">
<h4>平均绝对误差</h4> <h4>Mean Absolute Error</h4>
<div class="value" id="mae">-</div> <div class="value" id="mae">-</div>
<div class="unit">价格单位</div> <div class="unit">Price Units</div>
</div> </div>
<div class="error-stat"> <div class="error-stat">
<h4>均方根误差</h4> <h4>Root Mean Square Error</h4>
<div class="value" id="rmse">-</div> <div class="value" id="rmse">-</div>
<div class="unit">价格单位</div> <div class="unit">Price Units</div>
</div> </div>
<div class="error-stat"> <div class="error-stat">
<h4>平均绝对百分比误差</h4> <h4>Mean Absolute Percentage Error</h4>
<div class="value" id="mape">-</div> <div class="value" id="mape">-</div>
<div class="unit">%</div> <div class="unit">%</div>
</div> </div>
</div> </div>
<div class="error-details"> <div class="error-details">
<h4>详细对比数据:</h4> <h4>Detailed Comparison Data:</h4>
<div style="max-height: 300px; overflow-y: auto;"> <div style="max-height: 300px; overflow-y: auto;">
<table class="comparison-table"> <table class="comparison-table">
<thead> <thead>
<tr> <tr>
<th>时间</th> <th>Time</th>
<th>实际开盘</th> <th>Actual Open</th>
<th>预测开盘</th> <th>Predicted Open</th>
<th>实际最高</th> <th>Actual High</th>
<th>预测最高</th> <th>Predicted High</th>
<th>实际最低</th> <th>Actual Low</th>
<th>预测最低</th> <th>Predicted Low</th>
<th>实际收盘</th> <th>Actual Close</th>
<th>预测收盘</th> <th>Predicted Close</th>
</tr> </tr>
</thead> </thead>
<tbody id="comparison-tbody"> <tbody id="comparison-tbody">
@ -634,58 +634,58 @@
</div> </div>
<script> <script>
// 全局变量 // Global variables
let currentDataFile = null; let currentDataFile = null;
let currentDataInfo = null; let currentDataInfo = null;
let availableModels = []; let availableModels = [];
let modelLoaded = false; let modelLoaded = false;
// 页面加载完成后初始化 // Initialize after page loads
document.addEventListener('DOMContentLoaded', function() { document.addEventListener('DOMContentLoaded', function() {
initializeApp(); initializeApp();
}); });
// 初始化应用 // Initialize application
async function initializeApp() { async function initializeApp() {
console.log('🚀 初始化 Kronos Web UI...'); console.log('🚀 Initializing Kronos Web UI...');
// 加载可用模型 // Load available models
await loadAvailableModels(); await loadAvailableModels();
// 加载数据文件列表 // Load data file list
await loadDataFiles(); await loadDataFiles();
// 设置事件监听器 // Set up event listeners
setupEventListeners(); setupEventListeners();
// 初始化时间滑块 // Initialize time slider
initializeTimeSlider(); initializeTimeSlider();
console.log('✅ 应用初始化完成'); console.log('✅ Application initialization completed');
} }
// 加载可用模型 // Load available models
async function loadAvailableModels() { async function loadAvailableModels() {
try { try {
const response = await axios.get('/api/available-models'); const response = await axios.get('/api/available-models');
if (response.data.model_available) { if (response.data.model_available) {
availableModels = response.data.models; availableModels = response.data.models;
populateModelSelect(); populateModelSelect();
console.log('✅ 可用模型加载成功:', availableModels); console.log('✅ Available models loaded successfully:', availableModels);
} else { } else {
console.warn('⚠️ Kronos模型库不可用'); console.warn('⚠️ Kronos model library not available');
showStatus('warning', 'Kronos模型库不可用,将使用模拟预测'); showStatus('warning', 'Kronos model library not available, will use simulated prediction');
} }
} catch (error) { } catch (error) {
console.error('❌ 加载可用模型失败:', error); console.error('❌ Failed to load available models:', error);
showStatus('error', '加载可用模型失败'); showStatus('error', 'Failed to load available models');
} }
} }
// 填充模型选择下拉框 // Populate model selection dropdown
function populateModelSelect() { function populateModelSelect() {
const modelSelect = document.getElementById('model-select'); const modelSelect = document.getElementById('model-select');
modelSelect.innerHTML = '<option value="">请选择模型</option>'; modelSelect.innerHTML = '<option value="">Please select model</option>';
Object.entries(availableModels).forEach(([key, model]) => { Object.entries(availableModels).forEach(([key, model]) => {
const option = document.createElement('option'); const option = document.createElement('option');
@ -695,13 +695,13 @@
}); });
} }
// 加载模型 // Load model
async function loadModel() { async function loadModel() {
const modelKey = document.getElementById('model-select').value; const modelKey = document.getElementById('model-select').value;
const device = document.getElementById('device-select').value; const device = document.getElementById('device-select').value;
if (!modelKey) { if (!modelKey) {
showStatus('error', '请选择要加载的模型'); showStatus('error', 'Please select a model to load');
return; return;
} }
@ -719,45 +719,45 @@
showStatus('success', response.data.message); showStatus('success', response.data.message);
updateModelStatus(); updateModelStatus();
document.getElementById('predict-btn').disabled = false; document.getElementById('predict-btn').disabled = false;
console.log('✅ 模型加载成功:', response.data.model_info); console.log('✅ Model loaded successfully:', response.data.model_info);
} else { } else {
showStatus('error', response.data.error); showStatus('error', response.data.error);
} }
} catch (error) { } catch (error) {
console.error('❌ 模型加载失败:', error); console.error('❌ Model loading failed:', error);
showStatus('error', `模型加载失败: ${error.response?.data?.error || error.message}`); showStatus('error', `Model loading failed: ${error.response?.data?.error || error.message}`);
} finally { } finally {
showLoading(false); showLoading(false);
document.getElementById('load-model-btn').disabled = false; document.getElementById('load-model-btn').disabled = false;
} }
} }
// 更新模型状态 // Update model status
async function updateModelStatus() { async function updateModelStatus() {
try { try {
const response = await axios.get('/api/model-status'); const response = await axios.get('/api/model-status');
const status = response.data; const status = response.data;
if (status.loaded) { if (status.loaded) {
showStatus('success', `模型已加载: ${status.current_model.name} on ${status.current_model.device}`); showStatus('success', `Model loaded: ${status.current_model.name} on ${status.current_model.device}`);
} else if (status.available) { } else if (status.available) {
showStatus('info', '模型可用但未加载'); showStatus('info', 'Model available but not loaded');
} else { } else {
showStatus('warning', '模型库不可用'); showStatus('warning', 'Model library not available');
} }
} catch (error) { } catch (error) {
console.error('❌ 获取模型状态失败:', error); console.error('❌ Failed to get model status:', error);
} }
} }
// 加载数据文件列表 // Load data file list
async function loadDataFiles() { async function loadDataFiles() {
try { try {
const response = await axios.get('/api/data-files'); const response = await axios.get('/api/data-files');
const dataFiles = response.data; const dataFiles = response.data;
const dataFileSelect = document.getElementById('data-file-select'); const dataFileSelect = document.getElementById('data-file-select');
dataFileSelect.innerHTML = '<option value="">请选择数据文件</option>'; dataFileSelect.innerHTML = '<option value="">Please select data file</option>';
dataFiles.forEach(file => { dataFiles.forEach(file => {
const option = document.createElement('option'); const option = document.createElement('option');
@ -766,19 +766,19 @@
dataFileSelect.appendChild(option); dataFileSelect.appendChild(option);
}); });
console.log('✅ 数据文件列表加载成功:', dataFiles); console.log('✅ Data file list loaded successfully:', dataFiles);
} catch (error) { } catch (error) {
console.error('❌ 加载数据文件列表失败:', error); console.error('❌ Failed to load data file list:', error);
showStatus('error', '加载数据文件列表失败'); showStatus('error', 'Failed to load data file list');
} }
} }
// 加载数据文件 // Load data file
async function loadData() { async function loadData() {
const filePath = document.getElementById('data-file-select').value; const filePath = document.getElementById('data-file-select').value;
if (!filePath) { if (!filePath) {
showStatus('error', '请选择要加载的数据文件'); showStatus('error', 'Please select a data file to load');
return; return;
} }
@ -796,56 +796,56 @@
showDataInfo(response.data.data_info); showDataInfo(response.data.data_info);
showStatus('success', response.data.message); showStatus('success', response.data.message);
// 更新预测按钮状态 // Update prediction button status
if (modelLoaded) { if (modelLoaded) {
document.getElementById('predict-btn').disabled = false; document.getElementById('predict-btn').disabled = false;
} }
console.log('✅ 数据加载成功:', response.data.data_info); console.log('✅ Data loaded successfully:', response.data.data_info);
} else { } else {
showStatus('error', response.data.error); showStatus('error', response.data.error);
} }
} catch (error) { } catch (error) {
console.error('❌ 数据加载失败:', error); console.error('❌ Data loading failed:', error);
showStatus('error', `数据加载失败: ${error.response?.data?.error || error.message}`); showStatus('error', `Data loading failed: ${error.response?.data?.error || error.message}`);
} finally { } finally {
showLoading(false); showLoading(false);
document.getElementById('load-data-btn').disabled = false; document.getElementById('load-data-btn').disabled = false;
} }
} }
// 显示数据信息 // Display data information
function showDataInfo(dataInfo) { function showDataInfo(dataInfo) {
document.getElementById('data-info').style.display = 'block'; document.getElementById('data-info').style.display = 'block';
document.getElementById('data-rows').textContent = dataInfo.rows; document.getElementById('data-rows').textContent = dataInfo.rows;
document.getElementById('data-cols').textContent = dataInfo.columns.length; document.getElementById('data-cols').textContent = dataInfo.columns.length;
document.getElementById('data-time-range').textContent = `${dataInfo.start_date} ${dataInfo.end_date}`; document.getElementById('data-time-range').textContent = `${dataInfo.start_date} to ${dataInfo.end_date}`;
document.getElementById('data-price-range').textContent = `${dataInfo.price_range.min.toFixed(4)} - ${dataInfo.price_range.max.toFixed(4)}`; document.getElementById('data-price-range').textContent = `${dataInfo.price_range.min.toFixed(4)} - ${dataInfo.price_range.max.toFixed(4)}`;
document.getElementById('data-timeframe').textContent = dataInfo.timeframe; document.getElementById('data-timeframe').textContent = dataInfo.timeframe;
document.getElementById('data-prediction-cols').textContent = dataInfo.prediction_columns.join(', '); document.getElementById('data-prediction-cols').textContent = dataInfo.prediction_columns.join(', ');
// 初始化时间窗口滑条 // Initialize time window slider
initializeTimeWindowSlider(dataInfo); initializeTimeWindowSlider(dataInfo);
} }
// 时间窗口滑条相关变量 // Time window slider related variables
let sliderData = null; let sliderData = null;
let isDragging = false; let isDragging = false;
let currentHandle = null; let currentHandle = null;
// 初始化时间窗口滑条 // Initialize time window slider
function initializeTimeSlider() { function initializeTimeSlider() {
// 设置滑条事件监听器 // Set up slider event listeners
setupSliderEventListeners(); setupSliderEventListeners();
} }
// 设置滑条事件监听器 // Set up slider event listeners
function setupSliderEventListeners() { function setupSliderEventListeners() {
const startHandle = document.getElementById('start-handle'); const startHandle = document.getElementById('start-handle');
const endHandle = document.getElementById('end-handle'); const endHandle = document.getElementById('end-handle');
const track = document.querySelector('.slider-track'); const track = document.querySelector('.slider-track');
// 开始拖拽 // Start dragging
startHandle.addEventListener('mousedown', (e) => { startHandle.addEventListener('mousedown', (e) => {
isDragging = true; isDragging = true;
currentHandle = 'start'; currentHandle = 'start';
@ -858,7 +858,7 @@
e.preventDefault(); e.preventDefault();
}); });
// 拖拽中 // Dragging
document.addEventListener('mousemove', (e) => { document.addEventListener('mousemove', (e) => {
if (!isDragging) return; if (!isDragging) return;
@ -875,19 +875,19 @@
updateSliderFromHandles(); updateSliderFromHandles();
}); });
// 结束拖拽 // End dragging
document.addEventListener('mouseup', () => { document.addEventListener('mouseup', () => {
isDragging = false; isDragging = false;
currentHandle = null; currentHandle = null;
}); });
// 点击轨道直接设置位置 // Click track to set position directly
track.addEventListener('click', (e) => { track.addEventListener('click', (e) => {
const rect = track.getBoundingClientRect(); const rect = track.getBoundingClientRect();
const x = e.clientX - rect.left; const x = e.clientX - rect.left;
const percentage = Math.max(0, Math.min(1, x / rect.width)); const percentage = Math.max(0, Math.min(1, x / rect.width));
// 判断点击位置更接近哪个手柄 // Determine which handle is closer to the click position
const startHandle = document.getElementById('start-handle'); const startHandle = document.getElementById('start-handle');
const endHandle = document.getElementById('end-handle'); const endHandle = document.getElementById('end-handle');
const startRect = startHandle.getBoundingClientRect(); const startRect = startHandle.getBoundingClientRect();
@ -903,17 +903,17 @@
}); });
} }
// 更新开始手柄位置 // Update start handle position
function updateStartHandle(percentage) { function updateStartHandle(percentage) {
const startHandle = document.getElementById('start-handle'); const startHandle = document.getElementById('start-handle');
const selection = document.getElementById('slider-selection'); const selection = document.getElementById('slider-selection');
// 固定窗口大小为520个数据点 // Fixed window size of 520 data points
const windowSize = 520; const windowSize = 520;
const totalRows = sliderData ? sliderData.totalRows : 1000; const totalRows = sliderData ? sliderData.totalRows : 1000;
const windowPercentage = windowSize / totalRows; const windowPercentage = windowSize / totalRows;
// 确保开始手柄不会导致窗口超出数据范围 // Ensure start handle doesn't cause window to exceed data range
if (percentage + windowPercentage > 1) { if (percentage + windowPercentage > 1) {
percentage = 1 - windowPercentage; percentage = 1 - windowPercentage;
} }
@ -922,22 +922,22 @@
selection.style.left = (percentage * 100) + '%'; selection.style.left = (percentage * 100) + '%';
selection.style.width = (windowPercentage * 100) + '%'; selection.style.width = (windowPercentage * 100) + '%';
// 自动调整结束手柄位置,保持固定窗口大小 // Automatically adjust end handle position to maintain fixed window size
const endHandle = document.getElementById('end-handle'); const endHandle = document.getElementById('end-handle');
endHandle.style.left = ((percentage + windowPercentage) * 100) + '%'; endHandle.style.left = ((percentage + windowPercentage) * 100) + '%';
} }
// 更新结束手柄位置 // Update end handle position
function updateEndHandle(percentage) { function updateEndHandle(percentage) {
const endHandle = document.getElementById('end-handle'); const endHandle = document.getElementById('end-handle');
const selection = document.getElementById('slider-selection'); const selection = document.getElementById('slider-selection');
// 固定窗口大小为520个数据点 // Fixed window size of 520 data points
const windowSize = 520; const windowSize = 520;
const totalRows = sliderData ? sliderData.totalRows : 1000; const totalRows = sliderData ? sliderData.totalRows : 1000;
const windowPercentage = windowSize / totalRows; const windowPercentage = windowSize / totalRows;
// 确保结束手柄不会导致窗口超出数据范围 // Ensure end handle doesn't cause window to exceed data range
if (percentage - windowPercentage < 0) { if (percentage - windowPercentage < 0) {
percentage = windowPercentage; percentage = windowPercentage;
} }
@ -946,12 +946,12 @@
selection.style.left = ((percentage - windowPercentage) * 100) + '%'; selection.style.left = ((percentage - windowPercentage) * 100) + '%';
selection.style.width = (windowPercentage * 100) + '%'; selection.style.width = (windowPercentage * 100) + '%';
// 自动调整开始手柄位置,保持固定窗口大小 // Automatically adjust start handle position to maintain fixed window size
const startHandle = document.getElementById('start-handle'); const startHandle = document.getElementById('start-handle');
startHandle.style.left = ((percentage - windowPercentage) * 100) + '%'; startHandle.style.left = ((percentage - windowPercentage) * 100) + '%';
} }
// 根据手柄位置更新滑条显示 // Update slider display based on handle positions
function updateSliderFromHandles() { function updateSliderFromHandles() {
const startHandle = document.getElementById('start-handle'); const startHandle = document.getElementById('start-handle');
const endHandle = document.getElementById('end-handle'); const endHandle = document.getElementById('end-handle');
@ -961,7 +961,7 @@
if (!sliderData) return; if (!sliderData) return;
// 计算选中的时间范围 // Calculate selected time range
const totalTime = sliderData.endDate.getTime() - sliderData.startDate.getTime(); const totalTime = sliderData.endDate.getTime() - sliderData.startDate.getTime();
const startTime = sliderData.startDate.getTime() + (totalTime * startPercentage); const startTime = sliderData.startDate.getTime() + (totalTime * startPercentage);
const endTime = sliderData.startDate.getTime() + (totalTime * endPercentage); const endTime = sliderData.startDate.getTime() + (totalTime * endPercentage);
@ -969,49 +969,49 @@
const startDate = new Date(startTime); const startDate = new Date(startTime);
const endDate = new Date(endTime); const endDate = new Date(endTime);
// 更新显示信息 // Update display information
document.getElementById('window-start').textContent = `开始: ${startDate.toLocaleDateString()}`; document.getElementById('window-start').textContent = `Start: ${startDate.toLocaleDateString()}`;
document.getElementById('window-end').textContent = `结束: ${endDate.toLocaleDateString()}`; document.getElementById('window-end').textContent = `End: ${endDate.toLocaleDateString()}`;
// 显示固定的窗口大小 // Display fixed window size
document.getElementById('window-size').textContent = `窗口大小: 400 + 120 = 520 个数据点 (固定)`; document.getElementById('window-size').textContent = `Window Size: 400 + 120 = 520 data points (fixed)`;
// 输入框值保持固定 // Input field values remain fixed
document.getElementById('lookback').value = 400; document.getElementById('lookback').value = 400;
document.getElementById('pred-len').value = 120; document.getElementById('pred-len').value = 120;
} }
// 根据输入框更新滑条 // Update slider based on input fields
function updateSliderFromInputs() { function updateSliderFromInputs() {
if (!sliderData) return; if (!sliderData) return;
// 固定窗口大小400 + 120 = 520个数据点 // Fixed window size: 400 + 120 = 520 data points
const lookback = 400; const lookback = 400;
const predLen = 120; const predLen = 120;
const windowSize = lookback + predLen; // 固定为520 const windowSize = lookback + predLen; // Fixed at 520
// 计算滑条位置 // Calculate slider position
const totalRows = sliderData.totalRows; const totalRows = sliderData.totalRows;
if (windowSize > totalRows) { if (windowSize > totalRows) {
// 如果窗口大小超过总数据量,显示错误 // If window size exceeds total data amount, show error
showStatus('error', `数据量不足,需要至少${windowSize}个数据点,当前只有${totalRows}个`); showStatus('error', `Insufficient data, need at least ${windowSize} data points, currently only ${totalRows} available`);
return; return;
} }
// 计算滑条位置(默认选择数据的前半部分) // Calculate slider position (default select first half of data)
const startPercentage = 0.1; // 从10%开始 const startPercentage = 0.1; // Start from 10%
const endPercentage = startPercentage + (windowSize / totalRows); const endPercentage = startPercentage + (windowSize / totalRows);
// 更新手柄位置 // Update handle positions
updateStartHandle(startPercentage); updateStartHandle(startPercentage);
updateEndHandle(endPercentage); updateEndHandle(endPercentage);
// 更新显示信息 // Update display information
updateSliderFromHandles(); updateSliderFromHandles();
} }
// 初始化时间窗口滑条 // Initialize time window slider
function initializeTimeWindowSlider(dataInfo) { function initializeTimeWindowSlider(dataInfo) {
sliderData = { sliderData = {
startDate: new Date(dataInfo.start_date), startDate: new Date(dataInfo.start_date),
@ -1020,23 +1020,23 @@
timeframe: dataInfo.timeframe timeframe: dataInfo.timeframe
}; };
// 设置滑条标签 // Set slider labels
document.getElementById('min-label').textContent = dataInfo.start_date.split('T')[0]; document.getElementById('min-label').textContent = dataInfo.start_date.split('T')[0];
document.getElementById('max-label').textContent = dataInfo.end_date.split('T')[0]; document.getElementById('max-label').textContent = dataInfo.end_date.split('T')[0];
// 初始化滑条位置 // Initialize slider position
updateSliderFromInputs(); updateSliderFromInputs();
} }
// 开始预测 // Start prediction
async function startPrediction() { async function startPrediction() {
if (!currentDataFile) { if (!currentDataFile) {
showStatus('error', '请先加载数据文件'); showStatus('error', 'Please load data file first');
return; return;
} }
if (!modelLoaded) { if (!modelLoaded) {
showStatus('error', '请先加载模型'); showStatus('error', 'Please load model first');
return; return;
} }
@ -1047,21 +1047,21 @@
const lookback = parseInt(document.getElementById('lookback').value); const lookback = parseInt(document.getElementById('lookback').value);
const predLen = parseInt(document.getElementById('pred-len').value); const predLen = parseInt(document.getElementById('pred-len').value);
// 从时间窗口滑条获取选择的时间范围 // Get selected time range from time window slider
const startHandle = document.getElementById('start-handle'); const startHandle = document.getElementById('start-handle');
const startPercentage = parseFloat(startHandle.style.left) / 100; const startPercentage = parseFloat(startHandle.style.left) / 100;
if (!sliderData) { if (!sliderData) {
showStatus('error', '时间窗口滑条未初始化'); showStatus('error', 'Time window slider not initialized');
return; return;
} }
// 计算选择的时间范围 // Calculate selected time range
const totalTime = sliderData.endDate.getTime() - sliderData.startDate.getTime(); const totalTime = sliderData.endDate.getTime() - sliderData.startDate.getTime();
const startTime = sliderData.startDate.getTime() + (totalTime * startPercentage); const startTime = sliderData.startDate.getTime() + (totalTime * startPercentage);
const startDate = new Date(startTime); const startDate = new Date(startTime);
// 获取预测质量参数 // Get prediction quality parameters
const temperature = parseFloat(document.getElementById('temperature').value); const temperature = parseFloat(document.getElementById('temperature').value);
const topP = parseFloat(document.getElementById('top-p').value); const topP = parseFloat(document.getElementById('top-p').value);
const sampleCount = parseInt(document.getElementById('sample-count').value); const sampleCount = parseInt(document.getElementById('sample-count').value);
@ -1070,39 +1070,39 @@
file_path: currentDataFile, file_path: currentDataFile,
lookback: lookback, lookback: lookback,
pred_len: predLen, pred_len: predLen,
start_date: startDate.toISOString().slice(0, 16), // 格式化为 YYYY-MM-DDTHH:MM start_date: startDate.toISOString().slice(0, 16), // Format as YYYY-MM-DDTHH:MM
temperature: temperature, temperature: temperature,
top_p: topP, top_p: topP,
sample_count: sampleCount sample_count: sampleCount
}; };
console.log('🚀 开始预测,参数:', predictionParams); console.log('🚀 Starting prediction, parameters:', predictionParams);
const response = await axios.post('/api/predict', predictionParams); const response = await axios.post('/api/predict', predictionParams);
if (response.data.success) { if (response.data.success) {
// 显示预测结果 // Display prediction results
displayPredictionResult(response.data); displayPredictionResult(response.data);
showStatus('success', response.data.message); showStatus('success', response.data.message);
} else { } else {
showStatus('error', response.data.error); showStatus('error', response.data.error);
} }
} catch (error) { } catch (error) {
console.error('❌ 预测失败:', error); console.error('❌ Prediction failed:', error);
showStatus('error', `预测失败: ${error.response?.data?.error || error.message}`); showStatus('error', `Prediction failed: ${error.response?.data?.error || error.message}`);
} finally { } finally {
showLoading(false); showLoading(false);
document.getElementById('predict-btn').disabled = false; document.getElementById('predict-btn').disabled = false;
} }
} }
// 显示预测结果 // Display prediction results
function displayPredictionResult(result) { function displayPredictionResult(result) {
// 显示图表 // Display chart
const chartData = JSON.parse(result.chart); const chartData = JSON.parse(result.chart);
Plotly.newPlot('chart', chartData.data, chartData.layout); Plotly.newPlot('chart', chartData.data, chartData.layout);
// 显示对比分析(如果有实际数据) // Display comparison analysis (if actual data exists)
if (result.has_comparison) { if (result.has_comparison) {
displayComparisonAnalysis(result); displayComparisonAnalysis(result);
} else { } else {
@ -1110,27 +1110,27 @@
} }
} }
// 显示对比分析 // Display comparison analysis
function displayComparisonAnalysis(result) { function displayComparisonAnalysis(result) {
document.getElementById('comparison-section').style.display = 'block'; document.getElementById('comparison-section').style.display = 'block';
// 更新对比信息 // Update comparison information
document.getElementById('prediction-type').textContent = result.prediction_type; document.getElementById('prediction-type').textContent = result.prediction_type;
document.getElementById('comparison-data').textContent = `${result.actual_data.length} 个实际数据点`; document.getElementById('comparison-data').textContent = `${result.actual_data.length} actual data points`;
// 计算误差统计 // Calculate error statistics
const errorStats = getPredictionQuality(result.prediction_results, result.actual_data); const errorStats = getPredictionQuality(result.prediction_results, result.actual_data);
// 显示误差统计 // Display error statistics
document.getElementById('mae').textContent = errorStats.mae.toFixed(4); document.getElementById('mae').textContent = errorStats.mae.toFixed(4);
document.getElementById('rmse').textContent = errorStats.rmse.toFixed(4); document.getElementById('rmse').textContent = errorStats.rmse.toFixed(4);
document.getElementById('mape').textContent = errorStats.mape.toFixed(2); document.getElementById('mape').textContent = errorStats.mape.toFixed(2);
// 填充对比表格 // Fill comparison table
fillComparisonTable(result.prediction_results, result.actual_data); fillComparisonTable(result.prediction_results, result.actual_data);
} }
// 计算预测质量指标 // Calculate prediction quality metrics
function getPredictionQuality(predictions, actuals) { function getPredictionQuality(predictions, actuals) {
if (!predictions || !actuals || predictions.length === 0 || actuals.length === 0) { if (!predictions || !actuals || predictions.length === 0 || actuals.length === 0) {
return { mae: 0, rmse: 0, mape: 0 }; return { mae: 0, rmse: 0, mape: 0 };
@ -1143,7 +1143,7 @@
const pred = predictions[i]; const pred = predictions[i];
const act = actuals[i]; const act = actuals[i];
// 使用收盘价计算误差 // Use closing price to calculate errors
const error = Math.abs(pred.close - act.close); const error = Math.abs(pred.close - act.close);
const percentError = (error / act.close) * 100; const percentError = (error / act.close) * 100;
@ -1159,7 +1159,7 @@
return { mae, rmse, mape }; return { mae, rmse, mape };
} }
// 填充对比表格 // Fill comparison table
function fillComparisonTable(predictions, actuals) { function fillComparisonTable(predictions, actuals) {
const tbody = document.getElementById('comparison-tbody'); const tbody = document.getElementById('comparison-tbody');
tbody.innerHTML = ''; tbody.innerHTML = '';
@ -1186,18 +1186,18 @@
} }
} }
// 设置事件监听器 // Set up event listeners
function setupEventListeners() { function setupEventListeners() {
// 加载模型按钮 // Load model button
document.getElementById('load-model-btn').addEventListener('click', loadModel); document.getElementById('load-model-btn').addEventListener('click', loadModel);
// 加载数据按钮 // Load data button
document.getElementById('load-data-btn').addEventListener('click', loadData); document.getElementById('load-data-btn').addEventListener('click', loadData);
// 预测按钮 // Prediction button
document.getElementById('predict-btn').addEventListener('click', startPrediction); document.getElementById('predict-btn').addEventListener('click', startPrediction);
// 预测质量参数滑块 // Prediction quality parameter sliders
document.getElementById('temperature').addEventListener('input', function() { document.getElementById('temperature').addEventListener('input', function() {
document.getElementById('temperature-value').textContent = this.value; document.getElementById('temperature-value').textContent = this.value;
}); });
@ -1206,25 +1206,25 @@
document.getElementById('top-p-value').textContent = this.value; document.getElementById('top-p-value').textContent = this.value;
}); });
// 回看窗口大小变化时更新滑条 // Update slider when lookback window size changes
document.getElementById('lookback').addEventListener('input', updateSliderFromInputs); document.getElementById('lookback').addEventListener('input', updateSliderFromInputs);
document.getElementById('pred-len').addEventListener('input', updateSliderFromInputs); document.getElementById('pred-len').addEventListener('input', updateSliderFromInputs);
} }
// 显示状态信息 // Display status information
function showStatus(type, message) { function showStatus(type, message) {
const statusDiv = document.getElementById('model-status'); const statusDiv = document.getElementById('model-status');
statusDiv.className = `status ${type}`; statusDiv.className = `status ${type}`;
statusDiv.textContent = message; statusDiv.textContent = message;
statusDiv.style.display = 'block'; statusDiv.style.display = 'block';
// 自动隐藏 // Auto-hide
setTimeout(() => { setTimeout(() => {
statusDiv.style.display = 'none'; statusDiv.style.display = 'none';
}, 5000); }, 5000);
} }
// 显示/隐藏加载状态 // Show/hide loading status
function showLoading(show) { function showLoading(show) {
const loadingDiv = document.getElementById('loading'); const loadingDiv = document.getElementById('loading');
if (show) { if (show) {