🔮 第19章: 予測と解釈

実用的な予測とモデル解釈可能性をマスターする

🎯 Chapter 19の学習目標

# 基本セットアップ
library(tidymodels)
library(tidyverse)
library(vip)        # Variable Importance Plots
library(DALEXtra)   # Model Agnostic Explanations
library(DALEX)      # Descriptive mAchine Learning EXplanations

# データとモデルの準備
data("ames", package = "modeldata")
ames_clean <- ames |>
  select(Sale_Price, Lot_Area, Year_Built, 
         Neighborhood, House_Style, Overall_Qual) |>
  mutate(log_price = log10(Sale_Price))

# Train/Test分割
set.seed(123)
ames_split <- initial_split(ames_clean, prop = 0.8)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)

🔍 新しいデータでの予測

1. 基本的な予測実行

訓練済みモデルを使用して新しいデータで予測を行います。

# Random Forest モデルの訓練
rf_recipe <- recipe(log_price ~ ., data = ames_train) |>
  step_rm(Sale_Price) |>
  step_dummy(all_nominal_predictors()) |>
  step_normalize(all_numeric_predictors()) |>
  step_zv(all_predictors())

rf_spec <- rand_forest(
  mtry = 4,
  trees = 500,
  min_n = 5
) |>
  set_engine("ranger", importance = "impurity") |>
  set_mode("regression")

rf_workflow <- workflow() |>
  add_recipe(rf_recipe) |>
  add_model(rf_spec)

# モデル訓練
rf_fit <- rf_workflow |> fit(ames_train)

# テストデータでの予測
test_predictions <- rf_fit |>
  predict(ames_test) |>
  bind_cols(ames_test |> select(log_price, Sale_Price))

test_predictions |>
  head(10)

2. 予測区間の計算

予測の不確実性を量化して区間推定を行います。

# 量分位回帰による予測区間
library(quantreg)

# 分位点予測のためのモデル(複数の分位点)
quantile_spec <- linear_reg() |>
  set_engine("quantreg", tau = 0.1)  # 10%分位点

# 異なる分位点での予測
quantiles <- c(0.1, 0.25, 0.5, 0.75, 0.9)

quantile_predictions <- map_dfr(quantiles, ~ {
  quantile_model <- linear_reg() |>
    set_engine("quantreg", tau = .x) |>
    set_mode("regression")
  
  quantile_wf <- workflow() |>
    add_recipe(rf_recipe) |>
    add_model(quantile_model)
  
  quantile_fit <- quantile_wf |> fit(ames_train)
  
  predict(quantile_fit, ames_test) |>
    mutate(quantile = .x)
}) |>
  bind_cols(ames_test |> select(log_price) |> slice(rep(1:n(), 5)))

# 予測区間の可視化
prediction_intervals <- quantile_predictions |>
  pivot_wider(names_from = quantile, values_from = .pred, names_prefix = "q") |>
  mutate(row_id = row_number())

prediction_intervals |>
  slice_head(n = 50) |>
  ggplot(aes(x = row_id)) +
  geom_ribbon(aes(ymin = q0.1, ymax = q0.9), alpha = 0.3, fill = "lightblue") +
  geom_ribbon(aes(ymin = q0.25, ymax = q0.75), alpha = 0.5, fill = "lightgreen") +
  geom_line(aes(y = q0.5), color = "blue", size = 1) +
  geom_point(aes(y = log_price), color = "red", alpha = 0.7) +
  labs(title = "Prediction Intervals",
       x = "Observation", y = "Log Price",
       subtitle = "Blue ribbon: 80% interval, Green: 50% interval") +
  theme_minimal()

3. Bootstrap による予測区間

# Bootstrap による予測区間
library(rsample)

# Bootstrap サンプルでのモデル訓練
bootstrap_models <- bootstraps(ames_train, times = 100) |>
  mutate(
    models = map(splits, ~ {
      boot_data <- analysis(.x)
      rf_workflow |> fit(boot_data)
    })
  )

# 新しいデータ点での予測(例:最初のテストサンプル)
new_observation <- ames_test |> slice(1)

bootstrap_predictions <- bootstrap_models |>
  mutate(
    predictions = map_dbl(models, ~ {
      predict(.x, new_observation)$.pred
    })
  )

# 予測区間の計算
prediction_summary <- bootstrap_predictions |>
  summarise(
    median_pred = median(predictions),
    lower_95 = quantile(predictions, 0.025),
    upper_95 = quantile(predictions, 0.975),
    lower_50 = quantile(predictions, 0.25),
    upper_50 = quantile(predictions, 0.75)
  )

prediction_summary

📊 特徴量重要度分析

1. モデル固有の重要度

Random Forestの内蔵重要度指標を活用します。

# vipパッケージによる重要度可視化
library(vip)

