import base64 import os import re import subprocess import tempfile import time import urllib.request from contextlib import asynccontextmanager from typing import Any import numpy as np import onnxruntime as ort import scipy.signal import soundfile as sf from fastapi import Body, FastAPI, File, Form, Header, HTTPException, UploadFile from fastapi.responses import PlainTextResponse SAMPLE_RATE = 16000 N_FFT = 400 HOP_LENGTH = 160 N_MELS = 80 MAX_MEL_FRAMES = 2000 END_TOKEN = 50257 TASK_CODE = {"en": 50259, "zh": 50260} TIMESTAMP_BEGIN = 50364 MODEL_NAME = os.getenv("MODEL_NAME", "whisper-base-onnx") API_KEY = os.getenv("STT_API_KEY", "") MAX_DECODE_TOKENS = int(os.getenv("MAX_DECODE_TOKENS", "128")) VLM_ENABLED = os.getenv("VLM_ENABLED", "false").lower() in { "1", "true", "yes", "on", } VLM_MODEL_NAME = os.getenv("VLM_MODEL_NAME", "qwen3-vl-2b-rkllm") VLM_DEMO_BIN = os.getenv( "VLM_DEMO_BIN", "/opt/rkllm-root/quickstart/demo_Linux_aarch64/demo" ) VLM_LIB_DIR = os.getenv( "VLM_LIB_DIR", "/opt/rkllm-root/quickstart/demo_Linux_aarch64/lib" ) VLM_ENCODER_MODEL_PATH = os.getenv( "VLM_ENCODER_MODEL_PATH", "/opt/rkllm-root/models/qwen3-vl-2b_vision_rk3588.rknn" ) VLM_LLM_MODEL_PATH = os.getenv( "VLM_LLM_MODEL_PATH", "/opt/rkllm-root/models/qwen3-vl-2b-instruct_w8a8_rk3588.rkllm", ) VLM_CORE_NUM = int(os.getenv("VLM_CORE_NUM", "3")) VLM_MAX_NEW_TOKENS = int(os.getenv("VLM_MAX_NEW_TOKENS", "256")) VLM_MAX_CONTEXT_LEN = int(os.getenv("VLM_MAX_CONTEXT_LEN", "4096")) VLM_IMG_START = os.getenv("VLM_IMG_START", "<|vision_start|>") VLM_IMG_END = os.getenv("VLM_IMG_END", "<|vision_end|>") VLM_IMG_CONTENT = os.getenv("VLM_IMG_CONTENT", "<|image_pad|>") VLM_TIMEOUT_SEC = int(os.getenv("VLM_TIMEOUT_SEC", "300")) ENCODER_MODEL_PATH = os.getenv( "ENCODER_MODEL_PATH", "/models/whisper_encoder_base_20s.onnx" ) DECODER_MODEL_PATH = os.getenv( "DECODER_MODEL_PATH", "/models/whisper_decoder_base_20s.onnx" ) MEL_FILTERS_PATH = os.getenv("MEL_FILTERS_PATH", "/models/mel_80_filters.txt") VOCAB_EN_PATH = os.getenv("VOCAB_EN_PATH", "/models/vocab_en.txt") VOCAB_ZH_PATH = os.getenv("VOCAB_ZH_PATH", "/models/vocab_zh.txt") STATE: dict[str, Any] = { "encoder": None, "decoder": None, "mel_filters": None, "vocab_en": {}, "vocab_zh": {}, } ANSI_RE = re.compile(r"\x1B\[[0-9;]*[A-Za-z]") def read_vocab(path: str) -> dict[str, str]: vocab: dict[str, str] = {} with open(path, "r", encoding="utf-8") as f: for line in f: line = line.rstrip("\n") if not line: continue parts = line.split(" ", 1) token_id = parts[0] token_text = parts[1] if len(parts) > 1 else "" vocab[token_id] = token_text return vocab def load_mel_filters(path: str) -> np.ndarray: data = np.loadtxt(path, dtype=np.float32) return data.reshape((80, 201)) def ensure_sample_rate(waveform: np.ndarray, source_rate: int) -> np.ndarray: if source_rate == SAMPLE_RATE: return waveform target_len = int(round(len(waveform) * SAMPLE_RATE / source_rate)) return scipy.signal.resample(waveform, target_len).astype(np.float32) def to_mono(waveform: np.ndarray) -> np.ndarray: if waveform.ndim == 1: return waveform return waveform.mean(axis=1) def log_mel_spectrogram(audio: np.ndarray, mel_filters: np.ndarray) -> np.ndarray: _, _, stft = scipy.signal.stft( audio, fs=SAMPLE_RATE, window="hann", nperseg=N_FFT, noverlap=N_FFT - HOP_LENGTH, nfft=N_FFT, boundary=None, padded=False, ) magnitudes = np.abs(stft).astype(np.float32) ** 2 if magnitudes.shape[1] > 0: magnitudes = magnitudes[:, :-1] mel_spec = mel_filters @ magnitudes log_spec = np.log10(np.clip(mel_spec, 1e-10, None)) log_spec = np.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 return log_spec.astype(np.float32) def pad_or_trim(mel: np.ndarray) -> np.ndarray: out = np.zeros((N_MELS, MAX_MEL_FRAMES), dtype=np.float32) frames = min(mel.shape[1], MAX_MEL_FRAMES) out[:, :frames] = mel[:, :frames] return np.expand_dims(out, 0) def decode_tokens(vocab: dict[str, str], token_ids: list[int], language: str) -> str: pieces = [vocab.get(str(t), "") for t in token_ids] text = ( "".join(pieces) .replace("\u0120", " ") .replace("<|endoftext|>", "") .replace("\n", "") .strip() ) if language == "zh": try: text = base64.b64decode(text).decode("utf-8", errors="replace") except Exception: pass return text def transcribe_file(path: str, language: str) -> str: waveform, sr = sf.read(path) waveform = to_mono(np.asarray(waveform, dtype=np.float32)) waveform = ensure_sample_rate(waveform, sr) mel = log_mel_spectrogram(waveform, STATE["mel_filters"]) encoder_input = pad_or_trim(mel) encoded = STATE["encoder"].run(None, {"x": encoder_input})[0] tokens = [50258, TASK_CODE[language], 50359, 50363] emitted: list[int] = [] for _ in range(MAX_DECODE_TOKENS): decoder_out = STATE["decoder"].run( None, { "tokens": np.asarray([tokens], dtype=np.int64), "audio": encoded, }, )[0] next_token = int(decoder_out[0, -1].argmax()) if next_token == END_TOKEN: break tokens.append(next_token) if next_token <= TIMESTAMP_BEGIN: emitted.append(next_token) vocab = STATE["vocab_en"] if language == "en" else STATE["vocab_zh"] return decode_tokens(vocab, emitted, language) def convert_to_wav(src_path: str) -> str: fd, out_path = tempfile.mkstemp(suffix=".wav") os.close(fd) cmd = [ "ffmpeg", "-y", "-v", "error", "-i", src_path, "-ac", "1", "-ar", str(SAMPLE_RATE), out_path, ] try: subprocess.run(cmd, check=True) return out_path except subprocess.CalledProcessError as exc: raise HTTPException(status_code=400, detail=f"Failed to decode audio: {exc}") def check_api_key(authorization: str | None) -> None: if not API_KEY: return if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing Bearer token") token = authorization.split(" ", 1)[1].strip() if token != API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") def validate_vlm_enabled() -> None: if not VLM_ENABLED: raise HTTPException( status_code=503, detail="VLM endpoint is disabled. Set VLM_ENABLED=true.", ) required = [VLM_DEMO_BIN, VLM_ENCODER_MODEL_PATH, VLM_LLM_MODEL_PATH] for path in required: if not os.path.exists(path): raise HTTPException(status_code=500, detail=f"Missing VLM file: {path}") def image_url_to_file(url: str) -> str: fd, out_path = tempfile.mkstemp(suffix=".jpg") os.close(fd) try: if url.startswith("data:"): payload = url.split(",", 1)[1] image_bytes = base64.b64decode(payload) with open(out_path, "wb") as f: f.write(image_bytes) return out_path if url.startswith("http://") or url.startswith("https://"): with urllib.request.urlopen(url, timeout=30) as resp: image_bytes = resp.read() with open(out_path, "wb") as f: f.write(image_bytes) return out_path raise HTTPException( status_code=400, detail="Unsupported image_url. Use data: or https:// URL.", ) except HTTPException: if os.path.exists(out_path): os.unlink(out_path) raise except Exception as exc: if os.path.exists(out_path): os.unlink(out_path) raise HTTPException(status_code=400, detail=f"Failed to load image_url: {exc}") def clean_vlm_output(text: str) -> str: text = ANSI_RE.sub("", text) if "robot:" in text: text = text.rsplit("robot:", 1)[1] if "\nuser:" in text: text = text.split("\nuser:", 1)[0] return text.strip() def run_vlm(image_path: str, prompt: str) -> str: validate_vlm_enabled() llm_input = prompt if prompt.strip().startswith("") else f"{prompt}" cmd = [ VLM_DEMO_BIN, image_path, VLM_ENCODER_MODEL_PATH, VLM_LLM_MODEL_PATH, str(VLM_MAX_NEW_TOKENS), str(VLM_MAX_CONTEXT_LEN), str(VLM_CORE_NUM), VLM_IMG_START, VLM_IMG_END, VLM_IMG_CONTENT, ] env = os.environ.copy() current_ld = env.get("LD_LIBRARY_PATH", "") env["LD_LIBRARY_PATH"] = ( f"{VLM_LIB_DIR}:{current_ld}" if current_ld else VLM_LIB_DIR ) try: proc = subprocess.run( cmd, input=f"{llm_input}\nexit\n", text=True, capture_output=True, check=True, env=env, timeout=VLM_TIMEOUT_SEC, ) except subprocess.TimeoutExpired as exc: raise HTTPException(status_code=504, detail=f"VLM timed out: {exc}") except subprocess.CalledProcessError as exc: message = exc.stderr.strip() if exc.stderr else str(exc) raise HTTPException(status_code=500, detail=f"VLM execution failed: {message}") output = clean_vlm_output(proc.stdout) if not output: raise HTTPException(status_code=500, detail="VLM returned empty output") return output def extract_prompt_and_image(messages: list[dict[str, Any]]) -> tuple[str, str]: prompt = "" image_url = "" for msg in reversed(messages): if msg.get("role") != "user": continue content = msg.get("content") if isinstance(content, str): prompt = content.strip() elif isinstance(content, list): text_parts: list[str] = [] for part in content: if part.get("type") == "text" and part.get("text"): text_parts.append(str(part["text"])) if part.get("type") == "image_url": image_data = part.get("image_url") if isinstance(image_data, dict): image_url = str(image_data.get("url", "")) elif isinstance(image_data, str): image_url = image_data prompt = "\n".join([p for p in text_parts if p.strip()]).strip() if prompt or image_url: break if not prompt: prompt = "Describe this image in English." if not image_url: raise HTTPException( status_code=400, detail="messages must include image_url content in the user message", ) return prompt, image_url @asynccontextmanager async def lifespan(_: FastAPI): for path in [ ENCODER_MODEL_PATH, DECODER_MODEL_PATH, MEL_FILTERS_PATH, VOCAB_EN_PATH, VOCAB_ZH_PATH, ]: if not os.path.exists(path): raise RuntimeError(f"Required file not found: {path}") STATE["encoder"] = ort.InferenceSession( ENCODER_MODEL_PATH, providers=["CPUExecutionProvider"] ) STATE["decoder"] = ort.InferenceSession( DECODER_MODEL_PATH, providers=["CPUExecutionProvider"] ) STATE["mel_filters"] = load_mel_filters(MEL_FILTERS_PATH) STATE["vocab_en"] = read_vocab(VOCAB_EN_PATH) STATE["vocab_zh"] = read_vocab(VOCAB_ZH_PATH) yield app = FastAPI(title="RK Whisper STT API", version="0.1.0", lifespan=lifespan) @app.get("/health") async def health() -> dict[str, Any]: return { "ok": True, "model": MODEL_NAME, "encoder": ENCODER_MODEL_PATH, "decoder": DECODER_MODEL_PATH, "vlm_enabled": VLM_ENABLED, "vlm_model": VLM_MODEL_NAME, } @app.post("/v1/audio/transcriptions") async def transcriptions( file: UploadFile = File(...), model: str = Form(default=MODEL_NAME), language: str = Form(default="en"), response_format: str = Form(default="json"), authorization: str | None = Header(default=None), ): check_api_key(authorization) if model != MODEL_NAME: raise HTTPException( status_code=400, detail=f"Unsupported model '{model}', expected '{MODEL_NAME}'", ) if language not in TASK_CODE: raise HTTPException(status_code=400, detail="language must be en or zh") fd, input_path = tempfile.mkstemp(suffix="_upload") os.close(fd) wav_path = "" try: payload = await file.read() with open(input_path, "wb") as f: f.write(payload) wav_path = convert_to_wav(input_path) text = transcribe_file(wav_path, language) finally: if os.path.exists(input_path): os.unlink(input_path) if wav_path and os.path.exists(wav_path): os.unlink(wav_path) if response_format == "text": return PlainTextResponse(text) if response_format == "verbose_json": return { "task": "transcribe", "language": language, "model": MODEL_NAME, "text": text, "segments": [], } return {"text": text} @app.post("/v1/vision/understand") async def vision_understand( file: UploadFile = File(...), prompt: str = Form(default="Describe this image in English."), model: str = Form(default=VLM_MODEL_NAME), response_format: str = Form(default="json"), authorization: str | None = Header(default=None), ): check_api_key(authorization) if model != VLM_MODEL_NAME: raise HTTPException( status_code=400, detail=f"Unsupported model '{model}', expected '{VLM_MODEL_NAME}'", ) fd, image_path = tempfile.mkstemp(suffix="_image") os.close(fd) try: payload = await file.read() with open(image_path, "wb") as f: f.write(payload) text = run_vlm(image_path, prompt) finally: if os.path.exists(image_path): os.unlink(image_path) if response_format == "text": return PlainTextResponse(text) return {"text": text, "model": VLM_MODEL_NAME} @app.post("/v1/chat/completions") async def chat_completions( body: dict[str, Any] = Body(...), authorization: str | None = Header(default=None), ): check_api_key(authorization) model = str(body.get("model", VLM_MODEL_NAME)) if model != VLM_MODEL_NAME: raise HTTPException( status_code=400, detail=f"Unsupported model '{model}', expected '{VLM_MODEL_NAME}'", ) messages = body.get("messages") if not isinstance(messages, list) or not messages: raise HTTPException(status_code=400, detail="messages must be a non-empty list") prompt, image_url = extract_prompt_and_image(messages) image_path = image_url_to_file(image_url) try: text = run_vlm(image_path, prompt) finally: if os.path.exists(image_path): os.unlink(image_path) return { "id": "chatcmpl-rk-vl-1", "object": "chat.completion", "created": int(time.time()), "model": VLM_MODEL_NAME, "choices": [ { "index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop", } ], "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, }