kristofer revised this gist . Go to revision
1 file changed, 146 insertions
FastAPI_auth_example.py(file created)
@@ -0,0 +1,146 @@ | |||
1 | + | from fastapi import FastAPI, Depends, HTTPException, status, Header | |
2 | + | from sqlalchemy import create_engine, Column, String, Integer | |
3 | + | from sqlalchemy.ext.declarative import declarative_base | |
4 | + | from sqlalchemy.orm import sessionmaker | |
5 | + | from keycove import encrypt, decrypt, hash, generate_token | |
6 | + | from sqlalchemy.orm import Session | |
7 | + | ||
8 | + | app = FastAPI() | |
9 | + | ||
10 | + | SQLALCHEMY_DATABASE_URL = "sqlite:///db.sqlite3" | |
11 | + | engine = create_engine( | |
12 | + | SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} | |
13 | + | ) | |
14 | + | ||
15 | + | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
16 | + | Base = declarative_base() | |
17 | + | ||
18 | + | ||
19 | + | def get_db(): | |
20 | + | db = SessionLocal() | |
21 | + | try: | |
22 | + | yield db | |
23 | + | finally: | |
24 | + | db.close() | |
25 | + | ||
26 | + | ||
27 | + | class Key(Base): | |
28 | + | __tablename__ = "keys" | |
29 | + | id = Column(Integer, primary_key=True, autoincrement=True) | |
30 | + | hashed_key = Column(String, nullable=False, unique=True) | |
31 | + | encrypted_key = Column(String, nullable=True) | |
32 | + | ||
33 | + | ||
34 | + | Base.metadata.create_all(bind=engine) | |
35 | + | ||
36 | + | # this secret key was generated using the generate_secret_key function. | |
37 | + | # do not use this secret key in production!!! | |
38 | + | # do not store your secret key in your code - this is just for demonstration purposes | |
39 | + | # you should store your secret key in an environment variable | |
40 | + | secret_key = "EZdgwBIak481WZB8ZkZmzKHTDkRQFzDjeTrhSlU_v2g=" | |
41 | + | ||
42 | + | ||
43 | + | def verify_api_key(api_key: str = Header(None), db: Session = Depends(get_db)) -> None: | |
44 | + | """ | |
45 | + | This function verifies the provided API key by hashing it and checking if the hashed key exists in the database. | |
46 | + | If the hashed key does not exist in the database, it raises an HTTPException with a 404 status code. | |
47 | + | ||
48 | + | Parameters: | |
49 | + | api_key (str): The API key to verify. This is expected to be provided in the request header. | |
50 | + | db (Session): The database session to use for querying the database. | |
51 | + | ||
52 | + | Raises: | |
53 | + | HTTPException: If the provided API key is not valid. | |
54 | + | """ | |
55 | + | ||
56 | + | key = db.query(Key).filter(Key.hashed_key == hash(api_key)).first() | |
57 | + | if not key: | |
58 | + | raise HTTPException( | |
59 | + | status_code=status.HTTP_404_NOT_FOUND, | |
60 | + | detail=f"The provided API key is not valid. Please provide a valid API key.", | |
61 | + | ) | |
62 | + | ||
63 | + | ||
64 | + | @app.get("/protected") | |
65 | + | def protected_route(verify_api_key: None = Depends(verify_api_key)): | |
66 | + | """ | |
67 | + | This function is a protected route that requires a valid API key to access. | |
68 | + | If the API key is valid, it returns a message indicating that access has been granted. | |
69 | + | ||
70 | + | Parameters: | |
71 | + | verify_api_key (None): This is a dependency that ensures the API key is verified before the function is accessed. | |
72 | + | ||
73 | + | Returns: | |
74 | + | dict: A dictionary with a message indicating that access has been granted. | |
75 | + | """ | |
76 | + | ||
77 | + | return {"message": "access granted"} | |
78 | + | ||
79 | + | ||
80 | + | @app.post("/create_api_key") | |
81 | + | def create_api_key(db: Session = Depends(get_db)): | |
82 | + | """ | |
83 | + | This function creates a new API key, hashes it, encrypts it, and stores it in the database. | |
84 | + | It then returns the new API key. | |
85 | + | ||
86 | + | Parameters: | |
87 | + | db (Session): The database session to use for querying the database. | |
88 | + | ||
89 | + | Returns: | |
90 | + | dict: A dictionary with the new API key. | |
91 | + | """ | |
92 | + | ||
93 | + | api_key = generate_token() | |
94 | + | hashed_key = hash(api_key) | |
95 | + | encrypted_key = encrypt(api_key, secret_key) | |
96 | + | new_key = Key(hashed_key=hashed_key, encrypted_key=encrypted_key) | |
97 | + | ||
98 | + | db.add(new_key) | |
99 | + | db.commit() | |
100 | + | db.refresh(new_key) | |
101 | + | ||
102 | + | return {"api_key": api_key} | |
103 | + | ||
104 | + | ||
105 | + | @app.get("/decrypt_api_key") | |
106 | + | def decrypt_api_key( | |
107 | + | api_key: str = Header(None), | |
108 | + | verify_api_key: str = Depends(verify_api_key), | |
109 | + | db: Session = Depends(get_db), | |
110 | + | ): | |
111 | + | """ | |
112 | + | This function decrypts a given API key using a provided secret key. | |
113 | + | The same secret key that was used to encrypt the API key should be used to decrypt it. | |
114 | + | The function first hashes the provided API key and then queries the database for a key with the same hash. | |
115 | + | If such a key is found, the function decrypts the encrypted key stored in the database and returns it. | |
116 | + | ||
117 | + | Parameters: | |
118 | + | api_key (str): The API key to decrypt. | |
119 | + | verify_api_key (str): A dependency that ensures the API key is verified before the function is accessed. | |
120 | + | db (Session): The database session to use for querying the database. | |
121 | + | ||
122 | + | Returns: | |
123 | + | str: The decrypted API key. | |
124 | + | ||
125 | + | Raises: | |
126 | + | HTTPException: If the provided API key is not valid. | |
127 | + | """ | |
128 | + | ||
129 | + | key = db.query(Key).filter(Key.hashed_key == hash(api_key)).first() | |
130 | + | return decrypt(key.encrypted_key, secret_key) | |
131 | + | ||
132 | + | ||
133 | + | @app.get("/keys") | |
134 | + | def get_keys(db: Session = Depends(get_db)): | |
135 | + | """ | |
136 | + | This function retrieves all the hashed keys from the database. | |
137 | + | ||
138 | + | Parameters: | |
139 | + | db (Session): The database session to use for querying the database. | |
140 | + | ||
141 | + | Returns: | |
142 | + | dict: A dictionary with all the hashed keys. | |
143 | + | """ | |
144 | + | ||
145 | + | keys = db.query(Key).all() | |
146 | + | return {"keys": [key.hashed_key for key in keys]} |
Newer
Older