Skip to content

5. Batch Inference


1. Mission

001


2. TorchServe for Batch Inference

1) handler.py 생성

  • 다음과 같이 handler.py를 생성한다.


import ast
import logging
import os
import json
import torch
from abc import ABC
from transformers import M2M100ForConditionalGeneration, AutoTokenizer
from ts.torch_handler.base_handler import BaseHandler

torch.set_grad_enabled(False)
torch.set_num_threads(1)
logger = logging.getLogger(__name__)

class InferenceHandler(BaseHandler, ABC):
    def __init__(self):
        super(InferenceHandler, self).__init__()
        self.initialized = False

    def initialize(self, ctx):
        self.PID = os.getpid()
        self.manifest = ctx.manifest
        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        self.tgt_lang = self.manifest["model"]["modelName"]
        setup_config_path = os.path.join(model_dir, "setup_config.json")

        with open(setup_config_path) as setup_config_file:
            self.setup_config = json.load(setup_config_file)

        self.device = "cuda:" + self.setup_config["cuda_id"][self.tgt_lang]
        self.model = M2M100ForConditionalGeneration.from_pretrained(model_dir)
        self.model.to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)

        logger.info(f"\nPID: {self.PID} / Tgt_Lang: {self.tgt_lang} / Device: {self.device}\nMANIFEST: {self.manifest}\nPROPERTIES: {properties}\n{model_dir} loaded successfully\n")

        self.initialized = True

    def preprocess(self, requests):
        logger.info(f"\nPID: {self.PID} / Tgt_Lang: {self.tgt_lang} / Device: {self.device}\nLength of Requests: {len(requests)}\nRequests: {requests}\n")

        input_ids_batch, attention_mask_batch = None, None
        input_ids_list, attention_mask_list = [], []
        max_length = 0

        # requests마다 for문 돌리기
        for idx, data in enumerate(requests):
            # data 파싱
            p = data.get("data")

            if p is None:
                p = data.get("body")

            if isinstance(p, (bytes, bytearray)):
                p = ast.literal_eval(p.decode("utf-8"))

            src_lang = p["src_lang"]
            src_text = p["src_text"]
            # tgt_lang = p["tgt_lang"]

            logger.info(f"\nPID: {self.PID} / Tgt_Lang: {self.tgt_lang} / Device: {self.device}\nSrc_Lang: {src_lang} / Src_Text: {src_text}\n")

            # tokenizer src_lang 바꾸기
            self.tokenizer.src_lang = src_lang
            # tokenizer에 src_text 넣어서 {"input_ids": Tensor, "attention_mask": Tensor} 만들기
            inputs = self.tokenizer(src_text, padding=True, truncation=True, return_tensors='pt')

            # input_ids, attention_mask 꺼내기
            input_ids = inputs["input_ids"].to(self.device)
            attention_mask = inputs["attention_mask"].to(self.device)

            # input_ids_list, attention_mask_list에 넣기
            input_ids_list.append(input_ids)
            attention_mask_list.append(attention_mask)

            # requests에 들어온 src_text의 max_length 구하기 -> padding에 쓰기
            if input_ids.shape is not None:
                if input_ids.shape[1] >= max_length:
                    max_length = input_ids.shape[1]

        # input_ids_list의 index 수만큼 돌리기
        for i in range(len(input_ids_list)):
            # 현재 index의 length 구하기
            length = input_ids_list[i].shape[1]
            # (max_length - length)로 padding에 쓸 길이 구하기
            pad_length = max_length - length
            # pad_ones는 input_ids에 들어가는 pad_token_id
            pad_ones = torch.ones(1, pad_length, dtype=torch.long).to(self.device)
            # pad_zeros는 attention_mask에 들어가는 pad_token_id
            pad_zeros = torch.zeros(1, pad_length, dtype=torch.long).to(self.device)

            # padding
            input_ids_list[i] = torch.cat((input_ids_list[i], pad_ones), dim=1)
            attention_mask_list[i] = torch.cat((attention_mask_list[i], pad_zeros), dim=1)

            # input_ids_batch, attention_mask_batch에 넣기
            if input_ids.shape is not None:
                if input_ids_batch is None:
                    input_ids_batch = input_ids_list[i]
                    attention_mask_batch = attention_mask_list[i]
                else:
                    input_ids_batch = torch.cat((input_ids_batch, input_ids_list[i]), dim=0)
                    attention_mask_batch = torch.cat((attention_mask_batch, attention_mask_list[i]), dim=0)

        return (input_ids_batch, attention_mask_batch)

    def inference(self, preprocessed):
        input_ids_batch, attention_mask_batch = preprocessed
        encoded_batch = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}

        # generate에 {"input_ids": Tensor[[], [], ...], "attention_mask": Tensor[[], [], ...]} 넣고 batch로 돌리기
        generated_tokens = self.model.generate(
            **encoded_batch,
            forced_bos_token_id=self.tokenizer.get_lang_id(self.tgt_lang),
            num_beams=2,
        )
        # batch_decode로 generated_tokens를 batch로 돌리기
        translated_batch = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        logger.info(f"\nPID: {self.PID} / Tgt_Lang: {self.tgt_lang} / Device: {self.device}\nTranslated: {translated_batch}\n")

        return translated_batch

    def postprocess(self, inferred_batch):
        # outputs 형식에 맞게 만들기
        postprocessed = [{"translated": translated} for translated in inferred_batch]

        return postprocessed

    def handle(self, requests, context):
        self.context = context

        preprocessed = self.preprocess(requests)
        inferred_batch = self.inference(preprocessed)
        postprocessed = self.postprocess(inferred_batch)

        return postprocessed


