506 lines
15 KiB
Python
506 lines
15 KiB
Python
import base64
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import tempfile
|
|
import time
|
|
import urllib.request
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
import scipy.signal
|
|
import soundfile as sf
|
|
from fastapi import Body, FastAPI, File, Form, Header, HTTPException, UploadFile
|
|
from fastapi.responses import PlainTextResponse
|
|
|
|
SAMPLE_RATE = 16000
|
|
N_FFT = 400
|
|
HOP_LENGTH = 160
|
|
N_MELS = 80
|
|
MAX_MEL_FRAMES = 2000
|
|
END_TOKEN = 50257
|
|
TASK_CODE = {"en": 50259, "zh": 50260}
|
|
TIMESTAMP_BEGIN = 50364
|
|
|
|
MODEL_NAME = os.getenv("MODEL_NAME", "whisper-base-onnx")
|
|
API_KEY = os.getenv("STT_API_KEY", "")
|
|
MAX_DECODE_TOKENS = int(os.getenv("MAX_DECODE_TOKENS", "128"))
|
|
|
|
VLM_ENABLED = os.getenv("VLM_ENABLED", "false").lower() in {
|
|
"1",
|
|
"true",
|
|
"yes",
|
|
"on",
|
|
}
|
|
VLM_MODEL_NAME = os.getenv("VLM_MODEL_NAME", "qwen3-vl-2b-rkllm")
|
|
VLM_DEMO_BIN = os.getenv(
|
|
"VLM_DEMO_BIN", "/opt/rkllm-root/quickstart/demo_Linux_aarch64/demo"
|
|
)
|
|
VLM_LIB_DIR = os.getenv(
|
|
"VLM_LIB_DIR", "/opt/rkllm-root/quickstart/demo_Linux_aarch64/lib"
|
|
)
|
|
VLM_ENCODER_MODEL_PATH = os.getenv(
|
|
"VLM_ENCODER_MODEL_PATH", "/opt/rkllm-root/models/qwen3-vl-2b_vision_rk3588.rknn"
|
|
)
|
|
VLM_LLM_MODEL_PATH = os.getenv(
|
|
"VLM_LLM_MODEL_PATH",
|
|
"/opt/rkllm-root/models/qwen3-vl-2b-instruct_w8a8_rk3588.rkllm",
|
|
)
|
|
VLM_CORE_NUM = int(os.getenv("VLM_CORE_NUM", "3"))
|
|
VLM_MAX_NEW_TOKENS = int(os.getenv("VLM_MAX_NEW_TOKENS", "256"))
|
|
VLM_MAX_CONTEXT_LEN = int(os.getenv("VLM_MAX_CONTEXT_LEN", "4096"))
|
|
VLM_IMG_START = os.getenv("VLM_IMG_START", "<|vision_start|>")
|
|
VLM_IMG_END = os.getenv("VLM_IMG_END", "<|vision_end|>")
|
|
VLM_IMG_CONTENT = os.getenv("VLM_IMG_CONTENT", "<|image_pad|>")
|
|
VLM_TIMEOUT_SEC = int(os.getenv("VLM_TIMEOUT_SEC", "300"))
|
|
|
|
ENCODER_MODEL_PATH = os.getenv(
|
|
"ENCODER_MODEL_PATH", "/models/whisper_encoder_base_20s.onnx"
|
|
)
|
|
DECODER_MODEL_PATH = os.getenv(
|
|
"DECODER_MODEL_PATH", "/models/whisper_decoder_base_20s.onnx"
|
|
)
|
|
MEL_FILTERS_PATH = os.getenv("MEL_FILTERS_PATH", "/models/mel_80_filters.txt")
|
|
VOCAB_EN_PATH = os.getenv("VOCAB_EN_PATH", "/models/vocab_en.txt")
|
|
VOCAB_ZH_PATH = os.getenv("VOCAB_ZH_PATH", "/models/vocab_zh.txt")
|
|
|
|
STATE: dict[str, Any] = {
|
|
"encoder": None,
|
|
"decoder": None,
|
|
"mel_filters": None,
|
|
"vocab_en": {},
|
|
"vocab_zh": {},
|
|
}
|
|
|
|
ANSI_RE = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
|
|
|
|
|
|
def read_vocab(path: str) -> dict[str, str]:
|
|
vocab: dict[str, str] = {}
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.rstrip("\n")
|
|
if not line:
|
|
continue
|
|
parts = line.split(" ", 1)
|
|
token_id = parts[0]
|
|
token_text = parts[1] if len(parts) > 1 else ""
|
|
vocab[token_id] = token_text
|
|
return vocab
|
|
|
|
|
|
def load_mel_filters(path: str) -> np.ndarray:
|
|
data = np.loadtxt(path, dtype=np.float32)
|
|
return data.reshape((80, 201))
|
|
|
|
|
|
def ensure_sample_rate(waveform: np.ndarray, source_rate: int) -> np.ndarray:
|
|
if source_rate == SAMPLE_RATE:
|
|
return waveform
|
|
target_len = int(round(len(waveform) * SAMPLE_RATE / source_rate))
|
|
return scipy.signal.resample(waveform, target_len).astype(np.float32)
|
|
|
|
|
|
def to_mono(waveform: np.ndarray) -> np.ndarray:
|
|
if waveform.ndim == 1:
|
|
return waveform
|
|
return waveform.mean(axis=1)
|
|
|
|
|
|
def log_mel_spectrogram(audio: np.ndarray, mel_filters: np.ndarray) -> np.ndarray:
|
|
_, _, stft = scipy.signal.stft(
|
|
audio,
|
|
fs=SAMPLE_RATE,
|
|
window="hann",
|
|
nperseg=N_FFT,
|
|
noverlap=N_FFT - HOP_LENGTH,
|
|
nfft=N_FFT,
|
|
boundary=None,
|
|
padded=False,
|
|
)
|
|
magnitudes = np.abs(stft).astype(np.float32) ** 2
|
|
if magnitudes.shape[1] > 0:
|
|
magnitudes = magnitudes[:, :-1]
|
|
mel_spec = mel_filters @ magnitudes
|
|
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
|
|
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
|
log_spec = (log_spec + 4.0) / 4.0
|
|
return log_spec.astype(np.float32)
|
|
|
|
|
|
def pad_or_trim(mel: np.ndarray) -> np.ndarray:
|
|
out = np.zeros((N_MELS, MAX_MEL_FRAMES), dtype=np.float32)
|
|
frames = min(mel.shape[1], MAX_MEL_FRAMES)
|
|
out[:, :frames] = mel[:, :frames]
|
|
return np.expand_dims(out, 0)
|
|
|
|
|
|
def decode_tokens(vocab: dict[str, str], token_ids: list[int], language: str) -> str:
|
|
pieces = [vocab.get(str(t), "") for t in token_ids]
|
|
text = (
|
|
"".join(pieces)
|
|
.replace("\u0120", " ")
|
|
.replace("<|endoftext|>", "")
|
|
.replace("\n", "")
|
|
.strip()
|
|
)
|
|
if language == "zh":
|
|
try:
|
|
text = base64.b64decode(text).decode("utf-8", errors="replace")
|
|
except Exception:
|
|
pass
|
|
return text
|
|
|
|
|
|
def transcribe_file(path: str, language: str) -> str:
|
|
waveform, sr = sf.read(path)
|
|
waveform = to_mono(np.asarray(waveform, dtype=np.float32))
|
|
waveform = ensure_sample_rate(waveform, sr)
|
|
mel = log_mel_spectrogram(waveform, STATE["mel_filters"])
|
|
encoder_input = pad_or_trim(mel)
|
|
|
|
encoded = STATE["encoder"].run(None, {"x": encoder_input})[0]
|
|
|
|
tokens = [50258, TASK_CODE[language], 50359, 50363]
|
|
emitted: list[int] = []
|
|
|
|
for _ in range(MAX_DECODE_TOKENS):
|
|
decoder_out = STATE["decoder"].run(
|
|
None,
|
|
{
|
|
"tokens": np.asarray([tokens], dtype=np.int64),
|
|
"audio": encoded,
|
|
},
|
|
)[0]
|
|
next_token = int(decoder_out[0, -1].argmax())
|
|
if next_token == END_TOKEN:
|
|
break
|
|
tokens.append(next_token)
|
|
if next_token <= TIMESTAMP_BEGIN:
|
|
emitted.append(next_token)
|
|
|
|
vocab = STATE["vocab_en"] if language == "en" else STATE["vocab_zh"]
|
|
return decode_tokens(vocab, emitted, language)
|
|
|
|
|
|
def convert_to_wav(src_path: str) -> str:
|
|
fd, out_path = tempfile.mkstemp(suffix=".wav")
|
|
os.close(fd)
|
|
cmd = [
|
|
"ffmpeg",
|
|
"-y",
|
|
"-v",
|
|
"error",
|
|
"-i",
|
|
src_path,
|
|
"-ac",
|
|
"1",
|
|
"-ar",
|
|
str(SAMPLE_RATE),
|
|
out_path,
|
|
]
|
|
try:
|
|
subprocess.run(cmd, check=True)
|
|
return out_path
|
|
except subprocess.CalledProcessError as exc:
|
|
raise HTTPException(status_code=400, detail=f"Failed to decode audio: {exc}")
|
|
|
|
|
|
def check_api_key(authorization: str | None) -> None:
|
|
if not API_KEY:
|
|
return
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(status_code=401, detail="Missing Bearer token")
|
|
token = authorization.split(" ", 1)[1].strip()
|
|
if token != API_KEY:
|
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
|
|
|
|
def validate_vlm_enabled() -> None:
|
|
if not VLM_ENABLED:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="VLM endpoint is disabled. Set VLM_ENABLED=true.",
|
|
)
|
|
required = [VLM_DEMO_BIN, VLM_ENCODER_MODEL_PATH, VLM_LLM_MODEL_PATH]
|
|
for path in required:
|
|
if not os.path.exists(path):
|
|
raise HTTPException(status_code=500, detail=f"Missing VLM file: {path}")
|
|
|
|
|
|
def image_url_to_file(url: str) -> str:
|
|
fd, out_path = tempfile.mkstemp(suffix=".jpg")
|
|
os.close(fd)
|
|
try:
|
|
if url.startswith("data:"):
|
|
payload = url.split(",", 1)[1]
|
|
image_bytes = base64.b64decode(payload)
|
|
with open(out_path, "wb") as f:
|
|
f.write(image_bytes)
|
|
return out_path
|
|
|
|
if url.startswith("http://") or url.startswith("https://"):
|
|
with urllib.request.urlopen(url, timeout=30) as resp:
|
|
image_bytes = resp.read()
|
|
with open(out_path, "wb") as f:
|
|
f.write(image_bytes)
|
|
return out_path
|
|
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Unsupported image_url. Use data: or https:// URL.",
|
|
)
|
|
except HTTPException:
|
|
if os.path.exists(out_path):
|
|
os.unlink(out_path)
|
|
raise
|
|
except Exception as exc:
|
|
if os.path.exists(out_path):
|
|
os.unlink(out_path)
|
|
raise HTTPException(status_code=400, detail=f"Failed to load image_url: {exc}")
|
|
|
|
|
|
def clean_vlm_output(text: str) -> str:
|
|
text = ANSI_RE.sub("", text)
|
|
if "robot:" in text:
|
|
text = text.rsplit("robot:", 1)[1]
|
|
if "\nuser:" in text:
|
|
text = text.split("\nuser:", 1)[0]
|
|
return text.strip()
|
|
|
|
|
|
def run_vlm(image_path: str, prompt: str) -> str:
|
|
validate_vlm_enabled()
|
|
llm_input = prompt if prompt.strip().startswith("<image>") else f"<image>{prompt}"
|
|
cmd = [
|
|
VLM_DEMO_BIN,
|
|
image_path,
|
|
VLM_ENCODER_MODEL_PATH,
|
|
VLM_LLM_MODEL_PATH,
|
|
str(VLM_MAX_NEW_TOKENS),
|
|
str(VLM_MAX_CONTEXT_LEN),
|
|
str(VLM_CORE_NUM),
|
|
VLM_IMG_START,
|
|
VLM_IMG_END,
|
|
VLM_IMG_CONTENT,
|
|
]
|
|
env = os.environ.copy()
|
|
current_ld = env.get("LD_LIBRARY_PATH", "")
|
|
env["LD_LIBRARY_PATH"] = (
|
|
f"{VLM_LIB_DIR}:{current_ld}" if current_ld else VLM_LIB_DIR
|
|
)
|
|
|
|
try:
|
|
proc = subprocess.run(
|
|
cmd,
|
|
input=f"{llm_input}\nexit\n",
|
|
text=True,
|
|
capture_output=True,
|
|
check=True,
|
|
env=env,
|
|
timeout=VLM_TIMEOUT_SEC,
|
|
)
|
|
except subprocess.TimeoutExpired as exc:
|
|
raise HTTPException(status_code=504, detail=f"VLM timed out: {exc}")
|
|
except subprocess.CalledProcessError as exc:
|
|
message = exc.stderr.strip() if exc.stderr else str(exc)
|
|
raise HTTPException(status_code=500, detail=f"VLM execution failed: {message}")
|
|
|
|
output = clean_vlm_output(proc.stdout)
|
|
if not output:
|
|
raise HTTPException(status_code=500, detail="VLM returned empty output")
|
|
return output
|
|
|
|
|
|
def extract_prompt_and_image(messages: list[dict[str, Any]]) -> tuple[str, str]:
|
|
prompt = ""
|
|
image_url = ""
|
|
for msg in reversed(messages):
|
|
if msg.get("role") != "user":
|
|
continue
|
|
content = msg.get("content")
|
|
if isinstance(content, str):
|
|
prompt = content.strip()
|
|
elif isinstance(content, list):
|
|
text_parts: list[str] = []
|
|
for part in content:
|
|
if part.get("type") == "text" and part.get("text"):
|
|
text_parts.append(str(part["text"]))
|
|
if part.get("type") == "image_url":
|
|
image_data = part.get("image_url")
|
|
if isinstance(image_data, dict):
|
|
image_url = str(image_data.get("url", ""))
|
|
elif isinstance(image_data, str):
|
|
image_url = image_data
|
|
prompt = "\n".join([p for p in text_parts if p.strip()]).strip()
|
|
if prompt or image_url:
|
|
break
|
|
|
|
if not prompt:
|
|
prompt = "Describe this image in English."
|
|
if not image_url:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="messages must include image_url content in the user message",
|
|
)
|
|
return prompt, image_url
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_: FastAPI):
|
|
for path in [
|
|
ENCODER_MODEL_PATH,
|
|
DECODER_MODEL_PATH,
|
|
MEL_FILTERS_PATH,
|
|
VOCAB_EN_PATH,
|
|
VOCAB_ZH_PATH,
|
|
]:
|
|
if not os.path.exists(path):
|
|
raise RuntimeError(f"Required file not found: {path}")
|
|
|
|
STATE["encoder"] = ort.InferenceSession(
|
|
ENCODER_MODEL_PATH, providers=["CPUExecutionProvider"]
|
|
)
|
|
STATE["decoder"] = ort.InferenceSession(
|
|
DECODER_MODEL_PATH, providers=["CPUExecutionProvider"]
|
|
)
|
|
STATE["mel_filters"] = load_mel_filters(MEL_FILTERS_PATH)
|
|
STATE["vocab_en"] = read_vocab(VOCAB_EN_PATH)
|
|
STATE["vocab_zh"] = read_vocab(VOCAB_ZH_PATH)
|
|
yield
|
|
|
|
|
|
app = FastAPI(title="RK Whisper STT API", version="0.1.0", lifespan=lifespan)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict[str, Any]:
|
|
return {
|
|
"ok": True,
|
|
"model": MODEL_NAME,
|
|
"encoder": ENCODER_MODEL_PATH,
|
|
"decoder": DECODER_MODEL_PATH,
|
|
"vlm_enabled": VLM_ENABLED,
|
|
"vlm_model": VLM_MODEL_NAME,
|
|
}
|
|
|
|
|
|
@app.post("/v1/audio/transcriptions")
|
|
async def transcriptions(
|
|
file: UploadFile = File(...),
|
|
model: str = Form(default=MODEL_NAME),
|
|
language: str = Form(default="en"),
|
|
response_format: str = Form(default="json"),
|
|
authorization: str | None = Header(default=None),
|
|
):
|
|
check_api_key(authorization)
|
|
if model != MODEL_NAME:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported model '{model}', expected '{MODEL_NAME}'",
|
|
)
|
|
if language not in TASK_CODE:
|
|
raise HTTPException(status_code=400, detail="language must be en or zh")
|
|
|
|
fd, input_path = tempfile.mkstemp(suffix="_upload")
|
|
os.close(fd)
|
|
wav_path = ""
|
|
|
|
try:
|
|
payload = await file.read()
|
|
with open(input_path, "wb") as f:
|
|
f.write(payload)
|
|
wav_path = convert_to_wav(input_path)
|
|
text = transcribe_file(wav_path, language)
|
|
finally:
|
|
if os.path.exists(input_path):
|
|
os.unlink(input_path)
|
|
if wav_path and os.path.exists(wav_path):
|
|
os.unlink(wav_path)
|
|
|
|
if response_format == "text":
|
|
return PlainTextResponse(text)
|
|
if response_format == "verbose_json":
|
|
return {
|
|
"task": "transcribe",
|
|
"language": language,
|
|
"model": MODEL_NAME,
|
|
"text": text,
|
|
"segments": [],
|
|
}
|
|
return {"text": text}
|
|
|
|
|
|
@app.post("/v1/vision/understand")
|
|
async def vision_understand(
|
|
file: UploadFile = File(...),
|
|
prompt: str = Form(default="Describe this image in English."),
|
|
model: str = Form(default=VLM_MODEL_NAME),
|
|
response_format: str = Form(default="json"),
|
|
authorization: str | None = Header(default=None),
|
|
):
|
|
check_api_key(authorization)
|
|
if model != VLM_MODEL_NAME:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported model '{model}', expected '{VLM_MODEL_NAME}'",
|
|
)
|
|
|
|
fd, image_path = tempfile.mkstemp(suffix="_image")
|
|
os.close(fd)
|
|
try:
|
|
payload = await file.read()
|
|
with open(image_path, "wb") as f:
|
|
f.write(payload)
|
|
text = run_vlm(image_path, prompt)
|
|
finally:
|
|
if os.path.exists(image_path):
|
|
os.unlink(image_path)
|
|
|
|
if response_format == "text":
|
|
return PlainTextResponse(text)
|
|
return {"text": text, "model": VLM_MODEL_NAME}
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completions(
|
|
body: dict[str, Any] = Body(...),
|
|
authorization: str | None = Header(default=None),
|
|
):
|
|
check_api_key(authorization)
|
|
|
|
model = str(body.get("model", VLM_MODEL_NAME))
|
|
if model != VLM_MODEL_NAME:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported model '{model}', expected '{VLM_MODEL_NAME}'",
|
|
)
|
|
messages = body.get("messages")
|
|
if not isinstance(messages, list) or not messages:
|
|
raise HTTPException(status_code=400, detail="messages must be a non-empty list")
|
|
|
|
prompt, image_url = extract_prompt_and_image(messages)
|
|
image_path = image_url_to_file(image_url)
|
|
try:
|
|
text = run_vlm(image_path, prompt)
|
|
finally:
|
|
if os.path.exists(image_path):
|
|
os.unlink(image_path)
|
|
|
|
return {
|
|
"id": "chatcmpl-rk-vl-1",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": VLM_MODEL_NAME,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": text},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
}
|