diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..f650f84 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,20 @@ +.git +.github +.mypy_cache +.pytest_cache +.ruff_cache +.venv +__pycache__ +*.py[cod] +*.pyo +*.pyd +*.sqlite +*.db +.env +.env.* +!.env.example +Dockerfile +docker-compose*.yml +docs +tests +README.md diff --git a/.env.example b/.env.example index b019a4c..e7d437b 100644 --- a/.env.example +++ b/.env.example @@ -1,14 +1,47 @@ APP_NAME=Todo Modulith API APP_ENV=production -DATABASE_URL=postgresql+asyncpg://postgres:postgres@127.0.0.1:5432/todo_db +POSTGRES_USER=postgres +POSTGRES_PASSWORD= +POSTGRES_DB=todo_db +REDIS_PASSWORD= -REDIS_URL=redis://:password@127.0.0.1:6379/0 +DATABASE_URL= +DATABASE_POOL_SIZE=20 +DATABASE_MAX_OVERFLOW=10 +DATABASE_POOL_TIMEOUT=30 +DATABASE_POOL_RECYCLE=3600 -SECRET_KEY=your-super-secret-production-key-here +REDIS_URL= + +SECRET_KEY= + +MAX_REQUEST_SIZE_MB=5242880 #5mb ALGORITHM=HS256 +JWT_ISSUER=todo-modulith-api +JWT_AUDIENCE=todo-modulith-client ACCESS_TOKEN_EXPIRE_MINUTES=30 REFRESH_TOKEN_EXPIRE_MINUTES=10080 -RATE_LIMIT="100/minute" \ No newline at end of file +RATE_LIMIT="100/minute" + +CORS_ALLOW_ORIGINS=http://localhost:3000 +CORS_ALLOW_METHODS=* +CORS_ALLOW_HEADERS=* + +SECURITY_CONTENT_SECURITY_POLICY=default-src 'self'; frame-ancestors 'none' + +IDEMPOTENCY_TTL_SECONDS=86400 + +ACCOUNT_LOCKOUT_MAX_ATTEMPTS=5 +ACCOUNT_LOCKOUT_WINDOW_MINUTES=15 +ACCOUNT_LOCKOUT_DURATION_MINUTES=15 + +LOG_FORMAT=json + +SEED_ADMIN_EMAIL= +SEED_ADMIN_PASSWORD= +SEED_ADMIN_USERNAME=admin +SEED_ADMIN_FULLNAME=System Administrator +SEED_DEVELOPMENT_USERS_PASSWORD= diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..2401706 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,93 @@ +name: CI + +on: + pull_request: + push: + branches: + - main + tags: + - "v*.*.*" + +permissions: + contents: read + packages: write + +env: + IMAGE_NAME: ghcr.io/${{ github.repository }} + PYTHON_VERSION: "3.14" + POETRY_VERSION: "2.4.1" + +jobs: + verify: + name: Test and lint + runs-on: ubuntu-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Install Poetry + run: pipx install poetry==${{ env.POETRY_VERSION }} + + - name: Configure Poetry cache + uses: actions/cache@v4 + with: + path: ~/.cache/pypoetry + key: poetry-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }} + restore-keys: | + poetry-${{ runner.os }}-${{ env.PYTHON_VERSION }}- + + - name: Install dependencies + run: poetry install --with dev --no-interaction --no-ansi --no-root + + - name: Run lint + run: poetry run ruff check src tests scripts + + - name: Run tests + run: poetry run pytest -q + + docker: + name: Build and publish image + runs-on: ubuntu-latest + needs: verify + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GHCR + if: github.event_name == 'push' + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract Docker metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=tag + type=sha,prefix=sha- + + - name: Build image + uses: docker/build-push-action@v6 + with: + context: . + target: runtime + push: ${{ github.event_name == 'push' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/Dockerfile b/Dockerfile index aeb8e94..a71dacf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,25 +1,50 @@ -FROM python:3.11-slim +FROM python:3.14-slim AS builder + +ENV POETRY_VERSION=2.4.1 \ + POETRY_NO_INTERACTION=1 \ + POETRY_VIRTUALENVS_CREATE=1 \ + POETRY_VIRTUALENVS_IN_PROJECT=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 WORKDIR /app -# Install Poetry -RUN pip install poetry +RUN apt-get update \ + && apt-get install --no-install-recommends -y build-essential \ + && pip install --no-cache-dir "poetry==${POETRY_VERSION}" \ + && rm -rf /var/lib/apt/lists/* + +COPY pyproject.toml poetry.lock ./ + +RUN poetry install --only main --no-root --no-ansi -# Copy dependency files -COPY pyproject.toml poetry.lock* ./ -# Install dependencies without dev tools -RUN poetry config virtualenvs.create false \ - && poetry install --no-interaction --no-ansi --no-root +FROM python:3.14-slim AS runtime -# Copy source code +ENV APP_ENV=production \ + PATH="/app/.venv/bin:${PATH}" \ + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +WORKDIR /app + +RUN groupadd --system app \ + && useradd --system --gid app --home-dir /app --shell /usr/sbin/nologin app + +COPY --from=builder /app/.venv /app/.venv +COPY alembic.ini ./ +COPY alembic ./alembic +COPY scripts ./scripts COPY src ./src -# Expose port +RUN chmod +x /app/scripts/start.sh \ + && chown -R app:app /app + +USER app + EXPOSE 8000 -# Run application -COPY start.sh ./script/start.sh -RUN chmod +x start.sh +HEALTHCHECK --interval=30s --timeout=5s --start-period=30s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=3).read()" || exit 1 -CMD ["./script/start.sh"] \ No newline at end of file +CMD ["/app/scripts/start.sh"] diff --git a/Makefile b/Makefile index e818aee..c333ea9 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ COMPOSE_FILE := docker-compose.yml .DEFAULT_GOAL := help -.PHONY: help install run test lint import-check check migrate downgrade revision db-up db-down db-logs clean +.PHONY: help install run test lint import-check security-scan check migrate seed downgrade revision db-up db-down db-logs clean help: @echo "[make:help] Available commands:" @@ -19,8 +19,10 @@ help: @echo " [make:test] Run pytest" @echo " [make:lint] Run Ruff checks" @echo " [make:import-check] Verify src.main imports" + @echo " [make:security-scan] Run dependency vulnerability scan with pip-audit" @echo " [make:check] Run tests, lint, and import check" @echo " [make:migrate] Apply Alembic migrations" + @echo " [make:seed] Seed baseline database records" @echo " [make:downgrade] Roll back one Alembic migration" @echo " [make:revision] Create an Alembic migration: make revision name=\"describe change\"" @echo " [make:db-up] Start Docker Compose services" @@ -42,12 +44,20 @@ test: lint: @echo "[make:lint] Running Ruff checks" - @$(RUFF) check src tests + @$(RUFF) check src tests scripts import-check: @echo "[make:import-check] Verifying src.main imports" @PYTHONDONTWRITEBYTECODE=1 $(PYTHON) -c "import src.main; print('import ok')" +security-scan: + @echo "[make:security-scan] Running dependency vulnerability scan" + @if ! command -v pip-audit >/dev/null 2>&1; then \ + echo "[make:security-scan] pip-audit is not installed. Install it with: pip install pip-audit"; \ + exit 1; \ + fi + @pip-audit + check: test lint import-check @echo "[make:check] All checks completed" @@ -55,6 +65,10 @@ migrate: @echo "[make:migrate] Applying Alembic migrations" @$(ALEMBIC) upgrade head +seed: + @echo "[make:seed] Running database seeders" + @$(PYTHON) scripts/seed.py + downgrade: @echo "[make:downgrade] Rolling back one Alembic migration" @$(ALEMBIC) downgrade -1 diff --git a/README.md b/README.md index 6fc8a64..efbd827 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ The API is currently versioned under `/api/v1`. - [Docker Notes](#docker-notes) - [Development Guide](#development-guide) - [Troubleshooting](#troubleshooting) +- [Security TODO](#security-todo) - [Known Notes](#known-notes) ## Features @@ -229,16 +230,24 @@ Expected values: ```env APP_NAME=Todo Modulith API -DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/todo_db -SECRET_KEY=your-super-secret-production-key-here +POSTGRES_USER=postgres +POSTGRES_PASSWORD= +POSTGRES_DB=todo_db +REDIS_PASSWORD= +DATABASE_URL= +REDIS_URL= +SECRET_KEY= ALGORITHM=HS256 +JWT_ISSUER=todo-modulith-api +JWT_AUDIENCE=todo-modulith-client ACCESS_TOKEN_EXPIRE_MINUTES=30 +REFRESH_TOKEN_EXPIRE_MINUTES=10080 ``` For local development without Docker, point `DATABASE_URL` at your local PostgreSQL host, for example: ```env -DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/todo_db +DATABASE_URL=postgresql+asyncpg://postgres@localhost:5432/todo_db ``` ## Local Setup @@ -319,12 +328,46 @@ With Make: ```bash make migrate +make seed make revision name="add todo due date" make downgrade ``` Important: migration autogeneration depends on importing all SQLAlchemy models in `alembic/env.py`, so new module models must be imported there or through a central model registry. +Seed baseline authorization data after applying migrations: + +```bash +make seed +``` + +The seeder is idempotent. It creates default authorization resources, the default `admin` and `user` roles, default permissions, role-permission links, and matching Casbin policies without duplicating existing records. + +To seed an initial admin user, set these environment variables before running `make seed`: + +```env +SEED_ADMIN_EMAIL=admin@example.com +SEED_ADMIN_PASSWORD= +SEED_ADMIN_USERNAME=admin +SEED_ADMIN_FULLNAME=System Administrator +``` + +If `SEED_ADMIN_EMAIL` or `SEED_ADMIN_PASSWORD` is empty, user seeding is skipped. Existing users are not modified. + +When `APP_ENV=development`, the seeder can also create demo users with different roles. Set a shared development password before running `make seed`: + +```env +SEED_DEVELOPMENT_USERS_PASSWORD= +``` + +Development demo accounts: + +- `user@example.com` with the `user` role +- `manager@example.com` with the `manager` role +- `viewer@example.com` with the `viewer` role + +These users are skipped outside development and are not updated if they already exist. + ## Testing and Quality Checks Run tests: @@ -348,7 +391,7 @@ make check Current check set: - `pytest -q` -- `ruff check src tests` +- `ruff check src tests scripts` - import check for `src.main` ## Makefile Commands @@ -362,6 +405,7 @@ make lint make import-check make check make migrate +make seed make downgrade make revision name="migration message" make db-up @@ -476,9 +520,53 @@ Authorization: Bearer The token must contain a `sub` claim with a valid user id. +## Security TODO + +Legend: `Implemented` means code exists in the repository. `Partial` means code exists but still needs a fix, test, or production hardening. + +| Category | Recommended | Current Status | Notes | +| --- | --- | --- | --- | +| JWT Authentication | Required | Implemented | `AuthenticationMiddleware` validates bearer tokens for non-public routes. | +| Refresh Token Rotation | Required | Implemented | Refresh flow revokes the old refresh token and persists a new token. | +| RBAC + Permissions | Required | Implemented | Casbin-backed role and permission checks are wired through route dependencies. | +| Rate Limiting (Redis-backed) | Required | Implemented | Redis-backed limiter reads the configured `RATE_LIMIT` value. | +| Security Headers Middleware | Required | Implemented | Adds `X-Content-Type-Options`, `X-Frame-Options`, CSP `frame-ancestors`, `Referrer-Policy`, and `Permissions-Policy`. | +| CORS Configuration | Required | Implemented | CORS origins, methods, and headers are environment-driven through settings. | +| Request ID Middleware | Required | Implemented | Generates or propagates `X-Request-ID` and stores it on request state. | +| Audit Logging | Required | Implemented | Adds global endpoint audit logging, domain audit events, and separate persisted error traces. | +| Structured Logging | Required | Implemented | Logs request ID, method, path, status, latency, and user context when available. | +| Global Exception Handling | Required | Implemented | Domain exceptions are registered explicitly and `Exception` is used only as the fallback handler. | +| Input Validation | Required | Implemented | Pydantic schemas and application validation functions are used across user and todo flows. | +| Password Hashing (Argon2 or bcrypt) | Required | Implemented | User auth service uses bcrypt hashing. | +| Account Lockout | Required | Implemented | Tracks failed logins and temporarily locks accounts after configured thresholds. | +| Token Revocation | Required | Implemented | Refresh tokens are revoked on rotation/logout, and access tokens are denylisted in Redis until expiry. | +| OpenAPI Authentication | Required | Implemented | Swagger OAuth2 auth is configured, and docs/OpenAPI endpoints are disabled when `APP_ENV=production`. | +| Health Check Endpoint | Required | Implemented | `/health` endpoint returns service health. | +| Readiness/Liveness Endpoints | Required | Implemented | Adds `/live` and `/ready` operational endpoints. | +| Request Size Limiting | Required | Implemented | `LimitRequestSizeMiddleware` rejects oversized write requests. | +| Idempotency Support (for applicable POST endpoints) | Optional but valuable | Implemented | Supports `Idempotency-Key` replay caching for POST responses. | +| Database Migrations | Required | Implemented | Alembic is configured with migration commands in the README and Makefile. | +| Dependency Injection | Required | Implemented | FastAPI dependencies wire repositories, handlers, auth, authorization, and database sessions. | +| Configuration via Environment Variables | Required | Implemented | Pydantic settings read `.env` and reject the default secret key in production. | + +### Next Implementation Checklist + +- [x] Fix and verify rate limit configuration wiring. +- [x] Add security headers middleware. +- [x] Add request ID middleware. +- [x] Add structured request logging. +- [x] Add audit logging for sensitive actions. +- [x] Add account lockout or equivalent failed-login protection. +- [x] Disable or authenticate `/docs`, `/redoc`, and `/openapi.json` in production. +- [x] Add readiness and liveness endpoints. +- [x] Add production config validation for secrets and unsafe defaults. +- [x] Harden CORS through environment-driven allowed origins, methods, and headers. +- [x] Review exception responses to avoid leaking token parsing details or internal exception messages. +- [x] Add automated tests for request size limits, rate limiting, auth failures, authorization failures, CORS, security headers, and request IDs. +- [x] Add dependency vulnerability scanning to local or CI checks, for example `pip-audit` or an equivalent Poetry-compatible scanner. + ## Known Notes -- `alembic/env.py` currently prints metadata debug output during migrations. - `src/core/lifespan.py` still calls `Base.metadata.create_all`; with Alembic in place, production environments normally rely on migrations instead. - The project has a Pydantic v2 deprecation warning for class-based settings config. - The Dockerfile start script path needs alignment before relying on Docker builds. diff --git a/alembic/env.py b/alembic/env.py index 884b547..87a7004 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -13,6 +13,9 @@ from src.core.authorization.infrastructure.models.permission_model import ( PermissionModel, # noqa: F401 ) +from src.core.authorization.infrastructure.models.resource_model import ( + AuthorizationResourceModel, # noqa: F401 +) from src.core.authorization.infrastructure.models.role_model import ( RoleModel, # noqa: F401 ) @@ -23,6 +26,15 @@ UserHasRoleModel, # noqa: F401 ) from src.core.config.setting import get_settings +from src.core.security.infrastructure.models.audit_log_model import ( + AuditLogModel, # noqa: F401 +) +from src.core.security.infrastructure.models.error_trace_model import ( + ErrorTraceModel, # noqa: F401 +) +from src.core.security.infrastructure.models.login_attempt_model import ( + LoginAttemptModel, # noqa: F401 +) from src.modules.todo.infrastructure.models.todo_model import TodoModel # noqa: F401 from src.modules.user.infrastructure.models.refresh_token_model import ( RefreshTokenModel, # noqa: F401 @@ -32,10 +44,6 @@ settings = get_settings() -print( - "🔍 ALEMBIC DEBUG: Tables found in metadata ->", list(Base.metadata.tables.keys()) -) - config = context.config if config.config_file_name is not None: diff --git a/alembic/versions/b2f4c7d9a1e0_add_security_audit_and_login_attempts.py b/alembic/versions/b2f4c7d9a1e0_add_security_audit_and_login_attempts.py new file mode 100644 index 0000000..f130ae6 --- /dev/null +++ b/alembic/versions/b2f4c7d9a1e0_add_security_audit_and_login_attempts.py @@ -0,0 +1,111 @@ +"""add security audit and login attempts + +Revision ID: b2f4c7d9a1e0 +Revises: aa90557ef712 +Create Date: 2026-06-19 00:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = "b2f4c7d9a1e0" +down_revision: Union[str, Sequence[str], None] = "aa90557ef712" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "audit_logs", + sa.Column("action", sa.String(length=120), nullable=False), + sa.Column("actor_id", sa.String(length=64), nullable=True), + sa.Column("resource_type", sa.String(length=80), nullable=True), + sa.Column("resource_id", sa.String(length=64), nullable=True), + sa.Column("request_id", sa.String(length=120), nullable=True), + sa.Column("meta", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_audit_logs_action", "audit_logs", ["action"], unique=False) + op.create_index("ix_audit_logs_actor_id", "audit_logs", ["actor_id"], unique=False) + op.create_index( + "ix_audit_logs_created_at", "audit_logs", ["created_at"], unique=False + ) + op.create_index( + "ix_audit_logs_request_id", "audit_logs", ["request_id"], unique=False + ) + + op.create_table( + "login_attempts", + sa.Column("email", sa.String(length=255), nullable=False), + sa.Column("occurred_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("locked_until", sa.DateTime(timezone=True), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_login_attempts_email", "login_attempts", ["email"], unique=False) + op.create_index( + "ix_login_attempts_locked_until", + "login_attempts", + ["locked_until"], + unique=False, + ) + op.create_index( + "ix_login_attempts_occurred_at", + "login_attempts", + ["occurred_at"], + unique=False, + ) + + op.create_table( + "error_traces", + sa.Column("error_type", sa.String(length=120), nullable=False), + sa.Column("message", sa.Text(), nullable=False), + sa.Column("traceback", sa.Text(), nullable=False), + sa.Column("method", sa.String(length=12), nullable=False), + sa.Column("path", sa.String(length=500), nullable=False), + sa.Column("actor_id", sa.String(length=64), nullable=True), + sa.Column("request_id", sa.String(length=120), nullable=True), + sa.Column("meta", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_error_traces_actor_id", "error_traces", ["actor_id"], unique=False + ) + op.create_index( + "ix_error_traces_created_at", "error_traces", ["created_at"], unique=False + ) + op.create_index( + "ix_error_traces_error_type", "error_traces", ["error_type"], unique=False + ) + op.create_index("ix_error_traces_path", "error_traces", ["path"], unique=False) + op.create_index( + "ix_error_traces_request_id", "error_traces", ["request_id"], unique=False + ) + + +def downgrade() -> None: + op.drop_index("ix_error_traces_request_id", table_name="error_traces") + op.drop_index("ix_error_traces_path", table_name="error_traces") + op.drop_index("ix_error_traces_error_type", table_name="error_traces") + op.drop_index("ix_error_traces_created_at", table_name="error_traces") + op.drop_index("ix_error_traces_actor_id", table_name="error_traces") + op.drop_table("error_traces") + + op.drop_index("ix_login_attempts_occurred_at", table_name="login_attempts") + op.drop_index("ix_login_attempts_locked_until", table_name="login_attempts") + op.drop_index("ix_login_attempts_email", table_name="login_attempts") + op.drop_table("login_attempts") + + op.drop_index("ix_audit_logs_request_id", table_name="audit_logs") + op.drop_index("ix_audit_logs_created_at", table_name="audit_logs") + op.drop_index("ix_audit_logs_actor_id", table_name="audit_logs") + op.drop_index("ix_audit_logs_action", table_name="audit_logs") + op.drop_table("audit_logs") diff --git a/alembic/versions/c7a1b9e5d4f2_rename_authorization_description_columns.py b/alembic/versions/c7a1b9e5d4f2_rename_authorization_description_columns.py new file mode 100644 index 0000000..490460c --- /dev/null +++ b/alembic/versions/c7a1b9e5d4f2_rename_authorization_description_columns.py @@ -0,0 +1,34 @@ +"""rename authorization description columns + +Revision ID: c7a1b9e5d4f2 +Revises: b2f4c7d9a1e0 +Create Date: 2026-06-19 00:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op + + +revision: str = "c7a1b9e5d4f2" +down_revision: Union[str, Sequence[str], None] = "b2f4c7d9a1e0" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _rename_column(table_name: str, old_name: str, new_name: str) -> None: + with op.batch_alter_table(table_name) as batch_op: + batch_op.alter_column(old_name, new_column_name=new_name) + + +def upgrade() -> None: + """Upgrade schema.""" + _rename_column("permissions", "descpription", "description") + _rename_column("roles", "descpription", "description") + + +def downgrade() -> None: + """Downgrade schema.""" + _rename_column("roles", "description", "descpription") + _rename_column("permissions", "description", "descpription") diff --git a/alembic/versions/d9a7c3f2b6e1_add_authorization_resources.py b/alembic/versions/d9a7c3f2b6e1_add_authorization_resources.py new file mode 100644 index 0000000..6bbb8f5 --- /dev/null +++ b/alembic/versions/d9a7c3f2b6e1_add_authorization_resources.py @@ -0,0 +1,124 @@ +"""add authorization resources + +Revision ID: d9a7c3f2b6e1 +Revises: c7a1b9e5d4f2 +Create Date: 2026-06-19 00:00:00.000000 + +""" + +from typing import Sequence, Union +from uuid import uuid4 + +from alembic import op +import sqlalchemy as sa + + +revision: str = "d9a7c3f2b6e1" +down_revision: Union[str, Sequence[str], None] = "c7a1b9e5d4f2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.create_table( + "authorization_resources", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("key", sa.String(length=100), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_authorization_resources_key"), + "authorization_resources", + ["key"], + unique=True, + ) + + with op.batch_alter_table("permissions") as batch_op: + batch_op.add_column(sa.Column("resource_id", sa.Uuid(), nullable=True)) + batch_op.create_index( + op.f("ix_permissions_resource_id"), + ["resource_id"], + unique=False, + ) + + bind = op.get_bind() + resources = [ + row[0] + for row in bind.execute( + sa.text("select distinct resource from permissions where resource is not null") + ) + ] + + resource_ids = {} + for resource in resources: + resource_id = uuid4() + resource_ids[resource] = resource_id + bind.execute( + sa.text( + """ + insert into authorization_resources + (id, key, name, description) + values + (:id, :key, :name, :description) + """ + ), + { + "id": resource_id, + "key": resource, + "name": resource.replace("_", " ").title(), + "description": f"{resource} resources", + }, + ) + + for resource, resource_id in resource_ids.items(): + bind.execute( + sa.text( + """ + update permissions + set resource_id = :resource_id + where resource = :resource + """ + ), + {"resource_id": resource_id, "resource": resource}, + ) + + with op.batch_alter_table("permissions") as batch_op: + batch_op.create_foreign_key( + "fk_permissions_resource_id_authorization_resources", + "authorization_resources", + ["resource_id"], + ["id"], + ) + + +def downgrade() -> None: + """Downgrade schema.""" + with op.batch_alter_table("permissions") as batch_op: + batch_op.drop_constraint( + "fk_permissions_resource_id_authorization_resources", + type_="foreignkey", + ) + batch_op.drop_index(op.f("ix_permissions_resource_id")) + batch_op.drop_column("resource_id") + + op.drop_index( + op.f("ix_authorization_resources_key"), + table_name="authorization_resources", + ) + op.drop_table("authorization_resources") diff --git a/docker-compose.yml b/docker-compose.yml index cc5239f..e710572 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,25 +1,77 @@ services: api: - build: . + build: + context: . + target: runtime + image: fastapi-modulith:local + restart: unless-stopped ports: - - "8000:8000" + - "${APP_PORT:-8000}:8000" env_file: - .env + environment: + APP_ENV: production + DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:?Set POSTGRES_PASSWORD in .env}@db:5432/${POSTGRES_DB:-todo_db} + REDIS_URL: redis://:${REDIS_PASSWORD:?Set REDIS_PASSWORD in .env}@redis:6379/0 depends_on: - - db - volumes: - - ./src:/app/src # For hot-reloading during dev + db: + condition: service_healthy + redis: + condition: service_healthy + healthcheck: + test: + [ + "CMD", + "python", + "-c", + "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=3).read()", + ] + interval: 30s + timeout: 5s + retries: 3 + start_period: 30s db: - image: postgres:15-alpine + image: postgres:17-alpine + restart: unless-stopped environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: todo_db - ports: - - "5432:5432" + POSTGRES_USER: ${POSTGRES_USER:-postgres} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:?Set POSTGRES_PASSWORD in .env} + POSTGRES_DB: ${POSTGRES_DB:-todo_db} volumes: - postgres_data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER} -d $${POSTGRES_DB}"] + interval: 10s + timeout: 5s + retries: 5 + + redis: + image: redis:8-alpine + restart: unless-stopped + command: + [ + "redis-server", + "--appendonly", + "yes", + "--requirepass", + "${REDIS_PASSWORD:?Set REDIS_PASSWORD in .env}", + ] + volumes: + - redis_data:/data + healthcheck: + test: + [ + "CMD", + "redis-cli", + "-a", + "${REDIS_PASSWORD:?Set REDIS_PASSWORD in .env}", + "ping", + ] + interval: 10s + timeout: 5s + retries: 5 volumes: - postgres_data: \ No newline at end of file + postgres_data: + redis_data: diff --git a/docs/QUERY_OPTIMIZATION.md b/docs/QUERY_OPTIMIZATION.md new file mode 100644 index 0000000..8678993 --- /dev/null +++ b/docs/QUERY_OPTIMIZATION.md @@ -0,0 +1,632 @@ +# Query Optimization Guide + +> Best practices for optimizing PostgreSQL queries in a FastAPI + SQLAlchemy application. + +--- + +# Table of Contents + +1. General Principles +2. Indexing +3. Query Design +4. SQLAlchemy Best Practices +5. Pagination +6. Avoid N+1 Queries +7. Batch Operations +8. Transactions +9. Query Analysis +10. PostgreSQL Optimization +11. Connection Pool +12. Caching +13. Soft Delete Optimization +14. Full Text Search +15. Monitoring +16. Performance Checklist + +--- + +# General Principles + +## Only Select Required Columns + +❌ Bad + +```python +select(User) +``` + +✅ Good + +```python +select( + User.id, + User.email, +) +``` + +or + +```python +select(User.id, User.email) +``` + +Avoid loading unnecessary columns. + +--- + +## Avoid SELECT * + +Never write + +```sql +SELECT * +FROM users; +``` + +Prefer + +```sql +SELECT + id, + email, + fullname +FROM users; +``` + +--- + +## Filter Early + +Good + +```sql +SELECT * +FROM users +WHERE deleted_at IS NULL; +``` + +Avoid filtering in Python. + +--- + +# Indexing + +## Primary Key + +```sql +PRIMARY KEY(id) +``` + +Already indexed. + +--- + +## Foreign Keys + +Always create indexes. + +```sql +CREATE INDEX idx_user_role_user +ON user_roles(user_id); +``` + +--- + +## Frequently Filtered Columns + +Example + +```sql +email +username +status +deleted_at +created_at +``` + +--- + +## Composite Index + +Instead of + +```sql +WHERE organization_id = ? +AND deleted_at IS NULL +``` + +Use + +```sql +CREATE INDEX idx_org_deleted +ON users( + organization_id, + deleted_at +); +``` + +--- + +## Partial Index + +Excellent for soft delete. + +```sql +CREATE INDEX idx_users_active +ON users(email) +WHERE deleted_at IS NULL; +``` + +--- + +# Query Design + +## Use EXISTS + +Instead of + +```sql +SELECT COUNT(*) +``` + +Use + +```sql +SELECT EXISTS( + SELECT 1 +); +``` + +SQLAlchemy + +```python +from sqlalchemy import exists +``` + +--- + +## LIMIT + +Always limit results. + +```sql +LIMIT 50 +``` + +Never fetch millions of rows. + +--- + +## ORDER BY Indexed Columns + +Good + +```sql +ORDER BY created_at DESC +``` + +with + +```sql +INDEX(created_at) +``` + +--- + +# SQLAlchemy Best Practices + +## Use select() + +Prefer + +```python +stmt = select(User) +``` + +Avoid legacy Query API. + +--- + +## Load Only Needed Fields + +```python +stmt = select( + User.id, + User.email, +) +``` + +--- + +## scalars() + +Good + +```python +users = await session.scalars(stmt) +``` + +instead of + +```python +await session.execute(stmt) +``` + +when selecting ORM models. + +--- + +## one_or_none() + +Use + +```python +result.scalar_one_or_none() +``` + +instead of + +```python +all() +``` + +when expecting a single record. + +--- + +# Pagination + +Avoid + +```sql +OFFSET 100000 +``` + +Prefer Cursor Pagination. + +Example + +```sql +WHERE id > ? +ORDER BY id +LIMIT 50 +``` + +--- + +# Avoid N+1 Queries + +Bad + +```python +for user in users: + print(user.roles) +``` + +Good + +```python +select(User).options( + selectinload(User.roles) +) +``` + +or + +```python +joinedload() +``` + +depending on the use case. + +--- + +# Batch Operations + +Instead of + +```python +for user in users: + session.add(user) +``` + +Use + +```python +session.add_all(users) +``` + +--- + +Bulk Update + +```python +update(User) +``` + +--- + +Bulk Delete + +```python +delete(User) +``` + +--- + +# Transactions + +Keep transactions short. + +Good + +``` +Begin + +Update + +Commit +``` + +Avoid + +``` +Begin + +HTTP Request + +Redis + +Email + +Commit +``` + +--- + +# Query Analysis + +Use + +```sql +EXPLAIN ANALYZE +``` + +Example + +```sql +EXPLAIN ANALYZE +SELECT * +FROM users +WHERE email='john@test.com'; +``` + +Look for + +- Seq Scan +- Bitmap Heap Scan +- Index Scan + +Prefer + +``` +Index Scan +``` + +--- + +# PostgreSQL Optimization + +Vacuum + +```sql +VACUUM ANALYZE; +``` + +Auto Vacuum should be enabled. + +--- + +Update Statistics + +```sql +ANALYZE users; +``` + +--- + +Avoid Huge JSON Columns + +Store large blobs separately. + +--- + +Use UUID + +Prefer + +``` +UUID v7 +``` + +or + +``` +UUID v4 +``` + +instead of sequential integers when appropriate. + +--- + +# Connection Pool + +Recommended + +```python +create_async_engine( + DATABASE_URL, + pool_size=20, + max_overflow=40, + pool_pre_ping=True, + pool_recycle=1800, +) +``` + +Do not create an engine per request. + +--- + +# Caching + +Cache + +- User Profile +- Roles +- Permissions +- Settings + +Use Redis. + +Avoid caching mutable transactional data unless invalidation is handled. + +--- + +# Soft Delete + +Always filter + +```sql +deleted_at IS NULL +``` + +Create index + +```sql +CREATE INDEX idx_deleted +ON users(deleted_at); +``` + +Better + +```sql +CREATE INDEX idx_active_users +ON users(email) +WHERE deleted_at IS NULL; +``` + +--- + +# Full Text Search + +Instead of + +```sql +LIKE '%john%' +``` + +Use + +```sql +GIN Index +``` + +and + +```sql +tsvector +``` + +Example + +```sql +CREATE INDEX idx_users_search +ON users +USING gin(search_vector); +``` + +--- + +# Monitoring + +Enable + +```sql +pg_stat_statements +``` + +Useful query + +```sql +SELECT + query, + calls, + total_exec_time, + mean_exec_time +FROM pg_stat_statements +ORDER BY total_exec_time DESC; +``` + +--- + +Check Slow Queries + +```sql +log_min_duration_statement = 500 +``` + +--- + +# Performance Checklist + +## SQL + +- [ ] No SELECT * +- [ ] Uses LIMIT +- [ ] Uses indexes +- [ ] No unnecessary ORDER BY +- [ ] No unnecessary DISTINCT +- [ ] Uses EXISTS instead of COUNT where appropriate + +--- + +## SQLAlchemy + +- [ ] Uses AsyncSession +- [ ] Uses select() +- [ ] Uses scalars() +- [ ] Uses selectinload()/joinedload() +- [ ] Avoids N+1 queries +- [ ] Loads only required columns + +--- + +## PostgreSQL + +- [ ] EXPLAIN ANALYZE checked +- [ ] Indexes created +- [ ] Partial indexes for soft delete +- [ ] Foreign key indexes +- [ ] Composite indexes +- [ ] VACUUM enabled +- [ ] ANALYZE updated + +--- + +## API + +- [ ] Cursor pagination +- [ ] Request caching +- [ ] Rate limiting +- [ ] Query timeout +- [ ] Connection pooling + +--- + +## Production + +- [ ] pg_stat_statements enabled +- [ ] Slow query logging +- [ ] Metrics (Prometheus/Grafana) +- [ ] Redis cache +- [ ] Read replicas (if needed) \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 9741ca0..a7385db 100644 --- a/poetry.lock +++ b/poetry.lock @@ -690,6 +690,28 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<1.0)"] +[[package]] +name = "httpcore2" +version = "2.4.0" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "httpcore2-2.4.0-py3-none-any.whl", hash = "sha256:5218779da5d6e3c2013ac706121abfb3815d450e0613495c0de50264dce58242"}, + {file = "httpcore2-2.4.0.tar.gz", hash = "sha256:3093a8ab8980d9f910b9cb4351df9186a0ad2350a6284a9107ac9a362a584422"}, +] + +[package.dependencies] +h11 = ">=0.16" +truststore = ">=0.10" + +[package.extras] +asyncio = ["anyio (>=4.5.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<1.0)"] + [[package]] name = "httptools" version = "0.8.0" @@ -775,6 +797,31 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "httpx2" +version = "2.4.0" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "httpx2-2.4.0-py3-none-any.whl", hash = "sha256:425acd99297829599decf6701386dd84db3542597d36d3e2e4def930ecd57fd9"}, + {file = "httpx2-2.4.0.tar.gz", hash = "sha256:32e0734b61eb0824b3f56a9e98d6d92d381a3ef12c0045aa917ee63df6c411ef"}, +] + +[package.dependencies] +anyio = "*" +httpcore2 = "2.4.0" +idna = ">=3.18" +truststore = ">=0.10" + +[package.extras] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<16)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0) ; python_version <= \"3.13\""] + [[package]] name = "idna" version = "3.18" @@ -1748,6 +1795,18 @@ anyio = ">=3.6.2,<5" [package.extras] full = ["httpx (>=0.27.0,<0.29.0)", "httpx2 (>=2.0.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] +[[package]] +name = "truststore" +version = "0.10.4" +description = "Verify certificates using native system trust stores" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "truststore-0.10.4-py3-none-any.whl", hash = "sha256:adaeaecf1cbb5f4de3b1959b42d41f6fab57b2b1666adb59e89cb0b53361d981"}, + {file = "truststore-0.10.4.tar.gz", hash = "sha256:9d91bd436463ad5e4ee4aba766628dd6cd7010cf3e2461756b3303710eebc301"}, +] + [[package]] name = "typing-extensions" version = "4.15.0" @@ -2060,4 +2119,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = ">=3.14,<4.0" -content-hash = "b17d5872c4517e4e2a1f2005ab90769081eb05ececef562fdcc0b74af10270c5" +content-hash = "4ee3e01b31cef121c12b68b75f108185eeb015abba83fd268ccee3cda302ef2a" diff --git a/pyproject.toml b/pyproject.toml index c9c58b3..25e29e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,5 +37,6 @@ dev = [ "httpx (>=0.28.1,<0.29.0)", "ruff (>=0.15.17,<0.16.0)", "mypy (>=2.1.0,<3.0.0)", - "alembic (>=1.18.4,<2.0.0)" + "alembic (>=1.18.4,<2.0.0)", + "httpx2 (>=2.4.0,<3.0.0)" ] diff --git a/scripts/seed.py b/scripts/seed.py new file mode 100644 index 0000000..b87c6b1 --- /dev/null +++ b/scripts/seed.py @@ -0,0 +1,31 @@ +import asyncio +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from src.core.seed.runner import run_seeders # noqa: E402 + + +async def main() -> None: + result = await run_seeders() + authorization = result.authorization + print( + "[seed:authorization] " + f"resources_created={authorization.resources_created} " + f"roles_created={authorization.roles_created} " + f"permissions_created={authorization.permissions_created} " + f"role_permissions_created={authorization.role_permissions_created} " + f"policies_created={authorization.policies_created}" + ) + user = result.user + print( + "[seed:user] " + f"users_created={user.users_created} " + f"roles_assigned={user.roles_assigned}" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/start.sh b/scripts/start.sh index 80c2ef1..d6d26bb 100644 --- a/scripts/start.sh +++ b/scripts/start.sh @@ -1,13 +1,8 @@ -#!/bin/bash -# start.sh +#!/usr/bin/env bash +set -euo pipefail echo "Running database migrations..." -poetry run alembic upgrade head - -if [ $? -ne 0 ]; then - echo "Migration failed!" - exit 1 -fi +alembic upgrade head echo "Starting FastAPI application..." -exec poetry run uvicorn src.main:app --host 0.0.0.0 --port 8000 --workers 4 \ No newline at end of file +exec uvicorn src.main:app --host 0.0.0.0 --port "${PORT:-8000}" --workers "${WEB_CONCURRENCY:-4}" diff --git a/src/core/authorization/domain/service.py b/src/core/authorization/domain/service.py index b2fed2c..3d0537c 100644 --- a/src/core/authorization/domain/service.py +++ b/src/core/authorization/domain/service.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod +from datetime import datetime from uuid import UUID +from src.core.utils.cursor import CursorDirection from src.modules.authorization.domain.entities.permission import Permission from src.modules.authorization.domain.entities.role import Role @@ -38,6 +40,16 @@ async def get_role(self, role_id: UUID) -> Role | None: async def list_roles(self) -> list[Role]: pass + @abstractmethod + async def list_roles_cursor( + self, + cursor_created_at: datetime | None = None, + cursor_id: UUID | None = None, + limit: int = 10, + direction: CursorDirection = CursorDirection.DIRECTION_NEXT, + ) -> tuple[list[Role], bool]: + pass + @abstractmethod async def create_permission(self, permission: Permission) -> Permission: pass @@ -58,6 +70,16 @@ async def get_permission(self, permission_id: UUID) -> Permission | None: async def list_permissions(self) -> list[Permission]: pass + @abstractmethod + async def list_permissions_cursor( + self, + cursor_created_at: datetime | None = None, + cursor_id: UUID | None = None, + limit: int = 10, + direction: CursorDirection = CursorDirection.DIRECTION_NEXT, + ) -> tuple[list[Permission], bool]: + pass + @abstractmethod async def assign_permission_to_role( self, diff --git a/src/core/authorization/infrastructure/models/permission_model.py b/src/core/authorization/infrastructure/models/permission_model.py index 07b060d..fe816aa 100644 --- a/src/core/authorization/infrastructure/models/permission_model.py +++ b/src/core/authorization/infrastructure/models/permission_model.py @@ -1,4 +1,6 @@ -from sqlalchemy import String, UniqueConstraint +from uuid import UUID + +from sqlalchemy import ForeignKey, String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column from src.shared.database.mixin.timestamp import SoftDeleteMixin, TimeStampMixin @@ -12,6 +14,10 @@ class PermissionModel(Base, TimeStampMixin, SoftDeleteMixin): ) key: Mapped[str] = mapped_column(String(255), unique=True, index=True) + resource_id: Mapped[UUID] = mapped_column( + ForeignKey("authorization_resources.id"), + index=True, + ) resource: Mapped[str] = mapped_column(String(100), index=True) action: Mapped[str] = mapped_column(String(100), index=True) - descpription: Mapped[str] = mapped_column(String(255), nullable=True) + description: Mapped[str | None] = mapped_column(String(255), nullable=True) diff --git a/src/core/authorization/infrastructure/models/resource_model.py b/src/core/authorization/infrastructure/models/resource_model.py new file mode 100644 index 0000000..95c2082 --- /dev/null +++ b/src/core/authorization/infrastructure/models/resource_model.py @@ -0,0 +1,13 @@ +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column + +from src.shared.database.mixin.timestamp import SoftDeleteMixin, TimeStampMixin +from src.shared.database.model import Base + + +class AuthorizationResourceModel(Base, TimeStampMixin, SoftDeleteMixin): + __tablename__ = "authorization_resources" + + key: Mapped[str] = mapped_column(String(100), unique=True, index=True) + name: Mapped[str] = mapped_column(String(150), nullable=False) + description: Mapped[str | None] = mapped_column(String(255), nullable=True) diff --git a/src/core/authorization/infrastructure/models/role_model.py b/src/core/authorization/infrastructure/models/role_model.py index 990c8d1..7be2550 100644 --- a/src/core/authorization/infrastructure/models/role_model.py +++ b/src/core/authorization/infrastructure/models/role_model.py @@ -9,4 +9,4 @@ class RoleModel(Base, TimeStampMixin, SoftDeleteMixin): __tablename__ = "roles" name: Mapped[str] = mapped_column(String(100), unique=True, index=True) - descpription: Mapped[str] = mapped_column(String(255), nullable=True) + description: Mapped[str | None] = mapped_column(String(255), nullable=True) diff --git a/src/core/authorization/infrastructure/repositories/casbin_policy_repository.py b/src/core/authorization/infrastructure/repositories/casbin_policy_repository.py index 8c48705..7d0696e 100644 --- a/src/core/authorization/infrastructure/repositories/casbin_policy_repository.py +++ b/src/core/authorization/infrastructure/repositories/casbin_policy_repository.py @@ -1,22 +1,28 @@ +from datetime import datetime from uuid import UUID -from sqlalchemy import delete, select +from sqlalchemy import and_, delete, or_, select from sqlalchemy.ext.asyncio import AsyncSession +from src.core.utils.cursor import CursorDirection from src.core.authorization.infrastructure.models.casbin_rule_model import ( CasbinRuleModel, ) from src.core.authorization.infrastructure.models.permission_model import ( PermissionModel, ) +from src.core.authorization.infrastructure.models.resource_model import ( + AuthorizationResourceModel, +) +from src.core.authorization.infrastructure.models.role_model import RoleModel from src.core.authorization.infrastructure.models.role_permission_model import ( RolePermissionModel, ) -from src.core.authorization.infrastructure.models.role_model import RoleModel from src.core.authorization.infrastructure.models.user_has_role_model import ( UserHasRoleModel, ) from src.modules.authorization.domain.entities.permission import Permission +from src.modules.authorization.domain.entities.resource import AuthorizationResource from src.modules.authorization.domain.entities.role import Role @@ -29,6 +35,10 @@ async def load_policy_lines(self) -> list[str]: rules = result.scalars().all() return [self._to_policy_line(rule) for rule in rules] + async def list_policies(self) -> list[tuple[str, ...]]: + result = await self._db.execute(select(CasbinRuleModel)) + return [self._to_policy_tuple(rule) for rule in result.scalars().all()] + async def add_policy(self, ptype: str, *values: str) -> None: existing = await self._db.execute( select(CasbinRuleModel).where( @@ -98,18 +108,38 @@ async def get_roles_for_subject(self, subject: str) -> list[str]: ) return list(result.scalars().all()) + async def create_resource( + self, + resource: AuthorizationResource, + ) -> AuthorizationResource: + model = AuthorizationResourceModel( + id=resource.id, + key=resource.key, + name=resource.name, + description=resource.description, + ) + self._db.add(model) + await self._db.flush() + return self._resource_from_model(model) + + async def list_resources(self) -> list[AuthorizationResource]: + result = await self._db.execute(select(AuthorizationResourceModel)) + return [self._resource_from_model(model) for model in result.scalars().all()] + async def create_role(self, role: Role) -> Role: model = RoleModel( id=role.id, name=role.name, - descpription=role.description, + description=role.description, ) self._db.add(model) await self._db.flush() return self._role_from_model(model) async def get_role(self, role_id: UUID) -> Role | None: - result = await self._db.execute(select(RoleModel).where(RoleModel.id == role_id)) + result = await self._db.execute( + select(RoleModel).where(RoleModel.id == role_id) + ) model = result.scalar_one_or_none() if model is None: return None @@ -119,15 +149,43 @@ async def list_roles(self) -> list[Role]: result = await self._db.execute(select(RoleModel)) return [self._role_from_model(model) for model in result.scalars().all()] + async def list_roles_cursor( + self, + cursor_created_at: datetime | None = None, + cursor_id: UUID | None = None, + limit: int = 10, + direction: CursorDirection = CursorDirection.DIRECTION_NEXT, + ) -> tuple[list[Role], bool]: + query = select(RoleModel) + query = self._apply_cursor_pagination( + query, + RoleModel, + cursor_created_at, + cursor_id, + direction, + ).limit(limit + 1) + + result = await self._db.execute(query) + models = list(result.scalars().all()) + has_more = len(models) > limit + models = models[:limit] + + if direction == CursorDirection.DIRECTION_PREV: + models = list(reversed(models)) + + return [self._role_from_model(model) for model in models], has_more + async def update_role(self, role: Role) -> Role | None: - result = await self._db.execute(select(RoleModel).where(RoleModel.id == role.id)) + result = await self._db.execute( + select(RoleModel).where(RoleModel.id == role.id) + ) model = result.scalar_one_or_none() if model is None: return None old_name = model.name model.name = role.name - model.descpription = role.description + model.description = role.description await self._db.flush() if old_name != role.name: @@ -156,12 +214,14 @@ async def delete_role(self, role_id: UUID) -> None: await self._db.flush() async def create_permission(self, permission: Permission) -> Permission: + resource = await self._get_or_create_resource(permission.resource) model = PermissionModel( id=permission.id, key=permission.key, + resource_id=resource.id, resource=permission.resource, action=permission.action, - descpription=permission.description, + description=permission.description, ) self._db.add(model) await self._db.flush() @@ -178,10 +238,33 @@ async def get_permission(self, permission_id: UUID) -> Permission | None: async def list_permissions(self) -> list[Permission]: result = await self._db.execute(select(PermissionModel)) - return [ - self._permission_from_model(model) - for model in result.scalars().all() - ] + return [self._permission_from_model(model) for model in result.scalars().all()] + + async def list_permissions_cursor( + self, + cursor_created_at: datetime | None = None, + cursor_id: UUID | None = None, + limit: int = 10, + direction: CursorDirection = CursorDirection.DIRECTION_NEXT, + ) -> tuple[list[Permission], bool]: + query = select(PermissionModel) + query = self._apply_cursor_pagination( + query, + PermissionModel, + cursor_created_at, + cursor_id, + direction, + ).limit(limit + 1) + + result = await self._db.execute(query) + models = list(result.scalars().all()) + has_more = len(models) > limit + models = models[:limit] + + if direction == CursorDirection.DIRECTION_PREV: + models = list(reversed(models)) + + return [self._permission_from_model(model) for model in models], has_more async def update_permission(self, permission: Permission) -> Permission | None: result = await self._db.execute( @@ -192,10 +275,12 @@ async def update_permission(self, permission: Permission) -> Permission | None: return None old_key = model.key + resource = await self._get_or_create_resource(permission.resource) model.key = permission.key + model.resource_id = resource.id model.resource = permission.resource model.action = permission.action - model.descpription = permission.description + model.description = permission.description await self._db.flush() if old_key != permission.key: @@ -250,6 +335,18 @@ async def assign_permission_to_role( await self.add_policy("p", role.name, permission.key) + async def list_role_permissions(self) -> list[tuple[str, str]]: + result = await self._db.execute( + select(RoleModel.name, PermissionModel.key) + .join(RolePermissionModel, RolePermissionModel.role_id == RoleModel.id) + .join( + PermissionModel, PermissionModel.id == RolePermissionModel.permission_id + ) + ) + return [ + (role_name, permission_key) for role_name, permission_key in result.all() + ] + async def remove_permission_from_role( self, role_id: UUID, @@ -273,6 +370,32 @@ def _to_policy_line(self, rule: CasbinRuleModel) -> str: populated = [value for value in values if value is not None] return ", ".join([rule.ptype, *populated]) + def _to_policy_tuple(self, rule: CasbinRuleModel) -> tuple[str, ...]: + values = [rule.v0, rule.v1, rule.v2, rule.v3, rule.v4, rule.v5] + populated = [value for value in values if value is not None] + return (rule.ptype, *populated) + + async def _get_or_create_resource( + self, resource_key: str + ) -> AuthorizationResourceModel: + result = await self._db.execute( + select(AuthorizationResourceModel).where( + AuthorizationResourceModel.key == resource_key, + ) + ) + model = result.scalar_one_or_none() + if model is not None: + return model + + model = AuthorizationResourceModel( + key=resource_key, + name=resource_key.replace("_", " ").title(), + description=f"{resource_key} resources", + ) + self._db.add(model) + await self._db.flush() + return model + def _value_at(self, values: tuple[str, ...], index: int) -> str | None: if index >= len(values): return None @@ -318,7 +441,20 @@ def _role_from_model(self, model: RoleModel) -> Role: return Role( id=model.id, name=model.name, - description=model.descpription, + description=model.description, + created_at=model.created_at.isoformat(), + updated_at=model.updated_at.isoformat(), + ) + + def _resource_from_model( + self, + model: AuthorizationResourceModel, + ) -> AuthorizationResource: + return AuthorizationResource( + id=model.id, + key=model.key, + name=model.name, + description=model.description, ) def _permission_from_model(self, model: PermissionModel) -> Permission: @@ -327,5 +463,41 @@ def _permission_from_model(self, model: PermissionModel) -> Permission: key=model.key, resource=model.resource, action=model.action, - description=model.descpription, + description=model.description, + created_at=model.created_at.isoformat(), + updated_at=model.updated_at.isoformat(), ) + + def _apply_cursor_pagination( + self, + query, + model, + cursor_created_at: datetime | None, + cursor_id: UUID | None, + direction: CursorDirection, + ): + if cursor_created_at and cursor_id: + if direction == CursorDirection.DIRECTION_NEXT: + query = query.where( + or_( + model.created_at < cursor_created_at, + and_( + model.created_at == cursor_created_at, + model.id < cursor_id, + ), + ) + ) + return query.order_by(model.created_at.desc(), model.id.desc()) + + query = query.where( + or_( + model.created_at > cursor_created_at, + and_( + model.created_at == cursor_created_at, + model.id > cursor_id, + ), + ) + ) + return query.order_by(model.created_at.asc(), model.id.asc()) + + return query.order_by(model.created_at.desc(), model.id.desc()) diff --git a/src/core/authorization/infrastructure/services/casbin_authorization_service.py b/src/core/authorization/infrastructure/services/casbin_authorization_service.py index b6448c0..727cd9e 100644 --- a/src/core/authorization/infrastructure/services/casbin_authorization_service.py +++ b/src/core/authorization/infrastructure/services/casbin_authorization_service.py @@ -1,3 +1,4 @@ +from datetime import datetime from uuid import UUID from src.core.authorization.domain.service import AuthorizationService @@ -5,6 +6,7 @@ SQLAlchemyCasbinPolicyRepository, ) from src.core.authorization.permissions import permission_key +from src.core.utils.cursor import CursorDirection from src.modules.authorization.domain.entities.permission import Permission from src.modules.authorization.domain.entities.role import Role @@ -55,6 +57,20 @@ async def get_role(self, role_id: UUID) -> Role | None: async def list_roles(self) -> list[Role]: return await self._policy_repository.list_roles() + async def list_roles_cursor( + self, + cursor_created_at: datetime | None = None, + cursor_id: UUID | None = None, + limit: int = 10, + direction: CursorDirection = CursorDirection.DIRECTION_NEXT, + ) -> tuple[list[Role], bool]: + return await self._policy_repository.list_roles_cursor( + cursor_created_at=cursor_created_at, + cursor_id=cursor_id, + limit=limit, + direction=direction, + ) + async def create_permission(self, permission: Permission) -> Permission: return await self._policy_repository.create_permission(permission) @@ -70,6 +86,20 @@ async def get_permission(self, permission_id: UUID) -> Permission | None: async def list_permissions(self) -> list[Permission]: return await self._policy_repository.list_permissions() + async def list_permissions_cursor( + self, + cursor_created_at: datetime | None = None, + cursor_id: UUID | None = None, + limit: int = 10, + direction: CursorDirection = CursorDirection.DIRECTION_NEXT, + ) -> tuple[list[Permission], bool]: + return await self._policy_repository.list_permissions_cursor( + cursor_created_at=cursor_created_at, + cursor_id=cursor_id, + limit=limit, + direction=direction, + ) + async def assign_permission_to_role( self, role_id: UUID, diff --git a/src/core/authorization/permissions.py b/src/core/authorization/permissions.py index 3739549..faa6980 100644 --- a/src/core/authorization/permissions.py +++ b/src/core/authorization/permissions.py @@ -1,7 +1,48 @@ -TODO_RESOURCE = "todo" -USER_RESOURCE = "user" -ROLE_RESOURCE = "role" -PERMISSION_RESOURCE = "permission" +from dataclasses import dataclass + + +@dataclass(frozen=True) +class AuthorizationResourceDefinition: + key: str + name: str + description: str + + +@dataclass(frozen=True) +class AuthorizationRoleDefinition: + name: str + description: str + + +DEFAULT_RESOURCES = ( + AuthorizationResourceDefinition( + key="todo", + name="Todo", + description="Todo task resources", + ), + AuthorizationResourceDefinition( + key="user", + name="User", + description="User account resources", + ), + AuthorizationResourceDefinition( + key="role", + name="Role", + description="Authorization role resources", + ), + AuthorizationResourceDefinition( + key="permission", + name="Permission", + description="Authorization permission resources", + ), +) + +DEFAULT_RESOURCE_KEYS = {resource.key: resource.key for resource in DEFAULT_RESOURCES} + +TODO_RESOURCE = DEFAULT_RESOURCE_KEYS["todo"] +USER_RESOURCE = DEFAULT_RESOURCE_KEYS["user"] +ROLE_RESOURCE = DEFAULT_RESOURCE_KEYS["role"] +PERMISSION_RESOURCE = DEFAULT_RESOURCE_KEYS["permission"] CREATE_ACTION = "create" READ_ACTION = "read" @@ -11,6 +52,27 @@ DEFAULT_USER_ROLE = "user" ADMIN_ROLE = "admin" +MANAGER_ROLE = "manager" +VIEWER_ROLE = "viewer" + +DEFAULT_ROLES = ( + AuthorizationRoleDefinition( + name=ADMIN_ROLE, + description="Administrator with full platform access", + ), + AuthorizationRoleDefinition( + name=DEFAULT_USER_ROLE, + description="Default authenticated user", + ), + AuthorizationRoleDefinition( + name=MANAGER_ROLE, + description="Manager user with todo management access", + ), + AuthorizationRoleDefinition( + name=VIEWER_ROLE, + description="Read-only user", + ), +) def permission_key(resource: str, action: str) -> str: @@ -24,4 +86,11 @@ def permission_key(resource: str, action: str) -> str: ("p", DEFAULT_USER_ROLE, permission_key(TODO_RESOURCE, UPDATE_ACTION)), ("p", DEFAULT_USER_ROLE, permission_key(TODO_RESOURCE, DELETE_ACTION)), ("p", DEFAULT_USER_ROLE, permission_key(USER_RESOURCE, ME_ACTION)), + ("p", MANAGER_ROLE, permission_key(TODO_RESOURCE, CREATE_ACTION)), + ("p", MANAGER_ROLE, permission_key(TODO_RESOURCE, READ_ACTION)), + ("p", MANAGER_ROLE, permission_key(TODO_RESOURCE, UPDATE_ACTION)), + ("p", MANAGER_ROLE, permission_key(TODO_RESOURCE, DELETE_ACTION)), + ("p", MANAGER_ROLE, permission_key(USER_RESOURCE, ME_ACTION)), + ("p", VIEWER_ROLE, permission_key(TODO_RESOURCE, READ_ACTION)), + ("p", VIEWER_ROLE, permission_key(USER_RESOURCE, ME_ACTION)), ) diff --git a/src/core/bootstrap/exception.py b/src/core/bootstrap/exception.py index 65c9fba..b91405e 100644 --- a/src/core/bootstrap/exception.py +++ b/src/core/bootstrap/exception.py @@ -3,6 +3,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from src.core.exceptions.handler import ( + DOMAIN_EXCEPTION_MAP, domain_exception_handler, global_exception_handler, http_exception_handler, @@ -14,8 +15,7 @@ def register_exception(app: FastAPI): app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_exception_handler(StarletteHTTPException, http_exception_handler) - # Catches custom domain exceptions - app.add_exception_handler(Exception, domain_exception_handler) + for exception_type in DOMAIN_EXCEPTION_MAP: + app.add_exception_handler(exception_type, domain_exception_handler) - # Fallback (Note: order matters, put specific ones first) app.add_exception_handler(Exception, global_exception_handler) diff --git a/src/core/bootstrap/middleware.py b/src/core/bootstrap/middleware.py index 18d2bf0..a554b55 100644 --- a/src/core/bootstrap/middleware.py +++ b/src/core/bootstrap/middleware.py @@ -1,15 +1,35 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from src.core.config.setting import get_settings +from src.core.middleware.audit_logging import AuditLoggingMiddleware from src.core.middleware.auth import AuthenticationMiddleware +from src.core.middleware.csp import CSPMiddleware +from src.core.middleware.idempotency import IdempotencyMiddleware +from src.core.middleware.request_id import RequestIDMiddleware +from src.core.middleware.request_size import LimitRequestSizeMiddleware +from src.core.middleware.security_headers import SecurityHeadersMiddleware +from src.core.middleware.structured_logging import StructuredLoggingMiddleware + +settings = get_settings() def register_middleware(app: FastAPI): + app.add_middleware( + LimitRequestSizeMiddleware, + max_upload_size=settings.MAX_REQUEST_SIZE_MB, + ) + app.add_middleware(SecurityHeadersMiddleware) + app.add_middleware(CSPMiddleware) app.add_middleware( CORSMiddleware, - allow_origins=["http://localhost:3000"], # Your frontend URL + allow_origins=settings.cors_allow_origins, allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=settings.cors_allow_methods, + allow_headers=settings.cors_allow_headers, ) + app.add_middleware(IdempotencyMiddleware) + app.add_middleware(StructuredLoggingMiddleware) + app.add_middleware(AuditLoggingMiddleware) app.add_middleware(AuthenticationMiddleware) + app.add_middleware(RequestIDMiddleware) diff --git a/src/core/config/setting.py b/src/core/config/setting.py index 8549297..466ff65 100644 --- a/src/core/config/setting.py +++ b/src/core/config/setting.py @@ -1,18 +1,116 @@ from functools import lru_cache +from typing import ClassVar +from pydantic import Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): - APP_NAME: str = "Todo Modulith API" - APP_ENV: str = "development" - DATABASE_URL: str = "postgresql+asyncpg://user:password@localhost:5432/todo_db" - REDIS_URL: str = "redis://:eYVX7EwVmmxKPCDmwMtyKVge8oLd2t81@127.0.0.1:6379/0" - SECRET_KEY: str = "super-secret-key-change-in-production" - ALGORITHM: str = "HS256" - ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 - REFRESH_TOKEN_EXPIRE_MINUTES: int = 10080 - RATE_LIMIT: str = "100/minute" + DEFAULT_SECRET_KEY: ClassVar[str] = "super-secret-key-change-in-production" + + APP_NAME: str = Field(alias="APP_NAME", default="Todo Modulith API") + APP_ENV: str = Field(alias="APP_ENV", default="development") + DATABASE_URL: str = Field( + alias="DATABASE_URL", + default="postgresql+asyncpg://postgres@localhost:5432/todo_db", + ) + REDIS_URL: str = Field( + alias="REDIS_URL", + default="redis://127.0.0.1:6379/0", + ) + SECRET_KEY: str = Field(alias="SECRET_KEY", default=DEFAULT_SECRET_KEY) + ALGORITHM: str = Field(alias="ALGORITHM", default="HS256") + JWT_ISSUER: str = Field(alias="JWT_ISSUER", default="todo-modulith-api") + JWT_AUDIENCE: str = Field(alias="JWT_AUDIENCE", default="todo-modulith-client") + ACCESS_TOKEN_EXPIRE_MINUTES: int = Field( + alias="ACCESS_TOKEN_EXPIRE_MINUTES", default=30 + ) + REFRESH_TOKEN_EXPIRE_MINUTES: int = Field( + alias="REFRESH_TOKEN_EXPIRE_MINUTES", default=10080 + ) + RATE_LIMIT: str = Field(alias="RATE_LIMIT", default="100/minute") + CORS_ALLOW_ORIGINS: str = Field( + alias="CORS_ALLOW_ORIGINS", + default="http://localhost:3000", + ) + CORS_ALLOW_METHODS: str = Field(alias="CORS_ALLOW_METHODS", default="*") + CORS_ALLOW_HEADERS: str = Field(alias="CORS_ALLOW_HEADERS", default="*") + SECURITY_CONTENT_SECURITY_POLICY: str = Field( + alias="SECURITY_CONTENT_SECURITY_POLICY", + default="default-src 'self'; frame-ancestors 'none'", + ) + IDEMPOTENCY_TTL_SECONDS: int = Field(alias="IDEMPOTENCY_TTL_SECONDS", default=86400) + ACCOUNT_LOCKOUT_MAX_ATTEMPTS: int = Field( + alias="ACCOUNT_LOCKOUT_MAX_ATTEMPTS", default=5 + ) + ACCOUNT_LOCKOUT_WINDOW_MINUTES: int = Field( + alias="ACCOUNT_LOCKOUT_WINDOW_MINUTES", default=15 + ) + ACCOUNT_LOCKOUT_DURATION_MINUTES: int = Field( + alias="ACCOUNT_LOCKOUT_DURATION_MINUTES", default=15 + ) + LOG_FORMAT: str = Field(alias="LOG_FORMAT", default="json") + SEED_ADMIN_EMAIL: str = Field(alias="SEED_ADMIN_EMAIL", default="") + SEED_ADMIN_PASSWORD: str = Field(alias="SEED_ADMIN_PASSWORD", default="") + SEED_ADMIN_USERNAME: str = Field(alias="SEED_ADMIN_USERNAME", default="admin") + SEED_ADMIN_FULLNAME: str = Field( + alias="SEED_ADMIN_FULLNAME", + default="System Administrator", + ) + SEED_DEVELOPMENT_USERS_PASSWORD: str = Field( + alias="SEED_DEVELOPMENT_USERS_PASSWORD", + default="", + ) + MAX_REQUEST_SIZE_MB: int = Field( + alias="MAX_REQUEST_SIZE_MB", default=5 * 1024 * 1024 + ) + DATABASE_POOL_SIZE: int = Field(alias="DATABASE_POOL_SIZE", default=20) + DATABASE_MAX_OVERFLOW: int = Field(alias="DATABASE_MAX_OVERFLOW", default=10) + DATABASE_POOL_TIMEOUT: int = Field(alias="DATABASE_POOL_TIMEOUT", default=30) + DATABASE_POOL_RECYCLE: int = Field(alias="DATABASE_POOL_RECYCLE", default=3600) + + @property + def is_production(self) -> bool: + return self.APP_ENV.lower() == "production" + + @property + def cors_allow_origins(self) -> list[str]: + return self._split_csv(self.CORS_ALLOW_ORIGINS) + + @property + def cors_allow_methods(self) -> list[str]: + return self._split_csv(self.CORS_ALLOW_METHODS) + + @property + def cors_allow_headers(self) -> list[str]: + return self._split_csv(self.CORS_ALLOW_HEADERS) + + @staticmethod + def _split_csv(value: str) -> list[str]: + return [item.strip() for item in value.split(",") if item.strip()] + + @model_validator(mode="after") + def validate_production_security(self): + if not self.is_production: + return self + + if self.SECRET_KEY == self.DEFAULT_SECRET_KEY: + raise ValueError("SECRET_KEY must be changed in production") + if not self.DATABASE_URL.strip(): + raise ValueError("DATABASE_URL must be set in production") + if not self.REDIS_URL.strip(): + raise ValueError("REDIS_URL must be set in production") + if not self.JWT_ISSUER.strip(): + raise ValueError("JWT_ISSUER must be set in production") + if not self.JWT_AUDIENCE.strip(): + raise ValueError("JWT_AUDIENCE must be set in production") + if self.ACCESS_TOKEN_EXPIRE_MINUTES <= 0: + raise ValueError("ACCESS_TOKEN_EXPIRE_MINUTES must be positive") + if self.REFRESH_TOKEN_EXPIRE_MINUTES <= 0: + raise ValueError("REFRESH_TOKEN_EXPIRE_MINUTES must be positive") + if "*" in self.cors_allow_origins: + raise ValueError("CORS_ALLOW_ORIGINS cannot be wildcard in production") + return self model_config = SettingsConfigDict( env_file=".env", diff --git a/src/core/database/postgres/session.py b/src/core/database/postgres/session.py index 8b2b4f5..374693a 100644 --- a/src/core/database/postgres/session.py +++ b/src/core/database/postgres/session.py @@ -9,7 +9,16 @@ settings = get_settings() -engine = create_async_engine(settings.DATABASE_URL, echo=False, future=True) +engine = create_async_engine( + settings.DATABASE_URL, + echo=False, + future=True, + pool_size=settings.DATABASE_POOL_SIZE, + max_overflow=settings.DATABASE_MAX_OVERFLOW, + pool_pre_ping=True, + pool_timeout=settings.DATABASE_POOL_TIMEOUT, + pool_recycle=settings.DATABASE_POOL_RECYCLE, +) AsyncSessionLocal = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, autoflush=False ) diff --git a/src/core/dependency/auth.py b/src/core/dependency/auth.py index 391351b..bba3255 100644 --- a/src/core/dependency/auth.py +++ b/src/core/dependency/auth.py @@ -12,6 +12,7 @@ oauth2_scheme = OAuth2PasswordBearer( tokenUrl="/api/v1/auth/login", refreshUrl="/api/v1/auth/refresh", + auto_error=False, ) diff --git a/src/core/dependency/rate_limit.py b/src/core/dependency/rate_limit.py index 6fc5123..43b523a 100644 --- a/src/core/dependency/rate_limit.py +++ b/src/core/dependency/rate_limit.py @@ -1,4 +1,6 @@ -from fastapi import Request +from types import SimpleNamespace + +from fastapi import Request, Response from fastapi_limiter import FastAPILimiter from fastapi_limiter.depends import RateLimiter @@ -20,6 +22,31 @@ settings = get_settings() +def _rate_limiter_request(request: Request) -> Request: + if "app" not in request.scope: + return request + + routes = [] + for route in request.app.routes: + if hasattr(route, "path") and hasattr(route, "methods"): + routes.append(route) + continue + + effective_route_contexts = getattr(route, "effective_route_contexts", None) + if effective_route_contexts is None: + continue + + routes.extend( + nested_route + for nested_route in effective_route_contexts() + if hasattr(nested_route, "path") and hasattr(nested_route, "methods") + ) + + scope = dict(request.scope) + scope["app"] = SimpleNamespace(routes=routes) + return Request(scope, receive=request.receive) + + async def custom_identifier(request: Request) -> str: """Smart identifier: User ID > Proxy IP > Direct IP""" user_id = getattr(request.state, "user_id", None) @@ -46,14 +73,14 @@ async def close_rate_limiter(): await redis_client.aclose() -async def apply_global_rate_limit(request: Request): +async def apply_global_rate_limit(request: Request, response: Response): if request.url.path in EXEMPT_PATHS: return - limit_str = settings.GLOBAL_RATE_LIMIT + limit_str = settings.RATE_LIMIT times_str, period = limit_str.split("/") times = int(times_str) seconds = 60 if "minute" in period else 1 limiter = RateLimiter(times=times, seconds=seconds) - await limiter(request) + await limiter(_rate_limiter_request(request), response) diff --git a/src/core/middleware/audit_logging.py b/src/core/middleware/audit_logging.py new file mode 100644 index 0000000..b18b5d4 --- /dev/null +++ b/src/core/middleware/audit_logging.py @@ -0,0 +1,160 @@ +import logging +import traceback as traceback_module +from collections.abc import Callable + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware + +from src.core.database.postgres.session import AsyncSessionLocal +from src.core.security.audit import ( + AuditEvent, + AuditService, + ErrorTrace, + ErrorTraceService, +) +from src.core.security.infrastructure.repositories.audit_log_repository import ( + SQLAlchemyAuditRepository, +) +from src.core.security.infrastructure.repositories.error_trace_repository import ( + SQLAlchemyErrorTraceRepository, +) + +logger = logging.getLogger(__name__) + +EXCLUDED_AUDIT_PATHS = frozenset( + { + "/health", + "/live", + "/ready", + "/docs", + "/docs/", + "/redoc", + "/redoc/", + "/openapi.json", + } +) + + +class AuditLoggingMiddleware(BaseHTTPMiddleware): + def __init__( + self, + app, + audit_service_factory: Callable[[], AuditService] | None = None, + error_trace_service_factory: Callable[[], ErrorTraceService] | None = None, + ): + super().__init__(app) + self._audit_service_factory = audit_service_factory + self._error_trace_service_factory = error_trace_service_factory + + async def dispatch(self, request: Request, call_next): + if self._is_excluded(request): + return await call_next(request) + + try: + response = await call_next(request) + except Exception as exc: + await self._record_error_trace(request, exc) + raise + + await self._record_audit_event(request, response.status_code) + return response + + def _is_excluded(self, request: Request) -> bool: + return request.url.path.rstrip("/") in { + path.rstrip("/") for path in EXCLUDED_AUDIT_PATHS + } + + async def _record_audit_event(self, request: Request, status_code: int) -> None: + event = AuditEvent( + action=f"{request.method} {request.url.path}", + actor_id=getattr(request.state, "user_id", None), + resource_type=self._resource_type(request), + resource_id=self._resource_id(request), + request_id=getattr(request.state, "request_id", None), + metadata={ + "method": request.method, + "path": request.url.path, + "status_code": status_code, + "client_ip": self._client_ip(request), + "user_agent": request.headers.get("User-Agent"), + }, + ) + await self._safe_record_audit(event) + + async def _record_error_trace(self, request: Request, exc: Exception) -> None: + trace = ErrorTrace( + error_type=type(exc).__name__, + message=str(exc), + traceback=traceback_module.format_exc(), + method=request.method, + path=request.url.path, + actor_id=getattr(request.state, "user_id", None), + request_id=getattr(request.state, "request_id", None), + metadata={ + "client_ip": self._client_ip(request), + "user_agent": request.headers.get("User-Agent"), + }, + ) + await self._safe_record_error_trace(trace) + + async def _safe_record_audit(self, event: AuditEvent) -> None: + try: + factory = self._audit_service_factory or self._default_audit_service_factory + service = factory() + await service.record(event) + except Exception: + logger.exception("failed to record audit event") + + async def _safe_record_error_trace(self, trace: ErrorTrace) -> None: + try: + factory = ( + self._error_trace_service_factory + or self._default_error_trace_service_factory + ) + service = factory() + await service.record(trace) + except Exception: + logger.exception("failed to record error trace") + + def _default_audit_service_factory(self) -> AuditService: + return _SessionBackedAuditService() + + def _default_error_trace_service_factory(self) -> ErrorTraceService: + return _SessionBackedErrorTraceService() + + @staticmethod + def _resource_type(request: Request) -> str | None: + parts = [part for part in request.url.path.split("/") if part] + if len(parts) >= 3 and parts[0] == "api" and parts[1].startswith("v"): + return parts[2] + return parts[0] if parts else None + + @staticmethod + def _resource_id(request: Request) -> str | None: + for key in ("id", "todo_id", "role_id", "permission_id", "user_id"): + if key in request.path_params: + return str(request.path_params[key]) + return None + + @staticmethod + def _client_ip(request: Request) -> str | None: + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + return request.client.host if request.client else None + + +class _SessionBackedAuditService(AuditService): + async def record(self, event: AuditEvent) -> None: + async with AsyncSessionLocal() as session: + repository = SQLAlchemyAuditRepository(session) + await repository.save(event) + await session.commit() + + +class _SessionBackedErrorTraceService(ErrorTraceService): + async def record(self, trace: ErrorTrace) -> None: + async with AsyncSessionLocal() as session: + repository = SQLAlchemyErrorTraceRepository(session) + await repository.save(trace) + await session.commit() diff --git a/src/core/middleware/auth.py b/src/core/middleware/auth.py index c547df5..8fe3de5 100644 --- a/src/core/middleware/auth.py +++ b/src/core/middleware/auth.py @@ -5,10 +5,14 @@ from starlette.responses import JSONResponse, Response from src.core.security.jwt import JWTService +from src.core.security.token_revocation import TokenRevocationService +from src.shared.exceptions.credential_exception import InvalidCredentialsError PUBLIC_PATHS = frozenset( { "/health", + "/live", + "/ready", "/docs", "/docs/", "/redoc", @@ -43,16 +47,28 @@ async def dispatch( try: payload = JWTService.decode_token(token) + JWTService.require_token_type(payload, JWTService.ACCESS_TOKEN_TYPE) + if await TokenRevocationService.is_access_token_revoked(token): + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Token has been revoked"}, + ) + user_id = payload.get("sub") if not user_id: raise ValueError("Token missing 'sub' claim") request.state.user_id = user_id request.state.token_payload = payload - except JWTError as e: + except JWTError: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Invalid or expired token"}, + ) + except InvalidCredentialsError as e: return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, - content={"detail": f"Invalid or expired token: {str(e)}"}, + content={"detail": str(e)}, ) except Exception: return JSONResponse( diff --git a/src/core/middleware/csp.py b/src/core/middleware/csp.py new file mode 100644 index 0000000..d52cd3f --- /dev/null +++ b/src/core/middleware/csp.py @@ -0,0 +1,26 @@ +from starlette.middleware.base import BaseHTTPMiddleware + + +class CSPMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + response = await call_next(request) + + if request.url.path.startswith("/docs") or request.url.path.startswith( + "/redoc" + ): + response.headers["Content-Security-Policy"] = ( + "default-src 'self'; " + "script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; " + "style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; " + "img-src 'self' data: https:; " + "font-src 'self' https:;" + ) + else: + response.headers["Content-Security-Policy"] = ( + "default-src 'self'; " + "object-src 'none'; " + "base-uri 'self'; " + "frame-ancestors 'none';" + ) + + return response diff --git a/src/core/middleware/idempotency.py b/src/core/middleware/idempotency.py new file mode 100644 index 0000000..5fb2a72 --- /dev/null +++ b/src/core/middleware/idempotency.py @@ -0,0 +1,99 @@ +import hashlib +import json + +from fastapi import Request, status +from fastapi.responses import JSONResponse, Response +from starlette.middleware.base import BaseHTTPMiddleware + +from src.core.config.setting import get_settings +from src.core.database.redis.client import get_redis_client + +settings = get_settings() + + +class IdempotencyMiddleware(BaseHTTPMiddleware): + def __init__(self, app, redis=None): + super().__init__(app) + self._redis = redis + + async def dispatch(self, request: Request, call_next): + if request.method != "POST": + return await call_next(request) + + idempotency_key = request.headers.get("Idempotency-Key") + if not idempotency_key: + return await call_next(request) + + request_body = await request.body() + request_body_hash = hashlib.sha256(request_body).hexdigest() + redis = self._redis or await get_redis_client() + cache_key = self._cache_key(request, idempotency_key) + + cached = await redis.get(cache_key) + if cached: + cached_response = json.loads(cached) + if cached_response.get("request_body_hash") != request_body_hash: + return JSONResponse( + status_code=status.HTTP_409_CONFLICT, + content={ + "detail": ( + "Idempotency-Key was already used with a different " + "request body" + ) + }, + ) + return Response( + content=cached_response["body"], + status_code=cached_response["status_code"], + media_type=cached_response.get("media_type"), + headers={"X-Idempotent-Replay": "true"}, + ) + + request = self._rebuild_request(request, request_body) + response = await call_next(request) + body = await self._response_body(response) + + if 200 <= response.status_code < 300: + await redis.setex( + cache_key, + settings.IDEMPOTENCY_TTL_SECONDS, + json.dumps( + { + "body": body.decode(), + "status_code": response.status_code, + "media_type": response.media_type, + "request_body_hash": request_body_hash, + } + ), + ) + + return Response( + content=body, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + ) + + @staticmethod + def _cache_key(request: Request, idempotency_key: str) -> str: + auth_scope = request.headers.get("Authorization", "anonymous") + raw = f"{auth_scope}:{request.url.path}:{idempotency_key}" + digest = hashlib.sha256(raw.encode()).hexdigest() + return f"idempotency:{digest}" + + @staticmethod + def _rebuild_request(request: Request, body: bytes) -> Request: + async def receive(): + return {"type": "http.request", "body": body, "more_body": False} + + return Request(request.scope, receive) + + @staticmethod + async def _response_body(response) -> bytes: + if hasattr(response, "body"): + return response.body + + body = b"" + async for chunk in response.body_iterator: + body += chunk + return body diff --git a/src/core/middleware/request_id.py b/src/core/middleware/request_id.py new file mode 100644 index 0000000..0fc92b8 --- /dev/null +++ b/src/core/middleware/request_id.py @@ -0,0 +1,16 @@ +from uuid import uuid4 + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware + +REQUEST_ID_HEADER = "X-Request-ID" + + +class RequestIDMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4()) + request.state.request_id = request_id + + response = await call_next(request) + response.headers[REQUEST_ID_HEADER] = request_id + return response diff --git a/src/core/middleware/request_size.py b/src/core/middleware/request_size.py new file mode 100644 index 0000000..c18b6da --- /dev/null +++ b/src/core/middleware/request_size.py @@ -0,0 +1,29 @@ +from fastapi import Request, status +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + + +class LimitRequestSizeMiddleware(BaseHTTPMiddleware): + def __init__(self, app, max_upload_size: int): + super().__init__(app) + self.max_upload_size = max_upload_size + + async def dispatch(self, request: Request, call_next): + if request.method in ("POST", "PUT", "PATCH"): + content_length = request.headers.get("content-length") + + if content_length: + if int(content_length) > self.max_upload_size: + return JSONResponse( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + content={ + "detail": "Request payload exceeds maximum allowed size." + }, + ) + else: + return JSONResponse( + status_code=status.HTTP_411_LENGTH_REQUIRED, + content={"detail": "Content-Length header is required."}, + ) + + return await call_next(request) diff --git a/src/core/middleware/security_headers.py b/src/core/middleware/security_headers.py new file mode 100644 index 0000000..a57dd2e --- /dev/null +++ b/src/core/middleware/security_headers.py @@ -0,0 +1,23 @@ +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware + +from src.core.config.setting import get_settings + +settings = get_settings() + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + response = await call_next(request) + response.headers.setdefault("X-Content-Type-Options", "nosniff") + response.headers.setdefault("X-Frame-Options", "DENY") + response.headers.setdefault("Referrer-Policy", "no-referrer") + response.headers.setdefault( + "Permissions-Policy", + "camera=(), microphone=(), geolocation=()", + ) + response.headers.setdefault( + "Content-Security-Policy", + settings.SECURITY_CONTENT_SECURITY_POLICY, + ) + return response diff --git a/src/core/middleware/structured_logging.py b/src/core/middleware/structured_logging.py new file mode 100644 index 0000000..bfe3dc9 --- /dev/null +++ b/src/core/middleware/structured_logging.py @@ -0,0 +1,103 @@ +import json +import logging +import time +from datetime import datetime, timezone + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware + +logger = logging.getLogger(__name__) + + +class JsonLogFormatter(logging.Formatter): + def format(self, record: logging.LogRecord) -> str: + payload = { + "timestamp": datetime.fromtimestamp( + record.created, + tz=timezone.utc, + ).isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + + for field in ( + "method", + "path", + "status_code", + "latency_ms", + "request_id", + "user_id", + "error_type", + ): + if hasattr(record, field): + payload[field] = getattr(record, field) + + return json.dumps(payload, default=str) + + +def configure_logging(log_format: str = "json") -> None: + formatter: logging.Formatter + if log_format == "json": + formatter = JsonLogFormatter() + else: + formatter = logging.Formatter( + "%(asctime)s %(levelname)s %(name)s %(message)s" + ) + + root_logger = logging.getLogger() + if not root_logger.handlers: + root_logger.addHandler(logging.StreamHandler()) + + for handler in root_logger.handlers: + handler.setFormatter(formatter) + + +class StructuredLoggingMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + started_at = time.perf_counter() + try: + response = await call_next(request) + except Exception as exc: + self._log_request( + request=request, + started_at=started_at, + status_code=500, + level=logging.ERROR, + message="request failed", + error_type=type(exc).__name__, + ) + raise + + self._log_request( + request=request, + started_at=started_at, + status_code=response.status_code, + level=logging.INFO, + message="request completed", + ) + return response + + @staticmethod + def _log_request( + request: Request, + started_at: float, + status_code: int, + level: int, + message: str, + error_type: str | None = None, + ) -> None: + latency_ms = round((time.perf_counter() - started_at) * 1000, 2) + logger.log( + level, + message, + extra={ + "method": request.method, + "path": request.url.path, + "status_code": status_code, + "latency_ms": latency_ms, + "request_id": getattr(request.state, "request_id", None), + "user_id": getattr(request.state, "user_id", None), + "error_type": error_type, + }, + ) diff --git a/src/core/routers/admin.py b/src/core/routers/admin.py index e039316..9cd5606 100644 --- a/src/core/routers/admin.py +++ b/src/core/routers/admin.py @@ -1,6 +1,43 @@ -from fastapi import APIRouter, FastAPI +from fastapi import APIRouter, FastAPI, status +from fastapi.responses import JSONResponse -router = APIRouter(prefix="/api/admin") +from src.core.database.postgres.session import engine +from src.core.database.redis.client import get_redis_client + +router = APIRouter() + + +@router.get("/live", include_in_schema=False) +async def live(): + return {"status": "alive"} + + +@router.get("/ready", include_in_schema=False) +async def ready(): + checks = {"database": "ok", "redis": "ok"} + http_status = status.HTTP_200_OK + + try: + async with engine.connect() as connection: + await connection.exec_driver_sql("SELECT 1") + except Exception: + checks["database"] = "unavailable" + http_status = status.HTTP_503_SERVICE_UNAVAILABLE + + try: + redis = await get_redis_client() + await redis.ping() + except Exception: + checks["redis"] = "unavailable" + http_status = status.HTTP_503_SERVICE_UNAVAILABLE + + return JSONResponse( + status_code=http_status, + content={ + "status": "ready" if http_status == status.HTTP_200_OK else "not_ready", + "checks": checks, + }, + ) def register_router(app: FastAPI): diff --git a/src/core/schemas/response.py b/src/core/schemas/response.py index 3a3fcdc..0f36e54 100644 --- a/src/core/schemas/response.py +++ b/src/core/schemas/response.py @@ -37,6 +37,30 @@ class PaginatedResponse(BaseModel, Generic[T]): data: List[T] +# ========================================== +# Cursor-Based Pagination +# ========================================== +class CursorMeta(BaseModel): + next_cursor: Optional[str] = Field( + default=None, description="Cursor for the next page. None if no more items." + ) + prev_cursor: Optional[str] = Field( + default=None, description="Cursor for the previous page. None if at the start." + ) + has_next: bool = Field( + ..., description="True if there are more items after this page" + ) + has_prev: bool = Field(..., description="True if there are items before this page") + limit: int = Field(..., description="Number of items per page") + + +class CursorPaginatedResponse(BaseModel, Generic[T]): + success: bool = Field(default=True) + message: str = Field(default="Cursor-paginated data retrieved successfully") + meta: CursorMeta + data: List[T] + + # ========================================== # 3. Standard Error Response # ========================================== diff --git a/src/core/security/account_lockout.py b/src/core/security/account_lockout.py new file mode 100644 index 0000000..9603389 --- /dev/null +++ b/src/core/security/account_lockout.py @@ -0,0 +1,63 @@ +from datetime import datetime, timedelta, timezone +from typing import Protocol + +from src.core.config.setting import Settings, get_settings + + +class LoginAttemptRepository(Protocol): + async def count_failures_since(self, email: str, since: datetime) -> int: + pass + + async def record_failure( + self, + email: str, + occurred_at: datetime, + locked_until: datetime | None = None, + ) -> None: + pass + + async def get_locked_until(self, email: str) -> datetime | None: + pass + + async def clear(self, email: str) -> None: + pass + + +class AccountLockoutService: + def __init__( + self, + repository: LoginAttemptRepository | None = None, + settings: Settings | None = None, + ): + self._repository = repository + self._settings = settings or get_settings() + + async def ensure_login_allowed(self, email: str) -> None: + if self._repository is None: + return + + locked_until = await self._repository.get_locked_until(email) + if locked_until and locked_until > datetime.now(timezone.utc): + raise ValueError("Account is temporarily locked") + + async def record_failed_login(self, email: str) -> None: + if self._repository is None: + return + + now = datetime.now(timezone.utc) + since = now - timedelta(minutes=self._settings.ACCOUNT_LOCKOUT_WINDOW_MINUTES) + recent_failures = await self._repository.count_failures_since(email, since) + + locked_until = None + if recent_failures + 1 >= self._settings.ACCOUNT_LOCKOUT_MAX_ATTEMPTS: + locked_until = now + timedelta( + minutes=self._settings.ACCOUNT_LOCKOUT_DURATION_MINUTES + ) + + await self._repository.record_failure(email, now, locked_until) + + async def record_successful_login(self, email: str) -> None: + if self._repository is None: + return + + await self._repository.clear(email) diff --git a/src/core/security/audit.py b/src/core/security/audit.py new file mode 100644 index 0000000..98b2abf --- /dev/null +++ b/src/core/security/audit.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Protocol +from uuid import UUID, uuid4 + + +@dataclass +class AuditEvent: + action: str + actor_id: str | None = None + resource_type: str | None = None + resource_id: str | None = None + request_id: str | None = None + metadata: dict = field(default_factory=dict) + id: UUID = field(default_factory=uuid4) + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + +class AuditRepository(Protocol): + async def save(self, event: AuditEvent) -> AuditEvent: + pass + + +class AuditService: + def __init__(self, repository: AuditRepository | None = None): + self._repository = repository + + async def record(self, event: AuditEvent) -> None: + if self._repository is None: + return + await self._repository.save(event) + + +@dataclass +class ErrorTrace: + error_type: str + message: str + traceback: str + method: str + path: str + actor_id: str | None = None + request_id: str | None = None + metadata: dict = field(default_factory=dict) + id: UUID = field(default_factory=uuid4) + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + +class ErrorTraceRepository(Protocol): + async def save(self, trace: ErrorTrace) -> ErrorTrace: + pass + + +class ErrorTraceService: + def __init__(self, repository: ErrorTraceRepository | None = None): + self._repository = repository + + async def record(self, trace: ErrorTrace) -> None: + if self._repository is None: + return + await self._repository.save(trace) diff --git a/src/core/security/infrastructure/models/audit_log_model.py b/src/core/security/infrastructure/models/audit_log_model.py new file mode 100644 index 0000000..ab6ae01 --- /dev/null +++ b/src/core/security/infrastructure/models/audit_log_model.py @@ -0,0 +1,18 @@ +from datetime import datetime + +from sqlalchemy import DateTime, JSON, String +from sqlalchemy.orm import Mapped, mapped_column + +from src.shared.database.model import Base + + +class AuditLogModel(Base): + __tablename__ = "audit_logs" + + action: Mapped[str] = mapped_column(String(120), index=True) + actor_id: Mapped[str | None] = mapped_column(String(64), index=True, nullable=True) + resource_type: Mapped[str | None] = mapped_column(String(80), nullable=True) + resource_id: Mapped[str | None] = mapped_column(String(64), nullable=True) + request_id: Mapped[str | None] = mapped_column(String(120), index=True, nullable=True) + meta: Mapped[dict] = mapped_column(JSON, default=dict) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), index=True) diff --git a/src/core/security/infrastructure/models/error_trace_model.py b/src/core/security/infrastructure/models/error_trace_model.py new file mode 100644 index 0000000..5d54759 --- /dev/null +++ b/src/core/security/infrastructure/models/error_trace_model.py @@ -0,0 +1,20 @@ +from datetime import datetime + +from sqlalchemy import DateTime, JSON, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from src.shared.database.model import Base + + +class ErrorTraceModel(Base): + __tablename__ = "error_traces" + + error_type: Mapped[str] = mapped_column(String(120), index=True) + message: Mapped[str] = mapped_column(Text) + traceback: Mapped[str] = mapped_column(Text) + method: Mapped[str] = mapped_column(String(12)) + path: Mapped[str] = mapped_column(String(500), index=True) + actor_id: Mapped[str | None] = mapped_column(String(64), index=True, nullable=True) + request_id: Mapped[str | None] = mapped_column(String(120), index=True, nullable=True) + meta: Mapped[dict] = mapped_column(JSON, default=dict) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), index=True) diff --git a/src/core/security/infrastructure/models/login_attempt_model.py b/src/core/security/infrastructure/models/login_attempt_model.py new file mode 100644 index 0000000..9205658 --- /dev/null +++ b/src/core/security/infrastructure/models/login_attempt_model.py @@ -0,0 +1,18 @@ +from datetime import datetime + +from sqlalchemy import DateTime, String +from sqlalchemy.orm import Mapped, mapped_column + +from src.shared.database.model import Base + + +class LoginAttemptModel(Base): + __tablename__ = "login_attempts" + + email: Mapped[str] = mapped_column(String(255), index=True) + occurred_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), index=True) + locked_until: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), + index=True, + nullable=True, + ) diff --git a/src/core/security/infrastructure/repositories/audit_log_repository.py b/src/core/security/infrastructure/repositories/audit_log_repository.py new file mode 100644 index 0000000..245a0ff --- /dev/null +++ b/src/core/security/infrastructure/repositories/audit_log_repository.py @@ -0,0 +1,22 @@ +from src.core.security.audit import AuditEvent +from src.core.security.infrastructure.models.audit_log_model import AuditLogModel + + +class SQLAlchemyAuditRepository: + def __init__(self, db): + self._db = db + + async def save(self, event: AuditEvent) -> AuditEvent: + self._db.add( + AuditLogModel( + id=event.id, + action=event.action, + actor_id=event.actor_id, + resource_type=event.resource_type, + resource_id=event.resource_id, + request_id=event.request_id, + meta=event.metadata, + created_at=event.created_at, + ) + ) + return event diff --git a/src/core/security/infrastructure/repositories/error_trace_repository.py b/src/core/security/infrastructure/repositories/error_trace_repository.py new file mode 100644 index 0000000..addb36b --- /dev/null +++ b/src/core/security/infrastructure/repositories/error_trace_repository.py @@ -0,0 +1,24 @@ +from src.core.security.audit import ErrorTrace +from src.core.security.infrastructure.models.error_trace_model import ErrorTraceModel + + +class SQLAlchemyErrorTraceRepository: + def __init__(self, db): + self._db = db + + async def save(self, trace: ErrorTrace) -> ErrorTrace: + self._db.add( + ErrorTraceModel( + id=trace.id, + error_type=trace.error_type, + message=trace.message, + traceback=trace.traceback, + method=trace.method, + path=trace.path, + actor_id=trace.actor_id, + request_id=trace.request_id, + meta=trace.metadata, + created_at=trace.created_at, + ) + ) + return trace diff --git a/src/core/security/infrastructure/repositories/login_attempt_repository.py b/src/core/security/infrastructure/repositories/login_attempt_repository.py new file mode 100644 index 0000000..90660be --- /dev/null +++ b/src/core/security/infrastructure/repositories/login_attempt_repository.py @@ -0,0 +1,54 @@ +from datetime import datetime + +from sqlalchemy import delete, func, select + +from src.core.security.infrastructure.models.login_attempt_model import ( + LoginAttemptModel, +) + + +class SQLAlchemyLoginAttemptRepository: + def __init__(self, db): + self._db = db + + async def count_failures_since(self, email: str, since: datetime) -> int: + result = await self._db.execute( + select(func.count()) + .select_from(LoginAttemptModel) + .where( + LoginAttemptModel.email == email, + LoginAttemptModel.occurred_at >= since, + ) + ) + return int(result.scalar_one()) + + async def record_failure( + self, + email: str, + occurred_at: datetime, + locked_until: datetime | None = None, + ) -> None: + self._db.add( + LoginAttemptModel( + email=email, + occurred_at=occurred_at, + locked_until=locked_until, + ) + ) + + async def get_locked_until(self, email: str) -> datetime | None: + result = await self._db.execute( + select(LoginAttemptModel.locked_until) + .where( + LoginAttemptModel.email == email, + LoginAttemptModel.locked_until.is_not(None), + ) + .order_by(LoginAttemptModel.locked_until.desc()) + .limit(1) + ) + return result.scalar_one_or_none() + + async def clear(self, email: str) -> None: + await self._db.execute( + delete(LoginAttemptModel).where(LoginAttemptModel.email == email) + ) diff --git a/src/core/security/jwt.py b/src/core/security/jwt.py index 4b6c356..b1349df 100644 --- a/src/core/security/jwt.py +++ b/src/core/security/jwt.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta, timezone +from uuid import uuid4 from jose import JWTError, jwt @@ -9,29 +10,57 @@ class JWTService: + ACCESS_TOKEN_TYPE = "access" + REFRESH_TOKEN_TYPE = "refresh" + @staticmethod - def create_access_token(data: dict) -> str: + def _create_token(data: dict, token_type: str, expires_delta: timedelta) -> str: + now = datetime.now(timezone.utc) to_encode = data.copy() - expire = datetime.now(timezone.utc) + timedelta( - minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + to_encode.update( + { + "iss": settings.JWT_ISSUER, + "sub": str(to_encode["sub"]), + "aud": settings.JWT_AUDIENCE, + "exp": now + expires_delta, + "nbf": now, + "iat": now, + "jti": str(uuid4()), + "token_type": token_type, + } ) - to_encode.update({"exp": expire}) return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) + @staticmethod + def create_access_token(data: dict) -> str: + return JWTService._create_token( + data=data, + token_type=JWTService.ACCESS_TOKEN_TYPE, + expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), + ) + @staticmethod def create_refresh_token(data: dict) -> str: - to_encode = data.copy() - expire = datetime.now(timezone.utc) + timedelta( - minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES, + return JWTService._create_token( + data=data, + token_type=JWTService.REFRESH_TOKEN_TYPE, + expires_delta=timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES), ) - to_encode.update({"exp": expire}) - return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) @staticmethod def decode_token(token: str) -> dict: try: return jwt.decode( - token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM], + issuer=settings.JWT_ISSUER, + audience=settings.JWT_AUDIENCE, ) except JWTError: raise InvalidCredentialsError("Invalid or expired token") + + @staticmethod + def require_token_type(payload: dict, expected_token_type: str) -> None: + if payload.get("token_type") != expected_token_type: + raise InvalidCredentialsError(f"{expected_token_type.title()} token required") diff --git a/src/core/security/token_revocation.py b/src/core/security/token_revocation.py new file mode 100644 index 0000000..7b3a2e8 --- /dev/null +++ b/src/core/security/token_revocation.py @@ -0,0 +1,45 @@ +from datetime import datetime, timezone + +from src.core.database.redis.client import get_redis_client +from src.core.security.jwt import JWTService + + +class TokenRevocationService: + KEY_PREFIX = "revoked_access_token" + + @staticmethod + def _key(jti: str) -> str: + return f"{TokenRevocationService.KEY_PREFIX}:{jti}" + + @staticmethod + async def revoke_access_token(token: str, redis=None) -> None: + try: + payload = JWTService.decode_token(token) + except Exception: + return + + jti = payload.get("jti") + exp = payload.get("exp") + if not jti or not exp: + return + + ttl = int(exp - datetime.now(timezone.utc).timestamp()) + if ttl <= 0: + return + + redis_client = redis or await get_redis_client() + await redis_client.setex(TokenRevocationService._key(jti), ttl, "1") + + @staticmethod + async def is_access_token_revoked(token: str, redis=None) -> bool: + try: + payload = JWTService.decode_token(token) + except Exception: + return False + + jti = payload.get("jti") + if not jti: + return False + + redis_client = redis or await get_redis_client() + return bool(await redis_client.exists(TokenRevocationService._key(jti))) diff --git a/src/core/seed/__init__.py b/src/core/seed/__init__.py new file mode 100644 index 0000000..b5e329a --- /dev/null +++ b/src/core/seed/__init__.py @@ -0,0 +1 @@ +"""Database seeders.""" diff --git a/src/core/seed/authorization.py b/src/core/seed/authorization.py new file mode 100644 index 0000000..d87948b --- /dev/null +++ b/src/core/seed/authorization.py @@ -0,0 +1,179 @@ +from dataclasses import dataclass +from typing import Protocol + +from src.core.authorization.permissions import ( + DEFAULT_RESOURCES, + DEFAULT_ROLES, + DEFAULT_POLICIES, +) +from src.modules.authorization.domain.entities.permission import Permission +from src.modules.authorization.domain.entities.resource import AuthorizationResource +from src.modules.authorization.domain.entities.role import Role + + +class AuthorizationSeedRepository(Protocol): + async def list_resources(self) -> list[AuthorizationResource]: + raise NotImplementedError + + async def create_resource( + self, + resource: AuthorizationResource, + ) -> AuthorizationResource: + raise NotImplementedError + + async def list_roles(self) -> list[Role]: + raise NotImplementedError + + async def create_role(self, role: Role) -> Role: + raise NotImplementedError + + async def list_permissions(self) -> list[Permission]: + raise NotImplementedError + + async def create_permission(self, permission: Permission) -> Permission: + raise NotImplementedError + + async def assign_permission_to_role( + self, + role_id, + permission_id, + ) -> None: + raise NotImplementedError + + async def add_policy(self, ptype: str, *values: str) -> None: + raise NotImplementedError + + +@dataclass(frozen=True) +class AuthorizationSeedResult: + resources_created: int = 0 + roles_created: int = 0 + permissions_created: int = 0 + role_permissions_created: int = 0 + policies_created: int = 0 + + +async def seed_authorization( + repository: AuthorizationSeedRepository, +) -> AuthorizationSeedResult: + existing_resources = { + resource.key: resource for resource in await repository.list_resources() + } + existing_roles = {role.name: role for role in await repository.list_roles()} + existing_permissions = { + permission.key: permission for permission in await repository.list_permissions() + } + existing_role_permissions = await _load_role_permissions(repository) + existing_policies = await _load_policies(repository) + + resources_created = 0 + for resource_definition in DEFAULT_RESOURCES: + if resource_definition.key in existing_resources: + continue + + resource = await repository.create_resource( + AuthorizationResource.create( + key=resource_definition.key, + name=resource_definition.name, + description=resource_definition.description, + ) + ) + existing_resources[resource.key] = resource + resources_created += 1 + + roles_created = 0 + for name, description in _default_roles().items(): + if name in existing_roles: + continue + + role = await repository.create_role(Role.create(name=name, description=description)) + existing_roles[role.name] = role + roles_created += 1 + + permissions_created = 0 + for key in _default_permission_keys(): + if key in existing_permissions: + continue + + resource, action = key.split(":", 1) + permission = await repository.create_permission( + Permission.create( + key=key, + resource=resource, + action=action, + description=f"Allows {action} access on {resource}", + ) + ) + existing_permissions[permission.key] = permission + permissions_created += 1 + + role_permissions_created = 0 + for _, role_name, permission_key in _permission_policies(): + role_permission = (role_name, permission_key) + if role_permission in existing_role_permissions: + continue + + await repository.assign_permission_to_role( + existing_roles[role_name].id, + existing_permissions[permission_key].id, + ) + existing_role_permissions.add(role_permission) + role_permissions_created += 1 + + policies_created = 0 + for policy in DEFAULT_POLICIES: + if policy in existing_policies: + continue + + ptype, *values = policy + await repository.add_policy(ptype, *values) + existing_policies.add(policy) + policies_created += 1 + + return AuthorizationSeedResult( + resources_created=resources_created, + roles_created=roles_created, + permissions_created=permissions_created, + role_permissions_created=role_permissions_created, + policies_created=policies_created, + ) + + +def _default_roles() -> dict[str, str]: + return { + role.name: role.description + for role in DEFAULT_ROLES + } + + +def _default_permission_keys() -> list[str]: + return [ + permission_key + for _, _, permission_key in _permission_policies() + ] + + +def _permission_policies() -> list[tuple[str, str, str]]: + return [ + policy + for policy in DEFAULT_POLICIES + if policy[0] == "p" and policy[2] != "*" + ] + + +async def _load_role_permissions( + repository: AuthorizationSeedRepository, +) -> set[tuple[str, str]]: + if not hasattr(repository, "list_role_permissions"): + return set() + + return set(await repository.list_role_permissions()) + + +async def _load_policies( + repository: AuthorizationSeedRepository, +) -> set[tuple[str, ...]]: + if not hasattr(repository, "list_policies"): + return set() + + return set(await repository.list_policies()) diff --git a/src/core/seed/runner.py b/src/core/seed/runner.py new file mode 100644 index 0000000..dfd4555 --- /dev/null +++ b/src/core/seed/runner.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass + +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.authorization.infrastructure.repositories.casbin_policy_repository import ( + SQLAlchemyCasbinPolicyRepository, +) +from src.core.authorization.infrastructure.services.casbin_authorization_service import ( + CasbinAuthorizationService, +) +from src.core.config.setting import get_settings +from src.core.database.postgres.session import AsyncSessionLocal +from src.core.seed.authorization import AuthorizationSeedResult, seed_authorization +from src.core.seed.user import SeedUserConfig, UserSeedResult, seed_user +from src.modules.user.infrastructure.repositories.user_repository import ( + SQLAlchemyUserRepository, +) + + +@dataclass(frozen=True) +class SeedResult: + authorization: AuthorizationSeedResult + user: UserSeedResult + + +async def run_seeders() -> SeedResult: + async with AsyncSessionLocal() as session: + try: + result = await run_seeders_with_session(session) + await session.commit() + return result + except Exception: + await session.rollback() + raise + + +async def run_seeders_with_session(session: AsyncSession) -> SeedResult: + authorization_repository = SQLAlchemyCasbinPolicyRepository(session) + authorization = await seed_authorization(authorization_repository) + user = await seed_user( + user_repository=SQLAlchemyUserRepository(session), + authorization_service=CasbinAuthorizationService(authorization_repository), + config=SeedUserConfig.from_settings(get_settings()), + ) + return SeedResult(authorization=authorization, user=user) diff --git a/src/core/seed/user.py b/src/core/seed/user.py new file mode 100644 index 0000000..0f10f32 --- /dev/null +++ b/src/core/seed/user.py @@ -0,0 +1,162 @@ +from dataclasses import dataclass +from typing import Protocol + +from src.core.authorization.permissions import ( + ADMIN_ROLE, + DEFAULT_USER_ROLE, + MANAGER_ROLE, + VIEWER_ROLE, +) +from src.core.security.password import PasswordSerrvice +from src.modules.user.domain.entities.user import User + + +class SeedUserRepository(Protocol): + async def get_by_email(self, email: str) -> User | None: + raise NotImplementedError + + async def save(self, user: User) -> User: + raise NotImplementedError + + +class SeedAuthorizationService(Protocol): + async def assign_role(self, subject: str, role: str) -> None: + raise NotImplementedError + + +@dataclass(frozen=True) +class SeedUserConfig: + app_env: str + admin_email: str + admin_password: str + admin_username: str | None = None + admin_fullname: str | None = None + development_users_password: str = "" + + @classmethod + def from_settings(cls, settings) -> "SeedUserConfig": + return cls( + app_env=settings.APP_ENV, + admin_email=settings.SEED_ADMIN_EMAIL, + admin_password=settings.SEED_ADMIN_PASSWORD, + admin_username=settings.SEED_ADMIN_USERNAME, + admin_fullname=settings.SEED_ADMIN_FULLNAME, + development_users_password=settings.SEED_DEVELOPMENT_USERS_PASSWORD, + ) + + @property + def has_admin_credentials(self) -> bool: + return bool(self.admin_email.strip() and self.admin_password.strip()) + + @property + def should_seed_development_users(self) -> bool: + return ( + self.app_env.lower() == "development" + and bool(self.development_users_password.strip()) + ) + + +@dataclass(frozen=True) +class UserSeedResult: + users_created: int = 0 + roles_assigned: int = 0 + + +async def seed_user( + user_repository: SeedUserRepository, + authorization_service: SeedAuthorizationService, + config: SeedUserConfig, +) -> UserSeedResult: + users_created = 0 + roles_assigned = 0 + + if not config.has_admin_credentials: + admin_result = UserSeedResult() + else: + admin_result = await _seed_one_user( + user_repository=user_repository, + authorization_service=authorization_service, + email=config.admin_email, + password=config.admin_password, + username=config.admin_username, + fullname=config.admin_fullname, + role=ADMIN_ROLE, + ) + users_created += admin_result.users_created + roles_assigned += admin_result.roles_assigned + + if config.should_seed_development_users: + for development_user in _development_users(config.development_users_password): + result = await _seed_one_user( + user_repository=user_repository, + authorization_service=authorization_service, + email=development_user.email, + password=development_user.password, + username=development_user.username, + fullname=development_user.fullname, + role=development_user.role, + ) + users_created += result.users_created + roles_assigned += result.roles_assigned + + return UserSeedResult(users_created=users_created, roles_assigned=roles_assigned) + + +@dataclass(frozen=True) +class DevelopmentSeedUser: + email: str + password: str + username: str + fullname: str + role: str + + +def _development_users(password: str) -> tuple[DevelopmentSeedUser, ...]: + return ( + DevelopmentSeedUser( + email="user@example.com", + password=password, + username="user", + fullname="Default User", + role=DEFAULT_USER_ROLE, + ), + DevelopmentSeedUser( + email="manager@example.com", + password=password, + username="manager", + fullname="Todo Manager", + role=MANAGER_ROLE, + ), + DevelopmentSeedUser( + email="viewer@example.com", + password=password, + username="viewer", + fullname="Todo Viewer", + role=VIEWER_ROLE, + ), + ) + + +async def _seed_one_user( + user_repository: SeedUserRepository, + authorization_service: SeedAuthorizationService, + email: str, + password: str, + username: str | None, + fullname: str | None, + role: str, +) -> UserSeedResult: + existing = await user_repository.get_by_email(email) + if existing is not None: + return UserSeedResult() + + user = User.create( + email=email, + password=PasswordSerrvice.hash(password), + username=username, + fullname=fullname, + ) + saved_user = await user_repository.save(user) + await authorization_service.assign_role(str(saved_user.id), role) + + return UserSeedResult(users_created=1, roles_assigned=1) diff --git a/src/core/utils/cursor.py b/src/core/utils/cursor.py new file mode 100644 index 0000000..ab83a52 --- /dev/null +++ b/src/core/utils/cursor.py @@ -0,0 +1,36 @@ +import base64 +import json +from datetime import datetime +from enum import Enum +from uuid import UUID + + +class CursorDirection(Enum): + DIRECTION_NEXT = "next" + DIRECTION_PREV = "prev" + + +def encode_cursor(created_at: datetime, id: UUID, dir: CursorDirection) -> str: + """ + Encode a cursor from timestamp and ID. + Format: base64(json({"t": "ISO_TIMESTAMP", "id": "UUID"})) + """ + cursor_data = {"t": created_at.isoformat(), "id": str(id), "dir": dir.value} + json_str = json.dumps(cursor_data) + return base64.urlsafe_b64encode(json_str.encode()).decode() + + +def decode_cursor(cursor: str) -> tuple[datetime, UUID, CursorDirection]: + """ + Decode a cursor back to timestamp and ID. + Returns: (created_at, id) + """ + try: + json_str = base64.urlsafe_b64decode(cursor.encode()).decode() + cursor_data = json.loads(json_str) + created_at = datetime.fromisoformat(cursor_data["t"]) + dir = CursorDirection(cursor_data["dir"]) + id = UUID(cursor_data["id"]) + return created_at, id, dir + except Exception as e: + raise ValueError(f"Invalid cursor format: {e}") diff --git a/src/main.py b/src/main.py index afe1eeb..9dcbb7f 100644 --- a/src/main.py +++ b/src/main.py @@ -7,29 +7,41 @@ from src.core.bootstrap.middleware import register_middleware from src.core.config.setting import get_settings from src.core.dependency.rate_limit import apply_global_rate_limit +from src.core.middleware.structured_logging import configure_logging settings = get_settings() -app = FastAPI( - title=settings.APP_NAME, - version="1.0.0", - lifespan=lifespan.lifespan, - swagger_ui_parameters={ - "persistAuthorization": True, - "displayRequestDuration": True, - "filter": True, - "deepLinking": True, - "tryItOutEnabled": True, - }, - dependencies=[Depends(apply_global_rate_limit)], -) - -register_exception(app=app) -register_middleware(app=app) -v1_router.register_router(app=app) -admin_router.register_router(app=app) - - -@app.get("/health", tags=["Health Check"]) -def health_check(): - return {"status": "healthy"} + +def create_app(app_settings=settings) -> FastAPI: + configure_logging(app_settings.LOG_FORMAT) + + app = FastAPI( + title=app_settings.APP_NAME, + version="1.0.0", + lifespan=lifespan.lifespan, + docs_url=None if app_settings.is_production else "/docs", + redoc_url=None if app_settings.is_production else "/redoc", + openapi_url=None if app_settings.is_production else "/openapi.json", + swagger_ui_parameters={ + "persistAuthorization": True, + "displayRequestDuration": True, + "filter": True, + "deepLinking": True, + "tryItOutEnabled": True, + }, + dependencies=[Depends(apply_global_rate_limit)], + ) + + register_exception(app=app) + register_middleware(app=app) + v1_router.register_router(app=app) + admin_router.register_router(app=app) + + @app.get("/health", tags=["Health Check"]) + def health_check(): + return {"status": "healthy"} + + return app + + +app = create_app(settings) diff --git a/src/modules/authorization/domain/entities/permission.py b/src/modules/authorization/domain/entities/permission.py index 2105cbf..6c1545c 100644 --- a/src/modules/authorization/domain/entities/permission.py +++ b/src/modules/authorization/domain/entities/permission.py @@ -9,10 +9,16 @@ class Permission: resource: str action: str description: str | None = None + created_at: str | None = None + updated_at: str | None = None @classmethod def create( - cls, key: str, resource: str, action: str, description: str | None + cls, + key: str, + resource: str, + action: str, + description: str | None, ) -> "Permission": return cls( id=uuid4(), diff --git a/src/modules/authorization/domain/entities/resource.py b/src/modules/authorization/domain/entities/resource.py new file mode 100644 index 0000000..6f0eb8b --- /dev/null +++ b/src/modules/authorization/domain/entities/resource.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from uuid import UUID, uuid4 + + +@dataclass +class AuthorizationResource: + id: UUID + key: str + name: str + description: str | None = None + + @classmethod + def create( + cls, + key: str, + name: str, + description: str | None = None, + ) -> "AuthorizationResource": + return cls( + id=uuid4(), + key=key, + name=name, + description=description, + ) diff --git a/src/modules/authorization/domain/entities/role.py b/src/modules/authorization/domain/entities/role.py index 929a78d..4a77949 100644 --- a/src/modules/authorization/domain/entities/role.py +++ b/src/modules/authorization/domain/entities/role.py @@ -6,7 +6,9 @@ class Role: id: UUID name: str - description: str | None + description: str | None = None + created_at: str | None = None + updated_at: str | None = None @classmethod def create(cls, name: str, description: str | None = None) -> "Role": diff --git a/src/modules/authorization/presenter/routers/permission_router.py b/src/modules/authorization/presenter/routers/permission_router.py index a60b9f9..b68f0db 100644 --- a/src/modules/authorization/presenter/routers/permission_router.py +++ b/src/modules/authorization/presenter/routers/permission_router.py @@ -1,8 +1,10 @@ +from datetime import datetime +from typing import Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Query, status -from core.authorization.dependencies import require_permission +from src.core.authorization.dependencies import require_permission from src.core.authorization.infrastructure.services.casbin_authorization_service import ( CasbinAuthorizationService, ) @@ -15,7 +17,12 @@ permission_key, ) from src.core.database.postgres.session import get_unit_of_work -from src.core.schemas.response import PaginatedResponse, SuccessResponse +from src.core.schemas.response import ( + CursorMeta, + CursorPaginatedResponse, + SuccessResponse, +) +from src.core.utils.cursor import CursorDirection, decode_cursor, encode_cursor from src.modules.authorization.domain.entities.permission import Permission from src.modules.authorization.presenter.dependency import ( get_casbin_authorization_service, @@ -60,19 +67,62 @@ async def create_permission( @router.get( "/", - response_model=PaginatedResponse[PermissionResponse], + response_model=CursorPaginatedResponse[PermissionResponse], dependencies=[Depends(require_permission(PERMISSION_RESOURCE, READ_ACTION))], ) async def list_permissions( + cursor: Optional[str] = Query( + None, description="Cursor for pagination (from previous response)" + ), + limit: int = Query(10, ge=1, le=100, description="Number of items per page"), service: CasbinAuthorizationService = Depends(get_casbin_authorization_service), ): - return PaginatedResponse( + cursor_created_at = None + cursor_id = None + direction = CursorDirection.DIRECTION_NEXT + if cursor: + cursor_created_at, cursor_id, direction = decode_cursor(cursor) + + permissions, has_more = await service.list_permissions_cursor( + cursor_created_at=cursor_created_at, + cursor_id=cursor_id, + limit=limit, + direction=direction, + ) + + next_cursor = None + prev_cursor = None + + if has_more and permissions: + last_item = permissions[-1] + next_cursor = encode_cursor( + _created_at_datetime(last_item.created_at), + last_item.id, + CursorDirection.DIRECTION_NEXT, + ) + + if cursor and permissions: + first_item = permissions[0] + prev_cursor = encode_cursor( + _created_at_datetime(first_item.created_at), + first_item.id, + CursorDirection.DIRECTION_PREV, + ) + + return CursorPaginatedResponse( success=True, message="fetch permission success", data=[ _permission_response(permission) - for permission in await service.list_permissions() + for permission in permissions ], + meta=CursorMeta( + next_cursor=next_cursor, + prev_cursor=prev_cursor, + has_next=has_more, + has_prev=cursor is not None, + limit=limit, + ), ) @@ -158,4 +208,14 @@ def _permission_response(permission: Permission | None) -> PermissionResponse: resource=permission.resource, action=permission.action, description=permission.description, + created_at=permission.created_at, + updated_at=permission.updated_at, ) + + +def _created_at_datetime(created_at: datetime | str | None) -> datetime: + if isinstance(created_at, datetime): + return created_at + if isinstance(created_at, str): + return datetime.fromisoformat(created_at) + raise HTTPException(status_code=500, detail="Permission timestamp is missing") diff --git a/src/modules/authorization/presenter/routers/role_router.py b/src/modules/authorization/presenter/routers/role_router.py index c8cff10..af2bdf3 100644 --- a/src/modules/authorization/presenter/routers/role_router.py +++ b/src/modules/authorization/presenter/routers/role_router.py @@ -1,20 +1,27 @@ +from datetime import datetime +from typing import Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Query, status -from core.authorization.dependencies import require_permission -from core.authorization.permissions import ( +from src.core.authorization.dependencies import require_permission +from src.core.authorization.infrastructure.services.casbin_authorization_service import ( + CasbinAuthorizationService, +) +from src.core.authorization.permissions import ( CREATE_ACTION, DELETE_ACTION, READ_ACTION, ROLE_RESOURCE, UPDATE_ACTION, ) -from src.core.authorization.infrastructure.services.casbin_authorization_service import ( - CasbinAuthorizationService, -) from src.core.database.postgres.session import get_unit_of_work -from src.core.schemas.response import PaginatedResponse, SuccessResponse +from src.core.schemas.response import ( + CursorMeta, + CursorPaginatedResponse, + SuccessResponse, +) +from src.core.utils.cursor import CursorDirection, decode_cursor, encode_cursor from src.modules.authorization.domain.entities.role import Role from src.modules.authorization.presenter.dependency import ( get_casbin_authorization_service, @@ -51,16 +58,59 @@ async def create_role( @router.get( "/", - response_model=PaginatedResponse[RoleResponse], + response_model=CursorPaginatedResponse[RoleResponse], dependencies=[Depends(require_permission(ROLE_RESOURCE, READ_ACTION))], ) async def list_roles( + cursor: Optional[str] = Query( + None, description="Cursor for pagination (from previous response)" + ), + limit: int = Query(10, ge=1, le=100, description="Number of items per page"), service: CasbinAuthorizationService = Depends(get_casbin_authorization_service), ): - return SuccessResponse( + cursor_created_at = None + cursor_id = None + direction = CursorDirection.DIRECTION_NEXT + if cursor: + cursor_created_at, cursor_id, direction = decode_cursor(cursor) + + roles, has_more = await service.list_roles_cursor( + cursor_created_at=cursor_created_at, + cursor_id=cursor_id, + limit=limit, + direction=direction, + ) + + next_cursor = None + prev_cursor = None + + if has_more and roles: + last_item = roles[-1] + next_cursor = encode_cursor( + _created_at_datetime(last_item.created_at), + last_item.id, + CursorDirection.DIRECTION_NEXT, + ) + + if cursor and roles: + first_item = roles[0] + prev_cursor = encode_cursor( + _created_at_datetime(first_item.created_at), + first_item.id, + CursorDirection.DIRECTION_PREV, + ) + + return CursorPaginatedResponse( message="fetch role success", success=True, - data=[_role_response(role) for role in await service.list_roles()], + data=[_role_response(role) for role in roles], + meta=CursorMeta( + next_cursor=next_cursor, + prev_cursor=prev_cursor, + has_next=has_more, + has_prev=cursor is not None, + limit=limit, + ), ) @@ -168,4 +218,14 @@ def _role_response(role: Role | None) -> RoleResponse: id=str(role.id), name=role.name, description=role.description, + created_at=role.created_at, + updated_at=role.updated_at, ) + + +def _created_at_datetime(created_at: datetime | str | None) -> datetime: + if isinstance(created_at, datetime): + return created_at + if isinstance(created_at, str): + return datetime.fromisoformat(created_at) + raise HTTPException(status_code=500, detail="Role timestamp is missing") diff --git a/src/modules/authorization/presenter/schema/response.py b/src/modules/authorization/presenter/schema/response.py index 52583e2..814c37e 100644 --- a/src/modules/authorization/presenter/schema/response.py +++ b/src/modules/authorization/presenter/schema/response.py @@ -4,7 +4,9 @@ class RoleResponse(BaseModel): id: str name: str - description: str + description: str | None + created_at: str + updated_at: str class PermissionResponse(BaseModel): @@ -12,4 +14,6 @@ class PermissionResponse(BaseModel): key: str resource: str action: str - description: str + description: str | None + created_at: str + updated_at: str diff --git a/src/modules/todo/application/list_todo/handler.py b/src/modules/todo/application/list_todo/handler.py index c6ea0b2..bec39f1 100644 --- a/src/modules/todo/application/list_todo/handler.py +++ b/src/modules/todo/application/list_todo/handler.py @@ -1,3 +1,7 @@ +from datetime import datetime +from uuid import UUID + +from src.core.utils.cursor import CursorDirection from src.modules.todo.application.list_todo.query import GetTodosQuery from src.modules.todo.application.list_todo.validation import validate_get_todos_query from src.modules.todo.domain.entities.todo import Todo @@ -12,3 +16,27 @@ async def execute(self, command: GetTodosQuery) -> list[Todo]: validate_get_todos_query(command) return await self.todo_repo.get_all_by_user(command.user_id) + + +class GetTodosCursorQuery: + def __init__(self, todo_repo: TodoRepository): + self.todo_repo = todo_repo + + async def execute( + self, + user_id: UUID, + cursor_created_at: datetime | None = None, + cursor_id: UUID | None = None, + limit: int = 10, + direction: CursorDirection = CursorDirection.DIRECTION_NEXT, + ) -> tuple[list[Todo], bool]: + """ + Returns: (items, has_more) + """ + return await self.todo_repo.get_by_user_cursor( + user_id=user_id, + cursor_created_at=cursor_created_at, + cursor_id=cursor_id, + limit=limit, + direction=direction, + ) diff --git a/src/modules/todo/domain/repositories/todo_repository.py b/src/modules/todo/domain/repositories/todo_repository.py index 4af6210..d4e6301 100644 --- a/src/modules/todo/domain/repositories/todo_repository.py +++ b/src/modules/todo/domain/repositories/todo_repository.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod +from datetime import datetime from uuid import UUID +from src.core.utils.cursor import CursorDirection from src.modules.todo.domain.entities.todo import Todo @@ -13,6 +15,21 @@ async def get_by_id(self, todo_id: UUID) -> Todo | None: async def get_all_by_user(self, user_id: UUID) -> list[Todo]: pass + @abstractmethod + async def get_by_user_cursor( + self, + user_id: UUID, + cursor_created_at: datetime | None = None, + cursor_id: UUID | None = None, + limit: int = 10, + direction: CursorDirection = CursorDirection.DIRECTION_NEXT, + ) -> tuple[list[Todo], bool]: + """ + Get todos with cursor pagination. + Returns: (items, has_more) + """ + pass + @abstractmethod async def save(self, todo: Todo) -> Todo: pass diff --git a/src/modules/todo/infrastructure/repositories/todo_repository.py b/src/modules/todo/infrastructure/repositories/todo_repository.py index f62e6d4..592774c 100644 --- a/src/modules/todo/infrastructure/repositories/todo_repository.py +++ b/src/modules/todo/infrastructure/repositories/todo_repository.py @@ -1,8 +1,10 @@ +from datetime import datetime from uuid import UUID -from sqlalchemy import delete, select +from sqlalchemy import and_, delete, or_, select from sqlalchemy.ext.asyncio import AsyncSession +from src.core.utils.cursor import CursorDirection from src.modules.todo.domain.entities.todo import Todo from src.modules.todo.domain.repositories.todo_repository import TodoRepository from src.modules.todo.infrastructure.models.todo_model import TodoModel @@ -25,6 +27,70 @@ async def get_by_id(self, todo_id: UUID) -> Todo | None: user_id=model.user_id, ) + async def get_by_user_cursor( + self, + user_id: UUID, + cursor_created_at: datetime | None = None, + cursor_id: UUID | None = None, + limit: int = 10, + direction: CursorDirection = CursorDirection.DIRECTION_NEXT, + ) -> tuple[list[Todo], bool]: + """ + Cursor pagination logic: + - If direction="next": Get items AFTER the cursor (older items) + - If direction="prev": Get items BEFORE the cursor (newer items) + """ + query = select(TodoModel).where( + TodoModel.user_id == user_id, + TodoModel.deleted_at.is_(None), + ) + + # Apply cursor filter if provided + if cursor_created_at and cursor_id: + if direction == CursorDirection.DIRECTION_NEXT: + # Get items older than cursor (created_at < cursor OR (created_at == cursor AND id < cursor_id)) + query = query.where( + or_( + TodoModel.created_at < cursor_created_at, + and_( + TodoModel.created_at == cursor_created_at, + TodoModel.id < cursor_id, + ), + ) + ) + query = query.order_by(TodoModel.created_at.desc(), TodoModel.id.desc()) + else: + # Get items newer than cursor (created_at > cursor OR (created_at == cursor AND id > cursor_id)) + query = query.where( + or_( + TodoModel.created_at > cursor_created_at, + and_( + TodoModel.created_at == cursor_created_at, + TodoModel.id > cursor_id, + ), + ) + ) + query = query.order_by(TodoModel.created_at.asc(), TodoModel.id.asc()) + else: + # No cursor, just get the first page + query = query.order_by(TodoModel.created_at.desc(), TodoModel.id.desc()) + + # Fetch limit + 1 to check if there are more items + query = query.limit(limit + 1) + + result = await self.db.execute(query) + models = result.scalars().all() + + # Check if there are more items + has_more = len(models) > limit + models = models[:limit] # Trim to actual limit + + # If we fetched "prev", reverse to maintain consistent order (newest first) + if direction == CursorDirection.DIRECTION_PREV: + models = list(reversed(models)) + + return [self._to_entity(m) for m in models], has_more + async def get_all_by_user(self, user_id: UUID) -> list[Todo]: result = await self.db.execute( select(TodoModel).where(TodoModel.user_id == user_id) diff --git a/src/modules/todo/presentation/dependency.py b/src/modules/todo/presentation/dependency.py index f1aa941..a6ca8a4 100644 --- a/src/modules/todo/presentation/dependency.py +++ b/src/modules/todo/presentation/dependency.py @@ -4,7 +4,9 @@ from src.core.database.postgres.session import get_db, get_unit_of_work from src.modules.todo.application.create_todo.handler import CreateTodoHandler from src.modules.todo.application.delete_todo.handler import DeleteTodoHandler -from src.modules.todo.application.list_todo.handler import GetTodosQueryHandler +from src.modules.todo.application.list_todo.handler import ( + GetTodosCursorQuery, +) from src.modules.todo.application.update_todo.handler import UpdateTodoHandler from src.modules.todo.domain.repositories.todo_repository import TodoRepository from src.modules.todo.infrastructure.repositories.todo_repository import ( @@ -38,7 +40,7 @@ def get_delete_todo_handler( return DeleteTodoHandler(repo, unit_of_work) -def get_get_todos_query_handler( +def get_todos_query_handler( repo: TodoRepository = Depends(get_todo_repository), -) -> GetTodosQueryHandler: - return GetTodosQueryHandler(repo) +) -> GetTodosCursorQuery: + return GetTodosCursorQuery(repo) diff --git a/src/modules/todo/presentation/routers/todo_router.py b/src/modules/todo/presentation/routers/todo_router.py index 79cb743..ea18414 100644 --- a/src/modules/todo/presentation/routers/todo_router.py +++ b/src/modules/todo/presentation/routers/todo_router.py @@ -1,9 +1,8 @@ +from typing import Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Query, status -from core.schemas.response import PaginatedResponse, SuccessResponse -from modules.todo.presentation.schemas.response import TodoResponse from src.core.authorization.dependencies import require_permission from src.core.authorization.permissions import ( CREATE_ACTION, @@ -12,10 +11,18 @@ TODO_RESOURCE, UPDATE_ACTION, ) +from src.core.schemas.response import ( + CursorMeta, + CursorPaginatedResponse, + SuccessResponse, +) +from src.core.utils.cursor import CursorDirection, decode_cursor, encode_cursor from src.modules.todo.application.create_todo.command import CreateTodoCommand from src.modules.todo.application.create_todo.handler import CreateTodoHandler from src.modules.todo.application.delete_todo.handler import DeleteTodoHandler -from src.modules.todo.application.list_todo.handler import GetTodosQueryHandler +from src.modules.todo.application.list_todo.handler import ( + GetTodosCursorQuery, +) from src.modules.todo.application.list_todo.query import GetTodosQuery from src.modules.todo.application.update_todo.command import UpdateTodoCommand from src.modules.todo.application.update_todo.handler import UpdateTodoHandler @@ -26,9 +33,10 @@ from src.modules.todo.presentation.dependency import ( get_create_todo_handler, get_delete_todo_handler, - get_get_todos_query_handler, + get_todos_query_handler, get_update_todo_handler, ) +from src.modules.todo.presentation.schemas.response import TodoResponse router = APIRouter(prefix="/todos", tags=["Todos"]) @@ -55,24 +63,69 @@ async def create_todo( ) -@router.get("/", response_model=PaginatedResponse[TodoResponse]) +@router.get("/", response_model=CursorPaginatedResponse[TodoResponse]) async def get_todos( + cursor: Optional[str] = Query( + None, description="Cursor for pagination (from previous response)" + ), + limit: int = Query(10, ge=1, le=100, description="Number of items per page"), current_user: dict = Depends(require_permission(TODO_RESOURCE, READ_ACTION)), - query: GetTodosQueryHandler = Depends(get_get_todos_query_handler), + query: GetTodosCursorQuery = Depends(get_todos_query_handler), ): + cursor_created_at = None + cursor_id = None + direction = None + if cursor: + cursor_created_at, cursor_id, direction = decode_cursor(cursor) + command = GetTodosQuery(user_id=current_user.get("id")) - todos = await query.execute(command=command) - return PaginatedResponse( - message="fetch todo success", - success=True, - data=[ - TodoResponse( - id=str(todo.id), - title=todo.title, - is_completed=todo.is_completed, - ) - for todo in todos - ], + todos, has_more = await query.execute( + user_id=command.user_id, + cursor_created_at=cursor_created_at, + cursor_id=cursor_id, + limit=limit, + direction=direction, + ) + response_todos = [ + TodoResponse( + id=str(t.id), + title=t.title, + description=t.description, + is_completed=t.is_completed, + created_at=t.created_at.isoformat(), + ) + for t in todos + ] + + next_cursor = None + prev_cursor = None + + if has_more and len(todos) > 0: + last_item = todos[-1] + next_cursor = encode_cursor( + last_item.created_at, + last_item.id, + CursorDirection.DIRECTION_NEXT, + ) + + if cursor and len(todos) > 0: + first_item = todos[0] + prev_cursor = encode_cursor( + first_item.created_at, + first_item.id, + CursorDirection.DIRECTION_PREV, + ) + + return CursorPaginatedResponse( + message="Todos retrieved successfully", + meta=CursorMeta( + next_cursor=next_cursor, + prev_cursor=prev_cursor, + has_next=has_more, + has_prev=cursor is not None, + limit=limit, + ), + data=response_todos, ) diff --git a/src/modules/user/application/login_user/handler.py b/src/modules/user/application/login_user/handler.py index 726c5ce..0df548e 100644 --- a/src/modules/user/application/login_user/handler.py +++ b/src/modules/user/application/login_user/handler.py @@ -2,12 +2,13 @@ from datetime import datetime, timedelta, timezone from src.core.config.setting import get_settings +from src.core.security.account_lockout import AccountLockoutService +from src.core.security.audit import AuditEvent, AuditService from src.core.security.jwt import JWTService from src.core.security.password import PasswordSerrvice from src.modules.user.application.login_user.command import LoginUserCommand from src.modules.user.application.login_user.validation import validate_login_user_command from src.modules.user.domain.entities.refresh_token import RefreshToken -from src.modules.user.domain.exceptions.user_exception import UserNotFoundError from src.modules.user.domain.repositories.refresh_token_repository import ( RefreshTokenRepository, ) @@ -24,21 +25,38 @@ def __init__( user_repository: UserRepository, refresh_token_repository: RefreshTokenRepository, unit_of_work: UnitOfWork, + account_lockout_service: AccountLockoutService | None = None, + audit_service: AuditService | None = None, ): self._user_repository = user_repository self._refresh_token_repository = refresh_token_repository self._unit_of_work = unit_of_work + self._account_lockout_service = account_lockout_service + self._audit_service = audit_service async def execute(self, command: LoginUserCommand) -> dict[str, str]: validate_login_user_command(command) + if self._account_lockout_service is not None: + try: + await self._account_lockout_service.ensure_login_allowed( + command.username + ) + except ValueError: + async with self._unit_of_work: + await self._audit_login(command.username, "locked") + await self._unit_of_work.commit() + raise InvalidCredentialsError("Account is temporarily locked") + user = await self._user_repository.get_by_email(command.username) if user is None: - raise UserNotFoundError + await self._record_failed_login(command.username) + raise InvalidCredentialsError("Incorrect email or password") if not user or not PasswordSerrvice.verify_password( command.password, user.password ): + await self._record_failed_login(command.username) raise InvalidCredentialsError("Incorrect email or password") access_token = JWTService.create_access_token(data={"sub": str(user.id)}) @@ -50,13 +68,44 @@ async def execute(self, command: LoginUserCommand) -> dict[str, str]: ) async with self._unit_of_work: + if self._account_lockout_service is not None: + await self._account_lockout_service.record_successful_login( + command.username + ) + new_rt = RefreshToken.create( user_id=user.id, token_hash=token_hash, expires_at=expires_at ) await self._refresh_token_repository.save(new_rt) + await self._audit_login(command.username, "success", actor_id=str(user.id)) await self._unit_of_work.commit() return { "access_token": access_token, "refresh_token": refresh_token_str, } + + async def _record_failed_login(self, email: str) -> None: + async with self._unit_of_work: + if self._account_lockout_service is not None: + await self._account_lockout_service.record_failed_login(email) + await self._audit_login(email, "failure") + await self._unit_of_work.commit() + + async def _audit_login( + self, + email: str, + result: str, + actor_id: str | None = None, + ) -> None: + if self._audit_service is None: + return + await self._audit_service.record( + AuditEvent( + action="user.login", + actor_id=actor_id, + resource_type="user", + resource_id=actor_id, + metadata={"email": email, "result": result}, + ) + ) diff --git a/src/modules/user/application/logout_user/command.py b/src/modules/user/application/logout_user/command.py index e5ccab8..ec624a4 100644 --- a/src/modules/user/application/logout_user/command.py +++ b/src/modules/user/application/logout_user/command.py @@ -3,3 +3,4 @@ class LogoutUserCommand(BaseModel): user_id: str + access_token: str diff --git a/src/modules/user/application/logout_user/handler.py b/src/modules/user/application/logout_user/handler.py index 128e472..9e1287e 100644 --- a/src/modules/user/application/logout_user/handler.py +++ b/src/modules/user/application/logout_user/handler.py @@ -5,19 +5,29 @@ from src.modules.user.domain.repositories.refresh_token_repository import ( RefreshTokenRepository, ) +from src.core.security.token_revocation import TokenRevocationService from src.shared.unit_of_work import UnitOfWork class LogoutUserCommandHandler: def __init__( - self, refresh_token_repository: RefreshTokenRepository, unit_of_work: UnitOfWork + self, + refresh_token_repository: RefreshTokenRepository, + unit_of_work: UnitOfWork, + token_revocation_service: TokenRevocationService, ): self._refresh_token_repository = refresh_token_repository self._unit_of_work = unit_of_work + self._token_revocation_service = token_revocation_service - async def excute(self, command: LogoutUserCommand) -> None: + async def execute(self, command: LogoutUserCommand) -> None: validate_logout_user_command(command) async with self._unit_of_work: - self._refresh_token_repository.revoke_by_user_id(command.user_id) + await self._refresh_token_repository.revoke_by_user_id(command.user_id) + await self._token_revocation_service.revoke_access_token( + command.access_token + ) await self._unit_of_work.commit() + + excute = execute diff --git a/src/modules/user/application/logout_user/validation.py b/src/modules/user/application/logout_user/validation.py index f7b0cfa..44ed3b6 100644 --- a/src/modules/user/application/logout_user/validation.py +++ b/src/modules/user/application/logout_user/validation.py @@ -8,3 +8,6 @@ def validate_logout_user_command(command: LogoutUserCommand) -> None: UUID(command.user_id) except ValueError as exc: raise ValueError("User id must be a valid UUID") from exc + + if not command.access_token.strip(): + raise ValueError("Access token is required") diff --git a/src/modules/user/application/refresh_token/handler.py b/src/modules/user/application/refresh_token/handler.py index 6768fb7..b4f6eec 100644 --- a/src/modules/user/application/refresh_token/handler.py +++ b/src/modules/user/application/refresh_token/handler.py @@ -11,7 +11,10 @@ from src.modules.user.domain.repositories.refresh_token_repository import ( RefreshTokenRepository, ) -from src.shared.exceptions.credential_exception import InvalidRefreshTokenError +from src.shared.exceptions.credential_exception import ( + InvalidCredentialsError, + InvalidRefreshTokenError, +) from src.shared.unit_of_work import UnitOfWork settings = get_settings() @@ -42,6 +45,15 @@ async def execute(self, command: RefreshTokenCommand) -> dict: if stored_token.expires_at < datetime.now(timezone.utc): raise InvalidRefreshTokenError("Refresh token has expired") + try: + payload = JWTService.decode_token(command.token) + JWTService.require_token_type(payload, JWTService.REFRESH_TOKEN_TYPE) + except InvalidCredentialsError: + raise InvalidRefreshTokenError("Invalid refresh token") + + if payload.get("sub") != str(stored_token.user_id): + raise InvalidRefreshTokenError("Invalid refresh token") + async with self._unit_of_work: stored_token.revoke() await self._refresh_token_repo.save(stored_token) diff --git a/src/modules/user/infrastructure/service/auth_service.py b/src/modules/user/infrastructure/service/auth_service.py index 47e8da3..94c49ee 100644 --- a/src/modules/user/infrastructure/service/auth_service.py +++ b/src/modules/user/infrastructure/service/auth_service.py @@ -1,12 +1,7 @@ -from datetime import datetime, timedelta, timezone - -from jose import JWTError, jwt from passlib.context import CryptContext -from src.core.config.setting import get_settings -from src.modules.user.domain.exceptions import InvalidCredentialsError +from src.core.security.jwt import JWTService -settings = get_settings() pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -15,17 +10,7 @@ def verifyy_password(self, password: str, hashed_password: str) -> bool: return pwd_context.verify(password, hashed_password) def create_access_token(self, data: dict) -> str: - to_encode = data.copy() - expire = datetime.now(timezone.utc) + timedelta( - minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES - ) - to_encode.update({"exp": expire}) - return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) + return JWTService.create_access_token(data) def decode_token(self, token: str) -> dict: - try: - return jwt.decode( - token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] - ) - except JWTError: - raise InvalidCredentialsError("Invalid or expired token") + return JWTService.decode_token(token) diff --git a/src/modules/user/presentation/dependency.py b/src/modules/user/presentation/dependency.py index 8ad6144..eb58461 100644 --- a/src/modules/user/presentation/dependency.py +++ b/src/modules/user/presentation/dependency.py @@ -4,6 +4,15 @@ from src.core.authorization.dependencies import get_authorization_service from src.core.authorization.domain.service import AuthorizationService from src.core.database.postgres.session import get_db, get_unit_of_work +from src.core.security.account_lockout import AccountLockoutService +from src.core.security.audit import AuditService +from src.core.security.infrastructure.repositories.audit_log_repository import ( + SQLAlchemyAuditRepository, +) +from src.core.security.infrastructure.repositories.login_attempt_repository import ( + SQLAlchemyLoginAttemptRepository, +) +from src.core.security.token_revocation import TokenRevocationService from src.modules.user.application.detail_user.handler import DetailUserQueryHandler from src.modules.user.application.login_user.handler import LoginUserCommandHandler from src.modules.user.application.logout_user.handler import LogoutUserCommandHandler @@ -36,6 +45,20 @@ def get_refresh_token_repository( return SQLAlchemyRefreshTokenRepository(db) +def get_token_revocation_service() -> TokenRevocationService: + return TokenRevocationService() + + +def get_audit_service(db: AsyncSession = Depends(get_db)) -> AuditService: + return AuditService(SQLAlchemyAuditRepository(db)) + + +def get_account_lockout_service( + db: AsyncSession = Depends(get_db), +) -> AccountLockoutService: + return AccountLockoutService(SQLAlchemyLoginAttemptRepository(db)) + + def get_register_handler( repo: UserRepository = Depends(get_user_repository), unit_of_work: UnitOfWork = Depends(get_unit_of_work), @@ -48,8 +71,18 @@ def get_login_handler( user_repo: UserRepository = Depends(get_user_repository), refresh_token_repo: RefreshTokenRepository = Depends(get_refresh_token_repository), unit_of_work: UnitOfWork = Depends(get_unit_of_work), + account_lockout_service: AccountLockoutService = Depends( + get_account_lockout_service + ), + audit_service: AuditService = Depends(get_audit_service), ) -> LoginUserCommandHandler: - return LoginUserCommandHandler(user_repo, refresh_token_repo, unit_of_work) + return LoginUserCommandHandler( + user_repo, + refresh_token_repo, + unit_of_work, + account_lockout_service, + audit_service, + ) def get_user_detail_handler( @@ -68,5 +101,12 @@ def get_refresh_token_handler( def get_logout_handler( refresh_token_repo: RefreshTokenRepository = Depends(get_refresh_token_repository), unit_of_work: UnitOfWork = Depends(get_unit_of_work), + token_revocation_service: TokenRevocationService = Depends( + get_token_revocation_service + ), ) -> LogoutUserCommandHandler: - return LogoutUserCommandHandler(refresh_token_repo, unit_of_work) + return LogoutUserCommandHandler( + refresh_token_repo, + unit_of_work, + token_revocation_service, + ) diff --git a/src/modules/user/presentation/routers/user_router.py b/src/modules/user/presentation/routers/user_router.py index 1b30d8b..c7fe64a 100644 --- a/src/modules/user/presentation/routers/user_router.py +++ b/src/modules/user/presentation/routers/user_router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.security import OAuth2PasswordRequestForm from src.core.authorization.dependencies import require_permission @@ -130,7 +130,15 @@ async def get_me( @router.post("/logout", status_code=status.HTTP_204_NO_CONTENT) async def logout( + request: Request, current_user: dict = Depends(require_permission(USER_RESOURCE, UPDATE_ACTION)), handler: LogoutUserCommandHandler = Depends(get_logout_handler), ): - await handler.excute(LogoutUserCommand(user_id=str(current_user.get("id")))) + auth_header = request.headers.get("Authorization", "") + access_token = auth_header.removeprefix("Bearer ").strip() + await handler.execute( + LogoutUserCommand( + user_id=str(current_user.get("id")), + access_token=access_token, + ) + ) diff --git a/tests/core/test_admin_authorization_hardening.py b/tests/core/test_admin_authorization_hardening.py new file mode 100644 index 0000000..7d1d2c9 --- /dev/null +++ b/tests/core/test_admin_authorization_hardening.py @@ -0,0 +1,16 @@ +from src.modules.authorization.presenter.routers.permission_router import ( + router as permission_router, +) +from src.modules.authorization.presenter.routers.role_router import router as role_router + + +def test_role_and_permission_management_routes_require_dependencies(): + routes = [ + route + for route in [*role_router.routes, *permission_router.routes] + if hasattr(route, "dependencies") + ] + + assert routes + for route in routes: + assert route.dependencies, f"{route.path} has no authorization dependency" diff --git a/tests/core/test_auth_middleware.py b/tests/core/test_auth_middleware.py new file mode 100644 index 0000000..6764b79 --- /dev/null +++ b/tests/core/test_auth_middleware.py @@ -0,0 +1,83 @@ +import asyncio + +from src.core.middleware.auth import AuthenticationMiddleware +from src.core.security.jwt import JWTService +from src.core.security.token_revocation import TokenRevocationService +from starlette.requests import Request +from starlette.responses import JSONResponse + + +def test_authentication_middleware_rejects_revoked_access_token(monkeypatch): + async def run(): + async def revoked_token(_token): + return True + + monkeypatch.setattr( + TokenRevocationService, + "is_access_token_revoked", + revoked_token, + ) + + token = JWTService.create_access_token({"sub": "user-id"}) + request = Request( + { + "type": "http", + "method": "GET", + "path": "/protected", + "headers": [(b"authorization", f"Bearer {token}".encode())], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + } + ) + + async def call_next(_request): + return JSONResponse({"ok": True}) + + response = await AuthenticationMiddleware(None).dispatch(request, call_next) + + assert response.status_code == 401 + + asyncio.run(run()) + + +def test_authentication_middleware_rejects_refresh_token_on_protected_endpoint( + monkeypatch, +): + async def run(): + async def active_token(_token): + return False + + monkeypatch.setattr( + TokenRevocationService, + "is_access_token_revoked", + active_token, + ) + + token = JWTService.create_refresh_token({"sub": "user-id"}) + call_next_called = False + request = Request( + { + "type": "http", + "method": "GET", + "path": "/protected", + "headers": [(b"authorization", f"Bearer {token}".encode())], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + } + ) + + async def call_next(_request): + nonlocal call_next_called + call_next_called = True + return JSONResponse({"ok": True}) + + response = await AuthenticationMiddleware(None).dispatch(request, call_next) + + assert response.status_code == 401 + assert call_next_called is False + + asyncio.run(run()) diff --git a/tests/core/test_authorization_cursor_pagination.py b/tests/core/test_authorization_cursor_pagination.py new file mode 100644 index 0000000..fbace7a --- /dev/null +++ b/tests/core/test_authorization_cursor_pagination.py @@ -0,0 +1,112 @@ +import asyncio +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + +from src.core.utils.cursor import CursorDirection, decode_cursor, encode_cursor +from src.modules.authorization.domain.entities.permission import Permission +from src.modules.authorization.domain.entities.role import Role +from src.modules.authorization.presenter.routers.permission_router import ( + list_permissions, +) +from src.modules.authorization.presenter.routers.role_router import list_roles + + +class FakeAuthorizationService: + def __init__(self, roles=None, permissions=None): + self.roles = roles or [] + self.permissions = permissions or [] + + async def list_roles_cursor( + self, + cursor_created_at=None, + cursor_id=None, + limit=10, + direction=None, + ): + return self.roles[:limit], len(self.roles) > limit + + async def list_permissions_cursor( + self, + cursor_created_at=None, + cursor_id=None, + limit=10, + direction=None, + ): + return self.permissions[:limit], len(self.permissions) > limit + + +def test_list_roles_uses_cursor_paginated_response(): + async def run(): + now = datetime.now(UTC) + roles = [ + Role( + id=uuid4(), + name=f"role-{index}", + description=None, + created_at=(now - timedelta(minutes=index)).isoformat(), + updated_at=(now - timedelta(minutes=index)).isoformat(), + ) + for index in range(3) + ] + + response = await list_roles( + cursor=None, + limit=2, + service=FakeAuthorizationService(roles), + ) + + assert response.data[0].id == str(roles[0].id) + assert len(response.data) == 2 + assert response.meta.limit == 2 + assert response.meta.has_next is True + assert response.meta.has_prev is False + + cursor_created_at, cursor_id, direction = decode_cursor( + response.meta.next_cursor + ) + assert cursor_created_at == datetime.fromisoformat(roles[1].created_at) + assert cursor_id == roles[1].id + assert direction == CursorDirection.DIRECTION_NEXT + + asyncio.run(run()) + + +def test_list_permissions_exposes_previous_cursor_when_cursor_is_provided(): + async def run(): + now = datetime.now(UTC) + permissions = [ + Permission( + id=uuid4(), + key=f"todo:action-{index}", + resource="todo", + action=f"action-{index}", + description=None, + created_at=(now - timedelta(minutes=index)).isoformat(), + updated_at=(now - timedelta(minutes=index)).isoformat(), + ) + for index in range(2) + ] + cursor = encode_cursor( + datetime.fromisoformat(permissions[0].created_at), + permissions[0].id, + CursorDirection.DIRECTION_NEXT, + ) + + response = await list_permissions( + cursor=cursor, + limit=2, + service=FakeAuthorizationService(permissions=permissions), + ) + + assert len(response.data) == 2 + assert response.meta.has_next is False + assert response.meta.has_prev is True + + cursor_created_at, cursor_id, direction = decode_cursor( + response.meta.prev_cursor + ) + assert cursor_created_at == datetime.fromisoformat(permissions[0].created_at) + assert cursor_id == permissions[0].id + assert direction == CursorDirection.DIRECTION_PREV + + asyncio.run(run()) diff --git a/tests/core/test_authorization_models.py b/tests/core/test_authorization_models.py index f935c0c..7944a90 100644 --- a/tests/core/test_authorization_models.py +++ b/tests/core/test_authorization_models.py @@ -4,6 +4,9 @@ from src.core.authorization.infrastructure.models.permission_model import ( PermissionModel, ) +from src.core.authorization.infrastructure.models.resource_model import ( + AuthorizationResourceModel, +) from src.core.authorization.infrastructure.models.role_model import ( RoleModel, ) @@ -16,6 +19,7 @@ def test_authorization_tables_are_registered_in_metadata(): + assert AuthorizationResourceModel.__tablename__ == "authorization_resources" assert RoleModel.__tablename__ == "roles" assert PermissionModel.__tablename__ == "permissions" assert RolePermissionModel.__tablename__ == "role_permissions" @@ -24,10 +28,15 @@ def test_authorization_tables_are_registered_in_metadata(): def test_authorization_models_have_expected_columns(): - assert {"name"}.issubset(RoleModel.__table__.columns.keys()) - assert {"key", "resource", "action"}.issubset( + assert {"key", "name", "description"}.issubset( + AuthorizationResourceModel.__table__.columns.keys() + ) + assert {"name", "description"}.issubset(RoleModel.__table__.columns.keys()) + assert "descpription" not in RoleModel.__table__.columns.keys() + assert {"key", "resource_id", "resource", "action", "description"}.issubset( PermissionModel.__table__.columns.keys() ) + assert "descpription" not in PermissionModel.__table__.columns.keys() assert {"role_id", "permission_id"}.issubset( RolePermissionModel.__table__.columns.keys() ) diff --git a/tests/core/test_database_migration_polish.py b/tests/core/test_database_migration_polish.py new file mode 100644 index 0000000..c175f3e --- /dev/null +++ b/tests/core/test_database_migration_polish.py @@ -0,0 +1,28 @@ +from pathlib import Path + + +def test_alembic_env_does_not_emit_debug_prints(): + env_content = Path("alembic/env.py").read_text() + + assert "print(" not in env_content + assert "ALEMBIC DEBUG" not in env_content + + +def test_authorization_description_typo_is_migrated_forward(): + migration_content = Path( + "alembic/versions/c7a1b9e5d4f2_rename_authorization_description_columns.py" + ).read_text() + + assert "descpription" in migration_content + assert "description" in migration_content + assert "rename_column" in migration_content + + +def test_authorization_resources_are_migrated_forward(): + migration_content = Path( + "alembic/versions/d9a7c3f2b6e1_add_authorization_resources.py" + ).read_text() + + assert "authorization_resources" in migration_content + assert "resource_id" in migration_content + assert "permissions" in migration_content diff --git a/tests/core/test_global_audit_logging.py b/tests/core/test_global_audit_logging.py new file mode 100644 index 0000000..4ac6162 --- /dev/null +++ b/tests/core/test_global_audit_logging.py @@ -0,0 +1,117 @@ +import asyncio + +import pytest +from starlette.requests import Request +from starlette.responses import JSONResponse + +from src.core.middleware.audit_logging import AuditLoggingMiddleware + + +class FakeAuditService: + def __init__(self): + self.events = [] + + async def record(self, event): + self.events.append(event) + + +class FakeErrorTraceService: + def __init__(self): + self.traces = [] + + async def record(self, trace): + self.traces.append(trace) + + +def build_request(path="/api/v1/todos/123", method="PATCH"): + request = Request( + { + "type": "http", + "method": method, + "path": path, + "headers": [ + (b"user-agent", b"pytest"), + (b"x-forwarded-for", b"10.0.0.1"), + ], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + "path_params": {"todo_id": "123"}, + } + ) + request.state.request_id = "request-1" + request.state.user_id = "user-1" + return request + + +def test_global_audit_logging_records_api_request(): + async def run(): + audit_service = FakeAuditService() + middleware = AuditLoggingMiddleware( + None, + audit_service_factory=lambda: audit_service, + ) + + async def call_next(_request): + return JSONResponse({"ok": True}, status_code=202) + + response = await middleware.dispatch(build_request(), call_next) + + assert response.status_code == 202 + event = audit_service.events[0] + assert event.action == "PATCH /api/v1/todos/123" + assert event.actor_id == "user-1" + assert event.resource_type == "todos" + assert event.resource_id == "123" + assert event.request_id == "request-1" + assert event.metadata["status_code"] == 202 + assert event.metadata["client_ip"] == "10.0.0.1" + assert event.metadata["user_agent"] == "pytest" + + asyncio.run(run()) + + +def test_global_audit_logging_skips_operational_paths(): + async def run(): + audit_service = FakeAuditService() + middleware = AuditLoggingMiddleware( + None, + audit_service_factory=lambda: audit_service, + ) + + async def call_next(_request): + return JSONResponse({"status": "healthy"}) + + await middleware.dispatch(build_request(path="/health", method="GET"), call_next) + + assert audit_service.events == [] + + asyncio.run(run()) + + +def test_global_audit_logging_records_error_trace_and_reraises(): + async def run(): + audit_service = FakeAuditService() + error_trace_service = FakeErrorTraceService() + middleware = AuditLoggingMiddleware( + None, + audit_service_factory=lambda: audit_service, + error_trace_service_factory=lambda: error_trace_service, + ) + + async def call_next(_request): + raise RuntimeError("database unavailable") + + with pytest.raises(RuntimeError, match="database unavailable"): + await middleware.dispatch(build_request(), call_next) + + trace = error_trace_service.traces[0] + assert trace.error_type == "RuntimeError" + assert trace.message == "database unavailable" + assert "RuntimeError: database unavailable" in trace.traceback + assert trace.request_id == "request-1" + assert trace.actor_id == "user-1" + assert trace.path == "/api/v1/todos/123" + + asyncio.run(run()) diff --git a/tests/core/test_jwt_claims.py b/tests/core/test_jwt_claims.py new file mode 100644 index 0000000..308a46d --- /dev/null +++ b/tests/core/test_jwt_claims.py @@ -0,0 +1,48 @@ +from datetime import datetime, timezone + +from src.core.security.jwt import JWTService + + +def test_access_token_uses_standard_claim_payload(): + token = JWTService.create_access_token({"sub": "user-id"}) + + claims = JWTService.decode_token(token) + + assert claims["sub"] == "user-id" + assert claims["iss"] + assert claims["aud"] + assert claims["token_type"] == "access" + assert claims["jti"] + assert claims["iat"] + assert claims["nbf"] + assert claims["exp"] + assert claims["nbf"] <= claims["iat"] + assert claims["iat"] <= claims["exp"] + + +def test_refresh_token_uses_standard_claim_payload(): + token = JWTService.create_refresh_token({"sub": "user-id"}) + + claims = JWTService.decode_token(token) + + assert claims["sub"] == "user-id" + assert claims["iss"] + assert claims["aud"] + assert claims["token_type"] == "refresh" + assert claims["jti"] + assert claims["iat"] + assert claims["nbf"] + assert claims["exp"] + assert claims["nbf"] <= claims["iat"] + assert claims["iat"] <= claims["exp"] + + +def test_token_claim_expiry_is_timezone_aware_epoch_timestamp(): + token = JWTService.create_access_token({"sub": "user-id"}) + + claims = JWTService.decode_token(token) + expires_at = datetime.fromtimestamp(claims["exp"], tz=timezone.utc) + issued_at = datetime.fromtimestamp(claims["iat"], tz=timezone.utc) + + assert expires_at.tzinfo == timezone.utc + assert issued_at.tzinfo == timezone.utc diff --git a/tests/core/test_security_not_implemented.py b/tests/core/test_security_not_implemented.py new file mode 100644 index 0000000..1e71084 --- /dev/null +++ b/tests/core/test_security_not_implemented.py @@ -0,0 +1,378 @@ +import asyncio +import json +import logging +from datetime import datetime, timezone + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from starlette.requests import Request +from starlette.responses import JSONResponse + +from src.core.config.setting import Settings +from src.core.middleware.auth import AuthenticationMiddleware +from src.core.middleware.idempotency import IdempotencyMiddleware +from src.core.middleware.request_id import RequestIDMiddleware +from src.core.middleware.security_headers import SecurityHeadersMiddleware +from src.core.middleware.structured_logging import StructuredLoggingMiddleware +from src.core.routers.admin import register_router as register_admin_router +from src.core.security.account_lockout import AccountLockoutService +from src.core.security.audit import AuditEvent, AuditService + + +def test_security_headers_middleware_adds_expected_headers(): + async def run(): + request = Request( + { + "type": "http", + "method": "GET", + "path": "/health", + "headers": [], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + } + ) + + async def call_next(_request): + return JSONResponse({"ok": True}) + + response = await SecurityHeadersMiddleware(None).dispatch(request, call_next) + + assert response.headers["x-content-type-options"] == "nosniff" + assert response.headers["x-frame-options"] == "DENY" + assert response.headers["referrer-policy"] == "no-referrer" + assert "default-src" in response.headers["content-security-policy"] + + asyncio.run(run()) + + +def test_request_id_middleware_propagates_existing_request_id(): + async def run(): + request = Request( + { + "type": "http", + "method": "GET", + "path": "/health", + "headers": [(b"x-request-id", b"request-123")], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + } + ) + + async def call_next(received_request): + assert received_request.state.request_id == "request-123" + return JSONResponse({"ok": True}) + + response = await RequestIDMiddleware(None).dispatch(request, call_next) + + assert response.headers["x-request-id"] == "request-123" + + asyncio.run(run()) + + +def test_structured_logging_middleware_logs_request_context(caplog): + async def run(): + caplog.set_level(logging.INFO, logger="src.core.middleware.structured_logging") + request = Request( + { + "type": "http", + "method": "GET", + "path": "/api/v1/todos/", + "headers": [], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + } + ) + request.state.request_id = "request-123" + request.state.user_id = "user-123" + + async def call_next(_request): + return JSONResponse({"ok": True}, status_code=202) + + await StructuredLoggingMiddleware(None).dispatch(request, call_next) + + asyncio.run(run()) + + record = caplog.records[0] + assert record.method == "GET" + assert record.path == "/api/v1/todos/" + assert record.status_code == 202 + assert record.request_id == "request-123" + assert record.user_id == "user-123" + + +def test_structured_logging_middleware_logs_exception_context(caplog): + async def run(): + caplog.set_level(logging.ERROR, logger="src.core.middleware.structured_logging") + request = Request( + { + "type": "http", + "method": "POST", + "path": "/api/v1/todos/", + "headers": [], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + } + ) + request.state.request_id = "request-456" + request.state.user_id = "user-456" + + async def call_next(_request): + raise RuntimeError("write failed") + + with pytest.raises(RuntimeError, match="write failed"): + await StructuredLoggingMiddleware(None).dispatch(request, call_next) + + asyncio.run(run()) + + record = caplog.records[0] + assert record.method == "POST" + assert record.path == "/api/v1/todos/" + assert record.status_code == 500 + assert record.request_id == "request-456" + assert record.user_id == "user-456" + assert record.error_type == "RuntimeError" + + +def test_structured_logging_formats_record_as_json(): + from src.core.middleware.structured_logging import JsonLogFormatter + + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg="request completed", + args=(), + exc_info=None, + ) + record.method = "GET" + record.path = "/api/v1/todos/" + record.status_code = 200 + record.latency_ms = 1.5 + record.request_id = "request-1" + record.user_id = "user-1" + record.error_type = None + + payload = json.loads(JsonLogFormatter().format(record)) + + assert payload["message"] == "request completed" + assert payload["method"] == "GET" + assert payload["path"] == "/api/v1/todos/" + assert payload["status_code"] == 200 + assert payload["request_id"] == "request-1" + + +def test_admin_router_exposes_liveness_and_readiness(): + app = FastAPI() + register_admin_router(app) + client = TestClient(app) + + assert client.get("/live").json() == {"status": "alive"} + ready_response = client.get("/ready") + assert ready_response.status_code in (200, 503) + assert "checks" in ready_response.json() + + +def test_authentication_middleware_returns_generic_invalid_token_error(): + async def run(): + request = Request( + { + "type": "http", + "method": "GET", + "path": "/protected", + "headers": [(b"authorization", b"Bearer invalid-token")], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + } + ) + + async def call_next(_request): + return JSONResponse({"ok": True}) + + response = await AuthenticationMiddleware(None).dispatch(request, call_next) + body = json.loads(response.body.decode()) + + assert response.status_code == 401 + assert body["detail"] == "Invalid or expired token" + + asyncio.run(run()) + + +class FakeAuditRepository: + def __init__(self): + self.events = [] + + async def save(self, event): + self.events.append(event) + return event + + +def test_audit_service_persists_sensitive_event(): + async def run(): + repo = FakeAuditRepository() + service = AuditService(repo) + event = AuditEvent( + action="user.login", + actor_id="user-1", + resource_type="user", + resource_id="user-1", + request_id="request-1", + metadata={"result": "success"}, + ) + + await service.record(event) + + assert repo.events == [event] + + asyncio.run(run()) + + +class FakeLoginAttemptRepository: + def __init__(self): + self.failures = {} + self.locked_until = {} + self.cleared = [] + + async def count_failures_since(self, email, since): + return self.failures.get(email, 0) + + async def record_failure(self, email, occurred_at, locked_until=None): + self.failures[email] = self.failures.get(email, 0) + 1 + if locked_until is not None: + self.locked_until[email] = locked_until + + async def get_locked_until(self, email): + return self.locked_until.get(email) + + async def clear(self, email): + self.cleared.append(email) + self.failures.pop(email, None) + self.locked_until.pop(email, None) + + +def test_account_lockout_locks_after_configured_failures(): + async def run(): + repo = FakeLoginAttemptRepository() + settings = Settings( + ACCOUNT_LOCKOUT_MAX_ATTEMPTS=2, + ACCOUNT_LOCKOUT_WINDOW_MINUTES=5, + ACCOUNT_LOCKOUT_DURATION_MINUTES=15, + ) + service = AccountLockoutService(repo, settings) + + await service.record_failed_login("person@example.com") + await service.record_failed_login("person@example.com") + + assert repo.locked_until["person@example.com"] > datetime.now(timezone.utc) + with pytest.raises(ValueError, match="temporarily locked"): + await service.ensure_login_allowed("person@example.com") + + asyncio.run(run()) + + +class FakeRedis: + def __init__(self): + self.values = {} + + async def get(self, key): + return self.values.get(key) + + async def set(self, key, value, ex=None, nx=False): + if nx and key in self.values: + return False + self.values[key] = value + return True + + async def setex(self, key, ttl, value): + self.values[key] = value + + +def build_post_request(body: bytes): + async def receive(): + return {"type": "http.request", "body": body, "more_body": False} + + return Request( + { + "type": "http", + "method": "POST", + "path": "/api/v1/todos/", + "headers": [ + (b"idempotency-key", b"create-todo-1"), + (b"authorization", b"Bearer token"), + (b"content-type", b"application/json"), + ], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + }, + receive, + ) + + +def test_idempotency_middleware_replays_cached_post_response(): + async def run(): + redis = FakeRedis() + calls = 0 + async def call_next(_request): + nonlocal calls + calls += 1 + return JSONResponse({"created": True}, status_code=201) + + middleware = IdempotencyMiddleware(None, redis=redis) + first = await middleware.dispatch( + build_post_request(b'{"title":"first"}'), + call_next, + ) + second = await middleware.dispatch( + build_post_request(b'{"title":"first"}'), + call_next, + ) + + assert first.status_code == 201 + assert second.status_code == 201 + assert json.loads(second.body.decode()) == {"created": True} + assert calls == 1 + + asyncio.run(run()) + + +def test_idempotency_middleware_rejects_same_key_with_different_body(): + async def run(): + redis = FakeRedis() + calls = 0 + + async def call_next(_request): + nonlocal calls + calls += 1 + return JSONResponse({"created": True}, status_code=201) + + middleware = IdempotencyMiddleware(None, redis=redis) + first = await middleware.dispatch( + build_post_request(b'{"title":"first"}'), + call_next, + ) + second = await middleware.dispatch( + build_post_request(b'{"title":"second"}'), + call_next, + ) + + assert first.status_code == 201 + assert second.status_code == 409 + assert json.loads(second.body.decode())["detail"] == ( + "Idempotency-Key was already used with a different request body" + ) + assert calls == 1 + + asyncio.run(run()) diff --git a/tests/core/test_security_todo.py b/tests/core/test_security_todo.py new file mode 100644 index 0000000..3af81ea --- /dev/null +++ b/tests/core/test_security_todo.py @@ -0,0 +1,200 @@ +import pytest +from fastapi import FastAPI, Request, Response + +from src.core.bootstrap.exception import register_exception +from src.core.bootstrap.middleware import register_middleware +from src.core.config.setting import Settings +from src.core.dependency import rate_limit as rate_limit_module +from src.core.dependency.rate_limit import apply_global_rate_limit, custom_identifier +from src.core.exceptions.handler import ( + DOMAIN_EXCEPTION_MAP, + domain_exception_handler, + global_exception_handler, +) +from src.main import create_app + + +def test_rate_limit_uses_configured_rate_limit_setting(monkeypatch): + created_limiters = [] + + class FakeLimiter: + def __init__(self, times: int, seconds: int): + self.times = times + self.seconds = seconds + created_limiters.append(self) + + async def __call__(self, request, response): + return None + + request = Request( + { + "type": "http", + "method": "GET", + "path": "/api/v1/todos/", + "headers": [], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + } + ) + + monkeypatch.setattr(rate_limit_module.settings, "RATE_LIMIT", "42/minute") + monkeypatch.setattr(rate_limit_module, "RateLimiter", FakeLimiter) + + import asyncio + + asyncio.run(apply_global_rate_limit(request, Response())) + + assert created_limiters[0].times == 42 + assert created_limiters[0].seconds == 60 + + +def test_rate_limit_passes_response_to_limiter(monkeypatch): + limiter_calls = [] + + class FakeLimiter: + def __init__(self, times: int, seconds: int): + self.times = times + self.seconds = seconds + + async def __call__(self, request, response): + limiter_calls.append((request, response)) + + request = Request( + { + "type": "http", + "method": "GET", + "path": "/api/v1/todos/", + "headers": [], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + } + ) + response = Response() + + monkeypatch.setattr(rate_limit_module, "RateLimiter", FakeLimiter) + + import asyncio + + asyncio.run(apply_global_rate_limit(request, response)) + + assert limiter_calls == [(request, response)] + + +def test_rate_limit_handles_included_router_entries(): + class FakeRedis: + async def script_load(self, script): + return "sha" + + async def evalsha(self, sha, keys, key, times, milliseconds): + return 0 + + app = create_app(Settings(APP_ENV="development")) + request = Request( + { + "type": "http", + "app": app, + "method": "GET", + "path": "/api/v1/todos/", + "headers": [], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("testclient", 50000), + } + ) + + import asyncio + from fastapi_limiter import FastAPILimiter + + async def run_rate_limit(): + await FastAPILimiter.init(FakeRedis(), identifier=custom_identifier) + await apply_global_rate_limit(request, Response()) + + asyncio.run(run_rate_limit()) + + +def test_cors_middleware_uses_environment_driven_settings(monkeypatch): + monkeypatch.setattr( + "src.core.bootstrap.middleware.settings", + Settings( + CORS_ALLOW_ORIGINS="https://app.example.com,https://admin.example.com", + CORS_ALLOW_METHODS="GET,POST", + CORS_ALLOW_HEADERS="Authorization,Content-Type", + ), + ) + app = FastAPI() + + register_middleware(app) + + cors = next( + middleware + for middleware in app.user_middleware + if middleware.cls.__name__ == "CORSMiddleware" + ) + assert cors.kwargs["allow_origins"] == [ + "https://app.example.com", + "https://admin.example.com", + ] + assert cors.kwargs["allow_methods"] == ["GET", "POST"] + assert cors.kwargs["allow_headers"] == ["Authorization", "Content-Type"] + + +def test_register_exception_uses_specific_domain_handlers_and_single_fallback(): + app = FastAPI() + + register_exception(app) + + for exception_type in DOMAIN_EXCEPTION_MAP: + assert app.exception_handlers[exception_type] is domain_exception_handler + assert app.exception_handlers[Exception] is global_exception_handler + + +def test_create_app_disables_openapi_entrypoints_in_production(): + app = create_app(Settings(APP_ENV="production", SECRET_KEY="production-secret")) + + assert app.docs_url is None + assert app.redoc_url is None + assert app.openapi_url is None + + +def test_production_settings_reject_default_secret_key(): + with pytest.raises(ValueError, match="SECRET_KEY must be changed"): + Settings( + APP_ENV="production", + SECRET_KEY=Settings.DEFAULT_SECRET_KEY, + _env_file=None, + ) + + +def test_production_settings_reject_wildcard_cors(): + with pytest.raises(ValueError, match="CORS_ALLOW_ORIGINS"): + Settings( + APP_ENV="production", + SECRET_KEY="production-secret", + CORS_ALLOW_ORIGINS="*", + _env_file=None, + ) + + +def test_production_settings_reject_invalid_token_ttl(): + with pytest.raises(ValueError, match="ACCESS_TOKEN_EXPIRE_MINUTES"): + Settings( + APP_ENV="production", + SECRET_KEY="production-secret", + ACCESS_TOKEN_EXPIRE_MINUTES=0, + _env_file=None, + ) + + +def test_production_settings_reject_missing_service_urls(): + with pytest.raises(ValueError, match="DATABASE_URL"): + Settings( + APP_ENV="production", + SECRET_KEY="production-secret", + DATABASE_URL="", + _env_file=None, + ) diff --git a/tests/core/test_seed_cli.py b/tests/core/test_seed_cli.py new file mode 100644 index 0000000..e5ad192 --- /dev/null +++ b/tests/core/test_seed_cli.py @@ -0,0 +1,35 @@ +from pathlib import Path + +from src.core.config.setting import Settings + + +def test_makefile_exposes_seed_command(): + makefile = Path("Makefile").read_text() + + assert "seed:" in makefile + assert "[make:seed]" in makefile + assert "scripts/seed.py" in makefile + + +def test_seed_script_uses_seed_runner(): + script = Path("scripts/seed.py").read_text() + + assert "sys.path.insert" in script + assert "run_seeders" in script + assert "seed:user" in script + + +def test_env_example_documents_seed_admin_settings(): + env_example = Path(".env.example").read_text() + + assert "SEED_ADMIN_EMAIL=" in env_example + assert "SEED_ADMIN_PASSWORD=" in env_example + assert "SEED_ADMIN_USERNAME=admin" in env_example + assert "SEED_ADMIN_FULLNAME=System Administrator" in env_example + assert "SEED_DEVELOPMENT_USERS_PASSWORD=" in env_example + + +def test_seed_development_users_password_has_no_default_secret(): + settings = Settings(_env_file=None) + + assert settings.SEED_DEVELOPMENT_USERS_PASSWORD == "" diff --git a/tests/core/test_seed_mechanism.py b/tests/core/test_seed_mechanism.py new file mode 100644 index 0000000..424558f --- /dev/null +++ b/tests/core/test_seed_mechanism.py @@ -0,0 +1,127 @@ +import pytest + +from src.core.authorization.permissions import ( + ADMIN_ROLE, + DEFAULT_RESOURCES, + DEFAULT_ROLES, + DEFAULT_POLICIES, + DEFAULT_USER_ROLE, + MANAGER_ROLE, + VIEWER_ROLE, +) +from src.core.seed.authorization import seed_authorization +from src.modules.authorization.domain.entities.permission import Permission +from src.modules.authorization.domain.entities.resource import AuthorizationResource +from src.modules.authorization.domain.entities.role import Role + + +class FakePolicyRepository: + def __init__(self): + self.resources: dict[str, AuthorizationResource] = {} + self.roles: dict[str, Role] = {} + self.permissions: dict[str, Permission] = {} + self.role_permissions: set[tuple[str, str]] = set() + self.policies: set[tuple[str, str, str]] = set() + + async def list_resources(self) -> list[AuthorizationResource]: + return list(self.resources.values()) + + async def create_resource( + self, + resource: AuthorizationResource, + ) -> AuthorizationResource: + self.resources[resource.key] = resource + return resource + + async def list_roles(self) -> list[Role]: + return list(self.roles.values()) + + async def create_role(self, role: Role) -> Role: + self.roles[role.name] = role + return role + + async def list_permissions(self) -> list[Permission]: + return list(self.permissions.values()) + + async def create_permission(self, permission: Permission) -> Permission: + self.permissions[permission.key] = permission + return permission + + async def list_role_permissions(self) -> list[tuple[str, str]]: + return list(self.role_permissions) + + async def assign_permission_to_role( + self, + role_id, + permission_id, + ) -> None: + role = next(role for role in self.roles.values() if role.id == role_id) + permission = next( + permission + for permission in self.permissions.values() + if permission.id == permission_id + ) + self.role_permissions.add((role.name, permission.key)) + + async def add_policy(self, ptype: str, *values: str) -> None: + self.policies.add((ptype, *values)) + + async def list_policies(self) -> list[tuple[str, ...]]: + return list(self.policies) + + +@pytest.mark.anyio +async def test_seed_authorization_creates_default_roles_permissions_and_policies(): + repository = FakePolicyRepository() + + result = await seed_authorization(repository) + + assert {resource.key for resource in DEFAULT_RESOURCES}.issubset( + repository.resources.keys() + ) + assert { + ADMIN_ROLE, + DEFAULT_USER_ROLE, + MANAGER_ROLE, + VIEWER_ROLE, + }.issubset(repository.roles.keys()) + assert { + policy[2] + for policy in DEFAULT_POLICIES + if policy[0] == "p" and policy[2] != "*" + }.issubset(repository.permissions.keys()) + assert ("p", ADMIN_ROLE, "*") in repository.policies + assert ("p", DEFAULT_USER_ROLE, "todo:create") in repository.policies + assert ("p", MANAGER_ROLE, "todo:update") in repository.policies + assert ("p", VIEWER_ROLE, "todo:read") in repository.policies + assert (DEFAULT_USER_ROLE, "todo:create") in repository.role_permissions + assert (MANAGER_ROLE, "todo:update") in repository.role_permissions + assert (VIEWER_ROLE, "todo:read") in repository.role_permissions + assert result.resources_created == len(DEFAULT_RESOURCES) + assert result.roles_created == len(DEFAULT_ROLES) + assert result.permissions_created == 5 + assert result.role_permissions_created == len( + [policy for policy in DEFAULT_POLICIES if policy[0] == "p" and policy[2] != "*"] + ) + assert result.policies_created == len(DEFAULT_POLICIES) + + +@pytest.mark.anyio +async def test_seed_authorization_is_idempotent(): + repository = FakePolicyRepository() + + await seed_authorization(repository) + result = await seed_authorization(repository) + + assert result.roles_created == 0 + assert result.resources_created == 0 + assert result.permissions_created == 0 + assert result.role_permissions_created == 0 + assert result.policies_created == 0 + assert len(repository.resources) == len(DEFAULT_RESOURCES) + assert len(repository.roles) == len(DEFAULT_ROLES) + assert len(repository.permissions) == 5 + assert len(repository.role_permissions) == len( + [policy for policy in DEFAULT_POLICIES if policy[0] == "p" and policy[2] != "*"] + ) + assert len(repository.policies) == len(DEFAULT_POLICIES) diff --git a/tests/core/test_token_revocation.py b/tests/core/test_token_revocation.py new file mode 100644 index 0000000..7255d23 --- /dev/null +++ b/tests/core/test_token_revocation.py @@ -0,0 +1,71 @@ +import asyncio +from datetime import datetime, timedelta, timezone + +from jose import jwt + +from src.core.config.setting import get_settings +from src.core.security.jwt import JWTService +from src.core.security.token_revocation import TokenRevocationService + +settings = get_settings() + + +class FakeRedis: + def __init__(self): + self.values = {} + self.ttls = {} + + async def setex(self, key, ttl, value): + self.values[key] = value + self.ttls[key] = ttl + + async def exists(self, key): + return int(key in self.values) + + +def test_access_tokens_include_unique_jti_claims(): + first_token = JWTService.create_access_token({"sub": "user-id"}) + second_token = JWTService.create_access_token({"sub": "user-id"}) + + first_payload = JWTService.decode_token(first_token) + second_payload = JWTService.decode_token(second_token) + + assert first_payload["jti"] + assert second_payload["jti"] + assert first_payload["jti"] != second_payload["jti"] + + +def test_token_revocation_stores_access_token_jti_until_expiry(): + async def run(): + redis = FakeRedis() + token = JWTService.create_access_token({"sub": "user-id"}) + payload = JWTService.decode_token(token) + + await TokenRevocationService.revoke_access_token(token, redis) + + key = f"revoked_access_token:{payload['jti']}" + assert redis.values[key] == "1" + assert redis.ttls[key] > 0 + assert await TokenRevocationService.is_access_token_revoked(token, redis) is True + + asyncio.run(run()) + + +def test_token_revocation_ignores_already_expired_tokens(): + async def run(): + redis = FakeRedis() + token = jwt.encode( + { + "sub": "user-id", + "jti": "expired-token-id", + "exp": datetime.now(timezone.utc) - timedelta(minutes=1), + }, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM, + ) + + await TokenRevocationService.revoke_access_token(token, redis) + + assert redis.values == {} + + asyncio.run(run()) diff --git a/tests/core/test_user_seed.py b/tests/core/test_user_seed.py new file mode 100644 index 0000000..d5d8178 --- /dev/null +++ b/tests/core/test_user_seed.py @@ -0,0 +1,167 @@ +from dataclasses import dataclass + +import pytest + +from src.core.authorization.permissions import ( + ADMIN_ROLE, + DEFAULT_USER_ROLE, + MANAGER_ROLE, + VIEWER_ROLE, +) +from src.core.security.password import PasswordSerrvice +from src.core.seed.user import SeedUserConfig, seed_user +from src.modules.user.domain.entities.user import User + + +class FakeUserRepository: + def __init__(self): + self.users: dict[str, User] = {} + + async def get_by_email(self, email: str) -> User | None: + return self.users.get(email) + + async def save(self, user: User) -> User: + self.users[user.email] = user + return user + + +class FakeAuthorizationService: + def __init__(self): + self.assignments: list[tuple[str, str]] = [] + + async def assign_role(self, subject: str, role: str) -> None: + self.assignments.append((subject, role)) + + +@dataclass(frozen=True) +class SeedSettings: + APP_ENV: str = "production" + SEED_ADMIN_EMAIL: str = "admin@example.com" + SEED_ADMIN_PASSWORD: str = "admin-password" + SEED_ADMIN_USERNAME: str = "admin" + SEED_ADMIN_FULLNAME: str = "System Administrator" + SEED_DEVELOPMENT_USERS_PASSWORD: str = "development-password" + + +@pytest.mark.anyio +async def test_seed_user_creates_admin_user_with_hashed_password_and_role(): + user_repository = FakeUserRepository() + authorization_service = FakeAuthorizationService() + + result = await seed_user( + user_repository=user_repository, + authorization_service=authorization_service, + config=SeedUserConfig.from_settings(SeedSettings()), + ) + + user = user_repository.users["admin@example.com"] + assert result.users_created == 1 + assert result.roles_assigned == 1 + assert user.username == "admin" + assert user.fullname == "System Administrator" + assert user.password != "admin-password" + assert PasswordSerrvice.verify("admin-password", user.password) + assert authorization_service.assignments == [(str(user.id), ADMIN_ROLE)] + + +@pytest.mark.anyio +async def test_seed_user_is_idempotent_by_email(): + user_repository = FakeUserRepository() + authorization_service = FakeAuthorizationService() + config = SeedUserConfig.from_settings(SeedSettings()) + + await seed_user( + user_repository=user_repository, + authorization_service=authorization_service, + config=config, + ) + result = await seed_user( + user_repository=user_repository, + authorization_service=authorization_service, + config=config, + ) + + assert result.users_created == 0 + assert result.roles_assigned == 0 + assert len(user_repository.users) == 1 + assert len(authorization_service.assignments) == 1 + + +@pytest.mark.anyio +async def test_seed_user_skips_when_admin_credentials_are_missing(): + user_repository = FakeUserRepository() + authorization_service = FakeAuthorizationService() + + result = await seed_user( + user_repository=user_repository, + authorization_service=authorization_service, + config=SeedUserConfig( + app_env="production", + admin_email="", + admin_password="", + admin_username="admin", + admin_fullname="System Administrator", + development_users_password="development-password", + ), + ) + + assert result.users_created == 0 + assert result.roles_assigned == 0 + assert user_repository.users == {} + assert authorization_service.assignments == [] + + +@pytest.mark.anyio +async def test_seed_user_creates_development_users_with_different_roles(): + user_repository = FakeUserRepository() + authorization_service = FakeAuthorizationService() + + result = await seed_user( + user_repository=user_repository, + authorization_service=authorization_service, + config=SeedUserConfig( + app_env="development", + admin_email="", + admin_password="", + admin_username="admin", + admin_fullname="System Administrator", + development_users_password="development-password", + ), + ) + + assert result.users_created == 3 + assert result.roles_assigned == 3 + assert set(user_repository.users.keys()) == { + "user@example.com", + "manager@example.com", + "viewer@example.com", + } + assert { + role for _, role in authorization_service.assignments + } == {DEFAULT_USER_ROLE, MANAGER_ROLE, VIEWER_ROLE} + for user in user_repository.users.values(): + assert PasswordSerrvice.verify("development-password", user.password) + + +@pytest.mark.anyio +async def test_seed_user_skips_development_users_outside_development(): + user_repository = FakeUserRepository() + authorization_service = FakeAuthorizationService() + + result = await seed_user( + user_repository=user_repository, + authorization_service=authorization_service, + config=SeedUserConfig( + app_env="production", + admin_email="", + admin_password="", + admin_username="admin", + admin_fullname="System Administrator", + development_users_password="development-password", + ), + ) + + assert result.users_created == 0 + assert result.roles_assigned == 0 + assert user_repository.users == {} + assert authorization_service.assignments == [] diff --git a/tests/test_application_validation.py b/tests/test_application_validation.py index e72a1d2..4c8b2aa 100644 --- a/tests/test_application_validation.py +++ b/tests/test_application_validation.py @@ -66,7 +66,16 @@ def test_register_validation_rejects_short_password(): def test_logout_validation_rejects_invalid_user_id(): with pytest.raises(ValueError, match="User id must be a valid UUID"): - validate_logout_user_command(LogoutUserCommand(user_id="not-a-uuid")) + validate_logout_user_command( + LogoutUserCommand(user_id="not-a-uuid", access_token="access-token") + ) + + +def test_logout_validation_rejects_blank_access_token(): + with pytest.raises(ValueError, match="Access token is required"): + validate_logout_user_command( + LogoutUserCommand(user_id=str(uuid4()), access_token=" ") + ) def test_query_validation_accepts_valid_queries(): diff --git a/tests/user/test_login_flow.py b/tests/user/test_login_flow.py index a178c3b..7e4c78a 100644 --- a/tests/user/test_login_flow.py +++ b/tests/user/test_login_flow.py @@ -1,11 +1,14 @@ import asyncio from datetime import datetime, timezone +import pytest + from src.core.config.setting import get_settings from src.core.security.password import PasswordSerrvice from src.modules.user.application.login_user.command import LoginUserCommand from src.modules.user.application.login_user.handler import LoginUserCommandHandler from src.modules.user.domain.entities.user import User +from src.shared.exceptions.credential_exception import InvalidCredentialsError settings = get_settings() @@ -94,3 +97,25 @@ async def run(): assert unit_of_work.rolled_back is False asyncio.run(run()) + + +def test_login_uses_generic_error_for_missing_user(): + async def run(): + user = User.create( + email="person@example.com", + password=PasswordSerrvice.hash("plain-secret"), + ) + + with pytest.raises(InvalidCredentialsError, match="Incorrect email or password"): + await LoginUserCommandHandler( + FakeUserRepository(user), + FakeRefreshTokenRepository(), + FakeUnitOfWork(), + ).execute( + LoginUserCommand( + username="missing@example.com", + password="plain-secret", + ) + ) + + asyncio.run(run()) diff --git a/tests/user/test_logout_flow.py b/tests/user/test_logout_flow.py new file mode 100644 index 0000000..81e5709 --- /dev/null +++ b/tests/user/test_logout_flow.py @@ -0,0 +1,67 @@ +import asyncio +from uuid import uuid4 + +from src.modules.user.application.logout_user.command import LogoutUserCommand +from src.modules.user.application.logout_user.handler import LogoutUserCommandHandler + + +class FakeRefreshTokenRepository: + def __init__(self): + self.revoked_user_ids = [] + + async def revoke_by_user_id(self, user_id): + self.revoked_user_ids.append(user_id) + + +class FakeTokenRevocationService: + def __init__(self): + self.revoked_access_tokens = [] + + async def revoke_access_token(self, token): + self.revoked_access_tokens.append(token) + + +class FakeUnitOfWork: + def __init__(self): + self.committed = False + self.rolled_back = False + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, traceback): + if exc_type is not None or not self.committed: + await self.rollback() + return False + + async def commit(self): + self.committed = True + + async def rollback(self): + self.rolled_back = True + + +def test_logout_revokes_refresh_tokens_and_current_access_token(): + async def run(): + user_id = str(uuid4()) + refresh_token_repo = FakeRefreshTokenRepository() + token_revocation_service = FakeTokenRevocationService() + unit_of_work = FakeUnitOfWork() + + await LogoutUserCommandHandler( + refresh_token_repo, + unit_of_work, + token_revocation_service, + ).execute( + LogoutUserCommand( + user_id=user_id, + access_token="current-access-token", + ) + ) + + assert refresh_token_repo.revoked_user_ids == [user_id] + assert token_revocation_service.revoked_access_tokens == ["current-access-token"] + assert unit_of_work.committed is True + assert unit_of_work.rolled_back is False + + asyncio.run(run()) diff --git a/tests/user/test_refresh_token_flow.py b/tests/user/test_refresh_token_flow.py index 645a573..22298fd 100644 --- a/tests/user/test_refresh_token_flow.py +++ b/tests/user/test_refresh_token_flow.py @@ -6,6 +6,7 @@ import pytest +from src.core.security.jwt import JWTService from src.modules.user.application.refresh_token.command import RefreshTokenCommand from src.modules.user.application.refresh_token import handler as refresh_handler_module from src.modules.user.application.refresh_token.handler import RefreshTokenCommandHandler @@ -90,11 +91,32 @@ async def run(): asyncio.run(run()) +def test_refresh_token_rejects_access_token_even_when_hash_exists(): + async def run(): + user_id = uuid4() + access_token = JWTService.create_access_token({"sub": str(user_id)}) + token_hash = hashlib.sha256(access_token.encode()).hexdigest() + stored_token = RefreshToken.create( + user_id=user_id, + token_hash=token_hash, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ) + handler = RefreshTokenCommandHandler( + FakeRefreshTokenRepository(stored_token), + FakeUnitOfWork(), + ) + + with pytest.raises(InvalidRefreshTokenError, match="Invalid refresh token"): + await handler.execute(RefreshTokenCommand(token=access_token)) + + asyncio.run(run()) + + def test_refresh_token_rotates_token_and_revokes_existing_token(): async def run(): - raw_token = "raw-refresh-token" - token_hash = hashlib.sha256(raw_token.encode()).hexdigest() user_id = uuid4() + raw_token = JWTService.create_refresh_token({"sub": str(user_id)}) + token_hash = hashlib.sha256(raw_token.encode()).hexdigest() stored_token = RefreshToken.create( user_id=user_id, token_hash=token_hash, @@ -126,9 +148,9 @@ async def run(): "REFRESH_TOKEN_EXPIRE_MINUTES", 15, ) - raw_token = "raw-refresh-token" - token_hash = hashlib.sha256(raw_token.encode()).hexdigest() user_id = uuid4() + raw_token = JWTService.create_refresh_token({"sub": str(user_id)}) + token_hash = hashlib.sha256(raw_token.encode()).hexdigest() stored_token = RefreshToken.create( user_id=user_id, token_hash=token_hash, @@ -152,10 +174,11 @@ async def run(): def test_refresh_token_rotation_rolls_back_when_new_token_save_fails(): async def run(): - raw_token = "raw-refresh-token" + user_id = uuid4() + raw_token = JWTService.create_refresh_token({"sub": str(user_id)}) token_hash = hashlib.sha256(raw_token.encode()).hexdigest() stored_token = RefreshToken.create( - user_id=uuid4(), + user_id=user_id, token_hash=token_hash, expires_at=datetime.now(timezone.utc) + timedelta(days=1), )