2) handler.py 복사

  • 각 언어마다 GPU 사용을 다르게 하기 위해 handler 디렉터리를 생성한 후 위에서 작성한 handler.py언어.py로 각각 복사한다.


$(torchserve) mkidr handlers
$(torchserve) cp handler.py ./handlers/ko.py && cp handler.py ./handlers/en.py && ..


3) torchserve.sh 수정

  • 자동화를 위해 torchserve.sh도 수정한다.


#!/bin/bash

model_names="ko en zh es fr ru id ja vi de it hi ar pt"
model_dir=../model/voiceprint/m2m100_418M
model_ver=1.0

if [ "$1" = "new" ]; then
        if [ -f ./model_dirlist.txt ]; then
                rm ./model_dirlist.txt
        fi

        echo | ls $model_dir > ./model_dirlist.txt
        readarray files < ./model_dirlist.txt
        if [ -f ./model_dirlist.txt ]; then
                rm ./model_dirlist.txt
        fi

        files_len=${#files[@]}

        for ids in ${!files[@]}
        do
                if [[ "${files[$ids]}" =~ ".bin" ]]; then
                        model_file=${files[$ids]}
                        files_len=$((files_len-1))
                else
                        file_name=${files[$ids]}

                        if [ $ids = $files_len ]; then
                                extra_files_args=$extra_files_args$model_dir/$file_name
                        else
                                extra_files_args=$extra_files_args$model_dir/$file_name,
                        fi
                fi
        done

        extra_files_args=`echo $extra_files_args | tr -d ' '`

        if [ -d ./logs ]; then
                rm -rf ./logs
                echo Directory ./logs Removed!!!
        fi

        if [ -d ./model_store ]; then
                rm -rf ./model_store
                echo Directory ./model_store Removed!!!
        fi

        if [ ! -d ./model_stre ]; then
                mkdir ./model_store
                echo Directory ./model_store Made!!!
        fi

        echo Torch Model Archiver Start...

        for model_name in $model_names
        do
                if [ -f ./$model_name.mar ]; then
                        rm ./$model_name.mar
                fi

                time torch-model-archiver --model-name $model_name --version $model_ver --serialized-file $model_dir/$model_file --handler ./handlers/$model_name.py --extra-files "$extra_files_args"

                mv ./$model_name.mar ./model_store
                echo ./$model_name.mar Moved to ./model_store
        done
fi

torchserve --start --model-store ./model_store --ts-config ./config.properties && sleep 5s \
&& \
        for model_name in $model_names
        do
                curl -X POST "localhost:8081/models?url="$model_name".mar&batch_size=64&max_batch_delay=100"
        done


4) config.properties 수정

  • config.properties도 Batch Inference에 맞게 수정한다.


# Configure TorchServe listening address and port
inference_address=http://127.0.0.1:9999
management_address=http://127.0.0.1:8081
metrics_address=http://127.0.0.1:8082

# Limit GPU usage
number_of_gpu=4

# Load models at startup
model_store=./model_store
#load_models=en.mar

# Other properties
# Worker / Capacity
default_workers_per_model=1
job_queue_size=4096
default_response_timeout=86400


5) setup_config.json 생성

  • 모델이 있는 /root/Project/model/voiceprint/m2m_418M 디렉터리에 setup_config.json을 생성한다.


{
  "cuda_id": {
    "ko": "0",
    "en": "0",
    "zh": "0",
    "es": "0",
    "fr": "0",
    "ru": "1",
    "id": "1",
    "ja": "1",
    "vi": "1",
    "de": "1",
    "it": "2",
    "hi": "2",
    "ar": "2",
    "pt": "2"
  }
}


3. Gunicorn 서버

1) classifier.py 생성

  • Requests의 tgt_lang에 따라 나눠주는 classifier.py를 생성한다.


from flask import Flask, request
import requests
import json
import os

app = Flask(__name__)

class Classifier:
    def __init__(self):
        self.PID = os.getpid()
        print(f"PID: {self.PID} / Loaded Classifier!")

