Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions backend/aiosqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""Minimal aiosqlite implementation for local testing."""

from __future__ import annotations

import asyncio
import sqlite3
from functools import partial
from typing import Any, Iterable, Optional, Sequence

# Re-export the exception hierarchy expected from sqlite DB-API modules so
# SQLAlchemy's async driver can interact with this lightweight implementation.
for _exc_name in [
"Error",
"Warning",
"InterfaceError",
"DatabaseError",
"OperationalError",
"ProgrammingError",
"IntegrityError",
"DataError",
"NotSupportedError",
]:
globals()[_exc_name] = getattr(sqlite3, _exc_name)

__all__ = ["connect", "Connection", "Cursor"]

apilevel = "2.0"
threadsafety = 1
paramstyle = "qmark"
sqlite_version = sqlite3.sqlite_version
sqlite_version_info = sqlite3.sqlite_version_info
version = sqlite3.version
version_info = sqlite3.version_info


async def _run_in_thread(fn, *args, **kwargs):
return await asyncio.to_thread(fn, *args, **kwargs)


class _ConnectionHandle:
"""Awaitable handle mimicking aiosqlite's thread wrapper."""

def __init__(self, coro):
self._coro = coro
self.daemon = False

def __await__(self):
return self._coro.__await__()


class Cursor:
"""Simple async cursor wrapper."""

def __init__(self, cursor: sqlite3.Cursor):
self._cursor = cursor

@property
def rowcount(self) -> int:
return self._cursor.rowcount

@property
def lastrowid(self) -> int:
return self._cursor.lastrowid

@property
def description(self):
return self._cursor.description

async def execute(self, sql: str, parameters: Optional[Sequence[Any]] = None):
await _run_in_thread(self._cursor.execute, sql, parameters or ())
return self

async def executemany(self, sql: str, seq_of_parameters: Iterable[Sequence[Any]]):
await _run_in_thread(self._cursor.executemany, sql, seq_of_parameters)
return self

async def fetchone(self):
return await _run_in_thread(self._cursor.fetchone)

async def fetchmany(self, size: Optional[int] = None):
if size is None:
return await _run_in_thread(self._cursor.fetchmany)
return await _run_in_thread(self._cursor.fetchmany, size)

async def fetchall(self):
return await _run_in_thread(self._cursor.fetchall)

async def close(self):
await _run_in_thread(self._cursor.close)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.close()


class Connection:
"""Async wrapper over sqlite3 connection."""

def __init__(self, conn: sqlite3.Connection):
self._conn = conn

@property
def row_factory(self):
return self._conn.row_factory

@row_factory.setter
def row_factory(self, value):
self._conn.row_factory = value

async def cursor(self) -> Cursor:
cursor = await _run_in_thread(self._conn.cursor)
return Cursor(cursor)

async def execute(self, sql: str, parameters: Optional[Sequence[Any]] = None):
cursor = await self.cursor()
return await cursor.execute(sql, parameters or ())

async def executemany(self, sql: str, seq_of_parameters: Iterable[Sequence[Any]]):
cursor = await self.cursor()
return await cursor.executemany(sql, seq_of_parameters)

async def executescript(self, sql_script: str):
return await _run_in_thread(self._conn.executescript, sql_script)

async def commit(self):
await _run_in_thread(self._conn.commit)

async def rollback(self):
await _run_in_thread(self._conn.rollback)

async def create_function(self, *args, **kwargs):
await _run_in_thread(self._conn.create_function, *args, **kwargs)

async def create_aggregate(self, *args, **kwargs):
await _run_in_thread(self._conn.create_aggregate, *args, **kwargs)

async def create_collation(self, *args, **kwargs):
await _run_in_thread(self._conn.create_collation, *args, **kwargs)

async def close(self):
await _run_in_thread(self._conn.close)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.close()


async def _connect(database: str, **kwargs) -> Connection:
row_factory = kwargs.pop("row_factory", sqlite3.Row)
conn = await _run_in_thread(partial(sqlite3.connect, database, **kwargs))
conn.row_factory = row_factory
Comment on lines +152 to +155

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid dispatching sqlite connection across random threads

