157 lines
8.2 KiB
Python
157 lines
8.2 KiB
Python
from rest_framework import viewsets, permissions, status
|
|
from rest_framework.response import Response
|
|
from rest_framework.decorators import action
|
|
from drf_spectacular.utils import extend_schema, extend_schema_view
|
|
|
|
from ..models import AiModel
|
|
from ..serializers.ai_model_serializer import AiModelSerializer
|
|
from ..repositories.ai_model_repository import AiModelRepository
|
|
from ..services.ai_model_service import AiModelService, ConnectionError
|
|
|
|
from permissions.permission_classes import IsAdminOrManager
|
|
from rest_framework.permissions import IsAuthenticated
|
|
from rest_framework.parsers import MultiPartParser
|
|
|
|
from rest_framework.throttling import UserRateThrottle
|
|
|
|
# Dependency Injection: สร้าง Instance ของ Repository และ Service
|
|
# สำหรับโปรเจกต์ขนาดเล็กสามารถทำแบบนี้ได้
|
|
repo = AiModelRepository()
|
|
service = AiModelService(repository=repo)
|
|
|
|
@extend_schema_view(
|
|
# 1. การดำเนินการ CRUD ปกติ: จัดอยู่ใน Model Registry
|
|
list=extend_schema(tags=['2. Model Registry & Metadata Management']),
|
|
retrieve=extend_schema(tags=['2. Model Registry & Metadata Management']),
|
|
create=extend_schema(tags=['2. Model Registry & Metadata Management']),
|
|
update=extend_schema(tags=['2. Model Registry & Metadata Management']),
|
|
partial_update=extend_schema(tags=['2. Model Registry & Metadata Management']),
|
|
destroy=extend_schema(tags=['2. Model Registry & Metadata Management']),
|
|
)
|
|
class AiModelRegistryViewSet(viewsets.ModelViewSet):
|
|
queryset = AiModel.objects.all()
|
|
serializer_class = AiModelSerializer
|
|
permission_classes = [permissions.IsAdminUser]
|
|
|
|
# -----------------------------------------------
|
|
# Override Create/List (เรียกใช้ Service Layer)
|
|
# -----------------------------------------------
|
|
def get_queryset(self):
|
|
# List/Retrieve จะเรียก Service Layer แทนการเรียก ORM โดยตรง
|
|
return service.get_all_models()
|
|
|
|
def create(self, request, *args, **kwargs):
|
|
serializer = self.get_serializer(data=request.data)
|
|
serializer.is_valid(raise_exception=True)
|
|
|
|
try:
|
|
new_model = service.create_model(serializer.validated_data)
|
|
return Response(self.get_serializer(new_model).data, status=status.HTTP_201_CREATED)
|
|
except Exception as e:
|
|
return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
def get_permissions(self):
|
|
# การทำงานที่เกี่ยวข้องกับการเขียน/แก้ไข/ลบ ถือเป็นงาน Admin/Manager
|
|
if self.action in ['create', 'update', 'partial_update', 'destroy', 'set_status']:
|
|
# จำกัดสิทธิ์ ต้องเป็น Admin หรือ Manager เท่านั้น
|
|
return [IsAdminOrManager()]
|
|
|
|
# การทำงานที่เกี่ยวข้องกับการอ่านข้อมูล (List, Retrieve, Test Connection)
|
|
# ถือเป็นงานสำหรับผู้ใช้งานที่ล็อกอินแล้ว (Viewer ขึ้นไป)
|
|
elif self.action in ['list', 'retrieve', 'test_connection']:
|
|
return [IsAuthenticated()]
|
|
|
|
# Default สำหรับเมธอดอื่น ๆ หรือ Custom Action ที่ไม่ได้ระบุ
|
|
return [IsAuthenticated()]
|
|
|
|
# -----------------------------------------------
|
|
# Custom Action: ทดสอบการเชื่อมต่อ
|
|
# -----------------------------------------------
|
|
@extend_schema(tags=['3. MLOps Control & Service Orchestration'])
|
|
@action(detail=True, methods=['post'], url_path='test-connection')
|
|
def test_connection(self, request, pk=None):
|
|
try:
|
|
result = service.test_connection(pk=int(pk))
|
|
return Response(result)
|
|
except ValueError as e:
|
|
return Response({"detail": str(e)}, status=status.HTTP_404_NOT_FOUND)
|
|
except ConnectionError as e:
|
|
# Response ด้วย error ที่ชัดเจนจากการเชื่อมต่อ
|
|
return Response({"status": "error", "detail": str(e)}, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
@extend_schema(tags=['3. MLOps Control & Service Orchestration'])
|
|
@action(detail=True, methods=['patch'], url_path='set-status')
|
|
def set_status(self, request, pk=None):
|
|
new_status = request.data.get('status')
|
|
if not new_status or new_status not in [choice[0] for choice in AiModel.status_choices]:
|
|
return Response({"detail": "Invalid status provided."}, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
updated_model = service.set_status(pk=int(pk), new_status=new_status)
|
|
if updated_model:
|
|
return Response(self.get_serializer(updated_model).data)
|
|
return Response({"detail": "Model not found."}, status=status.HTTP_404_NOT_FOUND)
|
|
|
|
# -----------------------------------------------
|
|
# Custom Action: Run Inference (Proxy)
|
|
# -----------------------------------------------
|
|
@extend_schema(
|
|
tags=['3. MLOps Control & Service Orchestration'],
|
|
# ระบุ Request Body เป็นประเภทไฟล์สำหรับ MultiPart
|
|
request={
|
|
'multipart/form-data': {
|
|
'type': 'object',
|
|
'properties': {
|
|
'file': {'type': 'string', 'format': 'binary', 'description': 'DICOM/NIfTI file for inference'},
|
|
},
|
|
'required': ['file']
|
|
}
|
|
},
|
|
# ระบุ Response Schema (ผลลัพธ์จาก AI Service)
|
|
responses={
|
|
200: {'description': 'Inference result', 'content': {'application/json': {'schema': {'type': 'object'}}}},
|
|
403: {'description': 'Model is not ACTIVE or insufficient permissions'},
|
|
404: {'description': 'Model not found'},
|
|
500: {'description': 'AI Service Connection/Internal Error'}
|
|
}
|
|
)
|
|
@action(
|
|
detail=True,
|
|
methods=['post'],
|
|
url_path='run-inference',
|
|
parser_classes=[MultiPartParser],
|
|
permission_classes=[IsAuthenticated],
|
|
throttle_classes=[UserRateThrottle] # ถ้าผู้ใช้เรียกเกิน 50 ครั้ง/นาที จะถูกปฏิเสธด้วย 429 Too Many Requests
|
|
)
|
|
def run_inference(self, request, pk=None):
|
|
"""
|
|
Endpoint: POST /api/v1/models/{pk}/run-inference/
|
|
ทำหน้าที่รับไฟล์แล้ว Proxy ไปยัง AI Service ภายนอก
|
|
"""
|
|
try:
|
|
model_id = int(pk)
|
|
except (TypeError, ValueError):
|
|
return Response({"detail": "Invalid Model ID format."}, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
# 1. ดึงไฟล์ (Frontend ส่งเป็น 'file')
|
|
file_data = request.FILES.get('file')
|
|
if not file_data:
|
|
return Response({"detail": "File 'file' is required."}, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
# 2. เรียก Service Layer
|
|
try:
|
|
result = service.run_inference(
|
|
pk=model_id,
|
|
file_data=file_data,
|
|
user_id=request.user.id # ส่ง User ID สำหรับ Audit Log
|
|
)
|
|
return Response(result, status=status.HTTP_200_OK)
|
|
|
|
except ValueError as e: # Model not found
|
|
return Response({"detail": str(e)}, status=status.HTTP_404_NOT_FOUND)
|
|
except PermissionError as e: # Model status INACTIVE
|
|
return Response({"detail": str(e)}, status=status.HTTP_403_FORBIDDEN)
|
|
except ConnectionError as e: # AI Service Failure
|
|
# HTTP_503_SERVICE_UNAVAILABLE หรือ HTTP_504_GATEWAY_TIMEOUT อาจเหมาะสมกว่า
|
|
return Response({"detail": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|
except Exception as e:
|
|
return Response({"detail": f"An unexpected error occurred: {e}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) |