# 基本的な特徴量重要度
importance_plot <- rf_fit |>
  extract_fit_parsnip() |>
  vip(num_features = 15, geom = "col") +
  theme_minimal() +
  labs(title = "Feature Importance (Random Forest)")

importance_plot

# 数値での重要度取得
importance_values <- rf_fit |>
  extract_fit_parsnip() |>
  vi() |>
  arrange(desc(Importance))

importance_values

2. Permutation Importance

モデルに依存しない重要度測定手法です。

# Permutation importanceの計算
library(DALEX)

# DALEX explainerの作成
explainer <- explain(
  model = rf_fit,
  data = ames_train |> select(-log_price, -Sale_Price),
  y = ames_train$log_price,
  label = "Random Forest",
  verbose = FALSE
)

# Permutation importanceの計算
perm_importance <- model_parts(explainer, type = "difference")

# 結果の可視化
plot(perm_importance) + 
  ggtitle("Permutation Feature Importance") +
  theme_minimal()

# 数値結果
perm_importance$result |>
  filter(permutation == 0) |>  # baseline
  arrange(desc(dropout_loss)) |>
  head(10)

3. SHAP値による重要度

# SHAP値の計算(サンプルデータで)
library(shapr)

# 少数サンプルでのSHAP値計算
sample_data <- ames_test |> 
  select(-log_price, -Sale_Price) |>
  slice_head(n = 10)

# SHAP値計算の準備
shap_explainer <- explain(
  model = rf_fit,
  data = ames_train |> select(-log_price, -Sale_Price),
  y = ames_train$log_price,
  verbose = FALSE
)

# SHAP値計算
shap_values <- predict_parts(
  shap_explainer, 
  new_observation = sample_data[1, ],
  type = "shap"
)

# SHAP値の可視化
plot(shap_values) + 
  ggtitle("SHAP Values for Single Prediction") +
  theme_minimal()

🌍 大域的解釈:部分依存プロット

1. 単変数部分依存プロット

個々の特徴量がモデル予測に与える平均的な影響を可視化します。

# 部分依存プロットの作成
pdp_year_built <- model_profile(
  explainer,
  variables = "Year_Built",
  N = 100  # サンプル数
)

# 可視化
plot(pdp_year_built) +
  ggtitle("Partial Dependence: Year_Built") +
  theme_minimal()

# 複数変数の部分依存プロット
pdp_multi <- model_profile(
  explainer,
  variables = c("Year_Built", "Lot_Area", "Overall_Qual"),
  N = 50
)

plot(pdp_multi) +
  ggtitle("Partial Dependence: Multiple Variables") +
  theme_minimal()

2. 2次元部分依存プロット

# 2変数間の相互作用
pdp_2d <- model_profile(
  explainer,
  variables = c("Year_Built", "Overall_Qual"),
  type = "partial",
  variable_splits = list(
    Year_Built = seq(1900, 2010, by = 20),
    Overall_Qual = 1:10
  )
)

# ヒートマップでの可視化
library(plotly)

# データ整形
pdp_data <- pdp_2d$result |>
  select(`_vname_`, `_x_`, `_yhat_`) |>
  pivot_wider(names_from = `_vname_`, values_from = `_x_`) |>
  filter(!is.na(Year_Built), !is.na(Overall_Qual))

# ヒートマップ作成
ggplot(pdp_data, aes(x = Year_Built, y = Overall_Qual, fill = `_yhat_`)) +
  geom_tile() +
  scale_fill_viridis_c(name = "Predicted\nLog Price") +
  labs(title = "2D Partial Dependence Plot",
       x = "Year Built", y = "Overall Quality") +
  theme_minimal()

3. 累積局所効果プロット(ALE)

# ALE plotの作成
ale_plot <- model_profile(
  explainer,
  variables = "Year_Built",
  type = "accumulated"
)

plot(ale_plot) +
  ggtitle("Accumulated Local Effects: Year_Built") +
  theme_minimal()

# 複数変数のALE
ale_multi <- model_profile(
  explainer,
  variables = c("Year_Built", "Lot_Area"),
  type = "accumulated"
)

plot(ale_multi) +
  ggtitle("ALE Plots: Multiple Variables") +
  theme_minimal()

🎯 局所的解釈:LIME

1. LIME による個別予測の説明

Local Interpretable Model-agnostic Explanations による局所的解釈です。

# LIMEによる局所的説明
library(lime)

# LIME explainerの作成
lime_explainer <- lime(
  x = ames_train |> select(-log_price, -Sale_Price),
  model = rf_fit,
  bin_continuous = TRUE,
  n_bins = 4
)

# 個別サンプルの説明
explanation <- explain(
  x = ames_test |> slice(1:3) |> select(-log_price, -Sale_Price),
  explainer = lime_explainer,
  n_features = 10,
  n_permutations = 1000
)

# 結果の可視化
plot_features(explanation) +
  labs(title = "LIME Explanations for Individual Predictions") +
  theme_minimal()

