Flask?
Python 기반 웹 프레임워크. Django와 같은 Python 웹 프레임워크지만 Django 보다 더 쉽고 간단하다는 느낌을 받았다. Django는 MTV 패턴처럼 어느 정도 서비스 개발 사이클이 존재하는 비교적 전통적 방식의 프레임워크라면, Flask는 더 간단하고 빠르게 개발을 진행할 수 있다. 다만, Django와 비교하여 인증 및 인가, Admin 페이지에 있어서 기본 세팅이 되어 있지 않기 때문에 따로 라이브러리를 사용해야 하는 어려움이 있다.
Flask를 모델 서빙에 많이 사용하는 이유
Python 언어로 작성되어 있는 모델이 대부분이기 때문에 호환성을 위해 Python 웹 프레임워크를 사용한다. 모델에서 단순히 HTTP 기반 API를 사용하고 싶다면 간단하게 라우팅이 가능한 Flask나 FastAPI가 적합하다. 특히 Flask는 HTTP 응답과 요청을 처리하는 데 특화되어 있기 때문에 RESTful API를 작성할 때 많이 사용된다. 나도 이러한 이유로 Flask를 사용하여 모델에서 API를 뽑아냈다. 최근에는 FastAPI가 성능 면에서 더 우수하고 커뮤니티도 활성화되어 있어 FastAPI로 배포하는 것도 좋은 방법이 될 수 있겠다.
환경
- Mac OS
- python 3.9
구현
1. 가상환경 구축 및 활성화
여러 개의 파이썬 프로젝트를 진행하는 상황이라면 가상환경을 구축하는 것을 권장한다. 프로젝트마다 필요한 라이브러리, 프레임워크의 버전을 관리하기 용이하기 때문.
$ mkdir myproject
$ cd myproject
$ python3 -m venv .venv
$ . .venv/bin/activate
2. 플라스크
$ pip install Flask
$ vim app.py
from flask import Flask
app = Flask(__name__)
@app.route('/')
def index():
return 'Hello, Flask!'
if __name__ == '__main__':
app.run(debug=True)
$ export FLASK_APP=app.py
$ flask run
위 과정을 통해 Flask가 가상환경 위에서 동작하는지 확인한다.
3. 모델 연동
- pytorch를 동작시키기 위해 필요한 라이브러리 설치. (프로젝트의 모델의 성격에 따라 필요한 라이브러리를 설치하면 된다.)
$ pip install torch==1.13.0 torchvision==0.14.0 numpy==1.26.1 pytorch-pretrained-vit==0.0.7
- pytorch 파일을 app.py와 같은 디렉토리에 위치.
* requirements.txt에 설치한 라이브러리를 기록해 두면 나중에 배포 환경에서 다시 설치할 때 편리하다.
* .gitignore로 용량이 큰 pytorch 파일과 가상환경 디렉토리를 넣어둔다.
$ pip freeze > requirements.txt
- 모델 로드 및 요청 처리
코드는 pytorch 공식 문서를 참고했지만 이미지 전처리 방식이나 모델 로딩 방식은 프로젝트 성격에 따라 변경하여 사용해야 한다.
from flask import Flask, request, jsonify
from torchvision.transforms import transforms
import torchvision
from PIL import Image
import torch
from torch import nn
from pytorch_pretrained_vit import ViT
app = Flask(__name__)
# 모델 로드
device = torch.device('cpu')
vit = ViT('B_16_imagenet1k', pretrained=True)
class VisionTransformer(nn.Module):
def __init__(self):
super(VisionTransformer, self).__init__()
self.vit = vit
self.linear = nn.Linear(1000, 1)
def forward(self, x):
x = self.vit(x)
x = torch.sigmoid(self.linear(x))
return x
model = VisionTransformer()
model = model.to("cpu")
model.load_state_dict(torch.load("vit_weight.pt", map_location=device))
model.eval()
# 이미지 전처리
transform = transforms.Compose([
transforms.Resize(size=(384,384)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
# 라우팅 url 및 HTTP 메소드 정의
@app.route('/api/v1/attention', methods=['POST'])
def predict():
# POST 요청에서 이미지 파일 받기
if 'image' not in request.files:
return jsonify({'error': 'No image found'}), 400
image_file = request.files['image']
try:
image = Image.open(image_file)
except:
return jsonify({'error': 'Invalid image file'}), 400
image = transform(image)
# 추론
with torch.no_grad():
output = model(image.unsqueeze(0))
if(float(output[0][0]) > 0.5):
prediction = 1
else:
prediction = 0
# 결과 - json 형식으로 반환
return jsonify({'prediction': prediction})
if __name__ == '__main__':
app.run(host='0.0.0.0')
- 테스트
app.py를 실행 시키고 테스트를 진행한다.
python requests 라이브러리를 활용하거나 postman과 같은 통합 테스트 도구를 사용할 수 있다.
여기서는 postman에서 테스트를 진행하였다.
app.py에 작성했던 url로 form data 형식으로 이미지를 전송했을 때 json 형식으로 output이 반환됨을 확인할 수 있다.
정리
지금까지 Flask를 사용하여 Pytorch 모델을 REST API로 배포하였다. 본 글에서는 로컬 환경에서 동작시키는 것을 다뤘지만 배포 환경에서 동작시킬 때는 EC2 인스턴스를 띄워 진행할 수 있다. 다음 글에서는 EC2 환경에서 모델을 띄워보도록 하겠다.
출처
https://flask.palletsprojects.com/en/3.0.x/quickstart/
https://tutorials.pytorch.kr/intermediate/flask_rest_api_tutorial.html
'백엔드 > model serving' 카테고리의 다른 글
PyTorch 모델 서빙 (3) - EC2 환경에서 모델 실행+ 에러 (1) | 2024.01.04 |
---|---|
PyTorch 모델 서빙 (1) - 여러 방법들 (1) | 2024.01.02 |