CISE is a package that fits the multiple-gragh factorization (M-GRAF) model ( Wang et al., 2017) for undirected graphs \(A_i, i=1,\dots,n\), \[ A_{i[uv]}\mid\Pi_{i[uv]}\stackrel{\mbox{ind}}{\sim}\mbox{Bernoulli}(\Pi_{i[uv]}),\ u>v;u,v\in\{1,2,\dots,V\}, \\ \mbox{logit}(\Pi_{i})=Z+D_{i},\ i=1,\dots,n. \] Suppose the subject-specific deviation \(D_i\) has low rank \(K\). To accommodate different degrees of heterogeneity in the data, assume \(D_i\) has 3 types of decomposition:
\(D_{i}=Q_{i}\Lambda_{i}Q_{i}^{\top}\)
\(D_{i}=Q_{i}\Lambda Q_{i}^{\top}\)
\(D_{i}=Q \Lambda_{i }Q^{\top}\)
Model checking is needed to decide which assumption on \(D_i\) is reasonable.
Let’s focus on studying the relationship between brain connectivity and one cognitive trait - visuospatial processing (VP). We first load HCP data on brain networks and VP ability of 212 subjects.
library(CISE)
data(A)
data(VSPLOT)
According to ( Wang et al., 2017), the variant 2 on \(D_i\) is reasonable for this dataset while maintains a concise model. We use the MGRAF2
function in the CISE
package to extract the low rank components {\(Q_i\)} and \(\Lambda\). Classification of subjects into high and low VP group is proceeded via a distance-based procedure using these low rank components as described in ( Wang et al., 2017). The following procedure implements repeating 10-fold cross validation (CV) 30 times and displays the mean and and standard deviation of the CV accuracies under different choices of \(K\).
n = length(VSPLOT)
y = numeric(n)
y[VSPLOT=="high"] = 1
zero_id = which(y==0)
one_id = which(y==1)
max_K = 7
acc_MGRAF2 = matrix(0,nr=max_K,nc=2) # store mean and sd of CV accuracies
for(K in 1:max_K){
#print(paste("K=",K))
res = MGRAF2(A = A, K=K, tol=0.01, maxit=5)
Q_best = res$Q
Lambda_best = res$Lambda
## do 10-fold cross validation and repeat 10 times
nfd = 10 # number of folds
rep = 30
acc = numeric(nfd*rep) # store accuracy
set.seed(22)
for(t in 1:rep){
foldid = sample(1:nfd, size=n, replace=T)
for(fd in 1:nfd){
test_id = which(foldid==fd)
train_zero_id = setdiff(zero_id, test_id)
train_one_id = setdiff(one_id, test_id)
pred = sapply(test_id, function(i){
ave_dist_zero = mean( sapply(train_zero_id, function(j){
M = crossprod(Q_best[,,j], Q_best[,,i]) # KxK
temp = 0
for(k in 1:K){
temp = temp + Lambda_best[k] * sum(M[k,] * Lambda_best * M[k,])
}
dist = sqrt( 2 * sum(Lambda_best^2) - 2 * temp )
}) )
ave_dist_one = mean( sapply(train_one_id, function(j){
M = crossprod(Q_best[,,j], Q_best[,,i]) # KxK
temp = 0
for(k in 1:K){
temp = temp + Lambda_best[k] * sum(M[k,] * Lambda_best * M[k,])
}
dist = sqrt( 2 * sum(Lambda_best^2) - 2 * temp )
}) )
if (ave_dist_zero < ave_dist_one){
return(0)
}else{
return(1)
}
})
acc[nfd*(t-1)+fd] = sum(pred == y[test_id])/length(pred)
#print(fd)
}
#print(t)
}
acc_MGRAF2[K,1] = mean(acc)
acc_MGRAF2[K,2] = sd(acc)
}
library(knitr)
acc_data = data.frame(K=1:max_K, mean = acc_MGRAF2[ ,1], sd = acc_MGRAF2[ ,2])
kable(acc_data, digits = 3)
K | mean | sd |
---|---|---|
1 | 0.561 | 0.101 |
2 | 0.621 | 0.107 |
3 | 0.622 | 0.104 |
4 | 0.623 | 0.105 |
5 | 0.643 | 0.102 |
6 | 0.641 | 0.104 |
7 | 0.629 | 0.105 |