Onnx는 pytorch, tensorflow 등의 머신러닝 프레임워크에서 만들어진 모델들이 서로 호환될 수 있도록 만들어줍니다. 이를 사용하여 pytorch에서 만든 모델을 onnx로 export 한 후 tensorflow에서 사용할 수 있습니다.
1. Export
import torch
import torchvision.models as models
# 모델 생성
model = models.vgg11(pretrained=True)
# 평가 모드로 설정
model.eval()
pytorch모델을 준비하고 모델을 model.eval() 또는 model.train(False)로 eval모드로 바꿔줘야 합니다. (dropout, batchnorm의 비활성화를 위해서 필요)
dummy = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy, "vgg11.onnx")
onnx 변환은 torch.onnx을 이용합니다. 이때, 사용할 input size와 같은 사이즈의 dummy가 필요합니다. 여기서는 (1, 3, 224, 224) 사이즈로 생성해주었습니다. ONNX로 변환된 그래프의 경우 입력값의 사이즈는 모든 차원에 대해 고정됩니다.
2. Check
import onnx
model = onnx.load("vgg11.onnx")
print(onnx.helper.printable_graph(model.graph))
printable_graph을 사용하여 model의 정보를 확인할 수 있습니다.
graph torch-jit-export (
%input.1[FLOAT, 1x3x224x224]
) initializers (
%classifier.0.bias[FLOAT, 4096]
%classifier.0.weight[FLOAT, 4096x25088]
%classifier.3.bias[FLOAT, 4096]
%classifier.3.weight[FLOAT, 4096x4096]
%classifier.6.bias[FLOAT, 1000]
%classifier.6.weight[FLOAT, 1000x4096]
%features.0.bias[FLOAT, 64]
%features.0.weight[FLOAT, 64x3x3x3]
%features.11.bias[FLOAT, 512]
%features.11.weight[FLOAT, 512x256x3x3]
%features.13.bias[FLOAT, 512]
%features.13.weight[FLOAT, 512x512x3x3]
%features.16.bias[FLOAT, 512]
%features.16.weight[FLOAT, 512x512x3x3]
%features.18.bias[FLOAT, 512]
%features.18.weight[FLOAT, 512x512x3x3]
%features.3.bias[FLOAT, 128]
%features.3.weight[FLOAT, 128x64x3x3]
%features.6.bias[FLOAT, 256]
%features.6.weight[FLOAT, 256x128x3x3]
%features.8.bias[FLOAT, 256]
%features.8.weight[FLOAT, 256x256x3x3]
)
3. Compare
import time
import numpy as np
import onnxruntime
dummy = torch.randn(1, 3, 224, 224)
model = models.vgg11(pretrained=True)
model.eval()
start = time.time()
for _ in range(100):
torch_output = model(dummy)
print("torch inference:", time.time() - start)
onnxruntime_session = onnxruntime.InferenceSession("vgg11.onnx")
start = time.time()
for _ in range(100):
onnxruntime_outputs = onnxruntime_session.run(None, {"input.1": dummy.numpy()})
print("onnx inference:", time.time() - start)
onnxruntime을 이용해 모델을 추론할 수 있습니다. onnx모델이 얼마나 더 빠른지 100번 반복하여 실행한 후 비교하였습니다.
torch inference: 7.426271200180054
onnx inference: 2.333432912826538
속도 차이가 유의미하게 나는 것을 확인할 수 있습니다.
'Pytorch' 카테고리의 다른 글
[Pytorch] DataParallel vs DIstributedDataParallel (0) | 2022.07.17 |
---|---|
[Pytorch] 다양한 Learning Rate Scheduler (0) | 2022.06.26 |
[Pytorch] 유용한 method (view,reshape,squeeze,permute,stack,repeat,gather...) (0) | 2022.01.16 |