В этой статье мы рассмотрим 8 методов сравнения алгоритмов машинного обучения на R. С их помощью вы сможете выбрать модель, обеспечивающую наиболее точный прогноз, оценить статистическую значимость, а также узнать, насколько данная модель превзошла остальные по абсолютному результату.
Выбор лучшей модели
Работая над проектом в области машинного обучения, мы, как правило, создаем несколько хороших моделей. Все они имеют различные характеристики.
С помощью кросс-валидации мы оцениваем, насколько точно модели прогнозируют неизвестные им данные. А далее нам необходимы методы, позволяющие сравнить модели на основе полученных оценок, чтобы выбрать одну или две лучшие из них.
Детальное сравнение моделей
При анализе нового набора данных полезно визуализировать данные несколькими способами, чтобы посмотреть на них с различных точек зрения.
Этот же подход применим и к задаче выбора модели. Чтобы выбрать лучшую модель, необходимо с помощью различных методов визуализировать и сравнить среднюю точность, дисперсию и другие характеристики распределений оценок моделей.
Сравнение и выбор моделей. Реализация на R
В качестве практического примера, мы возьмем набор данных, обучим на нем ряд моделей, а затем применим различные техники визуализации, чтобы сравнить точность моделей.
Наш пример разделен три этапа:
- Подготовка данных. Загружаем необходимые библиотеки и набор данных.
- Обучение моделей. Обучаем стандартные модели, чтобы оценить и сравнить их.
- Сравнение моделей. Сравниваем обученные модели, применяя 8 различных методов.
1. Подготовка данных
Мы будем использовать набор данных, содержащий информацию о случаях сахарного диабета среди индейцев Пима (Pima Indians diabetes dataset), который можно загрузить из репозитория Калифорнийского университета в Ирвайне (UCI). Этот набор данных также доступен в R-пакете mlbench.
Задача представляет собой двухклассовую классификацию: необходимо предсказать, возникнет ли данное заболевание у пациента в течение следующих пяти лет. Признаками являются числовые значения различных медицинских показателей пациентов женского пола.
Загружаем библиотеки и данные.
# load libraries library(mlbench) library(caret) # load the dataset data(PimaIndiansDiabetes)
2. Обучение моделей
Мы обучим и оценим 5 различных моделей с помощью 3-кратной 10-блочной кросс-валидации (3 repeats, 10 folds). Эта конфигурация является стандартной при сравнении моделей. В качестве метрик для оценивания моделей будем использовать точность (accuracy) и каппа (kappa), потому что их легко интерпретировать.
Для сравнения мы выбрали 5 алгоритмов различного типа:
- классификационные и регрессионные деревья (classification and regression trees, CART);
- линейный дискриминантный анализ (linear discriminant analysis, LDA);
- метод опорных векторов с радиальной базисной функцией (support vector machine with radial basis function, SVM with RBF);
- метод k ближайших соседей (k-nearest neighbors, kNN);
- случайный лес (random forest, RF).
После того, как модели будут обучены, мы объединим их в список и применим к нему функцию resamples(). Эта функция проверяет возможность сравнения моделей, а также идентичность использованных схем обучения (т.е. конфигурацию trainControl). Объект results, возвращенный функцией resamples(), содержит оценки каждой модели для каждого валидационного блока в каждом повторении кросс-валидации. Все функции, которые мы будем использовать в дальнейшем, принимают на входе этот объект.
# prepare training scheme control <- trainControl(method="repeatedcv", number=10, repeats=3) # CART set.seed(7) fit.cart <- train(diabetes~., data=PimaIndiansDiabetes, method="rpart", trControl=control) # LDA set.seed(7) fit.lda <- train(diabetes~., data=PimaIndiansDiabetes, method="lda", trControl=control) # SVM set.seed(7) fit.svm <- train(diabetes~., data=PimaIndiansDiabetes, method="svmRadial", trControl=control) # kNN set.seed(7) fit.knn <- train(diabetes~., data=PimaIndiansDiabetes, method="knn", trControl=control) # Random Forest set.seed(7) fit.rf <- train(diabetes~., data=PimaIndiansDiabetes, method="rf", trControl=control) # collect resamples results <- resamples(list(CART=fit.cart, LDA=fit.lda, SVM=fit.svm, KNN=fit.knn, RF=fit.rf))
3. Сравнение моделей
Далее мы сравним созданные нами модели, используя 8 различных методов.
Итоговая таблица
Это самый простой метод сравнения. Просто вызываем функцию summary() для объекта results. Эта функция выведет на экран таблицу, каждая строка которой соответствует одной модели, а столбцы содержат различные статистические характеристики соответствующей метрики.
# summarize differences between models summary(results)
Особый интерес представляют медианы (Median) и максимальные значения (Max.).
Accuracy Min. 1st Qu. Median Mean 3rd Qu. Max. NA's CART 0.6234 0.7115 0.7403 0.7382 0.7760 0.8442 0 LDA 0.6711 0.7532 0.7662 0.7759 0.8052 0.8701 0 SVM 0.6711 0.7403 0.7582 0.7651 0.7890 0.8961 0 KNN 0.6184 0.6984 0.7321 0.7299 0.7532 0.8182 0 RF 0.6711 0.7273 0.7516 0.7617 0.7890 0.8571 0 Kappa Min. 1st Qu. Median Mean 3rd Qu. Max. NA's CART 0.1585 0.3296 0.3765 0.3934 0.4685 0.6393 0 LDA 0.2484 0.4196 0.4516 0.4801 0.5512 0.7048 0 SVM 0.2187 0.3889 0.4167 0.4520 0.5003 0.7638 0 KNN 0.1113 0.3228 0.3867 0.3819 0.4382 0.5867 0 RF 0.2624 0.3787 0.4516 0.4588 0.5193 0.6781 0
Ящики с усами
Ящики с усами (box and whisker plot) позволяют эффективно визуализировать и сравнить разброс оценок для каждой модели.
# box and whisker plots to compare models scales <- list(x=list(relation="free"), y=list(relation="free")) bwplot(results, scales=scales)
Обратите внимание, ящики отсортированы по убыванию медианы оценок (черные точки). В данной визуализации наиболее полезную информацию нам дают медианы и области перекрытия ящиков (т.е. области перекрытия межквартильных интервалов).
Графики функций плотности вероятности
Мы можем визуализировать распределения оценок точности моделей с помощью графиков функций плотности вероятности (probability density function). Этот метод позволяет сопоставить поведение различных моделей.
# density plots of accuracy scales <- list(x=list(relation="free"), y=list(relation="free")) densityplot(results, scales=scales, pch = "|")
Здесь полезно сравнить высоту пиков и ширину оснований графиков.
Точечные диаграммы
Точечная диаграмма (dot plot) визуализирует среднее значение оценок модели (точка), а также 95% доверительный интервал (т.е. интервал, которому принадлежит 95% всех оценок).
# dot plots of accuracy scales <- list(x=list(relation="free"), y=list(relation="free")) dotplot(results, scales=scales)
Здесь сравниваем средние значения и обращаем внимание на перекрытие доверительных интервалов.
Параллельные графики
Каждый из параллельных графиков (parallel plots) визуализирует оценки моделей для данного валидационного блока кросс-валидации. Мы видим, что один и тот же валидационный блок мог представлять сложность для одной модели, но быть простым для другой.
# parallel plots to compare models parallelplot(results)
Эта диаграмма позволяет нам выяснить, какие модели можно объединить в ансамбль. В частности, если мы наблюдаем коррелированные движения в противоположных направлениях, значит, данные модели являются кандидатами на участие в ансамбле.
Матрица диаграмм рассеяния
Матрица диаграмм рассеяния (scatter plot matrix) визуализирует попарное сравнение всех моделей. В каждой отдельной диаграмме рассеяния сравниваются оценки двух моделей для всех валидационных блоков кросс-валидации.
# pair-wise scatterplots of predictions to compare models splom(results)
Эта диаграмма незаменима в тех случаях, когда необходимо выяснить, коррелируют ли прогнозы двух данных моделей. Модели, которым свойственна слабая корреляция, являются кандидатами для объединения в ансамбль.
Например, мы видим, что таким парам, как LDA и SVM, а также SVM и RF свойственна сильная корреляция. С другой стороны, SVM и CART коррелируют слабо.
Индивидуальная диаграмма рассеяния
Для детального анализа мы можем вывести индивидуальную диаграмму рассеяния для любой пары моделей в более крупном масштабе.
# xyplot plots to compare models xyplot(results, models=c("LDA", "SVM"))
Для примера выведем сравнение моделей LDA и SVM, которым свойственна корреляция.
Тест статистической значимости
Мы можем вычислить статистическую значимость разницы между распределениями оценок различных моделей. Для этого воспользуемся функциями diff() и summary().
# difference in model predictions diffs <- diff(results) # summarize p-values for pair-wise comparisons summary(diffs)
В результате мы получим таблицу статистической значимости.
Нижняя диагональ таблицы содержит p-значения (величины статистической значимости) для нулевой гипотезы, которая утверждает, что распределения одинаковы. Чем меньше p-значение, тем менее вероятна справедливость нулевой гипотезы. Например, мы видим отсутствие разницы между CART и kNN, а также незначительную разницу между LDA и SVM.
Верхняя диагональ таблицы содержит разницу между распределениями. Эти значения показывают, как модели соотносятся друг с другом по абсолютному значению точности.
p-value adjustment: bonferroni Upper diagonal: estimates of the difference Lower diagonal: p-value for H0: difference = 0 Accuracy CART LDA SVM KNN RF CART -0.037759 -0.026908 0.008248 -0.023473 LDA 0.0050068 0.010851 0.046007 0.014286 SVM 0.0919580 0.3390336 0.035156 0.003435 KNN 1.0000000 1.218e-05 0.0007092 -0.031721 RF 0.1722106 0.1349151 1.0000000 0.0034441
Для повышения надежности рекомендуется увеличить количество испытаний, увеличив тем самым объем выборки оценок. Мы также могли бы визуализировать представленную выше информацию, но, на мой взгляд, сама по себе таблица более информативна.
Заключение
В этой статье мы рассмотрели 8 различных методов сравнения моделей на R:
- итоговая таблица;
- ящики с усами;
- графики функций плотности вероятности;
- точечные диаграммы;
- параллельные графики;
- матрица диаграмм рассеяния;
- индивидуальная диаграмма рассеяния;
- тест статистической значимости.
Возможно, мы не назвали ваш любимый метод сравнения моделей? Напишите о нем в комментариях, нам будет интересно узнать!
По материалам: machinelearningmastery.com