Files
2026-03-20 12:01:02 +05:30

271 lines
6.7 KiB
Python

# === FILE: app/db.py ===
import psycopg
from psycopg.rows import dict_row
from app.config import Config
def get_connection():
"""Create and return a PostgreSQL database connection."""
conn = psycopg.connect(
host=Config.DB_HOST,
port=Config.DB_PORT,
dbname=Config.DB_NAME,
user=Config.DB_USER,
password=Config.DB_PASSWORD
)
return conn
def init_db():
"""Initialize the database by creating tables from schema.sql."""
conn = None
try:
conn = get_connection()
cursor = conn.cursor()
# Read and execute schema.sql
with open('sql/schema.sql', 'r') as f:
schema = f.read()
cursor.execute(schema)
conn.commit()
cursor.close()
print("Database initialized successfully")
except Exception as e:
print(f"Error initializing database: {e}")
if conn:
conn.rollback()
raise
finally:
if conn:
conn.close()
def get_user_by_google_id(google_id):
"""Fetch user by Google ID."""
conn = None
try:
conn = get_connection()
cursor = conn.cursor(row_factory=dict_row)
cursor.execute(
"SELECT * FROM users WHERE google_id = %s",
(google_id,)
)
user = cursor.fetchone()
cursor.close()
return dict(user) if user else None
except Exception as e:
print(f"Error fetching user: {e}")
return None
finally:
if conn:
conn.close()
def create_user(google_id, email, name, profile_picture):
"""Create a new user and return user dict."""
conn = None
try:
conn = get_connection()
cursor = conn.cursor(row_factory=dict_row)
cursor.execute(
"""
INSERT INTO users (google_id, email, name, profile_picture)
VALUES (%s, %s, %s, %s)
RETURNING *
""",
(google_id, email, name, profile_picture)
)
user = cursor.fetchone()
conn.commit()
cursor.close()
return dict(user) if user else None
except Exception as e:
print(f"Error creating user: {e}")
if conn:
conn.rollback()
raise
finally:
if conn:
conn.close()
def get_or_create_user(google_id, email, name, profile_picture):
"""Get existing user or create new one (upsert logic)."""
user = get_user_by_google_id(google_id)
if user:
return user
else:
return create_user(google_id, email, name, profile_picture)
def get_files_by_user(user_id):
"""Get all files for a user, ordered by upload date (newest first)."""
conn = None
try:
conn = get_connection()
cursor = conn.cursor(row_factory=dict_row)
cursor.execute(
"""
SELECT * FROM files
WHERE user_id = %s
ORDER BY uploaded_at DESC
""",
(user_id,)
)
files = cursor.fetchall()
cursor.close()
return [dict(f) for f in files] if files else []
except Exception as e:
print(f"Error fetching files: {e}")
return []
finally:
if conn:
conn.close()
def add_file(user_id, filename, original_name, s3_key, file_size, file_type):
"""Add a new file record to the database."""
conn = None
try:
conn = get_connection()
cursor = conn.cursor(row_factory=dict_row)
cursor.execute(
"""
INSERT INTO files (user_id, filename, original_name, s3_key, file_size, file_type)
VALUES (%s, %s, %s, %s, %s, %s)
RETURNING *
""",
(user_id, filename, original_name, s3_key, file_size, file_type)
)
file = cursor.fetchone()
conn.commit()
cursor.close()
return dict(file) if file else None
except Exception as e:
print(f"Error adding file: {e}")
if conn:
conn.rollback()
raise
finally:
if conn:
conn.close()
def delete_file(file_id, user_id):
"""Delete a file record (with ownership verification)."""
conn = None
try:
conn = get_connection()
cursor = conn.cursor()
cursor.execute(
"DELETE FROM files WHERE id = %s AND user_id = %s",
(file_id, user_id)
)
deleted_count = cursor.rowcount
conn.commit()
cursor.close()
return deleted_count > 0
except Exception as e:
print(f"Error deleting file: {e}")
if conn:
conn.rollback()
return False
finally:
if conn:
conn.close()
def get_file(file_id, user_id):
"""Get a single file by ID (with ownership verification)."""
conn = None
try:
conn = get_connection()
cursor = conn.cursor(row_factory=dict_row)
cursor.execute(
"SELECT * FROM files WHERE id = %s AND user_id = %s",
(file_id, user_id)
)
file = cursor.fetchone()
cursor.close()
return dict(file) if file else None
except Exception as e:
print(f"Error fetching file: {e}")
return None
finally:
if conn:
conn.close()
def set_share_token(file_id, user_id, token):
"""Set a share token for a file (with ownership verification)."""
conn = None
try:
conn = get_connection()
cursor = conn.cursor()
cursor.execute(
"""
UPDATE files
SET share_token = %s
WHERE id = %s AND user_id = %s
""",
(token, file_id, user_id)
)
updated_count = cursor.rowcount
conn.commit()
cursor.close()
return updated_count > 0
except Exception as e:
print(f"Error setting share token: {e}")
if conn:
conn.rollback()
return False
finally:
if conn:
conn.close()
def get_file_by_share_token(token):
"""Get a file by its share token (no user_id check for public access)."""
conn = None
try:
conn = get_connection()
cursor = conn.cursor(row_factory=dict_row)
cursor.execute(
"SELECT * FROM files WHERE share_token = %s",
(token,)
)
file = cursor.fetchone()
cursor.close()
return dict(file) if file else None
except Exception as e:
print(f"Error fetching file by share token: {e}")
return None
finally:
if conn:
conn.close()