Spaces:
Sleeping
Sleeping
File size: 8,548 Bytes
d77abf8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 | """
Response compression middleware for FastAPI.
This module provides GZip and optionally Brotli compression
for API responses, improving transfer speeds significantly.
Usage:
from compression import CompressionMiddleware
app = FastAPI()
app.add_middleware(CompressionMiddleware, minimum_size=500)
"""
import gzip
import io
from typing import Callable, Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp
# Check if Brotli is available
try:
import brotli
BROTLI_AVAILABLE = True
except ImportError:
BROTLI_AVAILABLE = False
class CompressionMiddleware(BaseHTTPMiddleware):
"""
Middleware that compresses response bodies using GZip or Brotli.
Features:
- Automatically detects best compression from Accept-Encoding
- Skips small responses (configurable threshold)
- Skips already compressed content
- Preserves streaming responses
"""
# Content types that should be compressed
COMPRESSIBLE_TYPES = frozenset([
"application/json",
"application/xml",
"text/html",
"text/plain",
"text/css",
"text/javascript",
"application/javascript",
"application/x-javascript",
"image/svg+xml",
])
# Content types that should NOT be compressed (already compressed)
NON_COMPRESSIBLE_TYPES = frozenset([
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"application/zip",
"application/gzip",
"application/pdf",
"video/mp4",
"audio/mpeg",
])
def __init__(
self,
app: ASGIApp,
minimum_size: int = 500,
compression_level: int = 6,
prefer_brotli: bool = True
):
"""
Initialize compression middleware.
Args:
app: The ASGI application
minimum_size: Minimum response size to compress (bytes)
compression_level: GZip compression level (1-9)
prefer_brotli: Prefer Brotli over GZip if available
"""
super().__init__(app)
self.minimum_size = minimum_size
self.compression_level = compression_level
self.prefer_brotli = prefer_brotli and BROTLI_AVAILABLE
async def dispatch(
self,
request: Request,
call_next: Callable
) -> Response:
# Get accepted encodings
accept_encoding = request.headers.get("Accept-Encoding", "")
# Determine best compression
compression = self._get_best_compression(accept_encoding)
if not compression:
return await call_next(request)
# Get response
response = await call_next(request)
# Skip if streaming or already compressed
if isinstance(response, StreamingResponse):
return response
if response.headers.get("Content-Encoding"):
return response
# Check content type
content_type = response.headers.get("Content-Type", "")
base_content_type = content_type.split(";")[0].strip().lower()
if not self._should_compress(base_content_type):
return response
# Get body
body = b""
async for chunk in response.body_iterator:
body += chunk
# Skip small responses
if len(body) < self.minimum_size:
return Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
# Compress
compressed = self._compress(body, compression)
# Only use compressed if smaller
if len(compressed) >= len(body):
return Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
# Build response with compression
headers = dict(response.headers)
headers["Content-Encoding"] = compression
headers["Content-Length"] = str(len(compressed))
headers["Vary"] = "Accept-Encoding"
return Response(
content=compressed,
status_code=response.status_code,
headers=headers,
media_type=response.media_type
)
def _get_best_compression(self, accept_encoding: str) -> Optional[str]:
"""Determine the best compression method"""
accept_encoding = accept_encoding.lower()
# Check for Brotli first if preferred
if self.prefer_brotli and "br" in accept_encoding:
return "br"
if "gzip" in accept_encoding:
return "gzip"
# Fallback to Brotli if available
if BROTLI_AVAILABLE and "br" in accept_encoding:
return "br"
return None
def _should_compress(self, content_type: str) -> bool:
"""Check if content type should be compressed"""
if content_type in self.NON_COMPRESSIBLE_TYPES:
return False
if content_type in self.COMPRESSIBLE_TYPES:
return True
# Compress text and JSON-like types
return content_type.startswith(("text/", "application/json"))
def _compress(self, data: bytes, method: str) -> bytes:
"""Compress data using specified method"""
if method == "br" and BROTLI_AVAILABLE:
return brotli.compress(data, quality=4)
# GZip compression
buffer = io.BytesIO()
with gzip.GzipFile(
mode="wb",
fileobj=buffer,
compresslevel=self.compression_level
) as f:
f.write(data)
return buffer.getvalue()
class ETaggerMiddleware(BaseHTTPMiddleware):
"""
Middleware that adds ETag headers for cache validation.
This enables browsers to use If-None-Match headers
for conditional requests, reducing bandwidth.
"""
def __init__(
self,
app: ASGIApp,
weak_etag: bool = True
):
"""
Initialize ETag middleware.
Args:
app: The ASGI application
weak_etag: Use weak ETags (recommended for dynamic content)
"""
super().__init__(app)
self.weak_etag = weak_etag
async def dispatch(
self,
request: Request,
call_next: Callable
) -> Response:
# Only for GET and HEAD requests
if request.method not in ("GET", "HEAD"):
return await call_next(request)
response = await call_next(request)
# Skip if already has ETag or is streaming
if response.headers.get("ETag") or isinstance(response, StreamingResponse):
return await call_next(request)
# Get body
body = b""
async for chunk in response.body_iterator:
body += chunk
# Calculate ETag
import hashlib
hash_value = hashlib.md5(body).hexdigest()[:16]
etag = f'W/"{hash_value}"' if self.weak_etag else f'"{hash_value}"'
# Check If-None-Match header
if_none_match = request.headers.get("If-None-Match")
if if_none_match and if_none_match == etag:
return Response(
status_code=304,
headers={"ETag": etag}
)
# Add ETag to response
headers = dict(response.headers)
headers["ETag"] = etag
return Response(
content=body,
status_code=response.status_code,
headers=headers,
media_type=response.media_type
)
def get_compression_stats(original_size: int, compressed_size: int) -> dict:
"""
Calculate compression statistics.
Returns:
Dictionary with compression metrics
"""
ratio = (1 - compressed_size / original_size) * 100 if original_size > 0 else 0
return {
"original_size": original_size,
"compressed_size": compressed_size,
"saved_bytes": original_size - compressed_size,
"compression_ratio": f"{ratio:.1f}%"
}
|