104 lines
3.0 KiB
Python

import os
import tempfile
import torch
import logging
import boto3
from botocore.client import Config
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager
from pydantic_settings import BaseSettings
# --- Logging setup ---
logger = logging.getLogger("uvicorn")
logging.basicConfig(level=logging.INFO)
# --- 1. Settings ---
class Settings(BaseSettings):
MINIO_ENDPOINT: str = "http://localhost:9000"
MINIO_ACCESS_KEY: str = "minio_admin"
MINIO_SECRET_KEY: str = "minio_p@ssw0rd!"
MODEL_BUCKET: str = "models"
MODEL_FILE: str = "spleen_ct_spleen_model.ts"
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
model_config = {'env_file': '.env', 'env_file_encoding': 'utf-8'}
settings = Settings()
model = None
# --- 2. Load Model Function ---
def load_monai_model():
"""โหลด TorchScript model จาก MinIO"""
global model
try:
logger.info(f"Loading model '{settings.MODEL_FILE}' from MinIO...")
s3 = boto3.client(
"s3",
endpoint_url=settings.MINIO_ENDPOINT,
aws_access_key_id=settings.MINIO_ACCESS_KEY,
aws_secret_access_key=settings.MINIO_SECRET_KEY,
config=Config(signature_version="s3v4", connect_timeout=5, read_timeout=10)
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = os.path.join(temp_dir, settings.MODEL_FILE)
s3.download_file(settings.MODEL_BUCKET, settings.MODEL_FILE, local_path)
model_loaded = torch.jit.load(local_path, map_location=settings.DEVICE)
model_loaded.eval()
model = model_loaded
logger.info(f"Model '{settings.MODEL_FILE}' loaded successfully on {settings.DEVICE}")
except Exception as e:
logger.error(f"Model load failed: {e}")
raise HTTPException(status_code=500, detail=f"Model Initialization Failed: {e}")
# --- 3. Lifespan Event Handler (แทน @app.on_event) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global model
# Startup
load_monai_model()
yield
# Shutdown (optional cleanup)
model = None
logger.info("Model unloaded from memory.")
# --- 4. Create FastAPI App ---
app = FastAPI(
title="MONAI Model API",
description="FastAPI serving MONAI TorchScript model from MinIO",
version="1.0.0",
lifespan=lifespan
)
# --- 5. Root Endpoint ---
@app.get("/")
async def read_root():
return {
"status": "Service Running",
"model_loaded": model is not None,
"model_name": settings.MODEL_FILE,
"device": settings.DEVICE,
}
# --- 6. Reload Endpoint ---
@app.post("/reload")
async def reload_model():
"""รีโหลดโมเดลจาก MinIO โดยไม่ต้อง restart service"""
try:
load_monai_model()
return {"message": f"Model '{settings.MODEL_FILE}' reloaded successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))