2022年7月30日土曜日

覚書:打ち切りデータの扱い

打ち切り.knit

 統計のダルい系覚書第二弾.生存時間解析などにおける「打ち切りデータ」の扱い.

1. 打ち切りとは?

 打ち切りとは,「ある値以上(あるいは以下)になることは分かっているが,それ以上(あるいは以下)のどのような値になるか分からない」という状況を指す.例えば,ある生物が「出生してから死亡するまでの時間」を観察していた場合、最後まで死ななかった(右側打ち切りRight censored)とか,前回は生存が確認されたけど今回は死んでいた(間隔打ち切りInterval censored)といったことが頻繁に起こる.また,観測開始時点ですでに生まれてからしばらく経っていた(左側打ち切りLeft censored)ということもあるだろう.

 このような場合,打ち切りデータとして扱わないと,推定にバイアスがかかる.例えば,右側打ち切りのデータを排除して平均寿命を求めたら,長生きしている個体を無視していることになるので,寿命を過小評価することになる.では,どのように打ち切りデータを扱えばいいだろうか.以下では,指数分布における「右側打ち切り」を例に,尤度評価やR/JAGS/Stanでの実装についてまとめておく.

 先に言っておくが,打ち切りを扱う理屈は難しくない。だけど,なかなかダルい.とくにJAGSでの打ち切りの扱いは,はっきり言って死んでいる(残念ながら,Nimbleでもばっちり引き継がれてしまっている.WinBUGSやOpenBUGSよりも明らかに使いにくくなった).あと,打ち切りの有無(1/0)をベクトルとして与えることが多いが,打ち切りアリを1とするか,観測アリ(すなわち打ち切りなし)を1とするかがソフトや関数によってばらばらで,じつに面倒くさい.

2. 打ち切りの尤度評価

 打ち切りの尤度評価については,自分で尤度関数を定義したらわかりやすい.

 今,ある生物の死亡までの時間tがパラメータλを持つ指数分布に従っていたとする.打ち切りアリ(観測ナシ)の場合の観測値をt0,打ち切りナシ(観測アリ)のデータの打ち切り時間をt1と表記する.また,Rに倣って指数分布の確率密度関数をdexp,累積分布関数(ある時点までにイベントが起こる確率の関数=dexpを積分した値に相当)をpexpとする.

 この場合,打ち切りナシ(観測アリ)のデータの対数尤度は,当然,log(dexp(t1, λ))となる.これに対して,打ち切りアリ(観測ナシ)の対数尤度は,その時点までイベント(死亡)が起こらなかった確率が尤度になる.累積分布関数pexpは,その時点までにイベントが発生する確率なので,1から引いてlog(1 – pexp(t0, λ))が対数尤度となる.

3. Rのoptim関数の利用

 Rで対数尤度を定義して,optim関数で,最小化するλの値を探す場合は以下のように出来る.なお,マイナス対数尤度にしているのは,optim関数が最大ではなく最小にするパラメータを探す仕様だから.

n <- 100
t.original <- rexp(n,  1/10) # 真のλの期待値を10にして指数乱数の発生
cens.time <- 15 # 打ち切り時間は15にする
is.censored <- ifelse(t.original > cens.time, 1, 0) # 打ち切りアリ(観測ナシ)1
t.optim <- t.original
##### 使うのは以下
t1 <- t.optim [t.optim < 15] # 観測値
t0 <- rep(cens.time, n - length(t1))  # 打ち切り時間
 

loglf <- function(par){ #マイナス対数尤度を計算する関数の定義
  minus.loglik <- - sum(log(dexp(t1, par))) - sum(log(1 - pexp(t0, par)))
  return(minus.loglik)
}

1/optim(0.15, loglf)$par #0.15は初期値
## Warning in optim(0.15, loglf): one-dimensional optimization by Nelder-Mead is unreliable:
## use "Brent" or optimize() directly
## [1] 10.30596

 Nelder-Meadによる最適化は,1変量の最適化には不向きだと警告が発せられるが,ほぼ真のλが推定される.

