329 lines
8.0 KiB
Python
329 lines
8.0 KiB
Python
"""CRUD operations for database models."""
|
|
|
|
from datetime import datetime
|
|
from typing import Optional, List
|
|
from sqlalchemy import select, update, delete
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from bot.db.models import User, Reminder
|
|
from bot.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
# ==================== User Operations ====================
|
|
|
|
|
|
async def get_or_create_user(
|
|
session: AsyncSession,
|
|
tg_user_id: int,
|
|
username: Optional[str] = None,
|
|
first_name: Optional[str] = None,
|
|
last_name: Optional[str] = None,
|
|
) -> User:
|
|
"""
|
|
Get existing user or create new one.
|
|
|
|
Args:
|
|
session: Database session
|
|
tg_user_id: Telegram user ID
|
|
username: Telegram username
|
|
first_name: User's first name
|
|
last_name: User's last name
|
|
|
|
Returns:
|
|
User instance
|
|
"""
|
|
# Try to get existing user
|
|
result = await session.execute(
|
|
select(User).where(User.tg_user_id == tg_user_id)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user:
|
|
# Update user info if changed
|
|
if user.username != username or user.first_name != first_name or user.last_name != last_name:
|
|
user.username = username
|
|
user.first_name = first_name
|
|
user.last_name = last_name
|
|
user.updated_at = datetime.utcnow()
|
|
await session.commit()
|
|
logger.debug(f"Updated user info: {tg_user_id}")
|
|
return user
|
|
|
|
# Create new user
|
|
user = User(
|
|
tg_user_id=tg_user_id,
|
|
username=username,
|
|
first_name=first_name,
|
|
last_name=last_name,
|
|
)
|
|
session.add(user)
|
|
await session.commit()
|
|
await session.refresh(user)
|
|
logger.info(f"Created new user: {tg_user_id}")
|
|
return user
|
|
|
|
|
|
async def get_user_by_tg_id(session: AsyncSession, tg_user_id: int) -> Optional[User]:
|
|
"""
|
|
Get user by Telegram ID.
|
|
|
|
Args:
|
|
session: Database session
|
|
tg_user_id: Telegram user ID
|
|
|
|
Returns:
|
|
User instance or None
|
|
"""
|
|
result = await session.execute(
|
|
select(User).where(User.tg_user_id == tg_user_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
# ==================== Reminder Operations ====================
|
|
|
|
|
|
async def create_reminder(
|
|
session: AsyncSession,
|
|
user_id: int,
|
|
text: str,
|
|
days_interval: int,
|
|
time_of_day: datetime.time,
|
|
next_run_at: datetime,
|
|
) -> Reminder:
|
|
"""
|
|
Create a new reminder.
|
|
|
|
Args:
|
|
session: Database session
|
|
user_id: User's database ID
|
|
text: Reminder text
|
|
days_interval: Days between reminders
|
|
time_of_day: Time of day for reminder
|
|
next_run_at: Next execution datetime
|
|
|
|
Returns:
|
|
Created Reminder instance
|
|
"""
|
|
reminder = Reminder(
|
|
user_id=user_id,
|
|
text=text,
|
|
days_interval=days_interval,
|
|
time_of_day=time_of_day,
|
|
next_run_at=next_run_at,
|
|
)
|
|
session.add(reminder)
|
|
await session.commit()
|
|
await session.refresh(reminder)
|
|
logger.info(f"Created reminder {reminder.id} for user {user_id}")
|
|
return reminder
|
|
|
|
|
|
async def get_reminder_by_id(session: AsyncSession, reminder_id: int) -> Optional[Reminder]:
|
|
"""
|
|
Get reminder by ID.
|
|
|
|
Args:
|
|
session: Database session
|
|
reminder_id: Reminder ID
|
|
|
|
Returns:
|
|
Reminder instance or None
|
|
"""
|
|
result = await session.execute(
|
|
select(Reminder).where(Reminder.id == reminder_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_user_reminders(
|
|
session: AsyncSession,
|
|
user_id: int,
|
|
active_only: bool = False,
|
|
) -> List[Reminder]:
|
|
"""
|
|
Get all reminders for a user.
|
|
|
|
Args:
|
|
session: Database session
|
|
user_id: User's database ID
|
|
active_only: Return only active reminders
|
|
|
|
Returns:
|
|
List of Reminder instances
|
|
"""
|
|
query = select(Reminder).where(Reminder.user_id == user_id)
|
|
|
|
if active_only:
|
|
query = query.where(Reminder.is_active == True)
|
|
|
|
query = query.order_by(Reminder.created_at.desc())
|
|
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def get_due_reminders(session: AsyncSession, current_time: datetime) -> List[Reminder]:
|
|
"""
|
|
Get all active reminders that are due.
|
|
|
|
Args:
|
|
session: Database session
|
|
current_time: Current datetime to check against
|
|
|
|
Returns:
|
|
List of due Reminder instances
|
|
"""
|
|
result = await session.execute(
|
|
select(Reminder)
|
|
.where(Reminder.is_active == True)
|
|
.where(Reminder.next_run_at <= current_time)
|
|
.order_by(Reminder.next_run_at)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def update_reminder(
|
|
session: AsyncSession,
|
|
reminder_id: int,
|
|
**kwargs,
|
|
) -> Optional[Reminder]:
|
|
"""
|
|
Update reminder fields.
|
|
|
|
Args:
|
|
session: Database session
|
|
reminder_id: Reminder ID
|
|
**kwargs: Fields to update
|
|
|
|
Returns:
|
|
Updated Reminder instance or None
|
|
"""
|
|
reminder = await get_reminder_by_id(session, reminder_id)
|
|
if not reminder:
|
|
return None
|
|
|
|
for key, value in kwargs.items():
|
|
if hasattr(reminder, key):
|
|
setattr(reminder, key, value)
|
|
|
|
reminder.updated_at = datetime.utcnow()
|
|
await session.commit()
|
|
await session.refresh(reminder)
|
|
logger.debug(f"Updated reminder {reminder_id}")
|
|
return reminder
|
|
|
|
|
|
async def delete_reminder(session: AsyncSession, reminder_id: int) -> bool:
|
|
"""
|
|
Delete a reminder.
|
|
|
|
Args:
|
|
session: Database session
|
|
reminder_id: Reminder ID
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
result = await session.execute(
|
|
delete(Reminder).where(Reminder.id == reminder_id)
|
|
)
|
|
await session.commit()
|
|
|
|
if result.rowcount > 0:
|
|
logger.info(f"Deleted reminder {reminder_id}")
|
|
return True
|
|
return False
|
|
|
|
|
|
async def mark_reminder_done(
|
|
session: AsyncSession,
|
|
reminder_id: int,
|
|
next_run_at: datetime,
|
|
) -> Optional[Reminder]:
|
|
"""
|
|
Mark reminder as done and schedule next run.
|
|
|
|
Args:
|
|
session: Database session
|
|
reminder_id: Reminder ID
|
|
next_run_at: Next execution datetime
|
|
|
|
Returns:
|
|
Updated Reminder instance or None
|
|
"""
|
|
reminder = await get_reminder_by_id(session, reminder_id)
|
|
if not reminder:
|
|
return None
|
|
|
|
reminder.last_done_at = datetime.utcnow()
|
|
reminder.next_run_at = next_run_at
|
|
reminder.total_done_count += 1
|
|
reminder.updated_at = datetime.utcnow()
|
|
|
|
await session.commit()
|
|
await session.refresh(reminder)
|
|
logger.debug(f"Marked reminder {reminder_id} as done")
|
|
return reminder
|
|
|
|
|
|
async def snooze_reminder(
|
|
session: AsyncSession,
|
|
reminder_id: int,
|
|
next_run_at: datetime,
|
|
) -> Optional[Reminder]:
|
|
"""
|
|
Snooze reminder to a later time.
|
|
|
|
Args:
|
|
session: Database session
|
|
reminder_id: Reminder ID
|
|
next_run_at: Next execution datetime
|
|
|
|
Returns:
|
|
Updated Reminder instance or None
|
|
"""
|
|
reminder = await get_reminder_by_id(session, reminder_id)
|
|
if not reminder:
|
|
return None
|
|
|
|
reminder.next_run_at = next_run_at
|
|
reminder.snooze_count += 1
|
|
reminder.updated_at = datetime.utcnow()
|
|
|
|
await session.commit()
|
|
await session.refresh(reminder)
|
|
logger.debug(f"Snoozed reminder {reminder_id}")
|
|
return reminder
|
|
|
|
|
|
async def toggle_reminder_active(
|
|
session: AsyncSession,
|
|
reminder_id: int,
|
|
is_active: bool,
|
|
) -> Optional[Reminder]:
|
|
"""
|
|
Toggle reminder active status.
|
|
|
|
Args:
|
|
session: Database session
|
|
reminder_id: Reminder ID
|
|
is_active: New active status
|
|
|
|
Returns:
|
|
Updated Reminder instance or None
|
|
"""
|
|
reminder = await get_reminder_by_id(session, reminder_id)
|
|
if not reminder:
|
|
return None
|
|
|
|
reminder.is_active = is_active
|
|
reminder.updated_at = datetime.utcnow()
|
|
|
|
await session.commit()
|
|
await session.refresh(reminder)
|
|
logger.debug(f"Set reminder {reminder_id} active={is_active}")
|
|
return reminder
|