use async_trait::async_trait;
use mas_data_model::User;
use mas_storage::{
    user::{UserFilter, UserRepository},
    Clock,
};
use rand::RngCore;
use sea_query::{Expr, PostgresQueryBuilder, Query};
use sea_query_binder::SqlxBinder;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{
    filter::{Filter, StatementExt},
    iden::Users,
    pagination::QueryBuilderExt,
    tracing::ExecuteExt,
    DatabaseError,
};
mod email;
mod password;
mod recovery;
mod session;
mod terms;
#[cfg(test)]
mod tests;
pub use self::{
    email::PgUserEmailRepository, password::PgUserPasswordRepository,
    recovery::PgUserRecoveryRepository, session::PgBrowserSessionRepository,
    terms::PgUserTermsRepository,
};
pub struct PgUserRepository<'c> {
    conn: &'c mut PgConnection,
}
impl<'c> PgUserRepository<'c> {
    pub fn new(conn: &'c mut PgConnection) -> Self {
        Self { conn }
    }
}
mod priv_ {
    #![allow(missing_docs)]
    use chrono::{DateTime, Utc};
    use sea_query::enum_def;
    use uuid::Uuid;
    #[derive(Debug, Clone, sqlx::FromRow)]
    #[enum_def]
    pub(super) struct UserLookup {
        pub(super) user_id: Uuid,
        pub(super) username: String,
        pub(super) primary_user_email_id: Option<Uuid>,
        pub(super) created_at: DateTime<Utc>,
        pub(super) locked_at: Option<DateTime<Utc>>,
        pub(super) can_request_admin: bool,
    }
}
use priv_::{UserLookup, UserLookupIden};
impl From<UserLookup> for User {
    fn from(value: UserLookup) -> Self {
        let id = value.user_id.into();
        Self {
            id,
            username: value.username,
            sub: id.to_string(),
            primary_user_email_id: value.primary_user_email_id.map(Into::into),
            created_at: value.created_at,
            locked_at: value.locked_at,
            can_request_admin: value.can_request_admin,
        }
    }
}
impl Filter for UserFilter<'_> {
    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
        sea_query::Condition::all()
            .add_option(self.state().map(|state| {
                if state.is_locked() {
                    Expr::col((Users::Table, Users::LockedAt)).is_not_null()
                } else {
                    Expr::col((Users::Table, Users::LockedAt)).is_null()
                }
            }))
            .add_option(self.can_request_admin().map(|can_request_admin| {
                Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
            }))
    }
}
#[async_trait]
impl UserRepository for PgUserRepository<'_> {
    type Error = DatabaseError;
    #[tracing::instrument(
        name = "db.user.lookup",
        skip_all,
        fields(
            db.query.text,
            user.id = %id,
        ),
        err,
    )]
    async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
        let res = sqlx::query_as!(
            UserLookup,
            r#"
                SELECT user_id
                     , username
                     , primary_user_email_id
                     , created_at
                     , locked_at
                     , can_request_admin
                FROM users
                WHERE user_id = $1
            "#,
            Uuid::from(id),
        )
        .traced()
        .fetch_optional(&mut *self.conn)
        .await?;
        let Some(res) = res else { return Ok(None) };
        Ok(Some(res.into()))
    }
    #[tracing::instrument(
        name = "db.user.find_by_username",
        skip_all,
        fields(
            db.query.text,
            user.username = username,
        ),
        err,
    )]
    async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
        let res = sqlx::query_as!(
            UserLookup,
            r#"
                SELECT user_id
                     , username
                     , primary_user_email_id
                     , created_at
                     , locked_at
                     , can_request_admin
                FROM users
                WHERE username = $1
            "#,
            username,
        )
        .traced()
        .fetch_optional(&mut *self.conn)
        .await?;
        let Some(res) = res else { return Ok(None) };
        Ok(Some(res.into()))
    }
    #[tracing::instrument(
        name = "db.user.add",
        skip_all,
        fields(
            db.query.text,
            user.username = username,
            user.id,
        ),
        err,
    )]
    async fn add(
        &mut self,
        rng: &mut (dyn RngCore + Send),
        clock: &dyn Clock,
        username: String,
    ) -> Result<User, Self::Error> {
        let created_at = clock.now();
        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
        tracing::Span::current().record("user.id", tracing::field::display(id));
        let res = sqlx::query!(
            r#"
                INSERT INTO users (user_id, username, created_at)
                VALUES ($1, $2, $3)
                ON CONFLICT (username) DO NOTHING
            "#,
            Uuid::from(id),
            username,
            created_at,
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;
        DatabaseError::ensure_affected_rows(&res, 1)?;
        Ok(User {
            id,
            username,
            sub: id.to_string(),
            primary_user_email_id: None,
            created_at,
            locked_at: None,
            can_request_admin: false,
        })
    }
    #[tracing::instrument(
        name = "db.user.exists",
        skip_all,
        fields(
            db.query.text,
            user.username = username,
        ),
        err,
    )]
    async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
        let exists = sqlx::query_scalar!(
            r#"
                SELECT EXISTS(
                    SELECT 1 FROM users WHERE username = $1
                ) AS "exists!"
            "#,
            username
        )
        .traced()
        .fetch_one(&mut *self.conn)
        .await?;
        Ok(exists)
    }
    #[tracing::instrument(
        name = "db.user.lock",
        skip_all,
        fields(
            db.query.text,
            %user.id,
        ),
        err,
    )]
    async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
        if user.locked_at.is_some() {
            return Ok(user);
        }
        let locked_at = clock.now();
        let res = sqlx::query!(
            r#"
                UPDATE users
                SET locked_at = $1
                WHERE user_id = $2
            "#,
            locked_at,
            Uuid::from(user.id),
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;
        DatabaseError::ensure_affected_rows(&res, 1)?;
        user.locked_at = Some(locked_at);
        Ok(user)
    }
    #[tracing::instrument(
        name = "db.user.unlock",
        skip_all,
        fields(
            db.query.text,
            %user.id,
        ),
        err,
    )]
    async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
        if user.locked_at.is_none() {
            return Ok(user);
        }
        let res = sqlx::query!(
            r#"
                UPDATE users
                SET locked_at = NULL
                WHERE user_id = $1
            "#,
            Uuid::from(user.id),
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;
        DatabaseError::ensure_affected_rows(&res, 1)?;
        user.locked_at = None;
        Ok(user)
    }
    #[tracing::instrument(
        name = "db.user.set_can_request_admin",
        skip_all,
        fields(
            db.query.text,
            %user.id,
            user.can_request_admin = can_request_admin,
        ),
        err,
    )]
    async fn set_can_request_admin(
        &mut self,
        mut user: User,
        can_request_admin: bool,
    ) -> Result<User, Self::Error> {
        let res = sqlx::query!(
            r#"
                UPDATE users
                SET can_request_admin = $2
                WHERE user_id = $1
            "#,
            Uuid::from(user.id),
            can_request_admin,
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;
        DatabaseError::ensure_affected_rows(&res, 1)?;
        user.can_request_admin = can_request_admin;
        Ok(user)
    }
    #[tracing::instrument(
        name = "db.user.list",
        skip_all,
        fields(
            db.query.text,
        ),
        err,
    )]
    async fn list(
        &mut self,
        filter: UserFilter<'_>,
        pagination: mas_storage::Pagination,
    ) -> Result<mas_storage::Page<User>, Self::Error> {
        let (sql, arguments) = Query::select()
            .expr_as(
                Expr::col((Users::Table, Users::UserId)),
                UserLookupIden::UserId,
            )
            .expr_as(
                Expr::col((Users::Table, Users::Username)),
                UserLookupIden::Username,
            )
            .expr_as(
                Expr::col((Users::Table, Users::PrimaryUserEmailId)),
                UserLookupIden::PrimaryUserEmailId,
            )
            .expr_as(
                Expr::col((Users::Table, Users::CreatedAt)),
                UserLookupIden::CreatedAt,
            )
            .expr_as(
                Expr::col((Users::Table, Users::LockedAt)),
                UserLookupIden::LockedAt,
            )
            .expr_as(
                Expr::col((Users::Table, Users::CanRequestAdmin)),
                UserLookupIden::CanRequestAdmin,
            )
            .from(Users::Table)
            .apply_filter(filter)
            .generate_pagination((Users::Table, Users::UserId), pagination)
            .build_sqlx(PostgresQueryBuilder);
        let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
            .traced()
            .fetch_all(&mut *self.conn)
            .await?;
        let page = pagination.process(edges).map(User::from);
        Ok(page)
    }
    #[tracing::instrument(
        name = "db.user.count",
        skip_all,
        fields(
            db.query.text,
        ),
        err,
    )]
    async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
        let (sql, arguments) = Query::select()
            .expr(Expr::col((Users::Table, Users::UserId)).count())
            .from(Users::Table)
            .apply_filter(filter)
            .build_sqlx(PostgresQueryBuilder);
        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
            .traced()
            .fetch_one(&mut *self.conn)
            .await?;
        count
            .try_into()
            .map_err(DatabaseError::to_invalid_operation)
    }
    #[tracing::instrument(
        name = "db.user.acquire_lock_for_sync",
        skip_all,
        fields(
            db.query.text,
            user.id = %user.id,
        ),
        err,
    )]
    async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
        let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
        sqlx::query!(
            r#"
                SELECT pg_advisory_xact_lock($1)
            "#,
            lock_id,
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;
        Ok(())
    }
}