import os from fastapi import FastAPI, HTTPException, BackgroundTasks, Request from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from gpiozero import Buzzer, OutputDevice import time import threading import queue import uuid import json import asyncio from typing import Optional, Callable, List, Dict, Any from pydantic import BaseModel from rigol_phase import measure_phase, DEV as RIGOL_DEV app = FastAPI() class BatchTaskItem(BaseModel): cmd: str args: Dict[str, Any] = {} repeat: int = 1 motor_lock = threading.Lock() def _get_cors_origins() -> list[str]: raw = os.getenv("CORS_ORIGINS", "*").strip() if raw == "*": return ["*"] return [origin.strip() for origin in raw.split(",") if origin.strip()] _cors_origins = _get_cors_origins() _cors_allow_credentials_env = os.getenv("CORS_ALLOW_CREDENTIALS", "false").strip().lower() in ( "1", "true", "yes", "y", "on", ) _cors_allow_credentials = False if _cors_origins == ["*"] else _cors_allow_credentials_env app.add_middleware( CORSMiddleware, allow_origins=_cors_origins, allow_credentials=_cors_allow_credentials, allow_methods=[m.strip() for m in os.getenv("CORS_ALLOW_METHODS", "*").split(",") if m.strip()], allow_headers=[h.strip() for h in os.getenv("CORS_ALLOW_HEADERS", "*").split(",") if h.strip()], ) pin = int(os.getenv("BEE_PIN", "23")) BEE = Buzzer(pin) class GPIOStepper: def __init__( self, *, dir_pin: int = 22, step_pin: int = 27, dir_invert: bool = False, ) -> None: self.dir_invert = bool(dir_invert) # gpiozero 默认使用 BCM 编号 self._dir = OutputDevice(dir_pin) self._step = OutputDevice(step_pin) self._step.off() self._stop_event = threading.Event() def close(self) -> None: try: self._step.off() except Exception: pass try: self._dir.off() except Exception: pass def stop(self) -> None: self._stop_event.set() def set_dir(self, inc: bool) -> None: """Set direction. inc=True means microsteps increases. """ value = 1 if inc else 0 if self.dir_invert: value = 0 if value else 1 self._dir.value = value def move_steps(self, steps: int, freq_hz: int, on_step: Optional[Callable[[], None]] = None) -> None: """Move by microsteps.""" self._stop_event.clear() if steps < 0: raise ValueError("steps must be >= 0") if freq_hz <= 0: raise ValueError("freq_hz must be > 0") pulses = int(steps) if pulses <= 0: return half_period = 0.5 / float(freq_hz) for _ in range(pulses): if self._stop_event.is_set(): break self._step.on() time.sleep(half_period) self._step.off() time.sleep(half_period) if on_step: on_step() stepper = GPIOStepper() @app.get("/") def read_root(): return {"Hello": "World"} @app.get("/ping") def ping(): return {"ping": "pong"} @app.get("/bee") def bee_sound(): try: BEE.beep(on_time=0.1, off_time=0, n=1) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) return {"status": "Bee buzzed for 0.1 seconds"} state = { 'dis': 16000, # 发射器与接收器距离 microsteps 'phase': 0, # 相位差 rad 'freq': 0, # 频率 Hz 'p2p': 0, # 峰峰值 mV 'speed': 1200, # 电机速度 Hz 'tasks': [], # 任务列表 'last_measurement': {} # 最新示波器测量结果 } state_lock = threading.Lock() task_queue = queue.Queue() def worker(): while True: try: task = task_queue.get() if task is None: continue task_id = task['id'] # Start running with state_lock: for t in state['tasks']: if t['id'] == task_id: t['status'] = 'running' break try: if task['type'] == 'move': steps = task['steps'] is_positive = steps > 0 stepper.set_dir(is_positive) abs_steps = abs(steps) if abs_steps > 0: step_sign = 1 if is_positive else -1 delta = step_sign step_change = step_sign def on_step(): with state_lock: state['dis'] += delta for t in state['tasks']: if t['id'] == task_id: t['remaining_steps'] -= step_change if is_positive and t['remaining_steps'] < 0: t['remaining_steps'] = 0 elif not is_positive and t['remaining_steps'] > 0: t['remaining_steps'] = 0 break with state_lock: current_speed = state.get('speed', 1200) stepper.move_steps(abs_steps, current_speed, on_step=on_step) elif task['type'] == 'measure': # print(f"Measuring Rigol on {RIGOL_DEV}...") try: # 默认参数,可根据需求从 task 中获取 meas = measure_phase( dev=RIGOL_DEV, timeout=3.0, points_mode="NORM", fetch_idn=False ) # print(f"Measure result: {meas}") with state_lock: state['last_measurement'] = meas # 更新 state 中的简要字段 if 'dphi_rad' in meas: state['phase'] = meas['dphi_rad'] if 'f0_hz' in meas: state['freq'] = meas['f0_hz'] # 'amp1_pp_adc' 和 'amp2_pp_adc' # 假设 p2p 保存 Ch1 的 if 'amp1_pp_adc' in meas: state['p2p'] = meas['amp1_pp_adc'] except Exception as e: print(f"Measurement failed: {e}") except Exception as e: print(f"Error in task {task_id}: {e}") finally: # Remove from tasks list when done with state_lock: state['tasks'] = [t for t in state['tasks'] if t['id'] != task_id] task_queue.task_done() except Exception as e: print(f"Worker exception: {e}") # Start worker thread threading.Thread(target=worker, daemon=True).start() @app.get('/state') def get_state(): with state_lock: return state @app.post('/state/dis') def set_state_dis(dis: int): print(f"Setting dis to {dis}") with state_lock: state['dis'] = dis return {'dis': state['dis']} @app.post('/action/measure') def action_measure(): task_id = str(uuid.uuid4()) task = { 'id': task_id, 'type': 'measure', 'status': 'pending', 'created_at': time.time() } with state_lock: state['tasks'].append(task) task_queue.put(task) return {"status": "queued", "task": task} @app.post('/state/speed') def set_state_speed(speed: int): print(f"Setting speed to {speed}") with state_lock: state['speed'] = speed return {'speed': state['speed']} @app.post('/action/move') def action_move(steps: int): task_id = str(uuid.uuid4()) task = { 'id': task_id, 'type': 'move', 'steps': steps, 'remaining_steps': steps, 'status': 'pending', 'created_at': time.time() } with state_lock: state['tasks'].append(task) task_queue.put(task) return {"status": "queued", "task": task} @app.post('/action/move_measure') def action_move_measure(steps: int): # Create Move Task move_task_id = str(uuid.uuid4()) move_task = { 'id': move_task_id, 'type': 'move', 'steps': steps, 'remaining_steps': steps, 'status': 'pending', 'created_at': time.time() } # Create Measure Task measure_task_id = str(uuid.uuid4()) measure_task = { 'id': measure_task_id, 'type': 'measure', 'status': 'pending', 'created_at': time.time() } with state_lock: state['tasks'].append(move_task) state['tasks'].append(measure_task) task_queue.put(move_task) task_queue.put(measure_task) return {"status": "queued", "tasks": [move_task, measure_task]} @app.post('/action/batch') def action_batch(items: List[BatchTaskItem]): new_tasks = [] for item in items: count = max(1, item.repeat) for _ in range(count): if item.cmd == 'move': steps = int(item.args.get('steps', 0)) task_id = str(uuid.uuid4()) task = { 'id': task_id, 'type': 'move', 'steps': steps, 'remaining_steps': steps, 'status': 'pending', 'created_at': time.time() } new_tasks.append(task) elif item.cmd == 'measure': task_id = str(uuid.uuid4()) task = { 'id': task_id, 'type': 'measure', 'status': 'pending', 'created_at': time.time() } new_tasks.append(task) elif item.cmd == 'move_measure': steps = int(item.args.get('steps', 0)) move_task_id = str(uuid.uuid4()) move_task = { 'id': move_task_id, 'type': 'move', 'steps': steps, 'remaining_steps': steps, 'status': 'pending', 'created_at': time.time() } new_tasks.append(move_task) measure_task_id = str(uuid.uuid4()) measure_task = { 'id': measure_task_id, 'type': 'measure', 'status': 'pending', 'created_at': time.time() } new_tasks.append(measure_task) with state_lock: state['tasks'].extend(new_tasks) for t in new_tasks: task_queue.put(t) return {"status": "queued", "count": len(new_tasks)} @app.post('/action/cancel') def action_cancel(): # Stop current motor move stepper.stop() # Clear queue while not task_queue.empty(): try: task_queue.get_nowait() task_queue.task_done() except queue.Empty: break # Clear tasks in state with state_lock: state['tasks'] = [] return {"status": "cancelled", "message": "All tasks cancelled"} @app.get("/events") async def sse(request: Request): """Server-Sent Events endpoint to push state updates.""" async def event_generator(): last_state_str = None while True: if await request.is_disconnected(): break with state_lock: # Create a safe copy of state for serialization current_state = state.copy() all_tasks = state['tasks'] current_state['total_tasks'] = len(all_tasks) # Keep only the first 5 tasks (active + next pending) to reduce payload current_state['tasks'] = list(all_tasks[:5]) # Serialize try: data = json.dumps(current_state) except Exception: data = "{}" # Fallback if data != last_state_str: yield f"data: {data}\n\n" last_state_str = data await asyncio.sleep(0.1) return StreamingResponse(event_generator(), media_type="text/event-stream") if __name__ == "__main__": import uvicorn host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", "8000")) print(f"Starting FastAPI on http://{host}:{port} (docs: /docs)") # When running this file as a script, avoid importing this module twice. # Passing the app object directly prevents gpiozero from reserving the same pin twice. uvicorn.run(app, host=host, port=port, reload=False)