統計のダルい系覚書第二弾.生存時間解析などにおける「打ち切りデータ」の扱い.
1. 打ち切りとは?
打ち切りとは,「ある値以上(あるいは以下)になることは分かっているが,それ以上(あるいは以下)のどのような値になるか分からない」という状況を指す.例えば,ある生物が「出生してから死亡するまでの時間」を観察していた場合、最後まで死ななかった(右側打ち切りRight censored)とか,前回は生存が確認されたけど今回は死んでいた(間隔打ち切りInterval censored)といったことが頻繁に起こる.また,観測開始時点ですでに生まれてからしばらく経っていた(左側打ち切りLeft censored)ということもあるだろう.
このような場合,打ち切りデータとして扱わないと,推定にバイアスがかかる.例えば,右側打ち切りのデータを排除して平均寿命を求めたら,長生きしている個体を無視していることになるので,寿命を過小評価することになる.では,どのように打ち切りデータを扱えばいいだろうか.以下では,指数分布における「右側打ち切り」を例に,尤度評価やR/JAGS/Stanでの実装についてまとめておく.
先に言っておくが,打ち切りを扱う理屈は難しくない。だけど,なかなかダルい.とくにJAGSでの打ち切りの扱いは,はっきり言って死んでいる(残念ながら,Nimbleでもばっちり引き継がれてしまっている.WinBUGSやOpenBUGSよりも明らかに使いにくくなった).あと,打ち切りの有無(1/0)をベクトルとして与えることが多いが,打ち切りアリを1とするか,観測アリ(すなわち打ち切りなし)を1とするかがソフトや関数によってばらばらで,じつに面倒くさい.
2. 打ち切りの尤度評価
打ち切りの尤度評価については,自分で尤度関数を定義したらわかりやすい.
今,ある生物の死亡までの時間tがパラメータλを持つ指数分布に従っていたとする.打ち切りアリ(観測ナシ)の場合の観測値をt0,打ち切りナシ(観測アリ)のデータの打ち切り時間をt1と表記する.また,Rに倣って指数分布の確率密度関数をdexp,累積分布関数(ある時点までにイベントが起こる確率の関数=dexpを積分した値に相当)をpexpとする.
この場合,打ち切りナシ(観測アリ)のデータの対数尤度は,当然,log(dexp(t1, λ))となる.これに対して,打ち切りアリ(観測ナシ)の対数尤度は,その時点までイベント(死亡)が起こらなかった確率が尤度になる.累積分布関数pexpは,その時点までにイベントが発生する確率なので,1から引いてlog(1 – pexp(t0, λ))が対数尤度となる.
3. Rのoptim関数の利用
Rで対数尤度を定義して,optim関数で,最小化するλの値を探す場合は以下のように出来る.なお,マイナス対数尤度にしているのは,optim関数が最大ではなく最小にするパラメータを探す仕様だから.
n <- 100
t.original <- rexp(n, 1/10) # 真のλの期待値を10にして指数乱数の発生
cens.time <- 15 # 打ち切り時間は15にする
is.censored <- ifelse(t.original > cens.time, 1, 0) # 打ち切りアリ(観測ナシ)1
t.optim <- t.original
##### 使うのは以下
t1 <- t.optim [t.optim < 15] # 観測値
t0 <- rep(cens.time, n - length(t1)) # 打ち切り時間
loglf <- function(par){ #マイナス対数尤度を計算する関数の定義
minus.loglik <- - sum(log(dexp(t1, par))) - sum(log(1 - pexp(t0, par)))
return(minus.loglik)
}
1/optim(0.15, loglf)$par #0.15は初期値
## Warning in optim(0.15, loglf): one-dimensional optimization by Nelder-Mead is unreliable:
## use "Brent" or optimize() directly
## [1] 10.30596
Nelder-Meadによる最適化は,1変量の最適化には不向きだと警告が発せられるが,ほぼ真のλが推定される.
4. Rのflexsurvパッケージ
このパッケージは,いろんな確率分布を含んでいるので割と便利.このパッケージの場合,「観測もしくは打ち切りのタイミング」のベクトルと,観測の有無(打ち切りナシ(観測アリ)= 1)のベクトルを準備する.
# n <- 100
# t.original <- rexp(n, 1/10) # 先ほどと同様真の期待値を10にして指数乱数の発生
# cens.time <- 15 # 打ち切り時間は15にする
##### 使うのは以下
tflex <- ifelse(t.original > cens.time, cens.time, t.original) # 観測or打ち切りタイミング
obs.flex <- ifelse(t.original > cens.time, 0, 1) # 打ち切りナシ(観測アリ) = 1
library(flexsurv)
## 要求されたパッケージ survival をロード中です
model <- flexsurvreg(Surv(tflex, obs.flex) ~ 1, dist = "exponential")
1/model$res[1]
## [1] 10.30653
5. Stan
Stanは,optim関数と似た感じで打ち切りを扱える.
# n <- 100
# t.original <- rexp(n, 1/10) # 先ほどと同様真の期待値を10にして指数乱数の発生
# cens.time <- 15 # 打ち切り時間は15にする
##### 使うのは以下
tstan <- t.original [t.original <= cens.time] # 観測時間
cstan <- cens.time # 打ち切り時間
Nstan <- length(tstan) # 観測ありデータ数
Ncstan <- length(cstan) # 打ち切りデータ数
推定モデルは以下のような感じ(以下では,一応アヒル本に倣って書いた.打ち切りの対数尤度は同じなので掛けている).
data {
int<lower=0> Nstan;
int<lower=0> Ncstan;
real tstan[Nstan];
real<lower=max(ttstan)> cens.time;
}
parameters {
real<lower=0> lambda;
}
model {
for(i in 1:Nstan)
tstan[i] ~ exponential(lambda);
target += Ncstan * exponential_lccdf(cstan | lambda);
}
6. JAGS
簡単なモデルだと非常に使い勝手の良いJAGS.打ち切りの仕様だけは何とかしてほしい.なぜこうなるかは一応理解したつもりではいるが,余り深く考えないで「こうやればできる」でいいと思う.
3つベクトルを準備する.
ベクトル | 説明 |
---|---|
tjags | 打ち切りナシの場合は観測時間,アリの場合はNA(観測時間がNot Available) |
is.censored | 打ち切りの有無を1/0で.打ち切りアリ = 1 |
c | ほぼイミフの値.観測された場合は観測時間よりも大きな値,そうでない場合は打ち切り時間 |
# n <- 100
# t.original <- rexp(n, 1/10) # 先ほどと同様真の期待値を10にして指数乱数の発生
# cens.time <- 15 # 打ち切り時間は15にする
##### 使うのは以下
tjags <- ifelse(t.original > cens.time, NA, t.original) # 観測のタイミング,打ち切りの場合NA
is.censored <- ifelse(t.original > cens.time, 1, 0) # 打ち切りの有無(打ち切りアリ = 1)
c <- rep(cens.time, length(tjags))
c[is.censored == 0] <- tjags[is.censored == 0] + 0.01 # 打ち切りアリの場合は打ち切り時間,ナシの場合は観測時間 + 0.01
library(jagsUI)
model.file <- "C:\\bugstemp\\model1.txt"
sink(model.file)
cat("model
{
# Likelihood
for(i in 1:Nstay){
is.censored[i] ~ dinterval(tjags[i], c[i])
tjags[i] ~ dexp(1/lambda)
}
# Prior distribution
lambda~dgamma(0.001,0.001)
}
")
sink()
datalist <- list(tjags = tjags, c = c, is.censored = is.censored, Nstay = length(tjags))
inits <- function(){list(lambda = 10)}
parameters <- c("lambda")
n.chain <- 3; n.iter <- 20000; n.burnin <-10000; n.thin <-2
model <- jags(datalist, inits, parameters, model.file, n.chain = n.chain, n.iter = n.iter, n.burnin = n.burnin, n.thin = n.thin, parallel = TRUE)
model$mean$lam
ダルい・・・.