from fastapi import FastAPI, Depends, HTTPException, status, Header
from sqlalchemy import create_engine, Column, String, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from keycove import encrypt, decrypt, hash, generate_token
from sqlalchemy.orm import Session

app = FastAPI()

SQLALCHEMY_DATABASE_URL = "sqlite:///db.sqlite3"
engine = create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()


def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()


class Key(Base):
    __tablename__ = "keys"
    id = Column(Integer, primary_key=True, autoincrement=True)
    hashed_key = Column(String, nullable=False, unique=True)
    encrypted_key = Column(String, nullable=True)


Base.metadata.create_all(bind=engine)

# this secret key was generated using the generate_secret_key function.
# do not use this secret key in production!!!
# do not store your secret key in your code - this is just for demonstration purposes
# you should store your secret key in an environment variable
secret_key = "EZdgwBIak481WZB8ZkZmzKHTDkRQFzDjeTrhSlU_v2g="


def verify_api_key(api_key: str = Header(None), db: Session = Depends(get_db)) -> None:
    """
    This function verifies the provided API key by hashing it and checking if the hashed key exists in the database.
    If the hashed key does not exist in the database, it raises an HTTPException with a 404 status code.

    Parameters:
    api_key (str): The API key to verify. This is expected to be provided in the request header.
    db (Session): The database session to use for querying the database.

    Raises:
    HTTPException: If the provided API key is not valid.
    """

    key = db.query(Key).filter(Key.hashed_key == hash(api_key)).first()
    if not key:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"The provided API key is not valid. Please provide a valid API key.",
        )


@app.get("/protected")
def protected_route(verify_api_key: None = Depends(verify_api_key)):
    """
    This function is a protected route that requires a valid API key to access.
    If the API key is valid, it returns a message indicating that access has been granted.

    Parameters:
    verify_api_key (None): This is a dependency that ensures the API key is verified before the function is accessed.

    Returns:
    dict: A dictionary with a message indicating that access has been granted.
    """

    return {"message": "access granted"}


@app.post("/create_api_key")
def create_api_key(db: Session = Depends(get_db)):
    """
    This function creates a new API key, hashes it, encrypts it, and stores it in the database.
    It then returns the new API key.

    Parameters:
    db (Session): The database session to use for querying the database.

    Returns:
    dict: A dictionary with the new API key.
    """

    api_key = generate_token()
    hashed_key = hash(api_key)
    encrypted_key = encrypt(api_key, secret_key)
    new_key = Key(hashed_key=hashed_key, encrypted_key=encrypted_key)

    db.add(new_key)
    db.commit()
    db.refresh(new_key)

    return {"api_key": api_key}


@app.get("/decrypt_api_key")
def decrypt_api_key(
    api_key: str = Header(None),
    verify_api_key: str = Depends(verify_api_key),
    db: Session = Depends(get_db),
):
    """
    This function decrypts a given API key using a provided secret key.
    The same secret key that was used to encrypt the API key should be used to decrypt it.
    The function first hashes the provided API key and then queries the database for a key with the same hash.
    If such a key is found, the function decrypts the encrypted key stored in the database and returns it.

    Parameters:
    api_key (str): The API key to decrypt.
    verify_api_key (str): A dependency that ensures the API key is verified before the function is accessed.
    db (Session): The database session to use for querying the database.

    Returns:
    str: The decrypted API key.

    Raises:
    HTTPException: If the provided API key is not valid.
    """

    key = db.query(Key).filter(Key.hashed_key == hash(api_key)).first()
    return decrypt(key.encrypted_key, secret_key)


@app.get("/keys")
def get_keys(db: Session = Depends(get_db)):
    """
    This function retrieves all the hashed keys from the database.

    Parameters:
    db (Session): The database session to use for querying the database.

    Returns:
    dict: A dictionary with all the hashed keys.
    """

    keys = db.query(Key).all()
    return {"keys": [key.hashed_key for key in keys]}