実用的な予測とモデル解釈可能性をマスターする
# 基本セットアップ
library(tidymodels)
library(tidyverse)
library(vip)        # Variable Importance Plots
library(DALEXtra)   # Model Agnostic Explanations
library(DALEX)      # Descriptive mAchine Learning EXplanations
# データとモデルの準備
data("ames", package = "modeldata")
ames_clean <- ames |>
  select(Sale_Price, Lot_Area, Year_Built, 
         Neighborhood, House_Style, Overall_Qual) |>
  mutate(log_price = log10(Sale_Price))
# Train/Test分割
set.seed(123)
ames_split <- initial_split(ames_clean, prop = 0.8)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)
        訓練済みモデルを使用して新しいデータで予測を行います。
# Random Forest モデルの訓練
rf_recipe <- recipe(log_price ~ ., data = ames_train) |>
  step_rm(Sale_Price) |>
  step_dummy(all_nominal_predictors()) |>
  step_normalize(all_numeric_predictors()) |>
  step_zv(all_predictors())
rf_spec <- rand_forest(
  mtry = 4,
  trees = 500,
  min_n = 5
) |>
  set_engine("ranger", importance = "impurity") |>
  set_mode("regression")
rf_workflow <- workflow() |>
  add_recipe(rf_recipe) |>
  add_model(rf_spec)
# モデル訓練
rf_fit <- rf_workflow |> fit(ames_train)
# テストデータでの予測
test_predictions <- rf_fit |>
  predict(ames_test) |>
  bind_cols(ames_test |> select(log_price, Sale_Price))
test_predictions |>
  head(10)
            予測の不確実性を量化して区間推定を行います。
# 量分位回帰による予測区間
library(quantreg)
# 分位点予測のためのモデル(複数の分位点)
quantile_spec <- linear_reg() |>
  set_engine("quantreg", tau = 0.1)  # 10%分位点
# 異なる分位点での予測
quantiles <- c(0.1, 0.25, 0.5, 0.75, 0.9)
quantile_predictions <- map_dfr(quantiles, ~ {
  quantile_model <- linear_reg() |>
    set_engine("quantreg", tau = .x) |>
    set_mode("regression")
  
  quantile_wf <- workflow() |>
    add_recipe(rf_recipe) |>
    add_model(quantile_model)
  
  quantile_fit <- quantile_wf |> fit(ames_train)
  
  predict(quantile_fit, ames_test) |>
    mutate(quantile = .x)
}) |>
  bind_cols(ames_test |> select(log_price) |> slice(rep(1:n(), 5)))
# 予測区間の可視化
prediction_intervals <- quantile_predictions |>
  pivot_wider(names_from = quantile, values_from = .pred, names_prefix = "q") |>
  mutate(row_id = row_number())
prediction_intervals |>
  slice_head(n = 50) |>
  ggplot(aes(x = row_id)) +
  geom_ribbon(aes(ymin = q0.1, ymax = q0.9), alpha = 0.3, fill = "lightblue") +
  geom_ribbon(aes(ymin = q0.25, ymax = q0.75), alpha = 0.5, fill = "lightgreen") +
  geom_line(aes(y = q0.5), color = "blue", size = 1) +
  geom_point(aes(y = log_price), color = "red", alpha = 0.7) +
  labs(title = "Prediction Intervals",
       x = "Observation", y = "Log Price",
       subtitle = "Blue ribbon: 80% interval, Green: 50% interval") +
  theme_minimal()
            
# Bootstrap による予測区間
library(rsample)
# Bootstrap サンプルでのモデル訓練
bootstrap_models <- bootstraps(ames_train, times = 100) |>
  mutate(
    models = map(splits, ~ {
      boot_data <- analysis(.x)
      rf_workflow |> fit(boot_data)
    })
  )
# 新しいデータ点での予測(例:最初のテストサンプル)
new_observation <- ames_test |> slice(1)
bootstrap_predictions <- bootstrap_models |>
  mutate(
    predictions = map_dbl(models, ~ {
      predict(.x, new_observation)$.pred
    })
  )
# 予測区間の計算
prediction_summary <- bootstrap_predictions |>
  summarise(
    median_pred = median(predictions),
    lower_95 = quantile(predictions, 0.025),
    upper_95 = quantile(predictions, 0.975),
    lower_50 = quantile(predictions, 0.25),
    upper_50 = quantile(predictions, 0.75)
  )
