8 методов сравнения алгоритмов машинного обучения на R

В этой статье мы рассмотрим 8 методов сравнения алгоритмов машинного обучения на R. С их помощью вы сможете выбрать модель, обеспечивающую наиболее точный прогноз, оценить статистическую значимость, а также узнать, насколько данная модель превзошла остальные по абсолютному результату.

Выбор лучшей модели

Работая над проектом в области машинного обучения, мы, как правило, создаем несколько хороших моделей. Все они имеют различные характеристики.

С помощью кросс-валидации мы оцениваем, насколько точно модели прогнозируют неизвестные им данные. А далее нам необходимы методы, позволяющие сравнить модели на основе полученных оценок, чтобы выбрать одну или две лучшие из них.

Детальное сравнение моделей

При анализе нового набора данных полезно визуализировать данные несколькими способами, чтобы посмотреть на них с различных точек зрения.

Этот же подход применим и к задаче выбора модели. Чтобы выбрать лучшую модель, необходимо с помощью различных методов визуализировать и сравнить среднюю точность, дисперсию и другие характеристики распределений оценок моделей.

Сравнение и выбор моделей. Реализация на R

В качестве практического примера, мы возьмем набор данных, обучим на нем ряд моделей, а затем применим различные техники визуализации, чтобы сравнить точность моделей.

Наш пример разделен три этапа:

  1. Подготовка данных. Загружаем необходимые библиотеки и набор данных.
  2. Обучение моделей. Обучаем стандартные модели, чтобы оценить и сравнить их.
  3. Сравнение моделей. Сравниваем обученные модели, применяя 8 различных методов.

1. Подготовка данных

Мы будем использовать набор данных, содержащий информацию о случаях сахарного диабета среди индейцев Пима (Pima Indians diabetes dataset), который можно загрузить из репозитория Калифорнийского университета в Ирвайне (UCI). Этот набор данных также доступен в R-пакете mlbench.

Задача представляет собой двухклассовую классификацию: необходимо предсказать, возникнет ли данное заболевание у пациента в течение следующих пяти лет. Признаками являются числовые значения различных медицинских показателей пациентов женского пола.

Загружаем библиотеки и данные.

# load libraries

library(mlbench)

library(caret)

# load the dataset

data(PimaIndiansDiabetes)

2. Обучение моделей

Мы обучим и оценим 5 различных моделей с помощью 3-кратной 10-блочной кросс-валидации (3 repeats, 10 folds). Эта конфигурация является стандартной при сравнении моделей. В качестве метрик для оценивания моделей будем использовать точность (accuracy) и каппа (kappa), потому что их легко интерпретировать.

Для сравнения мы выбрали 5 алгоритмов различного типа:

  • классификационные и регрессионные деревья (classification and regression trees, CART);
  • линейный дискриминантный анализ (linear discriminant analysis, LDA);
  • метод опорных векторов с радиальной базисной функцией (support vector machine with radial basis function, SVM with RBF);
  • метод k ближайших соседей (k-nearest neighbors, kNN);
  • случайный лес (random forest, RF).

После того, как модели будут обучены, мы объединим их в список и применим к нему функцию resamples(). Эта функция проверяет возможность сравнения моделей, а также идентичность использованных схем обучения (т.е. конфигурацию trainControl). Объект results, возвращенный функцией resamples(), содержит оценки каждой модели для каждого валидационного блока в каждом повторении кросс-валидации. Все функции, которые мы будем использовать в дальнейшем, принимают на входе этот объект.

# prepare training scheme
control <- trainControl(method="repeatedcv", number=10, repeats=3)
# CART
set.seed(7)
fit.cart <- train(diabetes~., data=PimaIndiansDiabetes, method="rpart", trControl=control)
# LDA
set.seed(7)
fit.lda <- train(diabetes~., data=PimaIndiansDiabetes, method="lda", trControl=control)
# SVM
set.seed(7)
fit.svm <- train(diabetes~., data=PimaIndiansDiabetes, method="svmRadial", trControl=control)
# kNN
set.seed(7)
fit.knn <- train(diabetes~., data=PimaIndiansDiabetes, method="knn", trControl=control)
# Random Forest
set.seed(7)
fit.rf <- train(diabetes~., data=PimaIndiansDiabetes, method="rf", trControl=control)
# collect resamples
results <- resamples(list(CART=fit.cart, LDA=fit.lda, SVM=fit.svm, KNN=fit.knn, RF=fit.rf))

3. Сравнение моделей

Далее мы сравним созданные нами модели, используя 8 различных методов.

Итоговая таблица

Это самый простой метод сравнения. Просто вызываем функцию summary() для объекта results. Эта функция выведет на экран таблицу, каждая строка которой соответствует одной модели, а столбцы содержат различные статистические характеристики соответствующей метрики.

# summarize differences between models

summary(results)

Особый интерес представляют медианы (Median) и максимальные значения (Max.).

Accuracy
       Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
CART 0.6234  0.7115 0.7403 0.7382  0.7760 0.8442    0
LDA  0.6711  0.7532 0.7662 0.7759  0.8052 0.8701    0
SVM  0.6711  0.7403 0.7582 0.7651  0.7890 0.8961    0
KNN  0.6184  0.6984 0.7321 0.7299  0.7532 0.8182    0
RF   0.6711  0.7273 0.7516 0.7617  0.7890 0.8571    0
 
Kappa
       Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
