DEV Community

Cover image for Custom Middleware for FastAPI
Santhosh Balasa
Santhosh Balasa

Posted on • Updated on

Custom Middleware for FastAPI

This article describes a custom middleware designed for FastAPI applications. The middleware serves several purposes:

  1. Token Authentication: It validates JWT tokens using Azure AD JWKS endpoint, ensuring that the users are authenticated.
  2. Content-Type Handling: It handles the content type of JavaScript files.
  3. Serving UI Files: It serves the UI files and provides a fallback to serve the index.html file for other requests.

Components

Environment Variables

The middleware utilizes the following environment variables:

  • AZURE_CLIENT_ID: Azure client ID.
  • AZURE_CLIENT_SECRET: Azure client secret.
  • AZURE_TENANT_ID: Azure tenant ID.
  • EXPECTED_AUDIENCE: Expected audience for the JWT token.
  • EXPECTED_ISSUER: Expected issuer for the JWT token, constructed using the AZURE_TENANT_ID.

Functions

error_response(error_msg: str, status_code: int) -> JSONResponse

This function logs the error message and returns a JSONResponse with the error message. It's used for generating standard error responses.

validate_token(token: str) -> bool

This function validates a JWT token using the Azure AD JWKS endpoint. It returns True if the token is valid, and False otherwise.

Custom Middleware Class

class CustomMiddleware(BaseHTTPMiddleware)

This class contains the core logic of the middleware.

  • dispatch Method: This method processes the request and sets the content-type header for JavaScript files or serves the index.html file as a fallback for other requests.

Workflow

  1. Token Validation: If the request URL starts with "/api", the middleware checks for an Authorization header and validates the token using the validate_token function.
  2. File Serving: If the request URL doesn't start with "/api", the middleware checks if the requested file exists and serves it with the appropriate content type.
  3. Fallback to index.html: If the requested file is not found, the middleware serves the index.html file.

Code

"""
Custom Middleware for FastAPI Application

This module provides a custom middleware class for FastAPI applications, focused on handling authentication and serving UI files.

- Token Authentication: Validates JWT tokens using the Azure AD JWKS endpoint.
- Content-Type Handling: Manages the content type of JavaScript files.
- Serving UI Files: Serves UI files and provides a fallback to serve the index.html file for other requests.

Environment Variables:
    AZURE_CLIENT_ID: Azure client ID.
    AZURE_CLIENT_SECRET: Azure client secret.
    AZURE_TENANT_ID: Azure tenant ID.
    EXPECTED_AUDIENCE: Expected audience for the JWT token.
    EXPECTED_ISSUER: Expected issuer for the JWT token.

Functions:
    error_response(error_msg: str, status_code: int) -> JSONResponse
    validate_token(token: str) -> bool

Classes:
    CustomMiddleware(BaseHTTPMiddleware): Custom middleware class.

Usage:
    Add the CustomMiddleware class to your FastAPI application's middleware stack.
"""

import os
import jwt
import json
import logging
import requests
import mimetypes


from starlette.requests import Request
from typing import Callable, Coroutine, Any
from starlette.responses import Response, JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware

UI_DIR = "ui/dist"
UI_PATH = "ui/dist/index.html"
AZURE_CLIENT_ID = os.getenv("AZURE_CLIENT_ID")
AZURE_CLIENT_SECRET = os.getenv("AZURE_CLIENT_SECRET")
AZURE_TENANT_ID = os.getenv("AZURE_TENANT_ID")
EXPECTED_AUDIENCE = os.getenv("EXPECTED_AUDIENCE")
EXPECTED_ISSUER = f"https://sts.windows.net/{AZURE_TENANT_ID}/"
logger = logging.getLogger("CustomMiddleware")


def error_response(error_msg: str, status_code: int) -> JSONResponse:
    """
    Logs the error message and returns a JSONResponse with the error message.

    Args:
        error_msg (str): the error message to be logged and returned in the JSONResponse
        status_code (int): the HTTP status code

    Returns:
        JSONResponse: JSONResponse containing the error message
    """
    logger.error(error_msg)
    return JSONResponse(
        content={"error": {"message": error_msg}},
        status_code=status_code,
    )


def validate_token(token: str) -> bool:
    """
    Validate a JWT token using the Azure AD JWKS endpoint.

    Args:
        token (str): The JWT token to be validated.

    Returns:
        bool: True if the token is valid, False otherwise.
    """
    jwks_uri = f"https://login.microsoftonline.com/{AZURE_TENANT_ID}/discovery/v2.0/keys"

    jwks = json.loads(requests.get(jwks_uri).text)
    header = jwt.get_unverified_header(token)

    signing_key = None
    for key in jwks["keys"]:
        if key["kid"] == header["kid"]:
            signing_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key))
            break

    if not signing_key:
        return False

    try:
        jwt.decode(
            token,
            signing_key,
            algorithms=["RS256"],
            audience=EXPECTED_AUDIENCE,
            issuer=EXPECTED_ISSUER,
        )
        return True
    except Exception as e:
        logger.error(f"Token validation failed: {e}")
        return False


class CustomMiddleware(BaseHTTPMiddleware):
    """
    Custom middleware for handling content-type of JavaScript files and serving
    index.html as a fallback for other requests.
    """

    async def dispatch(
        self,
        request: Request,
        call_next: Callable[[Request], Coroutine[Any, Any, Response]],
    ) -> Response:
        """
        Process the request and set the content-type header for JavaScript files or
        serve the index.html file as a fallback for other requests.

        Args:
            request (Request): The incoming request.
            call_next (Callable): The next middleware or handler in the stack.

        Returns:
            Response: The generated response.
        """
        logger.debug(f"Request URL path: {request.url.path}")
        if request.url.path.startswith("/api"):
            try:
                if "Authorization" not in request.headers:
                    return error_response("Unauthorized", 401)
                token_header = request.headers["Authorization"]
                if token_header.startswith("Bearer "):
                    token = token_header.split("Bearer ")[-1]
                    if not validate_token(token):
                        return error_response("Unauthorized", 401)
                else:
                    return error_response("Token should begin with Bearer", 400)
            except Exception as e:
                return error_response(f"{e}", 400)
            response = await call_next(request)
            return response

        file_path = os.path.join(UI_DIR, request.url.path.lstrip("/"))

        if os.path.isfile(file_path):
            with open(file_path, "rb") as file:
                content_type, _ = mimetypes.guess_type(file_path)
                return Response(content=file.read(), media_type=content_type)

        response = await call_next(request)
        if response.status_code == 404:
            with open(UI_PATH, "rb") as file:
                return Response(content=file.read(), media_type="text/html")
        return response

Enter fullscreen mode Exit fullscreen mode

Conclusion

This custom middleware provides a unified solution for handling authentication and serving static files in a FastAPI application. The use of JWT for authentication provides security, while the integrated approach to serving static files simplifies the application structure.

Top comments (0)