# 説明の詳細
explanation |>
  select(case, feature, feature_weight, feature_desc) |>
  arrange(case, desc(abs(feature_weight)))

2. テキスト形式での解釈

# より詳細な LIME 説明
detailed_explanation <- explain(
  x = ames_test |> slice(1) |> select(-log_price, -Sale_Price),
  explainer = lime_explainer,
  n_features = 5,
  feature_select = "highest_weights"
)

# テキスト形式の説明生成
explanation_text <- detailed_explanation |>
  mutate(
    contribution = ifelse(feature_weight > 0, 
                         paste("増加要因:", feature_desc), 
                         paste("減少要因:", feature_desc)),
    impact = paste0("(影響度: ", round(abs(feature_weight), 3), ")")
  ) |>
  select(contribution, impact) |>
  unite("explanation", contribution:impact, sep = " ")

cat("予測値:", round(detailed_explanation$prediction[1], 3), "\n")
cat("実際値:", round(ames_test$log_price[1], 3), "\n\n")
cat("主要な説明要因:\n")
cat(paste(explanation_text$explanation, collapse = "\n"))

🔧 解釈可能性のベストプラクティス

1. 包括的解釈レポート

# 包括的解釈関数
comprehensive_interpretation <- function(model, train_data, test_data, target_col) {
  
  # 基本情報
  model_info <- list(
    model_type = class(model)[1],
    n_features = ncol(train_data) - 1,
    n_train = nrow(train_data),
    n_test = nrow(test_data)
  )
  
  # 全体的重要度
  global_importance <- model |>
    extract_fit_parsnip() |>
    vi() |>
    arrange(desc(Importance))
  
  # 予測性能
  predictions <- predict(model, test_data) |>
    bind_cols(test_data |> select(all_of(target_col)))
  
  performance <- predictions |>
    metrics(truth = all_of(target_col), estimate = .pred)
  
  # 特徴量統計
  feature_stats <- train_data |>
    select(-all_of(target_col)) |>
    summarise(across(everything(), list(
      mean = ~ mean(.x, na.rm = TRUE),
      median = ~ median(.x, na.rm = TRUE),
      sd = ~ sd(.x, na.rm = TRUE)
    )))
  
  list(
    model_info = model_info,
    performance = performance,
    importance = global_importance,
    feature_stats = feature_stats,
    predictions = predictions
  )
}

# 実行
interpretation_report <- comprehensive_interpretation(
  rf_fit, 
  ames_train, 
  ames_test, 
  "log_price"
)

# レポート表示
interpretation_report$model_info
interpretation_report$performance
interpretation_report$importance |> head(10)

2. 対話的解釈ダッシュボード

# Shiny アプリケーション用のコード例
library(shiny)
library(plotly)

# UI定義(簡略版)
ui <- fluidPage(
  titlePanel("Model Interpretation Dashboard"),
  
  sidebarLayout(
    sidebarPanel(
      selectInput("interpretation_type", "解釈手法:",
                  choices = c("Feature Importance", "Partial Dependence", 
                             "SHAP Values", "LIME")),
      selectInput("feature", "特徴量選択:",
                  choices = names(ames_train)),
      numericInput("sample_id", "サンプルID:", value = 1, min = 1, max = 100)
    ),
    
    mainPanel(
      plotlyOutput("interpretation_plot"),
      verbatimTextOutput("interpretation_text")
    )
  )
)

# Server定義(簡略版)
server <- function(input, output, session) {
  
  output$interpretation_plot <- renderPlotly({
    # 選択された解釈手法に応じてプロット生成
    if(input$interpretation_type == "Feature Importance") {
      p <- rf_fit |>
        extract_fit_parsnip() |>
        vip(num_features = 10) +
        theme_minimal()
      ggplotly(p)
    }
    # 他の解釈手法も同様に実装
  })
  
  output$interpretation_text <- renderText({
    paste("選択された解釈手法:", input$interpretation_type)
  })
}

# アプリケーション実行(コメントアウト)
# shinyApp(ui = ui, server = server)

解釈可能性のガイドライン:

  • 複数手法の組み合わせ: 大域的・局所的解釈の併用
  • ステークホルダー考慮: 対象ユーザーに応じた説明
  • 継続的検証: 時間経過による解釈の変化確認
  • バイアス検出: 不公平な判断基準の特定

🎯 まとめと次のステップ

Chapter 19で学んだこと

実用化のポイント

  • 解釈の限界を理解: 相関と因果の区別
  • 計算コストの考慮: 実時間制約との バランス
  • 説明の品質担保: 解釈結果の妥当性検証
  • 継続的監視: モデルドリフトの検出

次の章への準備

次の第20章では、複数のモデルを組み合わせるアンサンブル学習を学習します。stacksパッケージを使用して、予測精度をさらに向上させる手法を習得しましょう。