Files
rknn-inference-server/app/server.py
2026-02-24 17:36:44 -05:00

506 lines
15 KiB
Python

import base64
import os
import re
import subprocess
import tempfile
import time
import urllib.request
from contextlib import asynccontextmanager
from typing import Any
import numpy as np
import onnxruntime as ort
import scipy.signal
import soundfile as sf
from fastapi import Body, FastAPI, File, Form, Header, HTTPException, UploadFile
from fastapi.responses import PlainTextResponse
SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 80
MAX_MEL_FRAMES = 2000
END_TOKEN = 50257
TASK_CODE = {"en": 50259, "zh": 50260}
TIMESTAMP_BEGIN = 50364
MODEL_NAME = os.getenv("MODEL_NAME", "whisper-base-onnx")
API_KEY = os.getenv("STT_API_KEY", "")
MAX_DECODE_TOKENS = int(os.getenv("MAX_DECODE_TOKENS", "128"))
VLM_ENABLED = os.getenv("VLM_ENABLED", "false").lower() in {
"1",
"true",
"yes",
"on",
}
VLM_MODEL_NAME = os.getenv("VLM_MODEL_NAME", "qwen3-vl-2b-rkllm")
VLM_DEMO_BIN = os.getenv(
"VLM_DEMO_BIN", "/opt/rkllm-root/quickstart/demo_Linux_aarch64/demo"
)
VLM_LIB_DIR = os.getenv(
"VLM_LIB_DIR", "/opt/rkllm-root/quickstart/demo_Linux_aarch64/lib"
)
VLM_ENCODER_MODEL_PATH = os.getenv(
"VLM_ENCODER_MODEL_PATH", "/opt/rkllm-root/models/qwen3-vl-2b_vision_rk3588.rknn"
)
VLM_LLM_MODEL_PATH = os.getenv(
"VLM_LLM_MODEL_PATH",
"/opt/rkllm-root/models/qwen3-vl-2b-instruct_w8a8_rk3588.rkllm",
)
VLM_CORE_NUM = int(os.getenv("VLM_CORE_NUM", "3"))
VLM_MAX_NEW_TOKENS = int(os.getenv("VLM_MAX_NEW_TOKENS", "256"))
VLM_MAX_CONTEXT_LEN = int(os.getenv("VLM_MAX_CONTEXT_LEN", "4096"))
VLM_IMG_START = os.getenv("VLM_IMG_START", "<|vision_start|>")
VLM_IMG_END = os.getenv("VLM_IMG_END", "<|vision_end|>")
VLM_IMG_CONTENT = os.getenv("VLM_IMG_CONTENT", "<|image_pad|>")
VLM_TIMEOUT_SEC = int(os.getenv("VLM_TIMEOUT_SEC", "300"))
ENCODER_MODEL_PATH = os.getenv(
"ENCODER_MODEL_PATH", "/models/whisper_encoder_base_20s.onnx"
)
DECODER_MODEL_PATH = os.getenv(
"DECODER_MODEL_PATH", "/models/whisper_decoder_base_20s.onnx"
)
MEL_FILTERS_PATH = os.getenv("MEL_FILTERS_PATH", "/models/mel_80_filters.txt")
VOCAB_EN_PATH = os.getenv("VOCAB_EN_PATH", "/models/vocab_en.txt")
VOCAB_ZH_PATH = os.getenv("VOCAB_ZH_PATH", "/models/vocab_zh.txt")
STATE: dict[str, Any] = {
"encoder": None,
"decoder": None,
"mel_filters": None,
"vocab_en": {},
"vocab_zh": {},
}
ANSI_RE = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
def read_vocab(path: str) -> dict[str, str]:
vocab: dict[str, str] = {}
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.rstrip("\n")
if not line:
continue
parts = line.split(" ", 1)
token_id = parts[0]
token_text = parts[1] if len(parts) > 1 else ""
vocab[token_id] = token_text
return vocab
def load_mel_filters(path: str) -> np.ndarray:
data = np.loadtxt(path, dtype=np.float32)
return data.reshape((80, 201))
def ensure_sample_rate(waveform: np.ndarray, source_rate: int) -> np.ndarray:
if source_rate == SAMPLE_RATE:
return waveform
target_len = int(round(len(waveform) * SAMPLE_RATE / source_rate))
return scipy.signal.resample(waveform, target_len).astype(np.float32)
def to_mono(waveform: np.ndarray) -> np.ndarray:
if waveform.ndim == 1:
return waveform
return waveform.mean(axis=1)
def log_mel_spectrogram(audio: np.ndarray, mel_filters: np.ndarray) -> np.ndarray:
_, _, stft = scipy.signal.stft(
audio,
fs=SAMPLE_RATE,
window="hann",
nperseg=N_FFT,
noverlap=N_FFT - HOP_LENGTH,
nfft=N_FFT,
boundary=None,
padded=False,
)
magnitudes = np.abs(stft).astype(np.float32) ** 2
if magnitudes.shape[1] > 0:
magnitudes = magnitudes[:, :-1]
mel_spec = mel_filters @ magnitudes
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec.astype(np.float32)
def pad_or_trim(mel: np.ndarray) -> np.ndarray:
out = np.zeros((N_MELS, MAX_MEL_FRAMES), dtype=np.float32)
frames = min(mel.shape[1], MAX_MEL_FRAMES)
out[:, :frames] = mel[:, :frames]
return np.expand_dims(out, 0)
def decode_tokens(vocab: dict[str, str], token_ids: list[int], language: str) -> str:
pieces = [vocab.get(str(t), "") for t in token_ids]
text = (
"".join(pieces)
.replace("\u0120", " ")
.replace("<|endoftext|>", "")
.replace("\n", "")
.strip()
)
if language == "zh":
try:
text = base64.b64decode(text).decode("utf-8", errors="replace")
except Exception:
pass
return text
def transcribe_file(path: str, language: str) -> str:
waveform, sr = sf.read(path)
waveform = to_mono(np.asarray(waveform, dtype=np.float32))
waveform = ensure_sample_rate(waveform, sr)
mel = log_mel_spectrogram(waveform, STATE["mel_filters"])
encoder_input = pad_or_trim(mel)
encoded = STATE["encoder"].run(None, {"x": encoder_input})[0]
tokens = [50258, TASK_CODE[language], 50359, 50363]
emitted: list[int] = []
for _ in range(MAX_DECODE_TOKENS):
decoder_out = STATE["decoder"].run(
None,
{
"tokens": np.asarray([tokens], dtype=np.int64),
"audio": encoded,
},
)[0]
next_token = int(decoder_out[0, -1].argmax())
if next_token == END_TOKEN:
break
tokens.append(next_token)
if next_token <= TIMESTAMP_BEGIN:
emitted.append(next_token)
vocab = STATE["vocab_en"] if language == "en" else STATE["vocab_zh"]
return decode_tokens(vocab, emitted, language)
def convert_to_wav(src_path: str) -> str:
fd, out_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
cmd = [
"ffmpeg",
"-y",
"-v",
"error",
"-i",
src_path,
"-ac",
"1",
"-ar",
str(SAMPLE_RATE),
out_path,
]
try:
subprocess.run(cmd, check=True)
return out_path
except subprocess.CalledProcessError as exc:
raise HTTPException(status_code=400, detail=f"Failed to decode audio: {exc}")
def check_api_key(authorization: str | None) -> None:
if not API_KEY:
return
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing Bearer token")
token = authorization.split(" ", 1)[1].strip()
if token != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
def validate_vlm_enabled() -> None:
if not VLM_ENABLED:
raise HTTPException(
status_code=503,
detail="VLM endpoint is disabled. Set VLM_ENABLED=true.",
)
required = [VLM_DEMO_BIN, VLM_ENCODER_MODEL_PATH, VLM_LLM_MODEL_PATH]
for path in required:
if not os.path.exists(path):
raise HTTPException(status_code=500, detail=f"Missing VLM file: {path}")
def image_url_to_file(url: str) -> str:
fd, out_path = tempfile.mkstemp(suffix=".jpg")
os.close(fd)
try:
if url.startswith("data:"):
payload = url.split(",", 1)[1]
image_bytes = base64.b64decode(payload)
with open(out_path, "wb") as f:
f.write(image_bytes)
return out_path
if url.startswith("http://") or url.startswith("https://"):
with urllib.request.urlopen(url, timeout=30) as resp:
image_bytes = resp.read()
with open(out_path, "wb") as f:
f.write(image_bytes)
return out_path
raise HTTPException(
status_code=400,
detail="Unsupported image_url. Use data: or https:// URL.",
)
except HTTPException:
if os.path.exists(out_path):
os.unlink(out_path)
raise
except Exception as exc:
if os.path.exists(out_path):
os.unlink(out_path)
raise HTTPException(status_code=400, detail=f"Failed to load image_url: {exc}")
def clean_vlm_output(text: str) -> str:
text = ANSI_RE.sub("", text)
if "robot:" in text:
text = text.rsplit("robot:", 1)[1]
if "\nuser:" in text:
text = text.split("\nuser:", 1)[0]
return text.strip()
def run_vlm(image_path: str, prompt: str) -> str:
validate_vlm_enabled()
llm_input = prompt if prompt.strip().startswith("<image>") else f"<image>{prompt}"
cmd = [
VLM_DEMO_BIN,
image_path,
VLM_ENCODER_MODEL_PATH,
VLM_LLM_MODEL_PATH,
str(VLM_MAX_NEW_TOKENS),
str(VLM_MAX_CONTEXT_LEN),
str(VLM_CORE_NUM),
VLM_IMG_START,
VLM_IMG_END,
VLM_IMG_CONTENT,
]
env = os.environ.copy()
current_ld = env.get("LD_LIBRARY_PATH", "")
env["LD_LIBRARY_PATH"] = (
f"{VLM_LIB_DIR}:{current_ld}" if current_ld else VLM_LIB_DIR
)
try:
proc = subprocess.run(
cmd,
input=f"{llm_input}\nexit\n",
text=True,
capture_output=True,
check=True,
env=env,
timeout=VLM_TIMEOUT_SEC,
)
except subprocess.TimeoutExpired as exc:
raise HTTPException(status_code=504, detail=f"VLM timed out: {exc}")
except subprocess.CalledProcessError as exc:
message = exc.stderr.strip() if exc.stderr else str(exc)
raise HTTPException(status_code=500, detail=f"VLM execution failed: {message}")
output = clean_vlm_output(proc.stdout)
if not output:
raise HTTPException(status_code=500, detail="VLM returned empty output")
return output
def extract_prompt_and_image(messages: list[dict[str, Any]]) -> tuple[str, str]:
prompt = ""
image_url = ""
for msg in reversed(messages):
if msg.get("role") != "user":
continue
content = msg.get("content")
if isinstance(content, str):
prompt = content.strip()
elif isinstance(content, list):
text_parts: list[str] = []
for part in content:
if part.get("type") == "text" and part.get("text"):
text_parts.append(str(part["text"]))
if part.get("type") == "image_url":
image_data = part.get("image_url")
if isinstance(image_data, dict):
image_url = str(image_data.get("url", ""))
elif isinstance(image_data, str):
image_url = image_data
prompt = "\n".join([p for p in text_parts if p.strip()]).strip()
if prompt or image_url:
break
if not prompt:
prompt = "Describe this image in English."
if not image_url:
raise HTTPException(
status_code=400,
detail="messages must include image_url content in the user message",
)
return prompt, image_url
@asynccontextmanager
async def lifespan(_: FastAPI):
for path in [
ENCODER_MODEL_PATH,
DECODER_MODEL_PATH,
MEL_FILTERS_PATH,
VOCAB_EN_PATH,
VOCAB_ZH_PATH,
]:
if not os.path.exists(path):
raise RuntimeError(f"Required file not found: {path}")
STATE["encoder"] = ort.InferenceSession(
ENCODER_MODEL_PATH, providers=["CPUExecutionProvider"]
)
STATE["decoder"] = ort.InferenceSession(
DECODER_MODEL_PATH, providers=["CPUExecutionProvider"]
)
STATE["mel_filters"] = load_mel_filters(MEL_FILTERS_PATH)
STATE["vocab_en"] = read_vocab(VOCAB_EN_PATH)
STATE["vocab_zh"] = read_vocab(VOCAB_ZH_PATH)
yield
app = FastAPI(title="RK Whisper STT API", version="0.1.0", lifespan=lifespan)
@app.get("/health")
async def health() -> dict[str, Any]:
return {
"ok": True,
"model": MODEL_NAME,
"encoder": ENCODER_MODEL_PATH,
"decoder": DECODER_MODEL_PATH,
"vlm_enabled": VLM_ENABLED,
"vlm_model": VLM_MODEL_NAME,
}
@app.post("/v1/audio/transcriptions")
async def transcriptions(
file: UploadFile = File(...),
model: str = Form(default=MODEL_NAME),
language: str = Form(default="en"),
response_format: str = Form(default="json"),
authorization: str | None = Header(default=None),
):
check_api_key(authorization)
if model != MODEL_NAME:
raise HTTPException(
status_code=400,
detail=f"Unsupported model '{model}', expected '{MODEL_NAME}'",
)
if language not in TASK_CODE:
raise HTTPException(status_code=400, detail="language must be en or zh")
fd, input_path = tempfile.mkstemp(suffix="_upload")
os.close(fd)
wav_path = ""
try:
payload = await file.read()
with open(input_path, "wb") as f:
f.write(payload)
wav_path = convert_to_wav(input_path)
text = transcribe_file(wav_path, language)
finally:
if os.path.exists(input_path):
os.unlink(input_path)
if wav_path and os.path.exists(wav_path):
os.unlink(wav_path)
if response_format == "text":
return PlainTextResponse(text)
if response_format == "verbose_json":
return {
"task": "transcribe",
"language": language,
"model": MODEL_NAME,
"text": text,
"segments": [],
}
return {"text": text}
@app.post("/v1/vision/understand")
async def vision_understand(
file: UploadFile = File(...),
prompt: str = Form(default="Describe this image in English."),
model: str = Form(default=VLM_MODEL_NAME),
response_format: str = Form(default="json"),
authorization: str | None = Header(default=None),
):
check_api_key(authorization)
if model != VLM_MODEL_NAME:
raise HTTPException(
status_code=400,
detail=f"Unsupported model '{model}', expected '{VLM_MODEL_NAME}'",
)
fd, image_path = tempfile.mkstemp(suffix="_image")
os.close(fd)
try:
payload = await file.read()
with open(image_path, "wb") as f:
f.write(payload)
text = run_vlm(image_path, prompt)
finally:
if os.path.exists(image_path):
os.unlink(image_path)
if response_format == "text":
return PlainTextResponse(text)
return {"text": text, "model": VLM_MODEL_NAME}
@app.post("/v1/chat/completions")
async def chat_completions(
body: dict[str, Any] = Body(...),
authorization: str | None = Header(default=None),
):
check_api_key(authorization)
model = str(body.get("model", VLM_MODEL_NAME))
if model != VLM_MODEL_NAME:
raise HTTPException(
status_code=400,
detail=f"Unsupported model '{model}', expected '{VLM_MODEL_NAME}'",
)
messages = body.get("messages")
if not isinstance(messages, list) or not messages:
raise HTTPException(status_code=400, detail="messages must be a non-empty list")
prompt, image_url = extract_prompt_and_image(messages)
image_path = image_url_to_file(image_url)
try:
text = run_vlm(image_path, prompt)
finally:
if os.path.exists(image_path):
os.unlink(image_path)
return {
"id": "chatcmpl-rk-vl-1",
"object": "chat.completion",
"created": int(time.time()),
"model": VLM_MODEL_NAME,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}