236 lines
10 KiB
Python

import os
import tempfile
import torch
import logging
import boto3
from botocore.client import Config
from fastapi import FastAPI, HTTPException, UploadFile, File
from contextlib import asynccontextmanager
from pydantic_settings import BaseSettings
import numpy as np
import torch.nn.functional as F
# MONAI Dependencies ที่จำเป็นสำหรับการประมวลผล
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, LoadImage, EnsureChannelFirst, ScaleIntensityRange, SpatialPad, Spacing,Resize,CenterSpatialCrop, Orientation, NormalizeIntensity
# --- Logging setup ---
logger = logging.getLogger("uvicorn")
logging.basicConfig(level=logging.INFO)
# --- 1. Settings ---
class Settings(BaseSettings):
MINIO_ENDPOINT: str = "http://localhost:9000"
MINIO_ACCESS_KEY: str = "minio_admin"
MINIO_SECRET_KEY: str = "minio_p@ssw0rd!"
MODEL_BUCKET: str = "models"
MODEL_FILE: str = "spleen_ct_spleen_model.ts"
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
model_config = {'env_file': '.env', 'env_file_encoding': 'utf-8'}
settings = Settings()
model = None
# --- 2. Load Model Function ---
def load_monai_model():
"""โหลด TorchScript model จาก MinIO"""
global model
try:
logger.info(f"Loading model '{settings.MODEL_FILE}' from MinIO...")
s3 = boto3.client(
"s3",
endpoint_url=settings.MINIO_ENDPOINT,
aws_access_key_id=settings.MINIO_ACCESS_KEY,
aws_secret_access_key=settings.MINIO_SECRET_KEY,
config=Config(signature_version="s3v4", connect_timeout=5, read_timeout=10)
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = os.path.join(temp_dir, settings.MODEL_FILE)
s3.download_file(settings.MODEL_BUCKET, settings.MODEL_FILE, local_path)
model_loaded = torch.jit.load(local_path, map_location=settings.DEVICE)
model_loaded.eval()
model = model_loaded
logger.info(f"Model '{settings.MODEL_FILE}' loaded successfully on {settings.DEVICE}")
except Exception as e:
logger.error(f"Model load failed: {e}")
raise HTTPException(status_code=500, detail=f"Model Initialization Failed: {e}")
# --- 3. Lifespan Event Handler (แทน @app.on_event) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global model
# Startup
load_monai_model()
yield
# Shutdown (optional cleanup)
model = None
logger.info("Model unloaded from memory.")
# --- 4. Create FastAPI App ---
app = FastAPI(
title="MONAI Model API",
description="FastAPI serving MONAI TorchScript model from MinIO",
version="1.0.0",
lifespan=lifespan
)
# --- 5. Root Endpoint ---
@app.get("/")
async def read_root():
return {
"status": "Service Running",
"model_loaded": model is not None,
"model_name": settings.MODEL_FILE,
"device": settings.DEVICE,
}
# --- 6. Reload Endpoint ---
@app.post("/reload")
async def reload_model():
"""รีโหลดโมเดลจาก MinIO โดยไม่ต้อง restart service"""
try:
load_monai_model()
return {"message": f"Model '{settings.MODEL_FILE}' reloaded successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# --- 7. MONAI Inference Endpoint ---
@app.post("/inference/spleen")
async def spleen_segmentation(file: UploadFile = File(...)):
"""
รับไฟล์ภาพทางการแพทย์ (NIfTI) และดำเนินการ Segmentation ม้าม
พร้อมคำนวณปริมาตรจริงและวิเคราะห์ภาวะม้ามโต
"""
if model is None:
raise HTTPException(status_code=503, detail="Model is not loaded. Please wait or check logs.")
# เกณฑ์การวิเคราะห์ม้ามโต (Splenomegaly Thresholds)
SPLENOMEGALY_THRESHOLD_CM3 = 450.0
# 1. บันทึกไฟล์ที่ได้รับชั่วคราว
with tempfile.TemporaryDirectory() as temp_dir:
input_path = os.path.join(temp_dir, file.filename)
try:
content = await file.read()
with open(input_path, "wb") as f:
f.write(content)
logger.info(f"Received file: {file.filename} saved to {input_path}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to read/save uploaded file: {e}")
# 2. Pre-processing, Load Image และดึง Voxel Spacing (NEW)
try:
target_spacing = (1.5, 1.5, 2.0)
target_size = (96, 96, 96) # roi_size ที่โมเดลคาดหวัง
# --- ดึง Original Spacing ก่อน Transform ---
try:
nifti_img = nib.load(input_path)
# Dims [1, 2, 3] คือ Spacing สำหรับ x, y, z (มักเป็น mm)
original_spacing_mm = tuple(nifti_img.header['pixdim'][1:4].tolist())
logger.info(f"Original Voxel Spacing (mm): {original_spacing_mm}")
# คำนวณปริมาตรของ 1 Voxel ในภาพเดิม (mm³)
original_voxel_volume_mm3 = float(np.prod(original_spacing_mm))
except Exception as e:
logger.warning(f"Failed to load NIfTI Header/Spacing via nibabel. Falling back to target_spacing for volume calculation. Error: {e}")
# หากดึง Spacing เดิมไม่ได้ ให้ใช้ Spacing ของภาพที่ Resample แล้วเป็นค่าประมาณ
original_voxel_volume_mm3 = float(np.prod(target_spacing))
# --- MONAI Transforms (Resampling to target_spacing เกิดขึ้นที่นี่) ---
transform = Compose([
LoadImage(image_only=True, ensure_channel_first=True, reader='NibabelReader'),
Orientation(axcodes='RAS'),
Spacing(pixdim=target_spacing, mode='bilinear'), # <--- Resampling ที่นี่!
NormalizeIntensity(subtrahend=None, divisor=None, channel_wise=False),
SpatialPad(spatial_size=target_size),
CenterSpatialCrop(roi_size=target_size),
])
img_data = transform(input_path)
logger.info(f"Input shape after transform: {img_data.shape}, dtype: {img_data.dtype}, min={img_data.min().item():.4f}, max={img_data.max().item():.4f}")
input_tensor = torch.as_tensor(img_data, dtype=torch.float32, device=settings.DEVICE).unsqueeze(0)
except Exception as e:
logger.error(f"Pre-processing failed: {e}")
raise HTTPException(status_code=500, detail=f"Image Pre-processing Error: {e}")
# 3. Inference (Same)
# 4. Post-processing (Same)
model.eval()
with torch.no_grad():
roi_size = (96, 96, 96)
sw_batch_size = 4
prediction_raw = sliding_window_inference(
inputs=input_tensor, roi_size=roi_size, sw_batch_size=sw_batch_size,
predictor=model, overlap=0.5, mode="gaussian"
)
# ... (Softmax/Sigmoid/Argmax Logic - เหมือนเดิม) ...
if prediction_raw.shape[1] > 1:
prediction_prob = torch.softmax(prediction_raw, dim=1)
segmentation_map = torch.argmax(prediction_prob, dim=1).cpu().numpy()[0]
elif prediction_raw.shape[1] == 1:
prediction_prob = torch.sigmoid(prediction_raw)
segmentation_map = (prediction_prob > 0.5).cpu().numpy()[0, 0]
else:
raise RuntimeError(f"Unexpected model output channel count: {prediction_raw.shape[1]}")
unique_labels = np.unique(segmentation_map)
logger.info(f"Unique labels in segmentation map: {unique_labels}")
# 5. Post-processing (คำนวณสถิติและปริมาตรจริง) (MODIFIED)
if not isinstance(segmentation_map, np.ndarray) or segmentation_map.ndim != 3:
logger.error("Segmentation map is not a 3D numpy array.")
raise RuntimeError("Post-processing failed to produce a valid 3D map.")
spleen_voxels = int(np.sum(segmentation_map == 1))
# ต้อง Inverse Transform Segmentation Map กลับไปยัง Spacing เดิม
# อย่างไรก็ตาม ในการใช้งานจริง มักใช้ Voxel Volume ของ Resampled Image (ซึ่งมี Spacing คงที่)
# เนื่องจาก Monai/Inferer มักจะทำการ Resample ก่อน และ Volume Calculation ในงานวิจัย
# ส่วนใหญ่จะใช้น้ำหนัก Spacing หลัง Resample แล้ว (target_spacing) เพื่อรักษาความสม่ำเสมอ
# เราจะใช้ target_spacing เพื่อความสอดคล้องกับ Segmentation Map ที่ได้
resampled_voxel_volume_mm3 = float(np.prod(target_spacing))
# ปริมาตรม้าม (cm³) คำนวณจาก Voxel ที่ Segmented และ Spacing หลัง Resample
spleen_volume_cm3 = (spleen_voxels * resampled_voxel_volume_mm3) / 1000.0 # mm³ → cm³
# วินิจฉัย Splenomegaly
if spleen_volume_cm3 < 350:
diagnosis = "Normal Spleen Size"
elif 350 <= spleen_volume_cm3 < SPLENOMEGALY_THRESHOLD_CM3:
diagnosis = "Borderline Enlarged"
else:
diagnosis = "Splenomegaly Detected"
logger.info(f"Spleen volume: {spleen_volume_cm3:.2f} cm³ → {diagnosis}")
# 6. ส่งผลลัพธ์กลับ
return {
"filename": file.filename,
"status": "Success",
"spleen_voxels_count": spleen_voxels,
"resampled_voxel_volume_cm3": round(resampled_voxel_volume_mm3 / 1000.0, 6),
"estimated_spleen_volume_cm3": round(spleen_volume_cm3, 2),
"diagnosis": diagnosis,
"splenomegaly_threshold_cm3": SPLENOMEGALY_THRESHOLD_CM3,
"message": "Segmentation, volume calculation using resampled spacing, and splenomegaly analysis complete."
}