diff --git a/ajusta_bling/database/queries/auth.py b/ajusta_bling/database/queries/auth.py deleted file mode 100644 index 2d8df34..0000000 --- a/ajusta_bling/database/queries/auth.py +++ /dev/null @@ -1,16 +0,0 @@ -import psycopg2.extensions - - -def insert_token(cur: psycopg2.extensions.cursor, ip_address: str, access_token: str, refresh_token: str, expires_in: str): - cur.execute( - """ - INSERT INTO auth.tokens - (ip_address,access_token,refresh_token,expires_in) VALUES (%s,%s,%s,%s) - """, - ( - ip_address, # ip - access_token, - refresh_token, - expires_in, - ) - ) \ No newline at end of file diff --git a/ajusta_bling/database/queries/tokens.py b/ajusta_bling/database/queries/tokens.py new file mode 100644 index 0000000..3e0bd56 --- /dev/null +++ b/ajusta_bling/database/queries/tokens.py @@ -0,0 +1,55 @@ +from time import time + +import psycopg2.extensions + + +def insert_token(cur: psycopg2.extensions.cursor, ip_address: str, access_token: str, refresh_token: str, expires_in: str): + cur.execute( + """ + INSERT INTO auth.tokens + (ip_address,access_token,refresh_token,expires_in) VALUES (%s,%s,%s,%s) + """, + ( + ip_address, # ip + access_token, + refresh_token, + expires_in, + ) + ) + +def token_cleanup(cur: psycopg2.extensions.cursor) -> int | bool: + cur.execute( + """ + DELETE FROM auth.tokens + WHERE extract(epoch from created_at) + expires_in < extract(epoch from now()) + """ + ) + return cur.rowcount if cur.rowcount > 0 else False + + +def get_valid_token(cur: psycopg2.extensions.cursor, ip_address: str) -> tuple[str, str, str]: + cur.execute( + """ + SELECT access_token, refresh_token + FROM auth.tokens + WHERE ip_address = %s + ORDER BY created_at DESC + """, + (ip_address,) + ) + for row in cur.fetchall(): + if row['created_at_unix']+row['expires_in'] < time(): + delete_token(cur, row['uuid']) + return row['access_token'], row['refresh_token'], row['expires_in'] + +def delete_token(cur: psycopg2.extensions.cursor, uuid: str) -> bool: + cur.execute( + """ + DELETE FROM auth.tokens + WHERE uuid = %s + """, + (uuid,) + ) + return cur.rowcount > 0 + + diff --git a/ajusta_bling/oauth/__init__.py b/ajusta_bling/oauth/__init__.py new file mode 100644 index 0000000..dae7248 --- /dev/null +++ b/ajusta_bling/oauth/__init__.py @@ -0,0 +1,2 @@ +from . import actions, builders, client +from .client import Client diff --git a/ajusta_bling/oauth/actions.py b/ajusta_bling/oauth/actions.py new file mode 100644 index 0000000..f90014d --- /dev/null +++ b/ajusta_bling/oauth/actions.py @@ -0,0 +1,27 @@ +import requests + +from . import builders +from .client import Client + + +def process_callback(client: Client, code: str) -> tuple[str, str, int]: + payload = { + 'grant_type': 'authorization_code', + 'code': code, + } + + header = builders.build_callback_authorization_header(client) + + response = requests.post(client.access_url, + data = payload, + headers = header) + data = response.json() + access_token: str = str(data["access_token"]) + refresh_token: str = str(data["refresh_token"]) + expires_in: int = int(data["expires_in"]) + + return ( + access_token, + refresh_token, + expires_in + ) \ No newline at end of file diff --git a/ajusta_bling/oauth/builders.py b/ajusta_bling/oauth/builders.py new file mode 100644 index 0000000..79be594 --- /dev/null +++ b/ajusta_bling/oauth/builders.py @@ -0,0 +1,18 @@ +from base64 import b64encode + +from .client import Client + + +def build_authorize_url(client: Client, state: str) -> str: + return "%(url)s?client_id=%(client_id)s&redirect_uri=%(redirect)s&response_type=code&state=%(state)s" % { + "url": client.authorize_url, + "client_id": client.client_id, + "redirect": client.redirect_uri, + "state": state + } + +def build_callback_authorization_header(client: Client) -> dict: + return {"Authorization": "Basic " + + b64encode(f"{client.client_id}:{client.client_secret}" + .encode()) + .decode()} \ No newline at end of file diff --git a/ajusta_bling/oauth/client.py b/ajusta_bling/oauth/client.py new file mode 100644 index 0000000..aa12bab --- /dev/null +++ b/ajusta_bling/oauth/client.py @@ -0,0 +1,37 @@ +from os import getenv + + +class Client: + def __init__( + self, + client_id: str = getenv('OAUTH_CLIENT_ID'), + client_secret: str = getenv('OAUTH_CLIENT_SECRET'), + redirect_uri: str = getenv('OAUTH_REDIRECT_URI'), + authorize_url: str = getenv('OAUTH_URL_AUTHORIZE'), + access_url: str = getenv('OAUTH_URL_ACCESS_TOKEN'), + ) -> None: + self.__client_id = client_id + self.__client_secret = client_secret + self.__redirect_uri = redirect_uri + self.__authorize_url = authorize_url + self.__access_url = access_url + + @property + def client_id(self) -> str: + return self.__client_id + + @property + def client_secret(self) -> str: + return self.__client_secret + + @property + def redirect_uri(self) -> str: + return self.__redirect_uri + + @property + def authorize_url(self) -> str: + return self.__authorize_url + + @property + def access_url(self) -> str: + return self.__access_url \ No newline at end of file diff --git a/ajusta_bling/web/__init__.py b/ajusta_bling/web/__init__.py index 4b289cc..36a686f 100644 --- a/ajusta_bling/web/__init__.py +++ b/ajusta_bling/web/__init__.py @@ -1,54 +1,33 @@ from __future__ import annotations import secrets -from base64 import b64encode -from os import getenv -import requests +import database.queries.tokens as token_sql from flask import Flask, redirect, request, session, url_for -import ajusta_bling.database.queries.auth as sqlAuth +import ajusta_bling.oauth as oauth from ajusta_bling.common import Args from ajusta_bling.database import Database app = Flask(__name__) app.secret_key = "#^A549639t5@#&$p" db: Database | None = None +oauth_client: oauth.Client = oauth.Client(); @app.get('/auth') def auth(): session["state"] = secrets.token_urlsafe(16) - - return redirect("%(url)s?client_id=%(client_id)s&redirect_uri=%(redirect)s&response_type=code&state=%(state)s" % { - "url": getenv('OAUTH_URL_AUTHORIZE'), - "client_id": getenv('OAUTH_CLIENT_ID'), - "redirect": url_for('callback', _external = True), - "state": session["state"] - }) + return redirect(oauth.builders.build_authorize_url(oauth_client, session["state"])) @app.get('/callback') def callback(): if request.args.get("state") != session.pop("state", "fartnugget"): return "I banish thee, to the state of Ohio", 403 - payload = { - 'grant_type': 'authorization_code', - 'code': request.args.get('code'), - } - - header = "Basic " + b64encode(f"{getenv('OAUTH_CLIENT_ID')}:{getenv('OAUTH_CLIENT_SECRET')}".encode()).decode() - - response = requests.post(getenv('OAUTH_URL_ACCESS_TOKEN'), - data = payload, - headers= {"Authorization": header}) - data = response.json() - access_token: str = str(data["access_token"]) - refresh_token: str = str(data["refresh_token"]) - expires_in: int = int(data["expires_in"]) + access_token, refresh_token, expires_in = oauth.actions.process_callback(oauth_client, request.args.get("code")) with db.get_cur() as cur: - sqlAuth.insert_token(cur, request.remote_addr, access_token, refresh_token, expires_in) - + token_sql.insert_token(cur, request.remote_addr, access_token, refresh_token, expires_in) return redirect(url_for('index')) @app.route("/")