Конвертер любой модельки в любую модельку (в рамках ONNX
, Torch
, TFLite
)
Идея конвертирования следующая -- конвертировать любую модельку на входе в ONNX, а дальше конвертировать из ONNX в любую желаемую модельку. Сравнение пайплайнов "напрямую" и "через onnx" ниже:
напрямую | через ONNX |
---|---|
Реализация 6 связей | Реализация 4 связей |
Иерархия классов выглядит следующим образом (я не стал указывать, аргументы функций и тип возврата, ибо было бы громоздко -- чекайте код для этого):
BaseConverter
-- это абстрактный класс, в котором есть поля, соответствующие всем моделькам и их путям. Содержит реализации конвертирования в TFLite
и Torch
по умолчанию (из ONNX
), но заставляет потомков переопределить конвертирование в ONNX
. Важный момент: наследник переопределяет лишь защищённые методы конвертирования (начинающиеся с нижнего подчёркивания). В то время, как пользователь вызывает публичные методы, которые, в зависимости от условий делают следующее:
- возвращают уже сконвертированную модельку, если она уже была сконвертирована ранее
- загружают модель с диска, если пользователь указал
load_intermediate_models_from_disk
= True (и если эта модель действительно лежит на диске) - вызывает соответствующий защищённый метод конвертирования
_to_...()
, если ни одно из условий не выполнилось
Наследнику необходимо переопределить:
_to_onnx()
: метод конвертирования в onnx, который сохраняет новую модель на диске и возвращает её в видеonnx.ModelProto
_to_...()
: метод конвертирования в свою же модельfrom_model()
: конструктор от модели -- необходимо инициализировать соответствующее_..._model
поле!
Наследники -- это классы ONNXConverter
, TorchConverter
, TFLiteConverter
. Название каждого из этих классов отражает тип входной модели.
В реализации TorchConverter
необходимо определить размер по умолчанию для модели (это нужно, так как onnx и tflite имеют статически вычисляемый граф, а torch -- динамически вычисляемый), поэтому приходится переопределять ещё и конструкторы __init__
и from_model_path
.
Реализация TFLiteConverter
не предполагает конструирование от модели (только от пути). Это связано с трудностями работы с tf.lite.Interpreter
, который не даёт сохранить модель на диске (помимо этого, либы конвертирования не принимают tflite в виде модели -- только путь самурая)
Класс Converter
содержит в себе единственное поле -- base
, которое инициализируется одним из наследников BaseConverter
. Инициализация происходит автоматически на основе типа поданной модели, или расширения в переданном имени. То есть, пользователю не нужно вызывать from_model или from_path, всё, что требуется -- это передать аргументы model
и/или model_path
.
Опишем предполагаемый вариант использования Converter
в рамках проекта, в котором надо:
- конвертировать любую модельку во что-то, что можно натренировать
- натренировать на своём датасете
- конвертировать к обратному виду
Предлагается юзать 2 конвертера:
- Для конвертирования в
torch
для тренировки - Для конвертирования из
torch
обратно
Тогда код для TFLite-модельки будет выглядеть примерно так:
# создаём конвертер из нашей tflite модельки
converter = Converter(
model_path='model.flite',
load_intermediate_models_from_disk=False,
)
# тренируем торч-модельку (реализовано лишь для задачи классификации)
trainer = Trainer(
model=converter.to_torch(),
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
test_dataloader=test_dataloader,
)
trainer.train(epoch=20)
# создаём конвертер из торча и конвертируем обратно в tflite
input_shape = OnnxIO(converter.to_onnx()).input_shape
converter = Converter(
model=trainer.best_model,
input_shape=input_shape,
load_intermediate_models_from_disk=False,
)
converter.to_tflite()
Процессы конвертирования логируются в logs/converter.log
.
Есть примерчик в test_convertion.py
- Сам конвертер может быть с багами -- я его не успел нормально оттестировать.
- Возможно, класс Converter должен хранить поле, которое бы отвечало за тип модели. Тогда будет удобно конвертировать модель после тренировки к обратному виду, вызывая какой-нибудь метод
to_origin_type()
и не работая напрямую сto_tflite()
,to_torch()
илиto_onnx()
. - Неудобно масштабировать -- если кто-то захочет добавить условный
keras
, то придётся добавлять его и вBaseConverter
, и делать новогонаследника
, и дополнятьConverter
- Имеет смысл пересмотреть реализацию конструкторов -- сделать так, чтобы метод
to_onnx()
вызывался в конструкторе, раз мы всё равно от него пляшем.
Оказалось, что моё решение проблемы слоя Mul
либы onnx2tf
некорректное (оно работает для YOLO, но руинит drone.tflite). Придётся всё же вернуться к Reshape -- весь трэйс проблемы описан в тайге (начиная с 20.03
и акцентируя внимание на 17.04
)
Я написал класс тренировщика лишь для задачи классификации (лежит в ...). Но нам нужен тренировщик для Object Detection. Скорее всего, придётся тоже делать BaseTrainer и от него наследоваться для различных задач ComputerVision (можно подсмотреть, как это сделано у ultralytics)