4. Rのflexsurvパッケージ

 このパッケージは,いろんな確率分布を含んでいるので割と便利.このパッケージの場合,「観測もしくは打ち切りのタイミング」のベクトルと,観測の有無(打ち切りナシ(観測アリ)= 1)のベクトルを準備する.

# n <- 100
# t.original <- rexp(n, 1/10) # 先ほどと同様真の期待値を10にして指数乱数の発生
# cens.time <- 15 # 打ち切り時間は15にする

##### 使うのは以下
tflex <- ifelse(t.original > cens.time, cens.time, t.original) # 観測or打ち切りタイミング
obs.flex <- ifelse(t.original > cens.time, 0, 1) # 打ち切りナシ(観測アリ) = 1

library(flexsurv)
##  要求されたパッケージ survival をロード中です
model <- flexsurvreg(Surv(tflex, obs.flex) ~ 1, dist = "exponential")
1/model$res[1]
## [1] 10.30653

5. Stan

 Stanは,optim関数と似た感じで打ち切りを扱える.

# n <- 100  
# t.original <- rexp(n, 1/10) # 先ほどと同様真の期待値を10にして指数乱数の発生
# cens.time <- 15 # 打ち切り時間は15にする

##### 使うのは以下
tstan <- t.original [t.original <= cens.time] # 観測時間
cstan <- cens.time # 打ち切り時間
Nstan <- length(tstan) # 観測ありデータ数
Ncstan <- length(cstan) # 打ち切りデータ数

推定モデルは以下のような感じ(以下では,一応アヒル本に倣って書いた.打ち切りの対数尤度は同じなので掛けている).

data {
  int<lower=0> Nstan;
  int<lower=0> Ncstan;
  real tstan[Nstan];
  real<lower=max(ttstan)> cens.time;
}

parameters {
  real<lower=0> lambda;
}

model {
for(i in 1:Nstan)
  tstan[i] ~ exponential(lambda);
  target += Ncstan * exponential_lccdf(cstan | lambda);
}

6. JAGS

 簡単なモデルだと非常に使い勝手の良いJAGS.打ち切りの仕様だけは何とかしてほしい.なぜこうなるかは一応理解したつもりではいるが,余り深く考えないで「こうやればできる」でいいと思う.

 3つベクトルを準備する.

ベクトル 説明
tjags 打ち切りナシの場合は観測時間,アリの場合はNA(観測時間がNot Available)
is.censored 打ち切りの有無を1/0で.打ち切りアリ = 1
c ほぼイミフの値.観測された場合は観測時間よりも大きな値,そうでない場合は打ち切り時間
# n <- 100
# t.original <- rexp(n, 1/10) # 先ほどと同様真の期待値を10にして指数乱数の発生
# cens.time <- 15 # 打ち切り時間は15にする

##### 使うのは以下

tjags <- ifelse(t.original > cens.time, NA, t.original) # 観測のタイミング,打ち切りの場合NA
is.censored <- ifelse(t.original > cens.time, 1, 0) # 打ち切りの有無(打ち切りアリ = 1)
c <- rep(cens.time, length(tjags))
c[is.censored == 0] <- tjags[is.censored == 0] + 0.01 # 打ち切りアリの場合は打ち切り時間,ナシの場合は観測時間 + 0.01

library(jagsUI)
model.file <- "C:\\bugstemp\\model1.txt"
sink(model.file)
cat("model
    {   
    # Likelihood
    for(i in 1:Nstay){
      is.censored[i] ~ dinterval(tjags[i], c[i])
      tjags[i] ~ dexp(1/lambda)
    }   

    # Prior distribution
    lambda~dgamma(0.001,0.001)
    }
    ")

sink()

datalist <- list(tjags = tjags, c = c, is.censored = is.censored, Nstay = length(tjags))

inits <- function(){list(lambda = 10)}
parameters <- c("lambda")
n.chain <- 3; n.iter <- 20000; n.burnin <-10000; n.thin <-2
model <- jags(datalist, inits, parameters, model.file, n.chain = n.chain, n.iter = n.iter, n.burnin = n.burnin, n.thin = n.thin, parallel = TRUE)
model$mean$lam

ダルい・・・.