104 lines
3.0 KiB
Python
104 lines
3.0 KiB
Python
import os
|
|
import tempfile
|
|
import torch
|
|
import logging
|
|
import boto3
|
|
from botocore.client import Config
|
|
from fastapi import FastAPI, HTTPException
|
|
from contextlib import asynccontextmanager
|
|
from pydantic_settings import BaseSettings
|
|
|
|
|
|
# --- Logging setup ---
|
|
logger = logging.getLogger("uvicorn")
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
# --- 1. Settings ---
|
|
class Settings(BaseSettings):
|
|
MINIO_ENDPOINT: str = "http://localhost:9000"
|
|
MINIO_ACCESS_KEY: str = "minio_admin"
|
|
MINIO_SECRET_KEY: str = "minio_p@ssw0rd!"
|
|
MODEL_BUCKET: str = "models"
|
|
MODEL_FILE: str = "spleen_ct_spleen_model.ts"
|
|
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
model_config = {'env_file': '.env', 'env_file_encoding': 'utf-8'}
|
|
|
|
|
|
settings = Settings()
|
|
model = None
|
|
|
|
|
|
# --- 2. Load Model Function ---
|
|
def load_monai_model():
|
|
"""โหลด TorchScript model จาก MinIO"""
|
|
global model
|
|
try:
|
|
logger.info(f"Loading model '{settings.MODEL_FILE}' from MinIO...")
|
|
|
|
s3 = boto3.client(
|
|
"s3",
|
|
endpoint_url=settings.MINIO_ENDPOINT,
|
|
aws_access_key_id=settings.MINIO_ACCESS_KEY,
|
|
aws_secret_access_key=settings.MINIO_SECRET_KEY,
|
|
config=Config(signature_version="s3v4", connect_timeout=5, read_timeout=10)
|
|
)
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
local_path = os.path.join(temp_dir, settings.MODEL_FILE)
|
|
s3.download_file(settings.MODEL_BUCKET, settings.MODEL_FILE, local_path)
|
|
|
|
model_loaded = torch.jit.load(local_path, map_location=settings.DEVICE)
|
|
model_loaded.eval()
|
|
model = model_loaded
|
|
|
|
logger.info(f"Model '{settings.MODEL_FILE}' loaded successfully on {settings.DEVICE}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Model load failed: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Model Initialization Failed: {e}")
|
|
|
|
|
|
# --- 3. Lifespan Event Handler (แทน @app.on_event) ---
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global model
|
|
# Startup
|
|
load_monai_model()
|
|
yield
|
|
# Shutdown (optional cleanup)
|
|
model = None
|
|
logger.info("Model unloaded from memory.")
|
|
|
|
|
|
# --- 4. Create FastAPI App ---
|
|
app = FastAPI(
|
|
title="MONAI Model API",
|
|
description="FastAPI serving MONAI TorchScript model from MinIO",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
|
|
# --- 5. Root Endpoint ---
|
|
@app.get("/")
|
|
async def read_root():
|
|
return {
|
|
"status": "Service Running",
|
|
"model_loaded": model is not None,
|
|
"model_name": settings.MODEL_FILE,
|
|
"device": settings.DEVICE,
|
|
}
|
|
|
|
|
|
# --- 6. Reload Endpoint ---
|
|
@app.post("/reload")
|
|
async def reload_model():
|
|
"""รีโหลดโมเดลจาก MinIO โดยไม่ต้อง restart service"""
|
|
try:
|
|
load_monai_model()
|
|
return {"message": f"Model '{settings.MODEL_FILE}' reloaded successfully"}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|