initial
This commit is contained in:
505
app/server.py
Normal file
505
app/server.py
Normal file
@@ -0,0 +1,505 @@
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import urllib.request
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import scipy.signal
|
||||
import soundfile as sf
|
||||
from fastapi import Body, FastAPI, File, Form, Header, HTTPException, UploadFile
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
N_MELS = 80
|
||||
MAX_MEL_FRAMES = 2000
|
||||
END_TOKEN = 50257
|
||||
TASK_CODE = {"en": 50259, "zh": 50260}
|
||||
TIMESTAMP_BEGIN = 50364
|
||||
|
||||
MODEL_NAME = os.getenv("MODEL_NAME", "whisper-base-onnx")
|
||||
API_KEY = os.getenv("STT_API_KEY", "")
|
||||
MAX_DECODE_TOKENS = int(os.getenv("MAX_DECODE_TOKENS", "128"))
|
||||
|
||||
VLM_ENABLED = os.getenv("VLM_ENABLED", "false").lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"on",
|
||||
}
|
||||
VLM_MODEL_NAME = os.getenv("VLM_MODEL_NAME", "qwen3-vl-2b-rkllm")
|
||||
VLM_DEMO_BIN = os.getenv(
|
||||
"VLM_DEMO_BIN", "/opt/rkllm-root/quickstart/demo_Linux_aarch64/demo"
|
||||
)
|
||||
VLM_LIB_DIR = os.getenv(
|
||||
"VLM_LIB_DIR", "/opt/rkllm-root/quickstart/demo_Linux_aarch64/lib"
|
||||
)
|
||||
VLM_ENCODER_MODEL_PATH = os.getenv(
|
||||
"VLM_ENCODER_MODEL_PATH", "/opt/rkllm-root/models/qwen3-vl-2b_vision_rk3588.rknn"
|
||||
)
|
||||
VLM_LLM_MODEL_PATH = os.getenv(
|
||||
"VLM_LLM_MODEL_PATH",
|
||||
"/opt/rkllm-root/models/qwen3-vl-2b-instruct_w8a8_rk3588.rkllm",
|
||||
)
|
||||
VLM_CORE_NUM = int(os.getenv("VLM_CORE_NUM", "3"))
|
||||
VLM_MAX_NEW_TOKENS = int(os.getenv("VLM_MAX_NEW_TOKENS", "256"))
|
||||
VLM_MAX_CONTEXT_LEN = int(os.getenv("VLM_MAX_CONTEXT_LEN", "4096"))
|
||||
VLM_IMG_START = os.getenv("VLM_IMG_START", "<|vision_start|>")
|
||||
VLM_IMG_END = os.getenv("VLM_IMG_END", "<|vision_end|>")
|
||||
VLM_IMG_CONTENT = os.getenv("VLM_IMG_CONTENT", "<|image_pad|>")
|
||||
VLM_TIMEOUT_SEC = int(os.getenv("VLM_TIMEOUT_SEC", "300"))
|
||||
|
||||
ENCODER_MODEL_PATH = os.getenv(
|
||||
"ENCODER_MODEL_PATH", "/models/whisper_encoder_base_20s.onnx"
|
||||
)
|
||||
DECODER_MODEL_PATH = os.getenv(
|
||||
"DECODER_MODEL_PATH", "/models/whisper_decoder_base_20s.onnx"
|
||||
)
|
||||
MEL_FILTERS_PATH = os.getenv("MEL_FILTERS_PATH", "/models/mel_80_filters.txt")
|
||||
VOCAB_EN_PATH = os.getenv("VOCAB_EN_PATH", "/models/vocab_en.txt")
|
||||
VOCAB_ZH_PATH = os.getenv("VOCAB_ZH_PATH", "/models/vocab_zh.txt")
|
||||
|
||||
STATE: dict[str, Any] = {
|
||||
"encoder": None,
|
||||
"decoder": None,
|
||||
"mel_filters": None,
|
||||
"vocab_en": {},
|
||||
"vocab_zh": {},
|
||||
}
|
||||
|
||||
ANSI_RE = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
|
||||
|
||||
|
||||
def read_vocab(path: str) -> dict[str, str]:
|
||||
vocab: dict[str, str] = {}
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.rstrip("\n")
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split(" ", 1)
|
||||
token_id = parts[0]
|
||||
token_text = parts[1] if len(parts) > 1 else ""
|
||||
vocab[token_id] = token_text
|
||||
return vocab
|
||||
|
||||
|
||||
def load_mel_filters(path: str) -> np.ndarray:
|
||||
data = np.loadtxt(path, dtype=np.float32)
|
||||
return data.reshape((80, 201))
|
||||
|
||||
|
||||
def ensure_sample_rate(waveform: np.ndarray, source_rate: int) -> np.ndarray:
|
||||
if source_rate == SAMPLE_RATE:
|
||||
return waveform
|
||||
target_len = int(round(len(waveform) * SAMPLE_RATE / source_rate))
|
||||
return scipy.signal.resample(waveform, target_len).astype(np.float32)
|
||||
|
||||
|
||||
def to_mono(waveform: np.ndarray) -> np.ndarray:
|
||||
if waveform.ndim == 1:
|
||||
return waveform
|
||||
return waveform.mean(axis=1)
|
||||
|
||||
|
||||
def log_mel_spectrogram(audio: np.ndarray, mel_filters: np.ndarray) -> np.ndarray:
|
||||
_, _, stft = scipy.signal.stft(
|
||||
audio,
|
||||
fs=SAMPLE_RATE,
|
||||
window="hann",
|
||||
nperseg=N_FFT,
|
||||
noverlap=N_FFT - HOP_LENGTH,
|
||||
nfft=N_FFT,
|
||||
boundary=None,
|
||||
padded=False,
|
||||
)
|
||||
magnitudes = np.abs(stft).astype(np.float32) ** 2
|
||||
if magnitudes.shape[1] > 0:
|
||||
magnitudes = magnitudes[:, :-1]
|
||||
mel_spec = mel_filters @ magnitudes
|
||||
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec.astype(np.float32)
|
||||
|
||||
|
||||
def pad_or_trim(mel: np.ndarray) -> np.ndarray:
|
||||
out = np.zeros((N_MELS, MAX_MEL_FRAMES), dtype=np.float32)
|
||||
frames = min(mel.shape[1], MAX_MEL_FRAMES)
|
||||
out[:, :frames] = mel[:, :frames]
|
||||
return np.expand_dims(out, 0)
|
||||
|
||||
|
||||
def decode_tokens(vocab: dict[str, str], token_ids: list[int], language: str) -> str:
|
||||
pieces = [vocab.get(str(t), "") for t in token_ids]
|
||||
text = (
|
||||
"".join(pieces)
|
||||
.replace("\u0120", " ")
|
||||
.replace("<|endoftext|>", "")
|
||||
.replace("\n", "")
|
||||
.strip()
|
||||
)
|
||||
if language == "zh":
|
||||
try:
|
||||
text = base64.b64decode(text).decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
pass
|
||||
return text
|
||||
|
||||
|
||||
def transcribe_file(path: str, language: str) -> str:
|
||||
waveform, sr = sf.read(path)
|
||||
waveform = to_mono(np.asarray(waveform, dtype=np.float32))
|
||||
waveform = ensure_sample_rate(waveform, sr)
|
||||
mel = log_mel_spectrogram(waveform, STATE["mel_filters"])
|
||||
encoder_input = pad_or_trim(mel)
|
||||
|
||||
encoded = STATE["encoder"].run(None, {"x": encoder_input})[0]
|
||||
|
||||
tokens = [50258, TASK_CODE[language], 50359, 50363]
|
||||
emitted: list[int] = []
|
||||
|
||||
for _ in range(MAX_DECODE_TOKENS):
|
||||
decoder_out = STATE["decoder"].run(
|
||||
None,
|
||||
{
|
||||
"tokens": np.asarray([tokens], dtype=np.int64),
|
||||
"audio": encoded,
|
||||
},
|
||||
)[0]
|
||||
next_token = int(decoder_out[0, -1].argmax())
|
||||
if next_token == END_TOKEN:
|
||||
break
|
||||
tokens.append(next_token)
|
||||
if next_token <= TIMESTAMP_BEGIN:
|
||||
emitted.append(next_token)
|
||||
|
||||
vocab = STATE["vocab_en"] if language == "en" else STATE["vocab_zh"]
|
||||
return decode_tokens(vocab, emitted, language)
|
||||
|
||||
|
||||
def convert_to_wav(src_path: str) -> str:
|
||||
fd, out_path = tempfile.mkstemp(suffix=".wav")
|
||||
os.close(fd)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-v",
|
||||
"error",
|
||||
"-i",
|
||||
src_path,
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
str(SAMPLE_RATE),
|
||||
out_path,
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
return out_path
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to decode audio: {exc}")
|
||||
|
||||
|
||||
def check_api_key(authorization: str | None) -> None:
|
||||
if not API_KEY:
|
||||
return
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing Bearer token")
|
||||
token = authorization.split(" ", 1)[1].strip()
|
||||
if token != API_KEY:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
|
||||
def validate_vlm_enabled() -> None:
|
||||
if not VLM_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="VLM endpoint is disabled. Set VLM_ENABLED=true.",
|
||||
)
|
||||
required = [VLM_DEMO_BIN, VLM_ENCODER_MODEL_PATH, VLM_LLM_MODEL_PATH]
|
||||
for path in required:
|
||||
if not os.path.exists(path):
|
||||
raise HTTPException(status_code=500, detail=f"Missing VLM file: {path}")
|
||||
|
||||
|
||||
def image_url_to_file(url: str) -> str:
|
||||
fd, out_path = tempfile.mkstemp(suffix=".jpg")
|
||||
os.close(fd)
|
||||
try:
|
||||
if url.startswith("data:"):
|
||||
payload = url.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(payload)
|
||||
with open(out_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return out_path
|
||||
|
||||
if url.startswith("http://") or url.startswith("https://"):
|
||||
with urllib.request.urlopen(url, timeout=30) as resp:
|
||||
image_bytes = resp.read()
|
||||
with open(out_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return out_path
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Unsupported image_url. Use data: or https:// URL.",
|
||||
)
|
||||
except HTTPException:
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
raise
|
||||
except Exception as exc:
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
raise HTTPException(status_code=400, detail=f"Failed to load image_url: {exc}")
|
||||
|
||||
|
||||
def clean_vlm_output(text: str) -> str:
|
||||
text = ANSI_RE.sub("", text)
|
||||
if "robot:" in text:
|
||||
text = text.rsplit("robot:", 1)[1]
|
||||
if "\nuser:" in text:
|
||||
text = text.split("\nuser:", 1)[0]
|
||||
return text.strip()
|
||||
|
||||
|
||||
def run_vlm(image_path: str, prompt: str) -> str:
|
||||
validate_vlm_enabled()
|
||||
llm_input = prompt if prompt.strip().startswith("<image>") else f"<image>{prompt}"
|
||||
cmd = [
|
||||
VLM_DEMO_BIN,
|
||||
image_path,
|
||||
VLM_ENCODER_MODEL_PATH,
|
||||
VLM_LLM_MODEL_PATH,
|
||||
str(VLM_MAX_NEW_TOKENS),
|
||||
str(VLM_MAX_CONTEXT_LEN),
|
||||
str(VLM_CORE_NUM),
|
||||
VLM_IMG_START,
|
||||
VLM_IMG_END,
|
||||
VLM_IMG_CONTENT,
|
||||
]
|
||||
env = os.environ.copy()
|
||||
current_ld = env.get("LD_LIBRARY_PATH", "")
|
||||
env["LD_LIBRARY_PATH"] = (
|
||||
f"{VLM_LIB_DIR}:{current_ld}" if current_ld else VLM_LIB_DIR
|
||||
)
|
||||
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
input=f"{llm_input}\nexit\n",
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
env=env,
|
||||
timeout=VLM_TIMEOUT_SEC,
|
||||
)
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
raise HTTPException(status_code=504, detail=f"VLM timed out: {exc}")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
message = exc.stderr.strip() if exc.stderr else str(exc)
|
||||
raise HTTPException(status_code=500, detail=f"VLM execution failed: {message}")
|
||||
|
||||
output = clean_vlm_output(proc.stdout)
|
||||
if not output:
|
||||
raise HTTPException(status_code=500, detail="VLM returned empty output")
|
||||
return output
|
||||
|
||||
|
||||
def extract_prompt_and_image(messages: list[dict[str, Any]]) -> tuple[str, str]:
|
||||
prompt = ""
|
||||
image_url = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") != "user":
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
prompt = content.strip()
|
||||
elif isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in content:
|
||||
if part.get("type") == "text" and part.get("text"):
|
||||
text_parts.append(str(part["text"]))
|
||||
if part.get("type") == "image_url":
|
||||
image_data = part.get("image_url")
|
||||
if isinstance(image_data, dict):
|
||||
image_url = str(image_data.get("url", ""))
|
||||
elif isinstance(image_data, str):
|
||||
image_url = image_data
|
||||
prompt = "\n".join([p for p in text_parts if p.strip()]).strip()
|
||||
if prompt or image_url:
|
||||
break
|
||||
|
||||
if not prompt:
|
||||
prompt = "Describe this image in English."
|
||||
if not image_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="messages must include image_url content in the user message",
|
||||
)
|
||||
return prompt, image_url
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
for path in [
|
||||
ENCODER_MODEL_PATH,
|
||||
DECODER_MODEL_PATH,
|
||||
MEL_FILTERS_PATH,
|
||||
VOCAB_EN_PATH,
|
||||
VOCAB_ZH_PATH,
|
||||
]:
|
||||
if not os.path.exists(path):
|
||||
raise RuntimeError(f"Required file not found: {path}")
|
||||
|
||||
STATE["encoder"] = ort.InferenceSession(
|
||||
ENCODER_MODEL_PATH, providers=["CPUExecutionProvider"]
|
||||
)
|
||||
STATE["decoder"] = ort.InferenceSession(
|
||||
DECODER_MODEL_PATH, providers=["CPUExecutionProvider"]
|
||||
)
|
||||
STATE["mel_filters"] = load_mel_filters(MEL_FILTERS_PATH)
|
||||
STATE["vocab_en"] = read_vocab(VOCAB_EN_PATH)
|
||||
STATE["vocab_zh"] = read_vocab(VOCAB_ZH_PATH)
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="RK Whisper STT API", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, Any]:
|
||||
return {
|
||||
"ok": True,
|
||||
"model": MODEL_NAME,
|
||||
"encoder": ENCODER_MODEL_PATH,
|
||||
"decoder": DECODER_MODEL_PATH,
|
||||
"vlm_enabled": VLM_ENABLED,
|
||||
"vlm_model": VLM_MODEL_NAME,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def transcriptions(
|
||||
file: UploadFile = File(...),
|
||||
model: str = Form(default=MODEL_NAME),
|
||||
language: str = Form(default="en"),
|
||||
response_format: str = Form(default="json"),
|
||||
authorization: str | None = Header(default=None),
|
||||
):
|
||||
check_api_key(authorization)
|
||||
if model != MODEL_NAME:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported model '{model}', expected '{MODEL_NAME}'",
|
||||
)
|
||||
if language not in TASK_CODE:
|
||||
raise HTTPException(status_code=400, detail="language must be en or zh")
|
||||
|
||||
fd, input_path = tempfile.mkstemp(suffix="_upload")
|
||||
os.close(fd)
|
||||
wav_path = ""
|
||||
|
||||
try:
|
||||
payload = await file.read()
|
||||
with open(input_path, "wb") as f:
|
||||
f.write(payload)
|
||||
wav_path = convert_to_wav(input_path)
|
||||
text = transcribe_file(wav_path, language)
|
||||
finally:
|
||||
if os.path.exists(input_path):
|
||||
os.unlink(input_path)
|
||||
if wav_path and os.path.exists(wav_path):
|
||||
os.unlink(wav_path)
|
||||
|
||||
if response_format == "text":
|
||||
return PlainTextResponse(text)
|
||||
if response_format == "verbose_json":
|
||||
return {
|
||||
"task": "transcribe",
|
||||
"language": language,
|
||||
"model": MODEL_NAME,
|
||||
"text": text,
|
||||
"segments": [],
|
||||
}
|
||||
return {"text": text}
|
||||
|
||||
|
||||
@app.post("/v1/vision/understand")
|
||||
async def vision_understand(
|
||||
file: UploadFile = File(...),
|
||||
prompt: str = Form(default="Describe this image in English."),
|
||||
model: str = Form(default=VLM_MODEL_NAME),
|
||||
response_format: str = Form(default="json"),
|
||||
authorization: str | None = Header(default=None),
|
||||
):
|
||||
check_api_key(authorization)
|
||||
if model != VLM_MODEL_NAME:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported model '{model}', expected '{VLM_MODEL_NAME}'",
|
||||
)
|
||||
|
||||
fd, image_path = tempfile.mkstemp(suffix="_image")
|
||||
os.close(fd)
|
||||
try:
|
||||
payload = await file.read()
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(payload)
|
||||
text = run_vlm(image_path, prompt)
|
||||
finally:
|
||||
if os.path.exists(image_path):
|
||||
os.unlink(image_path)
|
||||
|
||||
if response_format == "text":
|
||||
return PlainTextResponse(text)
|
||||
return {"text": text, "model": VLM_MODEL_NAME}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(
|
||||
body: dict[str, Any] = Body(...),
|
||||
authorization: str | None = Header(default=None),
|
||||
):
|
||||
check_api_key(authorization)
|
||||
|
||||
model = str(body.get("model", VLM_MODEL_NAME))
|
||||
if model != VLM_MODEL_NAME:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported model '{model}', expected '{VLM_MODEL_NAME}'",
|
||||
)
|
||||
messages = body.get("messages")
|
||||
if not isinstance(messages, list) or not messages:
|
||||
raise HTTPException(status_code=400, detail="messages must be a non-empty list")
|
||||
|
||||
prompt, image_url = extract_prompt_and_image(messages)
|
||||
image_path = image_url_to_file(image_url)
|
||||
try:
|
||||
text = run_vlm(image_path, prompt)
|
||||
finally:
|
||||
if os.path.exists(image_path):
|
||||
os.unlink(image_path)
|
||||
|
||||
return {
|
||||
"id": "chatcmpl-rk-vl-1",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": VLM_MODEL_NAME,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
Reference in New Issue
Block a user