prediction_summary
        Random Forestの内蔵重要度指標を活用します。
# vipパッケージによる重要度可視化 library(vip) # 基本的な特徴量重要度 importance_plot <- rf_fit |> extract_fit_parsnip() |> vip(num_features = 15, geom = "col") + theme_minimal() + labs(title = "Feature Importance (Random Forest)") importance_plot # 数値での重要度取得 importance_values <- rf_fit |> extract_fit_parsnip() |> vi() |> arrange(desc(Importance)) importance_values
モデルに依存しない重要度測定手法です。
# Permutation importanceの計算
library(DALEX)
# DALEX explainerの作成
explainer <- explain(
  model = rf_fit,
  data = ames_train |> select(-log_price, -Sale_Price),
  y = ames_train$log_price,
  label = "Random Forest",
  verbose = FALSE
)
# Permutation importanceの計算
perm_importance <- model_parts(explainer, type = "difference")
# 結果の可視化
plot(perm_importance) + 
  ggtitle("Permutation Feature Importance") +
  theme_minimal()
# 数値結果
perm_importance$result |>
  filter(permutation == 0) |>  # baseline
  arrange(desc(dropout_loss)) |>
  head(10)
            
# SHAP値の計算(サンプルデータで)
library(shapr)
# 少数サンプルでのSHAP値計算
sample_data <- ames_test |> 
  select(-log_price, -Sale_Price) |>
  slice_head(n = 10)
# SHAP値計算の準備
shap_explainer <- explain(
  model = rf_fit,
  data = ames_train |> select(-log_price, -Sale_Price),
  y = ames_train$log_price,
  verbose = FALSE
)
# SHAP値計算
shap_values <- predict_parts(
  shap_explainer, 
  new_observation = sample_data[1, ],
  type = "shap"
)
# SHAP値の可視化
plot(shap_values) + 
  ggtitle("SHAP Values for Single Prediction") +
  theme_minimal()
        個々の特徴量がモデル予測に与える平均的な影響を可視化します。
# 部分依存プロットの作成
pdp_year_built <- model_profile(
  explainer,
  variables = "Year_Built",
  N = 100  # サンプル数
)
# 可視化
plot(pdp_year_built) +
  ggtitle("Partial Dependence: Year_Built") +
  theme_minimal()
# 複数変数の部分依存プロット
pdp_multi <- model_profile(
  explainer,
  variables = c("Year_Built", "Lot_Area", "Overall_Qual"),
  N = 50
)
plot(pdp_multi) +
  ggtitle("Partial Dependence: Multiple Variables") +
  theme_minimal()
            
# 2変数間の相互作用
pdp_2d <- model_profile(
  explainer,
  variables = c("Year_Built", "Overall_Qual"),
  type = "partial",
  variable_splits = list(
    Year_Built = seq(1900, 2010, by = 20),
    Overall_Qual = 1:10
  )
)
# ヒートマップでの可視化
library(plotly)
# データ整形
pdp_data <- pdp_2d$result |>
  select(`_vname_`, `_x_`, `_yhat_`) |>
  pivot_wider(names_from = `_vname_`, values_from = `_x_`) |>
  filter(!is.na(Year_Built), !is.na(Overall_Qual))
# ヒートマップ作成
ggplot(pdp_data, aes(x = Year_Built, y = Overall_Qual, fill = `_yhat_`)) +
  geom_tile() +
  scale_fill_viridis_c(name = "Predicted\nLog Price") +
  labs(title = "2D Partial Dependence Plot",
       x = "Year Built", y = "Overall Quality") +
  theme_minimal()
            
# ALE plotの作成
ale_plot <- model_profile(
  explainer,
  variables = "Year_Built",
  type = "accumulated"
)
plot(ale_plot) +
  ggtitle("Accumulated Local Effects: Year_Built") +
  theme_minimal()
# 複数変数のALE
ale_multi <- model_profile(
  explainer,
  variables = c("Year_Built", "Lot_Area"),
  type = "accumulated"
)
plot(ale_multi) +
  ggtitle("ALE Plots: Multiple Variables") +
  theme_minimal()
        Local Interpretable Model-agnostic Explanations による局所的解釈です。
