import os import tempfile import torch import logging import boto3 from botocore.client import Config from fastapi import FastAPI, HTTPException, UploadFile, File from contextlib import asynccontextmanager from pydantic_settings import BaseSettings import numpy as np import torch.nn.functional as F # MONAI Dependencies ที่จำเป็นสำหรับการประมวลผล from monai.inferers import sliding_window_inference from monai.transforms import Compose, LoadImage, EnsureChannelFirst, ScaleIntensityRange, SpatialPad, Spacing,Resize,CenterSpatialCrop, Orientation, NormalizeIntensity # --- Logging setup --- logger = logging.getLogger("uvicorn") logging.basicConfig(level=logging.INFO) # --- 1. Settings --- class Settings(BaseSettings): MINIO_ENDPOINT: str = "http://localhost:9000" MINIO_ACCESS_KEY: str = "minio_admin" MINIO_SECRET_KEY: str = "minio_p@ssw0rd!" MODEL_BUCKET: str = "models" MODEL_FILE: str = "spleen_ct_spleen_model.ts" DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" model_config = {'env_file': '.env', 'env_file_encoding': 'utf-8'} settings = Settings() model = None # --- 2. Load Model Function --- def load_monai_model(): """โหลด TorchScript model จาก MinIO""" global model try: logger.info(f"Loading model '{settings.MODEL_FILE}' from MinIO...") s3 = boto3.client( "s3", endpoint_url=settings.MINIO_ENDPOINT, aws_access_key_id=settings.MINIO_ACCESS_KEY, aws_secret_access_key=settings.MINIO_SECRET_KEY, config=Config(signature_version="s3v4", connect_timeout=5, read_timeout=10) ) with tempfile.TemporaryDirectory() as temp_dir: local_path = os.path.join(temp_dir, settings.MODEL_FILE) s3.download_file(settings.MODEL_BUCKET, settings.MODEL_FILE, local_path) model_loaded = torch.jit.load(local_path, map_location=settings.DEVICE) model_loaded.eval() model = model_loaded logger.info(f"Model '{settings.MODEL_FILE}' loaded successfully on {settings.DEVICE}") except Exception as e: logger.error(f"Model load failed: {e}") raise HTTPException(status_code=500, detail=f"Model Initialization Failed: {e}") # --- 3. Lifespan Event Handler (แทน @app.on_event) --- @asynccontextmanager async def lifespan(app: FastAPI): global model # Startup load_monai_model() yield # Shutdown (optional cleanup) model = None logger.info("Model unloaded from memory.") # --- 4. Create FastAPI App --- app = FastAPI( title="MONAI Model API", description="FastAPI serving MONAI TorchScript model from MinIO", version="1.0.0", lifespan=lifespan ) # --- 5. Root Endpoint --- @app.get("/") async def read_root(): return { "status": "Service Running", "model_loaded": model is not None, "model_name": settings.MODEL_FILE, "device": settings.DEVICE, } # --- 6. Reload Endpoint --- @app.post("/reload") async def reload_model(): """รีโหลดโมเดลจาก MinIO โดยไม่ต้อง restart service""" try: load_monai_model() return {"message": f"Model '{settings.MODEL_FILE}' reloaded successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # --- 7. MONAI Inference Endpoint --- @app.post("/inference/spleen") async def spleen_segmentation(file: UploadFile = File(...)): """ รับไฟล์ภาพทางการแพทย์ (NIfTI) และดำเนินการ Segmentation ม้าม พร้อมคำนวณปริมาตรจริงและวิเคราะห์ภาวะม้ามโต """ if model is None: raise HTTPException(status_code=503, detail="Model is not loaded. Please wait or check logs.") # เกณฑ์การวิเคราะห์ม้ามโต (Splenomegaly Thresholds) SPLENOMEGALY_THRESHOLD_CM3 = 450.0 # 1. บันทึกไฟล์ที่ได้รับชั่วคราว with tempfile.TemporaryDirectory() as temp_dir: input_path = os.path.join(temp_dir, file.filename) try: content = await file.read() with open(input_path, "wb") as f: f.write(content) logger.info(f"Received file: {file.filename} saved to {input_path}") except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to read/save uploaded file: {e}") # 2. Pre-processing, Load Image และดึง Voxel Spacing (NEW) try: target_spacing = (1.5, 1.5, 2.0) target_size = (96, 96, 96) # roi_size ที่โมเดลคาดหวัง # --- ดึง Original Spacing ก่อน Transform --- try: nifti_img = nib.load(input_path) # Dims [1, 2, 3] คือ Spacing สำหรับ x, y, z (มักเป็น mm) original_spacing_mm = tuple(nifti_img.header['pixdim'][1:4].tolist()) logger.info(f"Original Voxel Spacing (mm): {original_spacing_mm}") # คำนวณปริมาตรของ 1 Voxel ในภาพเดิม (mm³) original_voxel_volume_mm3 = float(np.prod(original_spacing_mm)) except Exception as e: logger.warning(f"Failed to load NIfTI Header/Spacing via nibabel. Falling back to target_spacing for volume calculation. Error: {e}") # หากดึง Spacing เดิมไม่ได้ ให้ใช้ Spacing ของภาพที่ Resample แล้วเป็นค่าประมาณ original_voxel_volume_mm3 = float(np.prod(target_spacing)) # --- MONAI Transforms (Resampling to target_spacing เกิดขึ้นที่นี่) --- transform = Compose([ LoadImage(image_only=True, ensure_channel_first=True, reader='NibabelReader'), Orientation(axcodes='RAS'), Spacing(pixdim=target_spacing, mode='bilinear'), # <--- Resampling ที่นี่! NormalizeIntensity(subtrahend=None, divisor=None, channel_wise=False), SpatialPad(spatial_size=target_size), CenterSpatialCrop(roi_size=target_size), ]) img_data = transform(input_path) logger.info(f"Input shape after transform: {img_data.shape}, dtype: {img_data.dtype}, min={img_data.min().item():.4f}, max={img_data.max().item():.4f}") input_tensor = torch.as_tensor(img_data, dtype=torch.float32, device=settings.DEVICE).unsqueeze(0) except Exception as e: logger.error(f"Pre-processing failed: {e}") raise HTTPException(status_code=500, detail=f"Image Pre-processing Error: {e}") # 3. Inference (Same) # 4. Post-processing (Same) model.eval() with torch.no_grad(): roi_size = (96, 96, 96) sw_batch_size = 4 prediction_raw = sliding_window_inference( inputs=input_tensor, roi_size=roi_size, sw_batch_size=sw_batch_size, predictor=model, overlap=0.5, mode="gaussian" ) if prediction_raw.shape[1] > 1: prediction_prob = torch.softmax(prediction_raw, dim=1) segmentation_map = torch.argmax(prediction_prob, dim=1).cpu().numpy()[0] elif prediction_raw.shape[1] == 1: prediction_prob = torch.sigmoid(prediction_raw) segmentation_map = (prediction_prob > 0.5).cpu().numpy()[0, 0] else: raise RuntimeError(f"Unexpected model output channel count: {prediction_raw.shape[1]}") unique_labels = np.unique(segmentation_map) logger.info(f"Unique labels in segmentation map: {unique_labels}") # 5. Post-processing (คำนวณสถิติและปริมาตรจริง) (MODIFIED) if not isinstance(segmentation_map, np.ndarray) or segmentation_map.ndim != 3: logger.error("Segmentation map is not a 3D numpy array.") raise RuntimeError("Post-processing failed to produce a valid 3D map.") spleen_voxels = int(np.sum(segmentation_map == 1)) # เราจะใช้ target_spacing เพื่อความสอดคล้องกับ Segmentation Map ที่ได้ resampled_voxel_volume_mm3 = float(np.prod(target_spacing)) # ปริมาตรม้าม (cm³) คำนวณจาก Voxel ที่ Segmented และ Spacing หลัง Resample spleen_volume_cm3 = (spleen_voxels * resampled_voxel_volume_mm3) / 1000.0 # mm³ → cm³ # วินิจฉัย Splenomegaly if spleen_volume_cm3 < 350: diagnosis = "Normal Spleen Size" elif 350 <= spleen_volume_cm3 < SPLENOMEGALY_THRESHOLD_CM3: diagnosis = "Borderline Enlarged" else: diagnosis = "Splenomegaly Detected" logger.info(f"Spleen volume: {spleen_volume_cm3:.2f} cm³ → {diagnosis}") # 6. ส่งผลลัพธ์กลับ return { "filename": file.filename, "status": "Success", "spleen_voxels_count": spleen_voxels, "resampled_voxel_volume_cm3": round(resampled_voxel_volume_mm3 / 1000.0, 6), "estimated_spleen_volume_cm3": round(spleen_volume_cm3, 2), "diagnosis": diagnosis, "splenomegaly_threshold_cm3": SPLENOMEGALY_THRESHOLD_CM3, "message": "Segmentation, volume calculation using resampled spacing, and splenomegaly analysis complete." }