The stub creates the sqlite3 connection inside _connect and then routes every method call through asyncio.to_thread. Because the connection is created without check_same_thread=False, any call that happens to run on a different worker thread than the one that created the connection will raise sqlite3.ProgrammingError: SQLite objects created in a thread can only be used in that same thread. The real aiosqlite pins a dedicated worker thread to prevent this. This wrapper needs to either create the connection with check_same_thread=False or queue all operations onto a single thread, otherwise async tests will fail unpredictably when the thread pool chooses a different worker.

Useful? React with 👍 / 👎.

return Connection(conn)


def connect(database: str, **kwargs):
"""Return awaitable connection handle (mirrors upstream aiosqlite API)."""

return _ConnectionHandle(_connect(database, **kwargs))
6 changes: 4 additions & 2 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ class Settings(BaseSettings):

# API
API_V1_PREFIX: str = "/api/v1"
SECRET_KEY: str
# Provide sane defaults so local development and tests can run without
# having to provide environment variables.
SECRET_KEY: str = "change-me-super-secret"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60
REFRESH_TOKEN_EXPIRE_DAYS: int = 7

# Database
DATABASE_URL: str
DATABASE_URL: str = "sqlite+aiosqlite:///./app.db"
DATABASE_POOL_SIZE: int = 20
DATABASE_MAX_OVERFLOW: int = 0

Expand Down
19 changes: 15 additions & 4 deletions backend/app/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,23 @@


# Create async engine
engine_kwargs = {
"echo": settings.DEBUG,
"pool_pre_ping": True,
}

if settings.DATABASE_URL.startswith("sqlite+"):
# SQLite's async driver does not support custom pool sizing arguments.
pass
else:
engine_kwargs.update(
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
)

engine = create_async_engine(
settings.DATABASE_URL,
echo=settings.DEBUG,
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
pool_pre_ping=True,
**engine_kwargs,
)

# Create async session factory
Expand Down
3 changes: 2 additions & 1 deletion backend/app/core/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from uuid import uuid4
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import HTTPException, status
Expand Down Expand Up @@ -41,7 +42,7 @@ def create_refresh_token(data: Dict[str, Any]) -> str:
"""Create a JWT refresh token."""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
to_encode.update({"exp": expire, "type": "refresh", "jti": uuid4().hex})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt

Expand Down
14 changes: 8 additions & 6 deletions backend/app/models/article.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@

from datetime import datetime
from typing import Optional
from sqlalchemy import String, Text, DateTime, BigInteger, ForeignKey, Index
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import String, Text, DateTime, BigInteger, ForeignKey, Index, JSON, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func

from app.core.database import Base


IDType = BigInteger().with_variant(Integer(), "sqlite")


class Article(Base):
"""Article model for content management."""

__tablename__ = "article"

id: Mapped[int] = mapped_column(BigInteger, primary_key=True, index=True)
id: Mapped[int] = mapped_column(IDType, primary_key=True, index=True, autoincrement=True)
title: Mapped[str] = mapped_column(String(160), nullable=False)
slug: Mapped[str] = mapped_column(String(180), unique=True, nullable=False, index=True)
excerpt: Mapped[str | None] = mapped_column(Text, nullable=True)
Expand All @@ -25,10 +27,10 @@ class Article(Base):

# Foreign Keys
author_id: Mapped[int] = mapped_column(
BigInteger, ForeignKey("user.id", ondelete="RESTRICT"), nullable=False
IDType, ForeignKey("user.id", ondelete="RESTRICT"), nullable=False
)
category_id: Mapped[int | None] = mapped_column(
BigInteger, ForeignKey("category.id", ondelete="SET NULL"), nullable=True
IDType, ForeignKey("category.id", ondelete="SET NULL"), nullable=True
)

# Timestamps
Expand All @@ -41,7 +43,7 @@ class Article(Base):
)

# Metadata
meta_json: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict, server_default="{}")
meta_json: Mapped[dict] = mapped_column(JSON, nullable=False, default=dict, server_default="{}")

