feat: separeted oauth from web server
This commit is contained in:
parent
3867885b89
commit
f97c0a542f
|
@ -1,16 +0,0 @@
|
|||
import psycopg2.extensions
|
||||
|
||||
|
||||
def insert_token(cur: psycopg2.extensions.cursor, ip_address: str, access_token: str, refresh_token: str, expires_in: str):
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO auth.tokens
|
||||
(ip_address,access_token,refresh_token,expires_in) VALUES (%s,%s,%s,%s)
|
||||
""",
|
||||
(
|
||||
ip_address, # ip
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_in,
|
||||
)
|
||||
)
|
55
ajusta_bling/database/queries/tokens.py
Normal file
55
ajusta_bling/database/queries/tokens.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
from time import time
|
||||
|
||||
import psycopg2.extensions
|
||||
|
||||
|
||||
def insert_token(cur: psycopg2.extensions.cursor, ip_address: str, access_token: str, refresh_token: str, expires_in: str):
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO auth.tokens
|
||||
(ip_address,access_token,refresh_token,expires_in) VALUES (%s,%s,%s,%s)
|
||||
""",
|
||||
(
|
||||
ip_address, # ip
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_in,
|
||||
)
|
||||
)
|
||||
|
||||
def token_cleanup(cur: psycopg2.extensions.cursor) -> int | bool:
|
||||
cur.execute(
|
||||
"""
|
||||
DELETE FROM auth.tokens
|
||||
WHERE extract(epoch from created_at) + expires_in < extract(epoch from now())
|
||||
"""
|
||||
)
|
||||
return cur.rowcount if cur.rowcount > 0 else False
|
||||
|
||||
|
||||
def get_valid_token(cur: psycopg2.extensions.cursor, ip_address: str) -> tuple[str, str, str]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT access_token, refresh_token
|
||||
FROM auth.tokens
|
||||
WHERE ip_address = %s
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
(ip_address,)
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
if row['created_at_unix']+row['expires_in'] < time():
|
||||
delete_token(cur, row['uuid'])
|
||||
return row['access_token'], row['refresh_token'], row['expires_in']
|
||||
|
||||
def delete_token(cur: psycopg2.extensions.cursor, uuid: str) -> bool:
|
||||
cur.execute(
|
||||
"""
|
||||
DELETE FROM auth.tokens
|
||||
WHERE uuid = %s
|
||||
""",
|
||||
(uuid,)
|
||||
)
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
2
ajusta_bling/oauth/__init__.py
Normal file
2
ajusta_bling/oauth/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from . import actions, builders, client
|
||||
from .client import Client
|
27
ajusta_bling/oauth/actions.py
Normal file
27
ajusta_bling/oauth/actions.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
import requests
|
||||
|
||||
from . import builders
|
||||
from .client import Client
|
||||
|
||||
|
||||
def process_callback(client: Client, code: str) -> tuple[str, str, int]:
|
||||
payload = {
|
||||
'grant_type': 'authorization_code',
|
||||
'code': code,
|
||||
}
|
||||
|
||||
header = builders.build_callback_authorization_header(client)
|
||||
|
||||
response = requests.post(client.access_url,
|
||||
data = payload,
|
||||
headers = header)
|
||||
data = response.json()
|
||||
access_token: str = str(data["access_token"])
|
||||
refresh_token: str = str(data["refresh_token"])
|
||||
expires_in: int = int(data["expires_in"])
|
||||
|
||||
return (
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_in
|
||||
)
|
18
ajusta_bling/oauth/builders.py
Normal file
18
ajusta_bling/oauth/builders.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from base64 import b64encode
|
||||
|
||||
from .client import Client
|
||||
|
||||
|
||||
def build_authorize_url(client: Client, state: str) -> str:
|
||||
return "%(url)s?client_id=%(client_id)s&redirect_uri=%(redirect)s&response_type=code&state=%(state)s" % {
|
||||
"url": client.authorize_url,
|
||||
"client_id": client.client_id,
|
||||
"redirect": client.redirect_uri,
|
||||
"state": state
|
||||
}
|
||||
|
||||
def build_callback_authorization_header(client: Client) -> dict:
|
||||
return {"Authorization": "Basic " +
|
||||
b64encode(f"{client.client_id}:{client.client_secret}"
|
||||
.encode())
|
||||
.decode()}
|
37
ajusta_bling/oauth/client.py
Normal file
37
ajusta_bling/oauth/client.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from os import getenv
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str = getenv('OAUTH_CLIENT_ID'),
|
||||
client_secret: str = getenv('OAUTH_CLIENT_SECRET'),
|
||||
redirect_uri: str = getenv('OAUTH_REDIRECT_URI'),
|
||||
authorize_url: str = getenv('OAUTH_URL_AUTHORIZE'),
|
||||
access_url: str = getenv('OAUTH_URL_ACCESS_TOKEN'),
|
||||
) -> None:
|
||||
self.__client_id = client_id
|
||||
self.__client_secret = client_secret
|
||||
self.__redirect_uri = redirect_uri
|
||||
self.__authorize_url = authorize_url
|
||||
self.__access_url = access_url
|
||||
|
||||
@property
|
||||
def client_id(self) -> str:
|
||||
return self.__client_id
|
||||
|
||||
@property
|
||||
def client_secret(self) -> str:
|
||||
return self.__client_secret
|
||||
|
||||
@property
|
||||
def redirect_uri(self) -> str:
|
||||
return self.__redirect_uri
|
||||
|
||||
@property
|
||||
def authorize_url(self) -> str:
|
||||
return self.__authorize_url
|
||||
|
||||
@property
|
||||
def access_url(self) -> str:
|
||||
return self.__access_url
|
|
@ -1,54 +1,33 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from base64 import b64encode
|
||||
from os import getenv
|
||||
|
||||
import requests
|
||||
import database.queries.tokens as token_sql
|
||||
from flask import Flask, redirect, request, session, url_for
|
||||
|
||||
import ajusta_bling.database.queries.auth as sqlAuth
|
||||
import ajusta_bling.oauth as oauth
|
||||
from ajusta_bling.common import Args
|
||||
from ajusta_bling.database import Database
|
||||
|
||||
app = Flask(__name__)
|
||||
app.secret_key = "#^A549639t5@#&$p"
|
||||
db: Database | None = None
|
||||
oauth_client: oauth.Client = oauth.Client();
|
||||
|
||||
@app.get('/auth')
|
||||
def auth():
|
||||
session["state"] = secrets.token_urlsafe(16)
|
||||
|
||||
return redirect("%(url)s?client_id=%(client_id)s&redirect_uri=%(redirect)s&response_type=code&state=%(state)s" % {
|
||||
"url": getenv('OAUTH_URL_AUTHORIZE'),
|
||||
"client_id": getenv('OAUTH_CLIENT_ID'),
|
||||
"redirect": url_for('callback', _external = True),
|
||||
"state": session["state"]
|
||||
})
|
||||
return redirect(oauth.builders.build_authorize_url(oauth_client, session["state"]))
|
||||
|
||||
@app.get('/callback')
|
||||
def callback():
|
||||
if request.args.get("state") != session.pop("state", "fartnugget"):
|
||||
return "I banish thee, to the state of Ohio", 403
|
||||
|
||||
payload = {
|
||||
'grant_type': 'authorization_code',
|
||||
'code': request.args.get('code'),
|
||||
}
|
||||
|
||||
header = "Basic " + b64encode(f"{getenv('OAUTH_CLIENT_ID')}:{getenv('OAUTH_CLIENT_SECRET')}".encode()).decode()
|
||||
|
||||
response = requests.post(getenv('OAUTH_URL_ACCESS_TOKEN'),
|
||||
data = payload,
|
||||
headers= {"Authorization": header})
|
||||
data = response.json()
|
||||
access_token: str = str(data["access_token"])
|
||||
refresh_token: str = str(data["refresh_token"])
|
||||
expires_in: int = int(data["expires_in"])
|
||||
access_token, refresh_token, expires_in = oauth.actions.process_callback(oauth_client, request.args.get("code"))
|
||||
|
||||
with db.get_cur() as cur:
|
||||
sqlAuth.insert_token(cur, request.remote_addr, access_token, refresh_token, expires_in)
|
||||
|
||||
token_sql.insert_token(cur, request.remote_addr, access_token, refresh_token, expires_in)
|
||||
return redirect(url_for('index'))
|
||||
|
||||
@app.route("/")
|
||||
|
|
Loading…
Reference in New Issue
Block a user