436 lines
13 KiB
Python
436 lines
13 KiB
Python
import os
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from gpiozero import Buzzer, OutputDevice
|
|
import time
|
|
import threading
|
|
import queue
|
|
import uuid
|
|
import json
|
|
import asyncio
|
|
from typing import Optional, Callable, List, Dict, Any
|
|
from pydantic import BaseModel
|
|
from rigol_phase import measure_phase, DEV as RIGOL_DEV
|
|
|
|
app = FastAPI()
|
|
|
|
class BatchTaskItem(BaseModel):
|
|
cmd: str
|
|
args: Dict[str, Any] = {}
|
|
repeat: int = 1
|
|
motor_lock = threading.Lock()
|
|
|
|
|
|
def _get_cors_origins() -> list[str]:
|
|
raw = os.getenv("CORS_ORIGINS", "*").strip()
|
|
if raw == "*":
|
|
return ["*"]
|
|
return [origin.strip() for origin in raw.split(",") if origin.strip()]
|
|
|
|
|
|
_cors_origins = _get_cors_origins()
|
|
_cors_allow_credentials_env = os.getenv("CORS_ALLOW_CREDENTIALS", "false").strip().lower() in (
|
|
"1",
|
|
"true",
|
|
"yes",
|
|
"y",
|
|
"on",
|
|
)
|
|
_cors_allow_credentials = False if _cors_origins == ["*"] else _cors_allow_credentials_env
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=_cors_origins,
|
|
allow_credentials=_cors_allow_credentials,
|
|
allow_methods=[m.strip() for m in os.getenv("CORS_ALLOW_METHODS", "*").split(",") if m.strip()],
|
|
allow_headers=[h.strip() for h in os.getenv("CORS_ALLOW_HEADERS", "*").split(",") if h.strip()],
|
|
)
|
|
|
|
pin = int(os.getenv("BEE_PIN", "23"))
|
|
BEE = Buzzer(pin)
|
|
|
|
class GPIOStepper:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dir_pin: int = 22,
|
|
step_pin: int = 27,
|
|
dir_invert: bool = False,
|
|
) -> None:
|
|
self.dir_invert = bool(dir_invert)
|
|
|
|
# gpiozero 默认使用 BCM 编号
|
|
self._dir = OutputDevice(dir_pin)
|
|
self._step = OutputDevice(step_pin)
|
|
self._step.off()
|
|
self._stop_event = threading.Event()
|
|
|
|
def close(self) -> None:
|
|
try:
|
|
self._step.off()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
self._dir.off()
|
|
except Exception:
|
|
pass
|
|
|
|
def stop(self) -> None:
|
|
self._stop_event.set()
|
|
|
|
def set_dir(self, inc: bool) -> None:
|
|
"""Set direction.
|
|
|
|
inc=True means microsteps increases.
|
|
"""
|
|
value = 1 if inc else 0
|
|
if self.dir_invert:
|
|
value = 0 if value else 1
|
|
self._dir.value = value
|
|
|
|
def move_steps(self, steps: int, freq_hz: int, on_step: Optional[Callable[[], None]] = None) -> None:
|
|
"""Move by microsteps."""
|
|
self._stop_event.clear()
|
|
if steps < 0:
|
|
raise ValueError("steps must be >= 0")
|
|
if freq_hz <= 0:
|
|
raise ValueError("freq_hz must be > 0")
|
|
|
|
pulses = int(steps)
|
|
if pulses <= 0:
|
|
return
|
|
|
|
half_period = 0.5 / float(freq_hz)
|
|
for _ in range(pulses):
|
|
if self._stop_event.is_set():
|
|
break
|
|
self._step.on()
|
|
time.sleep(half_period)
|
|
self._step.off()
|
|
time.sleep(half_period)
|
|
if on_step:
|
|
on_step()
|
|
|
|
stepper = GPIOStepper()
|
|
|
|
@app.get("/")
|
|
def read_root():
|
|
return {"Hello": "World"}
|
|
|
|
@app.get("/ping")
|
|
def ping():
|
|
return {"ping": "pong"}
|
|
|
|
@app.get("/bee")
|
|
def bee_sound():
|
|
try:
|
|
BEE.beep(on_time=0.1, off_time=0, n=1)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
return {"status": "Bee buzzed for 0.1 seconds"}
|
|
|
|
state = { 'dis': 16000, # 发射器与接收器距离 microsteps
|
|
'phase': 0, # 相位差 rad
|
|
'freq': 0, # 频率 Hz
|
|
'p2p': 0, # 峰峰值 mV
|
|
'speed': 1200, # 电机速度 Hz
|
|
'tasks': [], # 任务列表
|
|
'last_measurement': {} # 最新示波器测量结果
|
|
}
|
|
|
|
state_lock = threading.Lock()
|
|
task_queue = queue.Queue()
|
|
|
|
def worker():
|
|
while True:
|
|
try:
|
|
task = task_queue.get()
|
|
if task is None:
|
|
continue
|
|
|
|
task_id = task['id']
|
|
|
|
# Start running
|
|
with state_lock:
|
|
for t in state['tasks']:
|
|
if t['id'] == task_id:
|
|
t['status'] = 'running'
|
|
break
|
|
|
|
try:
|
|
if task['type'] == 'move':
|
|
steps = task['steps']
|
|
is_positive = steps > 0
|
|
stepper.set_dir(is_positive)
|
|
|
|
abs_steps = abs(steps)
|
|
if abs_steps > 0:
|
|
step_sign = 1 if is_positive else -1
|
|
delta = step_sign
|
|
step_change = step_sign
|
|
|
|
def on_step():
|
|
with state_lock:
|
|
state['dis'] += delta
|
|
for t in state['tasks']:
|
|
if t['id'] == task_id:
|
|
t['remaining_steps'] -= step_change
|
|
if is_positive and t['remaining_steps'] < 0:
|
|
t['remaining_steps'] = 0
|
|
elif not is_positive and t['remaining_steps'] > 0:
|
|
t['remaining_steps'] = 0
|
|
break
|
|
|
|
with state_lock:
|
|
current_speed = state.get('speed', 1200)
|
|
|
|
stepper.move_steps(abs_steps, current_speed, on_step=on_step)
|
|
|
|
elif task['type'] == 'measure':
|
|
# print(f"Measuring Rigol on {RIGOL_DEV}...")
|
|
try:
|
|
# 默认参数,可根据需求从 task 中获取
|
|
meas = measure_phase(
|
|
dev=RIGOL_DEV,
|
|
timeout=3.0,
|
|
points_mode="NORM",
|
|
fetch_idn=False
|
|
)
|
|
# print(f"Measure result: {meas}")
|
|
|
|
with state_lock:
|
|
state['last_measurement'] = meas
|
|
# 更新 state 中的简要字段
|
|
if 'dphi_rad' in meas:
|
|
state['phase'] = meas['dphi_rad']
|
|
if 'f0_hz' in meas:
|
|
state['freq'] = meas['f0_hz']
|
|
|
|
# 'amp1_pp_adc' 和 'amp2_pp_adc'
|
|
# 假设 p2p 保存 Ch1 的
|
|
if 'amp1_pp_adc' in meas:
|
|
state['p2p'] = meas['amp1_pp_adc']
|
|
|
|
except Exception as e:
|
|
print(f"Measurement failed: {e}")
|
|
|
|
except Exception as e:
|
|
print(f"Error in task {task_id}: {e}")
|
|
finally:
|
|
# Remove from tasks list when done
|
|
with state_lock:
|
|
state['tasks'] = [t for t in state['tasks'] if t['id'] != task_id]
|
|
task_queue.task_done()
|
|
except Exception as e:
|
|
print(f"Worker exception: {e}")
|
|
|
|
# Start worker thread
|
|
threading.Thread(target=worker, daemon=True).start()
|
|
|
|
@app.get('/state')
|
|
def get_state():
|
|
with state_lock:
|
|
return state
|
|
|
|
@app.post('/state/dis')
|
|
def set_state_dis(dis: int):
|
|
print(f"Setting dis to {dis}")
|
|
with state_lock:
|
|
state['dis'] = dis
|
|
return {'dis': state['dis']}
|
|
|
|
@app.post('/action/measure')
|
|
def action_measure():
|
|
task_id = str(uuid.uuid4())
|
|
task = {
|
|
'id': task_id,
|
|
'type': 'measure',
|
|
'status': 'pending',
|
|
'created_at': time.time()
|
|
}
|
|
|
|
with state_lock:
|
|
state['tasks'].append(task)
|
|
|
|
task_queue.put(task)
|
|
return {"status": "queued", "task": task}
|
|
|
|
@app.post('/state/speed')
|
|
def set_state_speed(speed: int):
|
|
print(f"Setting speed to {speed}")
|
|
with state_lock:
|
|
state['speed'] = speed
|
|
return {'speed': state['speed']}
|
|
|
|
@app.post('/action/move')
|
|
def action_move(steps: int):
|
|
task_id = str(uuid.uuid4())
|
|
task = {
|
|
'id': task_id,
|
|
'type': 'move',
|
|
'steps': steps,
|
|
'remaining_steps': steps,
|
|
'status': 'pending',
|
|
'created_at': time.time()
|
|
}
|
|
|
|
with state_lock:
|
|
state['tasks'].append(task)
|
|
|
|
task_queue.put(task)
|
|
return {"status": "queued", "task": task}
|
|
|
|
@app.post('/action/move_measure')
|
|
def action_move_measure(steps: int):
|
|
# Create Move Task
|
|
move_task_id = str(uuid.uuid4())
|
|
move_task = {
|
|
'id': move_task_id,
|
|
'type': 'move',
|
|
'steps': steps,
|
|
'remaining_steps': steps,
|
|
'status': 'pending',
|
|
'created_at': time.time()
|
|
}
|
|
|
|
# Create Measure Task
|
|
measure_task_id = str(uuid.uuid4())
|
|
measure_task = {
|
|
'id': measure_task_id,
|
|
'type': 'measure',
|
|
'status': 'pending',
|
|
'created_at': time.time()
|
|
}
|
|
|
|
with state_lock:
|
|
state['tasks'].append(move_task)
|
|
state['tasks'].append(measure_task)
|
|
|
|
task_queue.put(move_task)
|
|
task_queue.put(measure_task)
|
|
return {"status": "queued", "tasks": [move_task, measure_task]}
|
|
|
|
@app.post('/action/batch')
|
|
def action_batch(items: List[BatchTaskItem]):
|
|
new_tasks = []
|
|
|
|
for item in items:
|
|
count = max(1, item.repeat)
|
|
for _ in range(count):
|
|
if item.cmd == 'move':
|
|
steps = int(item.args.get('steps', 0))
|
|
|
|
task_id = str(uuid.uuid4())
|
|
task = {
|
|
'id': task_id,
|
|
'type': 'move',
|
|
'steps': steps,
|
|
'remaining_steps': steps,
|
|
'status': 'pending',
|
|
'created_at': time.time()
|
|
}
|
|
new_tasks.append(task)
|
|
|
|
elif item.cmd == 'measure':
|
|
task_id = str(uuid.uuid4())
|
|
task = {
|
|
'id': task_id,
|
|
'type': 'measure',
|
|
'status': 'pending',
|
|
'created_at': time.time()
|
|
}
|
|
new_tasks.append(task)
|
|
|
|
elif item.cmd == 'move_measure':
|
|
steps = int(item.args.get('steps', 0))
|
|
|
|
move_task_id = str(uuid.uuid4())
|
|
move_task = {
|
|
'id': move_task_id,
|
|
'type': 'move',
|
|
'steps': steps,
|
|
'remaining_steps': steps,
|
|
'status': 'pending',
|
|
'created_at': time.time()
|
|
}
|
|
new_tasks.append(move_task)
|
|
|
|
measure_task_id = str(uuid.uuid4())
|
|
measure_task = {
|
|
'id': measure_task_id,
|
|
'type': 'measure',
|
|
'status': 'pending',
|
|
'created_at': time.time()
|
|
}
|
|
new_tasks.append(measure_task)
|
|
|
|
with state_lock:
|
|
state['tasks'].extend(new_tasks)
|
|
|
|
for t in new_tasks:
|
|
task_queue.put(t)
|
|
|
|
return {"status": "queued", "count": len(new_tasks)}
|
|
|
|
@app.post('/action/cancel')
|
|
def action_cancel():
|
|
# Stop current motor move
|
|
stepper.stop()
|
|
|
|
# Clear queue
|
|
while not task_queue.empty():
|
|
try:
|
|
task_queue.get_nowait()
|
|
task_queue.task_done()
|
|
except queue.Empty:
|
|
break
|
|
|
|
# Clear tasks in state
|
|
with state_lock:
|
|
state['tasks'] = []
|
|
|
|
return {"status": "cancelled", "message": "All tasks cancelled"}
|
|
|
|
@app.get("/events")
|
|
async def sse(request: Request):
|
|
"""Server-Sent Events endpoint to push state updates."""
|
|
async def event_generator():
|
|
last_state_str = None
|
|
while True:
|
|
if await request.is_disconnected():
|
|
break
|
|
|
|
with state_lock:
|
|
# Create a safe copy of state for serialization
|
|
current_state = state.copy()
|
|
all_tasks = state['tasks']
|
|
current_state['total_tasks'] = len(all_tasks)
|
|
# Keep only the first 5 tasks (active + next pending) to reduce payload
|
|
current_state['tasks'] = list(all_tasks[:5])
|
|
|
|
# Serialize
|
|
try:
|
|
data = json.dumps(current_state)
|
|
except Exception:
|
|
data = "{}" # Fallback
|
|
|
|
if data != last_state_str:
|
|
yield f"data: {data}\n\n"
|
|
last_state_str = data
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
host = os.getenv("HOST", "0.0.0.0")
|
|
port = int(os.getenv("PORT", "8000"))
|
|
print(f"Starting FastAPI on http://{host}:{port} (docs: /docs)")
|
|
# When running this file as a script, avoid importing this module twice.
|
|
# Passing the app object directly prevents gpiozero from reserving the same pin twice.
|
|
uvicorn.run(app, host=host, port=port, reload=False)
|