実用的な予測とモデル解釈可能性をマスターする
# 基本セットアップ
library(tidymodels)
library(tidyverse)
library(vip) # Variable Importance Plots
library(DALEXtra) # Model Agnostic Explanations
library(DALEX) # Descriptive mAchine Learning EXplanations
# データとモデルの準備
data("ames", package = "modeldata")
ames_clean <- ames |>
select(Sale_Price, Lot_Area, Year_Built,
Neighborhood, House_Style, Overall_Qual) |>
mutate(log_price = log10(Sale_Price))
# Train/Test分割
set.seed(123)
ames_split <- initial_split(ames_clean, prop = 0.8)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)
訓練済みモデルを使用して新しいデータで予測を行います。
# Random Forest モデルの訓練
rf_recipe <- recipe(log_price ~ ., data = ames_train) |>
step_rm(Sale_Price) |>
step_dummy(all_nominal_predictors()) |>
step_normalize(all_numeric_predictors()) |>
step_zv(all_predictors())
rf_spec <- rand_forest(
mtry = 4,
trees = 500,
min_n = 5
) |>
set_engine("ranger", importance = "impurity") |>
set_mode("regression")
rf_workflow <- workflow() |>
add_recipe(rf_recipe) |>
add_model(rf_spec)
# モデル訓練
rf_fit <- rf_workflow |> fit(ames_train)
# テストデータでの予測
test_predictions <- rf_fit |>
predict(ames_test) |>
bind_cols(ames_test |> select(log_price, Sale_Price))
test_predictions |>
head(10)
予測の不確実性を量化して区間推定を行います。
# 量分位回帰による予測区間
library(quantreg)
# 分位点予測のためのモデル(複数の分位点)
quantile_spec <- linear_reg() |>
set_engine("quantreg", tau = 0.1) # 10%分位点
# 異なる分位点での予測
quantiles <- c(0.1, 0.25, 0.5, 0.75, 0.9)
quantile_predictions <- map_dfr(quantiles, ~ {
quantile_model <- linear_reg() |>
set_engine("quantreg", tau = .x) |>
set_mode("regression")
quantile_wf <- workflow() |>
add_recipe(rf_recipe) |>
add_model(quantile_model)
quantile_fit <- quantile_wf |> fit(ames_train)
predict(quantile_fit, ames_test) |>
mutate(quantile = .x)
}) |>
bind_cols(ames_test |> select(log_price) |> slice(rep(1:n(), 5)))
# 予測区間の可視化
prediction_intervals <- quantile_predictions |>
pivot_wider(names_from = quantile, values_from = .pred, names_prefix = "q") |>
mutate(row_id = row_number())
prediction_intervals |>
slice_head(n = 50) |>
ggplot(aes(x = row_id)) +
geom_ribbon(aes(ymin = q0.1, ymax = q0.9), alpha = 0.3, fill = "lightblue") +
geom_ribbon(aes(ymin = q0.25, ymax = q0.75), alpha = 0.5, fill = "lightgreen") +
geom_line(aes(y = q0.5), color = "blue", size = 1) +
geom_point(aes(y = log_price), color = "red", alpha = 0.7) +
labs(title = "Prediction Intervals",
x = "Observation", y = "Log Price",
subtitle = "Blue ribbon: 80% interval, Green: 50% interval") +
theme_minimal()
# Bootstrap による予測区間
library(rsample)
# Bootstrap サンプルでのモデル訓練
bootstrap_models <- bootstraps(ames_train, times = 100) |>
mutate(
models = map(splits, ~ {
boot_data <- analysis(.x)
rf_workflow |> fit(boot_data)
})
)
# 新しいデータ点での予測(例:最初のテストサンプル)
new_observation <- ames_test |> slice(1)
bootstrap_predictions <- bootstrap_models |>
mutate(
predictions = map_dbl(models, ~ {
predict(.x, new_observation)$.pred
})
)
# 予測区間の計算
prediction_summary <- bootstrap_predictions |>
summarise(
median_pred = median(predictions),
lower_95 = quantile(predictions, 0.025),
upper_95 = quantile(predictions, 0.975),
lower_50 = quantile(predictions, 0.25),
upper_50 = quantile(predictions, 0.75)
)
prediction_summary
Random Forestの内蔵重要度指標を活用します。
# vipパッケージによる重要度可視化 library(vip) # 基本的な特徴量重要度 importance_plot <- rf_fit |> extract_fit_parsnip() |> vip(num_features = 15, geom = "col") + theme_minimal() + labs(title = "Feature Importance (Random Forest)") importance_plot # 数値での重要度取得 importance_values <- rf_fit |> extract_fit_parsnip() |> vi() |> arrange(desc(Importance)) importance_values
モデルに依存しない重要度測定手法です。
# Permutation importanceの計算
library(DALEX)
# DALEX explainerの作成
explainer <- explain(
model = rf_fit,
data = ames_train |> select(-log_price, -Sale_Price),
y = ames_train$log_price,
label = "Random Forest",
verbose = FALSE
)
# Permutation importanceの計算
perm_importance <- model_parts(explainer, type = "difference")
# 結果の可視化
plot(perm_importance) +
ggtitle("Permutation Feature Importance") +
theme_minimal()
# 数値結果
perm_importance$result |>
filter(permutation == 0) |> # baseline
arrange(desc(dropout_loss)) |>
head(10)
# SHAP値の計算(サンプルデータで)
library(shapr)
# 少数サンプルでのSHAP値計算
sample_data <- ames_test |>
select(-log_price, -Sale_Price) |>
slice_head(n = 10)
# SHAP値計算の準備
shap_explainer <- explain(
model = rf_fit,
data = ames_train |> select(-log_price, -Sale_Price),
y = ames_train$log_price,
verbose = FALSE
)
# SHAP値計算
shap_values <- predict_parts(
shap_explainer,
new_observation = sample_data[1, ],
type = "shap"
)
# SHAP値の可視化
plot(shap_values) +
ggtitle("SHAP Values for Single Prediction") +
theme_minimal()
個々の特徴量がモデル予測に与える平均的な影響を可視化します。
# 部分依存プロットの作成
pdp_year_built <- model_profile(
explainer,
variables = "Year_Built",
N = 100 # サンプル数
)
# 可視化
plot(pdp_year_built) +
ggtitle("Partial Dependence: Year_Built") +
theme_minimal()
# 複数変数の部分依存プロット
pdp_multi <- model_profile(
explainer,
variables = c("Year_Built", "Lot_Area", "Overall_Qual"),
N = 50
)
plot(pdp_multi) +
ggtitle("Partial Dependence: Multiple Variables") +
theme_minimal()
# 2変数間の相互作用
pdp_2d <- model_profile(
explainer,
variables = c("Year_Built", "Overall_Qual"),
type = "partial",
variable_splits = list(
Year_Built = seq(1900, 2010, by = 20),
Overall_Qual = 1:10
)
)
# ヒートマップでの可視化
library(plotly)
# データ整形
pdp_data <- pdp_2d$result |>
select(`_vname_`, `_x_`, `_yhat_`) |>
pivot_wider(names_from = `_vname_`, values_from = `_x_`) |>
filter(!is.na(Year_Built), !is.na(Overall_Qual))
# ヒートマップ作成
ggplot(pdp_data, aes(x = Year_Built, y = Overall_Qual, fill = `_yhat_`)) +
geom_tile() +
scale_fill_viridis_c(name = "Predicted\nLog Price") +
labs(title = "2D Partial Dependence Plot",
x = "Year Built", y = "Overall Quality") +
theme_minimal()
# ALE plotの作成
ale_plot <- model_profile(
explainer,
variables = "Year_Built",
type = "accumulated"
)
plot(ale_plot) +
ggtitle("Accumulated Local Effects: Year_Built") +
theme_minimal()
# 複数変数のALE
ale_multi <- model_profile(
explainer,
variables = c("Year_Built", "Lot_Area"),
type = "accumulated"
)
plot(ale_multi) +
ggtitle("ALE Plots: Multiple Variables") +
theme_minimal()
Local Interpretable Model-agnostic Explanations による局所的解釈です。
# LIMEによる局所的説明 library(lime) # LIME explainerの作成 lime_explainer <- lime( x = ames_train |> select(-log_price, -Sale_Price), model = rf_fit, bin_continuous = TRUE, n_bins = 4 ) # 個別サンプルの説明 explanation <- explain( x = ames_test |> slice(1:3) |> select(-log_price, -Sale_Price), explainer = lime_explainer, n_features = 10, n_permutations = 1000 ) # 結果の可視化 plot_features(explanation) + labs(title = "LIME Explanations for Individual Predictions") + theme_minimal() # 説明の詳細 explanation |> select(case, feature, feature_weight, feature_desc) |> arrange(case, desc(abs(feature_weight)))
# より詳細な LIME 説明
detailed_explanation <- explain(
x = ames_test |> slice(1) |> select(-log_price, -Sale_Price),
explainer = lime_explainer,
n_features = 5,
feature_select = "highest_weights"
)
# テキスト形式の説明生成
explanation_text <- detailed_explanation |>
mutate(
contribution = ifelse(feature_weight > 0,
paste("増加要因:", feature_desc),
paste("減少要因:", feature_desc)),
impact = paste0("(影響度: ", round(abs(feature_weight), 3), ")")
) |>
select(contribution, impact) |>
unite("explanation", contribution:impact, sep = " ")
cat("予測値:", round(detailed_explanation$prediction[1], 3), "\n")
cat("実際値:", round(ames_test$log_price[1], 3), "\n\n")
cat("主要な説明要因:\n")
cat(paste(explanation_text$explanation, collapse = "\n"))
# 包括的解釈関数
comprehensive_interpretation <- function(model, train_data, test_data, target_col) {
# 基本情報
model_info <- list(
model_type = class(model)[1],
n_features = ncol(train_data) - 1,
n_train = nrow(train_data),
n_test = nrow(test_data)
)
# 全体的重要度
global_importance <- model |>
extract_fit_parsnip() |>
vi() |>
arrange(desc(Importance))
# 予測性能
predictions <- predict(model, test_data) |>
bind_cols(test_data |> select(all_of(target_col)))
performance <- predictions |>
metrics(truth = all_of(target_col), estimate = .pred)
# 特徴量統計
feature_stats <- train_data |>
select(-all_of(target_col)) |>
summarise(across(everything(), list(
mean = ~ mean(.x, na.rm = TRUE),
median = ~ median(.x, na.rm = TRUE),
sd = ~ sd(.x, na.rm = TRUE)
)))
list(
model_info = model_info,
performance = performance,
importance = global_importance,
feature_stats = feature_stats,
predictions = predictions
)
}
# 実行
interpretation_report <- comprehensive_interpretation(
rf_fit,
ames_train,
ames_test,
"log_price"
)
# レポート表示
interpretation_report$model_info
interpretation_report$performance
interpretation_report$importance |> head(10)
# Shiny アプリケーション用のコード例
library(shiny)
library(plotly)
# UI定義(簡略版)
ui <- fluidPage(
titlePanel("Model Interpretation Dashboard"),
sidebarLayout(
sidebarPanel(
selectInput("interpretation_type", "解釈手法:",
choices = c("Feature Importance", "Partial Dependence",
"SHAP Values", "LIME")),
selectInput("feature", "特徴量選択:",
choices = names(ames_train)),
numericInput("sample_id", "サンプルID:", value = 1, min = 1, max = 100)
),
mainPanel(
plotlyOutput("interpretation_plot"),
verbatimTextOutput("interpretation_text")
)
)
)
# Server定義(簡略版)
server <- function(input, output, session) {
output$interpretation_plot <- renderPlotly({
# 選択された解釈手法に応じてプロット生成
if(input$interpretation_type == "Feature Importance") {
p <- rf_fit |>
extract_fit_parsnip() |>
vip(num_features = 10) +
theme_minimal()
ggplotly(p)
}
# 他の解釈手法も同様に実装
})
output$interpretation_text <- renderText({
paste("選択された解釈手法:", input$interpretation_type)
})
}
# アプリケーション実行(コメントアウト)
# shinyApp(ui = ui, server = server)
解釈可能性のガイドライン:
次の第20章では、複数のモデルを組み合わせるアンサンブル学習を学習します。stacksパッケージを使用して、予測精度をさらに向上させる手法を習得しましょう。