earth
packageFunction | Works |
---|---|
tidypredict_fit() , tidypredict_sql() , parse_model() |
✔ |
tidypredict_to_column() |
✔ |
tidypredict_test() |
✔ |
tidypredict_interval() , tidypredict_sql_interval() |
✗ |
parsnip |
✔ |
tidypredict_
functionslibrary(earth)
data("etitanic", package = "earth")
model <- earth(age ~ sibsp + parch, data = etitanic, degree = 3)
Create the R formula
tidypredict_fit(model)
#> 22.2918960405403 + (ifelse(parch > 2, parch - 2, 0) * 13.0493891423277) +
#> (ifelse(parch < 2, 2 - parch, 0) * 4.85356462114366) + (ifelse(sibsp >
#> 1, sibsp - 1, 0) * -7.71566779782023) + (ifelse(sibsp > 1,
#> sibsp - 1, 0) * ifelse(parch > 1, parch - 1, 0) * 4.41874354843212) +
#> (ifelse(sibsp > 1, sibsp - 1, 0) * ifelse(parch < 1, 1 -
#> parch, 0) * 7.40395975552272) + (ifelse(parch > 4, parch -
#> 4, 0) * -18.8998708031826)
SQL output example
tidypredict_sql(model, dbplyr::simulate_odbc())
#> <SQL> 22.2918960405403 + (CASE WHEN (`parch` > 2.0) THEN (`parch` - 2.0) WHEN NOT(`parch` > 2.0) THEN (0.0) END * 13.0493891423277) + (CASE WHEN (`parch` < 2.0) THEN (2.0 - `parch`) WHEN NOT(`parch` < 2.0) THEN (0.0) END * 4.85356462114366) + (CASE WHEN (`sibsp` > 1.0) THEN (`sibsp` - 1.0) WHEN NOT(`sibsp` > 1.0) THEN (0.0) END * -7.71566779782023) + (CASE WHEN (`sibsp` > 1.0) THEN (`sibsp` - 1.0) WHEN NOT(`sibsp` > 1.0) THEN (0.0) END * CASE WHEN (`parch` > 1.0) THEN (`parch` - 1.0) WHEN NOT(`parch` > 1.0) THEN (0.0) END * 4.41874354843212) + (CASE WHEN (`sibsp` > 1.0) THEN (`sibsp` - 1.0) WHEN NOT(`sibsp` > 1.0) THEN (0.0) END * CASE WHEN (`parch` < 1.0) THEN (1.0 - `parch`) WHEN NOT(`parch` < 1.0) THEN (0.0) END * 7.40395975552272) + (CASE WHEN (`parch` > 4.0) THEN (`parch` - 4.0) WHEN NOT(`parch` > 4.0) THEN (0.0) END * -18.8998708031826)
Add the prediction to the original table
library(dplyr)
etitanic %>%
tidypredict_to_column(model) %>%
glimpse()
#> Rows: 1,046
#> Columns: 7
#> $ pclass <fct> 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, …
#> $ survived <int> 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, …
#> $ sex <fct> female, male, female, male, female, male, female, male, fema…
#> $ age <dbl> 29.0000, 0.9167, 2.0000, 30.0000, 25.0000, 48.0000, 63.0000,…
#> $ sibsp <int> 0, 1, 1, 1, 1, 0, 1, 0, 2, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, …
#> $ parch <int> 0, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, …
#> $ fit <dbl> 31.99903, 22.29190, 22.29190, 22.29190, 22.29190, 31.99903, …
Confirm that tidypredict
results match to the model’s predict()
results
tidypredict
supports the glm
argument as well:
model <- earth(survived ~ .,
data = etitanic,
glm = list(family = binomial), degree = 2)
tidypredict_fit(model)
#> 1 - 1/(1 + exp(2.91352600741339 + (ifelse(sex == "male", 1, 0) *
#> -3.1856245024853) + (ifelse(pclass == "3rd", 1, 0) * -5.03005595780699) +
#> (ifelse(sex == "male", 1, 0) * ifelse(age < 16, 16 - age,
#> 0) * 0.241814028713265) + (ifelse(pclass == "2nd", 1,
#> 0) * ifelse(sex == "male", 1, 0) * -1.76809447811123) + (ifelse(pclass ==
#> "3rd", 1, 0) * ifelse(sibsp < 4, 4 - sibsp, 0) * 0.61865274765985) +
#> (ifelse(pclass == "3rd", 1, 0) * ifelse(sex == "male", 1,
#> 0) * 1.22269536265148) + (ifelse(age > 32, age - 32,
#> 0) * -0.0375714917713112)))
The spec sets the is_glm
entry to 1, as well as the family
and link
entries.
str(parse_model(model), 2)
#> List of 2
#> $ general:List of 6
#> ..$ model : chr "earth"
#> ..$ type : chr "tree"
#> ..$ version: num 2
#> ..$ is_glm : num 1
#> ..$ family : chr "binomial"
#> ..$ link : chr "logit"
#> $ terms :List of 8
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> - attr(*, "class")= chr [1:3] "parsed_model" "pm_tree" "list"
parsnip
fitted models are also supported by tidypredict
:
library(parsnip)
p_model <- mars(mode = "regression", prod_degree = 3) %>%
set_engine("earth") %>%
fit(age ~ sibsp + parch, data = etitanic)
tidypredict_fit(p_model)
#> 22.2918960405403 + (ifelse(parch > 2, parch - 2, 0) * 13.0493891423277) +
#> (ifelse(parch < 2, 2 - parch, 0) * 4.85356462114366) + (ifelse(sibsp >
#> 1, sibsp - 1, 0) * -7.71566779782023) + (ifelse(sibsp > 1,
#> sibsp - 1, 0) * ifelse(parch > 1, parch - 1, 0) * 4.41874354843212) +
#> (ifelse(sibsp > 1, sibsp - 1, 0) * ifelse(parch < 1, 1 -
#> parch, 0) * 7.40395975552272) + (ifelse(parch > 4, parch -
#> 4, 0) * -18.8998708031826)
Here is an example of the model spec:
pm <- parse_model(model)
str(pm, 2)
#> List of 2
#> $ general:List of 6
#> ..$ model : chr "earth"
#> ..$ type : chr "tree"
#> ..$ version: num 2
#> ..$ is_glm : num 1
#> ..$ family : chr "binomial"
#> ..$ link : chr "logit"
#> $ terms :List of 8
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> - attr(*, "class")= chr [1:3] "parsed_model" "pm_tree" "list"
str(pm$terms[1:2])
#> List of 2
#> $ :List of 4
#> ..$ label : chr "(Intercept)"
#> ..$ coef : num 2.91
#> ..$ is_intercept: num 1
#> ..$ fields : list()
#> $ :List of 4
#> ..$ label : chr "sexmale"
#> ..$ coef : num -3.19
#> ..$ is_intercept: num 0
#> ..$ fields :List of 1
#> .. ..$ :List of 4
#> .. .. ..$ type: chr "conditional"
#> .. .. ..$ col : chr "sex"
#> .. .. ..$ val : chr "male"
#> .. .. ..$ op : chr "equal"