from rest_framework import viewsets, filters from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework import status from django_filters.rest_framework import DjangoFilterBackend from .models import Training, Attendance, TrainingExercise from .serializers import TrainingSerializer, TrainingDetailSerializer, AttendanceSerializer, TrainingExerciseSerializer from wrestleDesk.pagination import StandardResultsSetPagination class TrainingExerciseViewSet(viewsets.ModelViewSet): queryset = TrainingExercise.objects.select_related('training', 'exercise').all() serializer_class = TrainingExerciseSerializer pagination_class = StandardResultsSetPagination permission_classes = [IsAuthenticated] filter_backends = [DjangoFilterBackend, filters.OrderingFilter] filterset_fields = ['training'] ordering_fields = ['order', 'id'] class TrainingViewSet(viewsets.ModelViewSet): queryset = Training.objects.select_related('location').prefetch_related('trainers', 'attendances').all() serializer_class = TrainingSerializer pagination_class = StandardResultsSetPagination permission_classes = [IsAuthenticated] filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter] filterset_fields = ['group', 'is_completed', 'date', 'location'] search_fields = ['notes'] ordering_fields = ['date', 'created_at'] def get_serializer_class(self): if self.action == 'retrieve': return TrainingDetailSerializer return TrainingSerializer def get_queryset(self): queryset = super().get_queryset() date_from = self.request.query_params.get('date_from') date_to = self.request.query_params.get('date_to') if date_from: queryset = queryset.filter(date__gte=date_from) if date_to: queryset = queryset.filter(date__lte=date_to) return queryset class AttendanceViewSet(viewsets.ModelViewSet): queryset = Attendance.objects.select_related('training', 'wrestler').all() serializer_class = AttendanceSerializer pagination_class = StandardResultsSetPagination permission_classes = [IsAuthenticated] filter_backends = [DjangoFilterBackend, filters.OrderingFilter] filterset_fields = ['training', 'wrestler'] ordering_fields = ['created_at'] def list(self, request, *args, **kwargs): training_id = self.request.query_params.get('training') if not training_id: return Response( {'error': 'training query parameter is required'}, status=status.HTTP_400_BAD_REQUEST ) return super().list(request, *args, **kwargs) def get_queryset(self): queryset = super().get_queryset() if self.action == 'list': training_id = self.request.query_params.get('training') if training_id: queryset = queryset.filter(training_id=training_id) return queryset