# Cross-validation for both
cv_ridge <- cv.glmnet(X, y, alpha = 0, nfolds = 10)
cv_lasso <- cv.glmnet(X, y, alpha = 1, nfolds = 10)
# Combine results
comparison_data <- bind_rows(
tibble(
log_lambda = log(cv_ridge$lambda),
cvm = cv_ridge$cvm,
cvsd = cv_ridge$cvsd,
method = "Ridge"
),
tibble(
log_lambda = log(cv_lasso$lambda),
cvm = cv_lasso$cvm,
cvsd = cv_lasso$cvsd,
method = "Lasso"
)
)
# Best lambdas
best_lambdas <- tibble(
method = c("Ridge", "Lasso"),
log_lambda = c(log(cv_ridge$lambda.min), log(cv_lasso$lambda.min)),
min_cvm = c(min(cv_ridge$cvm), min(cv_lasso$cvm))
)
p <- ggplot(comparison_data, aes(x = log_lambda, y = cvm, color = method)) +
geom_ribbon(aes(ymin = cvm - cvsd, ymax = cvm + cvsd, fill = method),
alpha = 0.2, color = NA) +
geom_line(linewidth = 1) +
geom_point(data = best_lambdas, aes(x = log_lambda, y = min_cvm),
size = 4, shape = 18) +
scale_color_manual(values = c("Ridge" = col_blue, "Lasso" = col_orange)) +
scale_fill_manual(values = c("Ridge" = col_blue, "Lasso" = col_orange)) +
labs(title = "Ridge vs Lasso: Cross-Validation Comparison",
subtitle = "Diamonds mark lambda.min for each method",
x = expression(log(lambda)),
y = "CV Mean Squared Error",
color = "", fill = "") +
theme(legend.position = "bottom")
ggsave(file.path(fig_dir, "12_ridge_vs_lasso_comparison.png"), p, width = 12, height = 6, dpi = 150)
p