131 lines
4.6 KiB
Python
131 lines
4.6 KiB
Python
from rest_framework import status, viewsets
|
|
from rest_framework.decorators import api_view, permission_classes, throttle_classes, action
|
|
from rest_framework.permissions import AllowAny, IsAuthenticated
|
|
from rest_framework.response import Response
|
|
from rest_framework.throttling import AnonRateThrottle
|
|
from rest_framework_simplejwt.tokens import RefreshToken
|
|
from django.contrib.auth import authenticate
|
|
from django.contrib.auth.models import User
|
|
from .models import UserPreferences
|
|
from .serializers import (
|
|
LoginSerializer, RegisterSerializer, UserSerializer, UserPreferencesSerializer,
|
|
UserListSerializer, UserCreateSerializer, UserUpdateSerializer, PasswordChangeSerializer
|
|
)
|
|
from .permissions import HasUserManagementAccess
|
|
from wrestleDesk.pagination import StandardResultsSetPagination
|
|
|
|
|
|
class AuthRateThrottle(AnonRateThrottle):
|
|
rate = '5/minute'
|
|
|
|
|
|
@api_view(['POST'])
|
|
@permission_classes([AllowAny])
|
|
@throttle_classes([AuthRateThrottle])
|
|
def login(request):
|
|
serializer = LoginSerializer(data=request.data)
|
|
if serializer.is_valid():
|
|
user = authenticate(
|
|
username=serializer.validated_data['username'],
|
|
password=serializer.validated_data['password']
|
|
)
|
|
if user:
|
|
refresh = RefreshToken.for_user(user)
|
|
return Response({
|
|
'access': str(refresh.access_token),
|
|
'refresh': str(refresh),
|
|
'user': UserSerializer(user).data
|
|
})
|
|
return Response(
|
|
{'detail': 'Invalid credentials'},
|
|
status=status.HTTP_401_UNAUTHORIZED
|
|
)
|
|
return Response({'detail': serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
|
|
@api_view(['POST'])
|
|
@permission_classes([AllowAny])
|
|
@throttle_classes([AuthRateThrottle])
|
|
def register(request):
|
|
serializer = RegisterSerializer(data=request.data)
|
|
if serializer.is_valid():
|
|
user = serializer.save()
|
|
refresh = RefreshToken.for_user(user)
|
|
return Response({
|
|
'access': str(refresh.access_token),
|
|
'refresh': str(refresh),
|
|
'user': UserSerializer(user).data
|
|
}, status=status.HTTP_201_CREATED)
|
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
|
|
@api_view(['POST'])
|
|
@permission_classes([AllowAny])
|
|
@throttle_classes([AuthRateThrottle])
|
|
def refresh_token(request):
|
|
refresh_token = request.data.get('refresh')
|
|
if not refresh_token:
|
|
return Response(
|
|
{'detail': 'Refresh token required'},
|
|
status=status.HTTP_400_BAD_REQUEST
|
|
)
|
|
try:
|
|
refresh = RefreshToken(refresh_token)
|
|
return Response({
|
|
'access': str(refresh.access_token),
|
|
})
|
|
except Exception:
|
|
return Response(
|
|
{'detail': 'Invalid refresh token'},
|
|
status=status.HTTP_401_UNAUTHORIZED
|
|
)
|
|
|
|
|
|
@api_view(['GET'])
|
|
@permission_classes([IsAuthenticated])
|
|
def me(request):
|
|
return Response(UserSerializer(request.user).data)
|
|
|
|
|
|
@api_view(['GET', 'PATCH'])
|
|
@permission_classes([IsAuthenticated])
|
|
def user_preferences(request):
|
|
if request.method == 'GET':
|
|
prefs, _ = UserPreferences.objects.get_or_create(user=request.user)
|
|
serializer = UserPreferencesSerializer(prefs)
|
|
return Response(serializer.data)
|
|
|
|
elif request.method == 'PATCH':
|
|
prefs, _ = UserPreferences.objects.get_or_create(user=request.user)
|
|
serializer = UserPreferencesSerializer(prefs, data=request.data, partial=True)
|
|
if serializer.is_valid():
|
|
serializer.save()
|
|
return Response(serializer.data)
|
|
return Response(serializer.errors, status=400)
|
|
|
|
|
|
class UserManagementViewSet(viewsets.ModelViewSet):
|
|
queryset = User.objects.all().select_related('profile')
|
|
permission_classes = [HasUserManagementAccess]
|
|
pagination_class = StandardResultsSetPagination
|
|
|
|
def get_serializer_class(self):
|
|
if self.action == 'create':
|
|
return UserCreateSerializer
|
|
elif self.action in ['update', 'partial_update']:
|
|
return UserUpdateSerializer
|
|
return UserListSerializer
|
|
|
|
def get_queryset(self):
|
|
return User.objects.all().select_related('profile').order_by('-date_joined')
|
|
|
|
@action(detail=True, methods=['post'])
|
|
def set_password(self, request, pk=None):
|
|
user = self.get_object()
|
|
serializer = PasswordChangeSerializer(data=request.data)
|
|
if serializer.is_valid():
|
|
user.set_password(serializer.validated_data['password'])
|
|
user.save()
|
|
return Response({'status': 'password set'})
|
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|