def request_text(URL: str, p: dict):
        response = requests.post(URL, data=json.dumps(p)).json()

        return response

classifier = Classifier()

@app.route("/", methods=["POST"])
def main():
    p = request.get_json(force=True)

    src_lang = p["src_lang"]
    src_text = p["src_text"]
    tgt_lang = p["tgt_lang"]

    URL = "http://127.0.0.1:9999/predictions/" + tgt_lang

    translated = request_text(URL, p)

    print(f"\nPID: {classifier.PID} \ URL: {URL}\nSrc_Lang: {src_lang} \ Src_text: {src_text}\nTgt_Lang: {tgt_lang} \ Translated: {translated['translated']}")

    return translated

if __name__ == "__main__":
    app.run(host="0.0.0.0", port="80", debug=True)


2) gunicorn.conf.py 수정

  • 다음과 같이 gunicorn.conf.py를 수정한다.


import multiprocessing

# Python Module:Variable
wsgi_app = "classifier:app"

# Host:Port
bind = "127.0.0.1:8888"

# Worker Class == Default: "sync"
worker_class = "gevent"

# Num of Workers
# workers = int(multiprocessing.cpu_count() * 0.1)
workers = 16

# Num of Worker Connections
worker_connections = 1024

# Timeout == 0(Deactivate)
timeout = 86400


4. Nginx 서버

1) nginx.conf 수정

  • 다음과 같이 nginx.conf를 수정한다.


user www-data;
worker_processes 10; # 동시에 Requests 받을 수 있는 수 = worker_processes * worker_connections
pid /run/nginx.pid;
include /etc/nginx/modules-enabled/*.conf;
# worker_rlimit_nofile 2048;

events {
        worker_connections 4096; # Reverse Proxy이므로 2배 설정
        multi_accept on;
        use epoll;
}

http {

        ##
        # Basic Settings
        ##

        sendfile on;
        tcp_nopush on;
        tcp_nodelay on;

        ##
        # Server Settings
        ##

        server {
                listen 7777;
                server_name localhost;

                location / {
                        proxy_pass http://127.0.0.1:8888/$1;
                }

                error_page 500 502 503 504 /50x.html;
                location = /50x.html {
                        root /usr/share/nginx/html;
                }
        }

        ##
        # Timeout Settings
        ##

        # keepalive_timeout 86400; # Connection 시간
        # keepalive_reqeusts 100 # Connection 횟수

        client_body_timeout 86400; # Client로부터 request body를 받는 시간
        send_timeout 86400; # Client에게 response body를 보내는 시간
        reset_timedout_connection off; # Timeout으로 Connection이 끊긴 경우 리셋

        proxy_connect_timeout 86400; # Proxy 서버와 연결되는 시간 (75초를 초과할 수 없음)
        proxy_send_timeout 86400; # Proxy 서버에 요청을 보내는 시간
        proxy_read_timeout 86400; # Proxy 서버에서 응답을 받는 시간

        types_hash_max_size 2048;
        # server_tokens off;

        # server_names_hash_bucket_size 64;
        # server_name_in_redireict off;

        include /etc/nginx/mime.types;
        default_type application/octet-stream;

        ##
        # SSL Settings
        ##

        ssl_protocols TLSv1 TLSv1.1 TLSv1.2 TLSv1.3; # Dropping SSLv3, ref: POODLE
        ssl_prefer_server_ciphers on;

        ##
        # Logging Settings
        ##

        access_log /var/log/nginx/access.log;
        error_log /var/log/nginx/error.log;

        ##
        # Gzip Settings
        ##

        gzip on;

        # gzip_vary on;
        # gzip_proxied any;
        # gzip_comp_level 6;
        # gzip_buffers 16 8k;
        # gzip_http_version 1.1;
        # gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript;

        ##
        # Virtual Host Configs
        ##

        # include /etc/nginx/conf.d/*.conf;
        include /etc/nginx/sites-enabled/*;
}

#mail {
#       # See sample authentication script at:
#       # http://wiki.nginx.org/ImapAuthenticateWithApachePhpScript
#
#       # auth_http localhost/auth.php;
#       # pop3_capabilities "TOP" "USER";
#       # imap_capabilities "IMAP4rev1" "UIDPLUS";
#
#       server {
#               listen     localhost:110;
#               protocol   pop3;
#               proxy      on;
#       }
#
#       server {
#               listen     localhost:143;
#               protocol   imap;
#               proxy      on;
#       }
#}


5. Batch Inference 테스트

1) TorchServe 실행

  • 다음과 같이 TorchServe를 실행한다.


$(torchserve) bash torchserve.sh new


2) GPU 사용량 확인

  • 다음과 같이 GPU 사용량을 확인할 수 있다.


$(torchserve) watch -n 1 nvidia-smi


002