CART 0.1585  0.3296 0.3765 0.3934  0.4685 0.6393    0
LDA  0.2484  0.4196 0.4516 0.4801  0.5512 0.7048    0
SVM  0.2187  0.3889 0.4167 0.4520  0.5003 0.7638    0
KNN  0.1113  0.3228 0.3867 0.3819  0.4382 0.5867    0
RF   0.2624  0.3787 0.4516 0.4588  0.5193 0.6781    0

Ящики с усами

Ящики с усами (box and whisker plot) позволяют эффективно визуализировать и сравнить разброс оценок для каждой модели.

# box and whisker plots to compare models

scales <- list(x=list(relation="free"), y=list(relation="free"))

bwplot(results, scales=scales)

Обратите внимание, ящики отсортированы по убыванию медианы оценок (черные точки). В данной визуализации наиболее полезную информацию нам дают медианы и области перекрытия ящиков (т.е. области перекрытия межквартильных интервалов).

Графики функций плотности вероятности

Мы можем визуализировать распределения оценок точности моделей с помощью графиков функций плотности вероятности (probability density function). Этот метод позволяет сопоставить поведение различных моделей.

# density plots of accuracy

scales <- list(x=list(relation="free"), y=list(relation="free"))

densityplot(results, scales=scales, pch = "|")

Здесь полезно сравнить высоту пиков и ширину оснований графиков.

Точечные диаграммы

Точечная диаграмма (dot plot) визуализирует среднее значение оценок модели (точка), а также 95% доверительный интервал (т.е. интервал, которому принадлежит 95% всех оценок).

# dot plots of accuracy

scales <- list(x=list(relation="free"), y=list(relation="free"))

dotplot(results, scales=scales)

Здесь сравниваем средние значения и обращаем внимание на перекрытие доверительных интервалов.

Параллельные графики

Каждый из параллельных графиков (parallel plots) визуализирует оценки моделей для данного валидационного блока кросс-валидации. Мы видим, что один и тот же валидационный блок мог представлять сложность для одной модели, но быть простым для другой.

# parallel plots to compare models

parallelplot(results)

Эта диаграмма позволяет нам выяснить, какие модели можно объединить в ансамбль. В частности, если мы наблюдаем коррелированные движения в противоположных направлениях, значит, данные модели являются кандидатами на участие в ансамбле.

Матрица диаграмм рассеяния

Матрица диаграмм рассеяния (scatter plot matrix) визуализирует попарное сравнение всех моделей. В каждой отдельной диаграмме рассеяния сравниваются оценки двух моделей для всех валидационных блоков кросс-валидации.

# pair-wise scatterplots of predictions to compare models

splom(results)

Эта диаграмма незаменима в тех случаях, когда необходимо выяснить, коррелируют ли прогнозы двух данных моделей. Модели, которым свойственна слабая корреляция, являются кандидатами для объединения в ансамбль.

Например, мы видим, что таким парам, как LDA и SVM, а также SVM и RF свойственна сильная корреляция. С другой стороны, SVM и CART коррелируют слабо.

Индивидуальная диаграмма рассеяния

Для детального анализа мы можем вывести индивидуальную диаграмму рассеяния для любой пары моделей в более крупном масштабе.

# xyplot plots to compare models

xyplot(results, models=c("LDA", "SVM"))

Для примера выведем сравнение моделей LDA и SVM, которым свойственна корреляция.

Тест статистической значимости

Мы можем вычислить статистическую значимость разницы между распределениями оценок различных моделей. Для этого воспользуемся функциями diff() и summary().

# difference in model predictions

diffs <- diff(results)

# summarize p-values for pair-wise comparisons

summary(diffs)

В результате мы получим таблицу статистической значимости.

Нижняя диагональ таблицы содержит p-значения (величины статистической значимости) для нулевой гипотезы, которая утверждает, что распределения одинаковы. Чем меньше p-значение, тем менее вероятна справедливость нулевой гипотезы. Например, мы видим отсутствие разницы между CART и kNN, а также незначительную разницу между LDA и SVM.

Верхняя диагональ таблицы содержит разницу между распределениями. Эти значения показывают, как модели соотносятся друг с другом по абсолютному значению точности.

p-value adjustment: bonferroni

Upper diagonal: estimates of the difference

Lower diagonal: p-value for H0: difference = 0

Accuracy

     CART      LDA       SVM       KNN       RF     

CART           -0.037759 -0.026908  0.008248 -0.023473

LDA  0.0050068            0.010851  0.046007  0.014286

SVM  0.0919580 0.3390336            0.035156  0.003435

KNN  1.0000000 1.218e-05 0.0007092           -0.031721

RF   0.1722106 0.1349151 1.0000000 0.0034441

Для повышения надежности рекомендуется увеличить количество испытаний, увеличив тем самым объем выборки оценок. Мы также могли бы визуализировать представленную выше информацию, но, на мой взгляд, сама по себе таблица более информативна.

Заключение

В этой статье мы рассмотрели 8 различных методов сравнения моделей на R:

  • итоговая таблица;
  • ящики с усами;
  • графики функций плотности вероятности;
  • точечные диаграммы;
  • параллельные графики;
  • матрица диаграмм рассеяния;
  • индивидуальная диаграмма рассеяния;
  • тест статистической значимости.

Возможно, мы не назвали ваш любимый метод сравнения моделей? Напишите о нем в комментариях, нам будет интересно узнать!

По материалам: machinelearningmastery.com

Добавить комментарий

Ваш e-mail не будет опубликован.

закрыть

Поделиться

Отправить на почту
закрыть

Вход

закрыть

Регистрация

+ =