2026-01-24 12:16:32 +08:00

436 lines
13 KiB
Python

import os
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from gpiozero import Buzzer, OutputDevice
import time
import threading
import queue
import uuid
import json
import asyncio
from typing import Optional, Callable, List, Dict, Any
from pydantic import BaseModel
from rigol_phase import measure_phase, DEV as RIGOL_DEV
app = FastAPI()
class BatchTaskItem(BaseModel):
cmd: str
args: Dict[str, Any] = {}
repeat: int = 1
motor_lock = threading.Lock()
def _get_cors_origins() -> list[str]:
raw = os.getenv("CORS_ORIGINS", "*").strip()
if raw == "*":
return ["*"]
return [origin.strip() for origin in raw.split(",") if origin.strip()]
_cors_origins = _get_cors_origins()
_cors_allow_credentials_env = os.getenv("CORS_ALLOW_CREDENTIALS", "false").strip().lower() in (
"1",
"true",
"yes",
"y",
"on",
)
_cors_allow_credentials = False if _cors_origins == ["*"] else _cors_allow_credentials_env
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_credentials=_cors_allow_credentials,
allow_methods=[m.strip() for m in os.getenv("CORS_ALLOW_METHODS", "*").split(",") if m.strip()],
allow_headers=[h.strip() for h in os.getenv("CORS_ALLOW_HEADERS", "*").split(",") if h.strip()],
)
pin = int(os.getenv("BEE_PIN", "23"))
BEE = Buzzer(pin)
class GPIOStepper:
def __init__(
self,
*,
dir_pin: int = 22,
step_pin: int = 27,
dir_invert: bool = False,
) -> None:
self.dir_invert = bool(dir_invert)
# gpiozero 默认使用 BCM 编号
self._dir = OutputDevice(dir_pin)
self._step = OutputDevice(step_pin)
self._step.off()
self._stop_event = threading.Event()
def close(self) -> None:
try:
self._step.off()
except Exception:
pass
try:
self._dir.off()
except Exception:
pass
def stop(self) -> None:
self._stop_event.set()
def set_dir(self, inc: bool) -> None:
"""Set direction.
inc=True means microsteps increases.
"""
value = 1 if inc else 0
if self.dir_invert:
value = 0 if value else 1
self._dir.value = value
def move_steps(self, steps: int, freq_hz: int, on_step: Optional[Callable[[], None]] = None) -> None:
"""Move by microsteps."""
self._stop_event.clear()
if steps < 0:
raise ValueError("steps must be >= 0")
if freq_hz <= 0:
raise ValueError("freq_hz must be > 0")
pulses = int(steps)
if pulses <= 0:
return
half_period = 0.5 / float(freq_hz)
for _ in range(pulses):
if self._stop_event.is_set():
break
self._step.on()
time.sleep(half_period)
self._step.off()
time.sleep(half_period)
if on_step:
on_step()
stepper = GPIOStepper()
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.get("/ping")
def ping():
return {"ping": "pong"}
@app.get("/bee")
def bee_sound():
try:
BEE.beep(on_time=0.1, off_time=0, n=1)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return {"status": "Bee buzzed for 0.1 seconds"}
state = { 'dis': 16000, # 发射器与接收器距离 microsteps
'phase': 0, # 相位差 rad
'freq': 0, # 频率 Hz
'p2p': 0, # 峰峰值 mV
'speed': 1200, # 电机速度 Hz
'tasks': [], # 任务列表
'last_measurement': {} # 最新示波器测量结果
}
state_lock = threading.Lock()
task_queue = queue.Queue()
def worker():
while True:
try:
task = task_queue.get()
if task is None:
continue
task_id = task['id']
# Start running
with state_lock:
for t in state['tasks']:
if t['id'] == task_id:
t['status'] = 'running'
break
try:
if task['type'] == 'move':
steps = task['steps']
is_positive = steps > 0
stepper.set_dir(is_positive)
abs_steps = abs(steps)
if abs_steps > 0:
step_sign = 1 if is_positive else -1
delta = step_sign
step_change = step_sign
def on_step():
with state_lock:
state['dis'] += delta
for t in state['tasks']:
if t['id'] == task_id:
t['remaining_steps'] -= step_change
if is_positive and t['remaining_steps'] < 0:
t['remaining_steps'] = 0
elif not is_positive and t['remaining_steps'] > 0:
t['remaining_steps'] = 0
break
with state_lock:
current_speed = state.get('speed', 1200)
stepper.move_steps(abs_steps, current_speed, on_step=on_step)
elif task['type'] == 'measure':
# print(f"Measuring Rigol on {RIGOL_DEV}...")
try:
# 默认参数,可根据需求从 task 中获取
meas = measure_phase(
dev=RIGOL_DEV,
timeout=3.0,
points_mode="NORM",
fetch_idn=False
)
# print(f"Measure result: {meas}")
with state_lock:
state['last_measurement'] = meas
# 更新 state 中的简要字段
if 'dphi_rad' in meas:
state['phase'] = meas['dphi_rad']
if 'f0_hz' in meas:
state['freq'] = meas['f0_hz']
# 'amp1_pp_adc' 和 'amp2_pp_adc'
# 假设 p2p 保存 Ch1 的
if 'amp1_pp_adc' in meas:
state['p2p'] = meas['amp1_pp_adc']
except Exception as e:
print(f"Measurement failed: {e}")
except Exception as e:
print(f"Error in task {task_id}: {e}")
finally:
# Remove from tasks list when done
with state_lock:
state['tasks'] = [t for t in state['tasks'] if t['id'] != task_id]
task_queue.task_done()
except Exception as e:
print(f"Worker exception: {e}")
# Start worker thread
threading.Thread(target=worker, daemon=True).start()
@app.get('/state')
def get_state():
with state_lock:
return state
@app.post('/state/dis')
def set_state_dis(dis: int):
print(f"Setting dis to {dis}")
with state_lock:
state['dis'] = dis
return {'dis': state['dis']}
@app.post('/action/measure')
def action_measure():
task_id = str(uuid.uuid4())
task = {
'id': task_id,
'type': 'measure',
'status': 'pending',
'created_at': time.time()
}
with state_lock:
state['tasks'].append(task)
task_queue.put(task)
return {"status": "queued", "task": task}
@app.post('/state/speed')
def set_state_speed(speed: int):
print(f"Setting speed to {speed}")
with state_lock:
state['speed'] = speed
return {'speed': state['speed']}
@app.post('/action/move')
def action_move(steps: int):
task_id = str(uuid.uuid4())
task = {
'id': task_id,
'type': 'move',
'steps': steps,
'remaining_steps': steps,
'status': 'pending',
'created_at': time.time()
}
with state_lock:
state['tasks'].append(task)
task_queue.put(task)
return {"status": "queued", "task": task}
@app.post('/action/move_measure')
def action_move_measure(steps: int):
# Create Move Task
move_task_id = str(uuid.uuid4())
move_task = {
'id': move_task_id,
'type': 'move',
'steps': steps,
'remaining_steps': steps,
'status': 'pending',
'created_at': time.time()
}
# Create Measure Task
measure_task_id = str(uuid.uuid4())
measure_task = {
'id': measure_task_id,
'type': 'measure',
'status': 'pending',
'created_at': time.time()
}
with state_lock:
state['tasks'].append(move_task)
state['tasks'].append(measure_task)
task_queue.put(move_task)
task_queue.put(measure_task)
return {"status": "queued", "tasks": [move_task, measure_task]}
@app.post('/action/batch')
def action_batch(items: List[BatchTaskItem]):
new_tasks = []
for item in items:
count = max(1, item.repeat)
for _ in range(count):
if item.cmd == 'move':
steps = int(item.args.get('steps', 0))
task_id = str(uuid.uuid4())
task = {
'id': task_id,
'type': 'move',
'steps': steps,
'remaining_steps': steps,
'status': 'pending',
'created_at': time.time()
}
new_tasks.append(task)
elif item.cmd == 'measure':
task_id = str(uuid.uuid4())
task = {
'id': task_id,
'type': 'measure',
'status': 'pending',
'created_at': time.time()
}
new_tasks.append(task)
elif item.cmd == 'move_measure':
steps = int(item.args.get('steps', 0))
move_task_id = str(uuid.uuid4())
move_task = {
'id': move_task_id,
'type': 'move',
'steps': steps,
'remaining_steps': steps,
'status': 'pending',
'created_at': time.time()
}
new_tasks.append(move_task)
measure_task_id = str(uuid.uuid4())
measure_task = {
'id': measure_task_id,
'type': 'measure',
'status': 'pending',
'created_at': time.time()
}
new_tasks.append(measure_task)
with state_lock:
state['tasks'].extend(new_tasks)
for t in new_tasks:
task_queue.put(t)
return {"status": "queued", "count": len(new_tasks)}
@app.post('/action/cancel')
def action_cancel():
# Stop current motor move
stepper.stop()
# Clear queue
while not task_queue.empty():
try:
task_queue.get_nowait()
task_queue.task_done()
except queue.Empty:
break
# Clear tasks in state
with state_lock:
state['tasks'] = []
return {"status": "cancelled", "message": "All tasks cancelled"}
@app.get("/events")
async def sse(request: Request):
"""Server-Sent Events endpoint to push state updates."""
async def event_generator():
last_state_str = None
while True:
if await request.is_disconnected():
break
with state_lock:
# Create a safe copy of state for serialization
current_state = state.copy()
all_tasks = state['tasks']
current_state['total_tasks'] = len(all_tasks)
# Keep only the first 5 tasks (active + next pending) to reduce payload
current_state['tasks'] = list(all_tasks[:5])
# Serialize
try:
data = json.dumps(current_state)
except Exception:
data = "{}" # Fallback
if data != last_state_str:
yield f"data: {data}\n\n"
last_state_str = data
await asyncio.sleep(0.1)
return StreamingResponse(event_generator(), media_type="text/event-stream")
if __name__ == "__main__":
import uvicorn
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "8000"))
print(f"Starting FastAPI on http://{host}:{port} (docs: /docs)")
# When running this file as a script, avoid importing this module twice.
# Passing the app object directly prevents gpiozero from reserving the same pin twice.
uvicorn.run(app, host=host, port=port, reload=False)