您好,登錄后才能下訂單哦!
用Tensorflow和FastAPI構建圖像分類API,很多新手對此不是很清楚,為了幫助大家解決這個難題,下面小編將為大家詳細講解,有這方面需求的人可以來學習下,希望你能有所收獲。
讓我們從一個簡單的helloworld示例開始
首先,我們導入FastAPI類并創建一個對象應用程序。這個類有一些有用的參數,比如我們可以傳遞swaggerui的標題和描述。
from fastapi import FastAPI app = FastAPI(title='Hello world')
我們定義一個函數并用@app.get. 這意味著我們的API/index支持GET方法。這里定義的函數是異步的,FastAPI通過為普通的def函數創建線程池來自動處理異步和不使用異步方法,并且它為異步函數使用異步事件循環。
@app.get('/index') async def hello_world(): return "hello world"
我們將創建一個API來對圖像進行分類,我們將其命名為predict/image。我們將使用Tensorflow來創建圖像分類模型。
Tensorflow圖像分類教程:https://aniketmaurya.ml/blog/tensorflow/deep%20learning/2019/05/12/image-classification-with-tf2.html
我們創建了一個函數load_model,它將返回一個帶有預訓練權重的MobileNet CNN模型,即它已經被訓練為對1000個不同類別的圖像進行分類。
import tensorflow as tf def load_model(): model = tf.keras.applications.MobileNetV2(weights="imagenet") print("Model loaded") return model model = load_model()
我們定義了一個predict函數,它將接受圖像并返回預測。我們將圖像大小調整為224x224,并將像素值規格化為[-1,1]。
from tensorflow.keras.applications.imagenet_utils import decode_predictions
decode_predictions用于解碼預測對象的類名。這里我們將返回前2個可能的類。
def predict(image: Image.Image): image = np.asarray(image.resize((224, 224)))[..., :3] image = np.expand_dims(image, 0) image = image / 127.5 - 1.0 result = decode_predictions(model.predict(image), 2)[0] response = [] for i, res in enumerate(result): resp = {} resp["class"] = res[1] resp["confidence"] = f"{res[2]*100:0.2f} %" response.append(resp) return response
現在我們將創建一個支持文件上傳的API/predict/image。我們將過濾文件擴展名以僅支持jpg、jpeg和png格式的圖像。
我們將使用Pillow加載上傳的圖像。
def read_imagefile(file) -> Image.Image: image = Image.open(BytesIO(file)) return image @app.post("/predict/image") async def predict_api(file: UploadFile = File(...)): extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png") if not extension: return "Image must be jpg or png format!" image = read_imagefile(await file.read()) prediction = predict(image) return prediction
import uvicorn from fastapi import FastAPI, File, UploadFile from application.components import predict, read_imagefile app = FastAPI() @app.post("/predict/image") async def predict_api(file: UploadFile = File(...)): extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png") if not extension: return "Image must be jpg or png format!" image = read_imagefile(await file.read()) prediction = predict(image) return prediction @app.post("/api/covid-symptom-check") def check_risk(symptom: Symptom): return symptom_check.get_risk_level(symptom) if __name__ == "__main__": uvicorn.run(app, debug=True)
看完上述內容是否對您有幫助呢?如果還想對相關知識有進一步的了解或閱讀更多相關文章,請關注億速云行業資訊頻道,感謝您對億速云的支持。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。