FastAPI处理OAuth2

最近在重新学习FastAPI的文档,认证这部分相对独立,简单做个demo,基本都是官网的内容,稍作修改。

官网链接:https://fastapi.tiangolo.com/zh/tutorial/security/oauth2-jwt/

一、基本概念

JWT(Json Web Token),基本的概念可以在网上搜到很多,目前认证的主流都是采用token的方式。

一个典型的JWT的结构是xxxxx.yyyyy.zzzzz

其中第一部分是headers,第二部分payload,第三部分Signature。核心内容保存在payload中。payload中有一些固定参数名称的意义:

  • iss 【issuer】发布者的url地址

  • sub 【subject】该JWT所面向的用户,用于处理特定应用,不是常用的字段

  • aud 【audience】接受者的url地址

  • exp 【expiration】 该jwt销毁的时间;unix时间戳

  • nbf 【not before】 该jwt的使用时间不能早于该时间;unix时间戳

  • iat 【issued at】 该jwt的发布时间;unix 时间戳

  • jti 【JWT ID】 该jwt的唯一ID编号

当然也支持在payload中自定义参数。

二、demo

第一步,实现用户的密码加解密,模拟数据库返回用户信息

from passlib.context import CryptContext
from pydantic import BaseModel
from typing import Optional

# 所有用到的Schema信息
class Token(BaseModel):
    access_token: str
    token_type: str

class TokenData(BaseModel):
    username: Optional[str] = None

class User(BaseModel):
    username: str
    email: Optional[str] = None
    disabled: Optional[bool] = None

class UserInDB(User):
    hashed_password: str

# 模拟数据库,保存用户信息
# 其中hashed_password是加密过后的密码
fake_users_db = {
    "guodabao": {
        "username": "guodabao",
        "email": "[email protected]",
        "hashed_password": "$2b$12$HQfuwwNR857Pp4ySQDvXNOAZkblQuU4wBcsIJVxqsf3sYoiMf7W42",
        "disabled": False,
    }
}
# 通过passlib中CryptContext实现加密,解密
pwd_context = CryptContext(schemes=["bcrypt"]) # schemes指定加密方式

# 验证密码
def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

# 获取密码
def get_password_hash(password):
    return pwd_context.hash(password)

# 模拟从数据库取用户数据
def get_user(db, username: str):
    if username in db:
        user_dict = db[username]
        return UserInDB(**user_dict)

# 通过密码验证并返回用户信息
def authenticate_user(fake_db, username: str, password: str):
    user = get_user(fake_db, username)
    if not user:
        return False
    if not verify_password(password, user.hashed_password):
        return False
    return user

第二步,实现认证

from datetime import datetime, timedelta
from jose import JWTError, jwt
from starlette.status import HTTP_401_UNAUTHORIZED

# JTW需要的基本信息
# 通过openssl rand -hex 32 生成的秘钥
SECRET_KEY = "d6f72340d4afe840c036ba5d593f7fd36c3c811602f77abebdef8127b74ddce2" 
ALGORITHM = "HS256" # 加密签名算法
ACCESS_TOKEN_EXPIRE_MINUTES = 30 # token保留时长
TOKEN_URL = "access-token"  # 获取token的URL


# 创建JWT的核心逻辑
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
    to_encode = data.copy() # JWT中payload中的信息
    if expires_delta:       # 设置token销毁时长
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})  # 更新payload中exp
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) # 加密,获得JTW
    return encoded_jwt

# 对token解密,还原user
async def get_current_user(token: str = Depends(oauth2_scheme)):
    credentials_exception = HTTPException(
        status_code=HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])  # 解密
        username: str = payload.get("username") # 本例通过username作为Key
        if username is None:
            raise credentials_exception
        token_data = TokenData(username=username)
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=token_data.username)
    if user is None:
        raise credentials_exception
    return user

第三步,实现API

from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm

app = FastAPI()

oauth2_scheme = OAuth2PasswordBearer(tokenUrl=TOKEN_URL)

async def get_current_active_user(current_user: User = Depends(get_current_user)):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user

@app.post("/access-token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
    user = authenticate_user(fake_users_db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"username": user.username}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}

@app.get("/users/me/", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_active_user)):
    return current_user

@app.get("/users/me/items/")
async def read_own_items(current_user: User = Depends(get_current_active_user)):
    return [{"item_id": "Foo", "owner": current_user.username}]

@app.post("/password/")  # 获取fake_users_db中password的加密密码
def get_hashed_password(password: str):
    return get_password_hash(password)

三、完成代码

from datetime import datetime, timedelta
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from starlette.status import HTTP_401_UNAUTHORIZED
from typing import Optional


SECRET_KEY = "d6f72340d4afe840c036ba5d593f7fd36c3c811602f77abebdef8127b74ddce2"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
TOKEN_URL = "access-token"


fake_users_db = {
    "guodabao": {
        "username": "guodabao",
        "email": "[email protected]",
        "hashed_password": "$2b$12$HQfuwwNR857Pp4ySQDvXNOAZkblQuU4wBcsIJVxqsf3sYoiMf7W42",
        "disabled": False,
    }
}


class Token(BaseModel):
    access_token: str
    token_type: str


class TokenData(BaseModel):
    username: Optional[str] = None


class User(BaseModel):
    username: str
    email: Optional[str] = None
    disabled: Optional[bool] = None


class UserInDB(User):
    hashed_password: str


pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

oauth2_scheme = OAuth2PasswordBearer(tokenUrl=TOKEN_URL)

app = FastAPI()


def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)


def get_password_hash(password):
    return pwd_context.hash(password)


def get_user(db, username: str):
    if username in db:
        user_dict = db[username]
        return UserInDB(**user_dict)


def authenticate_user(fake_db, username: str, password: str):
    user = get_user(fake_db, username)
    if not user:
        return False
    if not verify_password(password, user.hashed_password):
        return False
    return user


def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt


async def get_current_user(token: str = Depends(oauth2_scheme)):
    credentials_exception = HTTPException(
        status_code=HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("username")
        if username is None:
            raise credentials_exception
        token_data = TokenData(username=username)
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=token_data.username)
    if user is None:
        raise credentials_exception
    return user


async def get_current_active_user(current_user: User = Depends(get_current_user)):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user


@app.post("/access-token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
    user = authenticate_user(fake_users_db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"username": user.username}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}


@app.get("/users/me/", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_active_user)):
    return current_user


@app.get("/users/me/items/")
async def read_own_items(current_user: User = Depends(get_current_active_user)):
    return [{"item_id": "Foo", "owner": current_user.username}]


@app.post("/password/")
def get_hashed_password(password: str):
    return get_password_hash(password)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章