230 lines
9.7 KiB
Python
230 lines
9.7 KiB
Python
import os
|
|
import tempfile
|
|
import torch
|
|
import logging
|
|
import boto3
|
|
from botocore.client import Config
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File
|
|
from contextlib import asynccontextmanager
|
|
from pydantic_settings import BaseSettings
|
|
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
# MONAI Dependencies ที่จำเป็นสำหรับการประมวลผล
|
|
from monai.inferers import sliding_window_inference
|
|
from monai.transforms import Compose, LoadImage, EnsureChannelFirst, ScaleIntensityRange, SpatialPad, Spacing,Resize,CenterSpatialCrop, Orientation, NormalizeIntensity
|
|
|
|
# --- Logging setup ---
|
|
logger = logging.getLogger("uvicorn")
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
# --- 1. Settings ---
|
|
class Settings(BaseSettings):
|
|
MINIO_ENDPOINT: str = "http://localhost:9000"
|
|
MINIO_ACCESS_KEY: str = "minio_admin"
|
|
MINIO_SECRET_KEY: str = "minio_p@ssw0rd!"
|
|
MODEL_BUCKET: str = "models"
|
|
MODEL_FILE: str = "spleen_ct_spleen_model.ts"
|
|
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
model_config = {'env_file': '.env', 'env_file_encoding': 'utf-8'}
|
|
|
|
|
|
settings = Settings()
|
|
model = None
|
|
|
|
|
|
# --- 2. Load Model Function ---
|
|
def load_monai_model():
|
|
"""โหลด TorchScript model จาก MinIO"""
|
|
global model
|
|
try:
|
|
logger.info(f"Loading model '{settings.MODEL_FILE}' from MinIO...")
|
|
|
|
s3 = boto3.client(
|
|
"s3",
|
|
endpoint_url=settings.MINIO_ENDPOINT,
|
|
aws_access_key_id=settings.MINIO_ACCESS_KEY,
|
|
aws_secret_access_key=settings.MINIO_SECRET_KEY,
|
|
config=Config(signature_version="s3v4", connect_timeout=5, read_timeout=10)
|
|
)
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
local_path = os.path.join(temp_dir, settings.MODEL_FILE)
|
|
s3.download_file(settings.MODEL_BUCKET, settings.MODEL_FILE, local_path)
|
|
|
|
model_loaded = torch.jit.load(local_path, map_location=settings.DEVICE)
|
|
model_loaded.eval()
|
|
model = model_loaded
|
|
|
|
logger.info(f"Model '{settings.MODEL_FILE}' loaded successfully on {settings.DEVICE}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Model load failed: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Model Initialization Failed: {e}")
|
|
|
|
|
|
# --- 3. Lifespan Event Handler (แทน @app.on_event) ---
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global model
|
|
# Startup
|
|
load_monai_model()
|
|
yield
|
|
# Shutdown (optional cleanup)
|
|
model = None
|
|
logger.info("Model unloaded from memory.")
|
|
|
|
|
|
# --- 4. Create FastAPI App ---
|
|
app = FastAPI(
|
|
title="MONAI Model API",
|
|
description="FastAPI serving MONAI TorchScript model from MinIO",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
|
|
# --- 5. Root Endpoint ---
|
|
@app.get("/")
|
|
async def read_root():
|
|
return {
|
|
"status": "Service Running",
|
|
"model_loaded": model is not None,
|
|
"model_name": settings.MODEL_FILE,
|
|
"device": settings.DEVICE,
|
|
}
|
|
|
|
|
|
# --- 6. Reload Endpoint ---
|
|
@app.post("/reload")
|
|
async def reload_model():
|
|
"""รีโหลดโมเดลจาก MinIO โดยไม่ต้อง restart service"""
|
|
try:
|
|
load_monai_model()
|
|
return {"message": f"Model '{settings.MODEL_FILE}' reloaded successfully"}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# --- 7. MONAI Inference Endpoint ---
|
|
@app.post("/inference/spleen")
|
|
async def spleen_segmentation(file: UploadFile = File(...)):
|
|
"""
|
|
รับไฟล์ภาพทางการแพทย์ (NIfTI) และดำเนินการ Segmentation ม้าม
|
|
พร้อมคำนวณปริมาตรจริงและวิเคราะห์ภาวะม้ามโต
|
|
"""
|
|
if model is None:
|
|
raise HTTPException(status_code=503, detail="Model is not loaded. Please wait or check logs.")
|
|
|
|
# เกณฑ์การวิเคราะห์ม้ามโต (Splenomegaly Thresholds)
|
|
SPLENOMEGALY_THRESHOLD_CM3 = 450.0
|
|
|
|
# 1. บันทึกไฟล์ที่ได้รับชั่วคราว
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
input_path = os.path.join(temp_dir, file.filename)
|
|
|
|
try:
|
|
content = await file.read()
|
|
with open(input_path, "wb") as f:
|
|
f.write(content)
|
|
logger.info(f"Received file: {file.filename} saved to {input_path}")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to read/save uploaded file: {e}")
|
|
|
|
# 2. Pre-processing, Load Image และดึง Voxel Spacing (NEW)
|
|
try:
|
|
target_spacing = (1.5, 1.5, 2.0)
|
|
target_size = (96, 96, 96) # roi_size ที่โมเดลคาดหวัง
|
|
|
|
# --- ดึง Original Spacing ก่อน Transform ---
|
|
try:
|
|
nifti_img = nib.load(input_path)
|
|
# Dims [1, 2, 3] คือ Spacing สำหรับ x, y, z (มักเป็น mm)
|
|
original_spacing_mm = tuple(nifti_img.header['pixdim'][1:4].tolist())
|
|
logger.info(f"Original Voxel Spacing (mm): {original_spacing_mm}")
|
|
|
|
# คำนวณปริมาตรของ 1 Voxel ในภาพเดิม (mm³)
|
|
original_voxel_volume_mm3 = float(np.prod(original_spacing_mm))
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load NIfTI Header/Spacing via nibabel. Falling back to target_spacing for volume calculation. Error: {e}")
|
|
# หากดึง Spacing เดิมไม่ได้ ให้ใช้ Spacing ของภาพที่ Resample แล้วเป็นค่าประมาณ
|
|
original_voxel_volume_mm3 = float(np.prod(target_spacing))
|
|
|
|
# --- MONAI Transforms (Resampling to target_spacing เกิดขึ้นที่นี่) ---
|
|
transform = Compose([
|
|
LoadImage(image_only=True, ensure_channel_first=True, reader='NibabelReader'),
|
|
Orientation(axcodes='RAS'),
|
|
Spacing(pixdim=target_spacing, mode='bilinear'), # <--- Resampling ที่นี่!
|
|
|
|
NormalizeIntensity(subtrahend=None, divisor=None, channel_wise=False),
|
|
SpatialPad(spatial_size=target_size),
|
|
CenterSpatialCrop(roi_size=target_size),
|
|
])
|
|
|
|
img_data = transform(input_path)
|
|
|
|
logger.info(f"Input shape after transform: {img_data.shape}, dtype: {img_data.dtype}, min={img_data.min().item():.4f}, max={img_data.max().item():.4f}")
|
|
input_tensor = torch.as_tensor(img_data, dtype=torch.float32, device=settings.DEVICE).unsqueeze(0)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Pre-processing failed: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Image Pre-processing Error: {e}")
|
|
|
|
# 3. Inference (Same)
|
|
# 4. Post-processing (Same)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
roi_size = (96, 96, 96)
|
|
sw_batch_size = 4
|
|
prediction_raw = sliding_window_inference(
|
|
inputs=input_tensor, roi_size=roi_size, sw_batch_size=sw_batch_size,
|
|
predictor=model, overlap=0.5, mode="gaussian"
|
|
)
|
|
|
|
if prediction_raw.shape[1] > 1:
|
|
prediction_prob = torch.softmax(prediction_raw, dim=1)
|
|
segmentation_map = torch.argmax(prediction_prob, dim=1).cpu().numpy()[0]
|
|
elif prediction_raw.shape[1] == 1:
|
|
prediction_prob = torch.sigmoid(prediction_raw)
|
|
segmentation_map = (prediction_prob > 0.5).cpu().numpy()[0, 0]
|
|
else:
|
|
raise RuntimeError(f"Unexpected model output channel count: {prediction_raw.shape[1]}")
|
|
|
|
unique_labels = np.unique(segmentation_map)
|
|
logger.info(f"Unique labels in segmentation map: {unique_labels}")
|
|
|
|
# 5. Post-processing (คำนวณสถิติและปริมาตรจริง) (MODIFIED)
|
|
if not isinstance(segmentation_map, np.ndarray) or segmentation_map.ndim != 3:
|
|
logger.error("Segmentation map is not a 3D numpy array.")
|
|
raise RuntimeError("Post-processing failed to produce a valid 3D map.")
|
|
|
|
spleen_voxels = int(np.sum(segmentation_map == 1))
|
|
|
|
# เราจะใช้ target_spacing เพื่อความสอดคล้องกับ Segmentation Map ที่ได้
|
|
resampled_voxel_volume_mm3 = float(np.prod(target_spacing))
|
|
|
|
# ปริมาตรม้าม (cm³) คำนวณจาก Voxel ที่ Segmented และ Spacing หลัง Resample
|
|
spleen_volume_cm3 = (spleen_voxels * resampled_voxel_volume_mm3) / 1000.0 # mm³ → cm³
|
|
|
|
# วินิจฉัย Splenomegaly
|
|
if spleen_volume_cm3 < 350:
|
|
diagnosis = "Normal Spleen Size"
|
|
elif 350 <= spleen_volume_cm3 < SPLENOMEGALY_THRESHOLD_CM3:
|
|
diagnosis = "Borderline Enlarged"
|
|
else:
|
|
diagnosis = "Splenomegaly Detected"
|
|
|
|
logger.info(f"Spleen volume: {spleen_volume_cm3:.2f} cm³ → {diagnosis}")
|
|
|
|
# 6. ส่งผลลัพธ์กลับ
|
|
return {
|
|
"filename": file.filename,
|
|
"status": "Success",
|
|
"spleen_voxels_count": spleen_voxels,
|
|
"resampled_voxel_volume_cm3": round(resampled_voxel_volume_mm3 / 1000.0, 6),
|
|
"estimated_spleen_volume_cm3": round(spleen_volume_cm3, 2),
|
|
"diagnosis": diagnosis,
|
|
"splenomegaly_threshold_cm3": SPLENOMEGALY_THRESHOLD_CM3,
|
|
"message": "Segmentation, volume calculation using resampled spacing, and splenomegaly analysis complete."
|
|
} |