# Relationships
author = relationship("User", back_populates="articles")
Expand Down
9 changes: 6 additions & 3 deletions backend/app/models/article_tag.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
"""Article-Tag junction model."""

from sqlalchemy import BigInteger, ForeignKey, UniqueConstraint
from sqlalchemy import BigInteger, ForeignKey, UniqueConstraint, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship

from app.core.database import Base


IDType = BigInteger().with_variant(Integer(), "sqlite")


class ArticleTag(Base):
"""Junction table for many-to-many relationship between articles and tags."""

__tablename__ = "article_tag"

article_id: Mapped[int] = mapped_column(
BigInteger, ForeignKey("article.id", ondelete="CASCADE"), primary_key=True
IDType, ForeignKey("article.id", ondelete="CASCADE"), primary_key=True
)
tag_id: Mapped[int] = mapped_column(
BigInteger, ForeignKey("tag.id", ondelete="CASCADE"), primary_key=True
IDType, ForeignKey("tag.id", ondelete="CASCADE"), primary_key=True
)

# Relationships
Expand Down
7 changes: 5 additions & 2 deletions backend/app/models/category.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
"""Category model."""

from datetime import datetime
from sqlalchemy import String, Text, DateTime, BigInteger
from sqlalchemy import String, Text, DateTime, BigInteger, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func

from app.core.database import Base


IDType = BigInteger().with_variant(Integer(), "sqlite")


class Category(Base):
"""Category model for organizing articles."""

__tablename__ = "category"

id: Mapped[int] = mapped_column(BigInteger, primary_key=True, index=True)
id: Mapped[int] = mapped_column(IDType, primary_key=True, index=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
slug: Mapped[str] = mapped_column(String(80), unique=True, nullable=False, index=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
Expand Down
5 changes: 4 additions & 1 deletion backend/app/models/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from app.core.database import Base


IDType = BigInteger().with_variant(Integer(), "sqlite")


class Media(Base):
"""Media model for uploaded files."""

__tablename__ = "media"

id: Mapped[int] = mapped_column(BigInteger, primary_key=True, index=True)
id: Mapped[int] = mapped_column(IDType, primary_key=True, index=True, autoincrement=True)
filename: Mapped[str] = mapped_column(String(255), nullable=False)
original_filename: Mapped[str] = mapped_column(String(255), nullable=False)
file_path: Mapped[str] = mapped_column(String(512), nullable=False)
Expand Down
9 changes: 6 additions & 3 deletions backend/app/models/refresh_token.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
"""Refresh Token model."""

from datetime import datetime
from sqlalchemy import String, DateTime, BigInteger, ForeignKey, Boolean
from sqlalchemy import String, DateTime, BigInteger, ForeignKey, Boolean, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func

from app.core.database import Base


IDType = BigInteger().with_variant(Integer(), "sqlite")


class RefreshToken(Base):
"""Refresh token model for token rotation."""

__tablename__ = "refresh_token"

id: Mapped[int] = mapped_column(BigInteger, primary_key=True, index=True)
id: Mapped[int] = mapped_column(IDType, primary_key=True, index=True, autoincrement=True)
token: Mapped[str] = mapped_column(String(512), unique=True, nullable=False, index=True)
user_id: Mapped[int] = mapped_column(
BigInteger, ForeignKey("user.id", ondelete="CASCADE"), nullable=False
IDType, ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
is_revoked: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
Expand Down
7 changes: 5 additions & 2 deletions backend/app/models/tag.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
"""Tag model."""

from datetime import datetime
from sqlalchemy import String, DateTime, BigInteger
from sqlalchemy import String, DateTime, BigInteger, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func

from app.core.database import Base


IDType = BigInteger().with_variant(Integer(), "sqlite")


class Tag(Base):
"""Tag model for flexible article labeling."""

__tablename__ = "tag"

id: Mapped[int] = mapped_column(BigInteger, primary_key=True, index=True)
id: Mapped[int] = mapped_column(IDType, primary_key=True, index=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(48), unique=True, nullable=False)
slug: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(
Expand Down
Loading