# LIMEによる局所的説明 library(lime) # LIME explainerの作成 lime_explainer <- lime( x = ames_train |> select(-log_price, -Sale_Price), model = rf_fit, bin_continuous = TRUE, n_bins = 4 ) # 個別サンプルの説明 explanation <- explain( x = ames_test |> slice(1:3) |> select(-log_price, -Sale_Price), explainer = lime_explainer, n_features = 10, n_permutations = 1000 ) # 結果の可視化 plot_features(explanation) + labs(title = "LIME Explanations for Individual Predictions") + theme_minimal() # 説明の詳細 explanation |> select(case, feature, feature_weight, feature_desc) |> arrange(case, desc(abs(feature_weight)))
# より詳細な LIME 説明
detailed_explanation <- explain(
  x = ames_test |> slice(1) |> select(-log_price, -Sale_Price),
  explainer = lime_explainer,
  n_features = 5,
  feature_select = "highest_weights"
)
# テキスト形式の説明生成
explanation_text <- detailed_explanation |>
  mutate(
    contribution = ifelse(feature_weight > 0, 
                         paste("増加要因:", feature_desc), 
                         paste("減少要因:", feature_desc)),
    impact = paste0("(影響度: ", round(abs(feature_weight), 3), ")")
  ) |>
  select(contribution, impact) |>
  unite("explanation", contribution:impact, sep = " ")
cat("予測値:", round(detailed_explanation$prediction[1], 3), "\n")
cat("実際値:", round(ames_test$log_price[1], 3), "\n\n")
cat("主要な説明要因:\n")
cat(paste(explanation_text$explanation, collapse = "\n"))
        
# 包括的解釈関数
comprehensive_interpretation <- function(model, train_data, test_data, target_col) {
  
  # 基本情報
  model_info <- list(
    model_type = class(model)[1],
    n_features = ncol(train_data) - 1,
    n_train = nrow(train_data),
    n_test = nrow(test_data)
  )
  
  # 全体的重要度
  global_importance <- model |>
    extract_fit_parsnip() |>
    vi() |>
    arrange(desc(Importance))
  
  # 予測性能
  predictions <- predict(model, test_data) |>
    bind_cols(test_data |> select(all_of(target_col)))
  
  performance <- predictions |>
    metrics(truth = all_of(target_col), estimate = .pred)
  
  # 特徴量統計
  feature_stats <- train_data |>
    select(-all_of(target_col)) |>
    summarise(across(everything(), list(
      mean = ~ mean(.x, na.rm = TRUE),
      median = ~ median(.x, na.rm = TRUE),
      sd = ~ sd(.x, na.rm = TRUE)
    )))
  
  list(
    model_info = model_info,
    performance = performance,
    importance = global_importance,
    feature_stats = feature_stats,
    predictions = predictions
  )
}
# 実行
interpretation_report <- comprehensive_interpretation(
  rf_fit, 
  ames_train, 
  ames_test, 
  "log_price"
)
# レポート表示
interpretation_report$model_info
interpretation_report$performance
interpretation_report$importance |> head(10)
            
# Shiny アプリケーション用のコード例
library(shiny)
library(plotly)
# UI定義(簡略版)
ui <- fluidPage(
  titlePanel("Model Interpretation Dashboard"),
  
  sidebarLayout(
    sidebarPanel(
      selectInput("interpretation_type", "解釈手法:",
                  choices = c("Feature Importance", "Partial Dependence", 
                             "SHAP Values", "LIME")),
      selectInput("feature", "特徴量選択:",
                  choices = names(ames_train)),
      numericInput("sample_id", "サンプルID:", value = 1, min = 1, max = 100)
    ),
    
    mainPanel(
      plotlyOutput("interpretation_plot"),
      verbatimTextOutput("interpretation_text")
    )
  )
)
# Server定義(簡略版)
server <- function(input, output, session) {
  
  output$interpretation_plot <- renderPlotly({
    # 選択された解釈手法に応じてプロット生成
    if(input$interpretation_type == "Feature Importance") {
      p <- rf_fit |>
        extract_fit_parsnip() |>
        vip(num_features = 10) +
        theme_minimal()
      ggplotly(p)
    }
    # 他の解釈手法も同様に実装
  })
  
  output$interpretation_text <- renderText({
    paste("選択された解釈手法:", input$interpretation_type)
  })
}
# アプリケーション実行(コメントアウト)
# shinyApp(ui = ui, server = server)
            解釈可能性のガイドライン:
次の第20章では、複数のモデルを組み合わせるアンサンブル学習を学習します。stacksパッケージを使用して、予測精度をさらに向上させる手法を習得しましょう。
※ 当サイトはAmazonアソシエイトプログラムに参加しています