Abstract
This is the supplementaryR
file for regression splines in the lecture note “Spline”.
We use the U.S. birth rate data as an example. The data records birth rates from 1917 to 2003. The birth rate trend is obviously very nonlinear.
load("../data/birthrates.Rda")
head(birthrates)
## Year Birthrate
## 1917 1917 183.1
## 1918 1918 183.9
## 1919 1919 163.1
## 1920 1920 179.5
## 1921 1921 181.4
## 1922 1922 173.4
par(mar = c(4,4,1,1))
plot(birthrates, pch = 19, col = "darkorange")
It might be interesting to fit a linear regression with high order polynomials to approximate this curve. This can be carried out using the poly()
function, which calculates all polynomials up to a certain power. Please note that this is a more stable method compared with writing out the powers such as I(Year^2)
, I(Year^3)
etc because the Year
variable is very large, and is numerically unstable.
par(mfrow=c(1,2))
par(mar = c(2,3,2,0))
lmfit <- lm(Birthrate ~ poly(Year, 3), data = birthrates)
plot(birthrates, pch = 19, col = "darkorange")
lines(birthrates$Year, lmfit$fitted.values, lty = 1, col = "deepskyblue", lwd = 2)
title("degree = 3")
par(mar = c(2,3,2,0))
lmfit <- lm(Birthrate ~ poly(Year, 5), data = birthrates)
plot(birthrates, pch = 19, col = "darkorange")
lines(birthrates$Year, lmfit$fitted.values, lty = 1, col = "deepskyblue", lwd = 2)
title("degree = 5")
These fittings do not seem to perform very well. How about we take a different approach, model the curve locally. Well, we know there is an approach that works in a similar way – \(k\)NN. But we will try something new. Let’s first divide the year range into several non-overlapping intervals, say, every 10 years. Then we will estimate the regression coefficients within each interval by averaging the observations, just like \(k\)NN. The only difference is that for prediction, we do not recalculate the neighbors anymore, just check the intervals.
par(mar = c(2,2,2,0))
mybasis = matrix(NA, nrow(birthrates), 8)
for (l in 1:8)
mybasis[, l] = birthrates$Year*(birthrates$Year >= 1917 + l*10)
lmfit <- lm(birthrates$Birthrate ~ ., data = data.frame(mybasis))
plot(birthrates, pch = 19, col = "darkorange")
lines(birthrates$Year, lmfit$fitted.values, lty = 1, type = "s", col = "deepskyblue", lwd = 2)
title("Histgram Regression")
The method is called a histogram regression. Suppose the interval that contains a given testing point \(x\) is \(\phi(x)\), then, we are fitting a model with
\[\widehat{f}(x) = \frac{\sum_{i=1}^n Y_i \,\, I\{X_i \in \phi(x)\} }{ \sum_{i=1}^n I\{X_i \in \phi(x)\}}\]
You may know the word histogram from the plotting the density of a set of observations. Yes, these two are actually motivated by the same philosophy. We will discuss the connection later on. For the purpose of fitting a regression function, the histogram regression does not seem to perform ideally since there will be jumps at the edge of an interval. Hence we need a more flexible framework.
Instead of fitting constant functions within each interval (between two knots), we may consider fitting a line. Consider a simpler case, where we just use 3 knots at 1938, 1960, 1978, which gives 4 intervals.
par(mfrow=c(1,2))
myknots = c(1936, 1960, 1978)
bounds = c(1917, myknots, 2003)
# piecewise constant
mybasis = cbind("x_1" = (birthrates$Year < myknots[1]),
"x_2" = (birthrates$Year >= myknots[1])*(birthrates$Year < myknots[2]),
"x_3" = (birthrates$Year >= myknots[2])*(birthrates$Year < myknots[3]),
"x_4" = (birthrates$Year >= myknots[3]))
lmfit <- lm(birthrates$Birthrate ~ . -1, data = data.frame(mybasis))
par(mar = c(2,3,2,0))
plot(birthrates, pch = 19, col = "darkorange")
abline(v = myknots, lty = 2)
title("Piecewise constant")
for (k in 1:4)
points(c(bounds[k], bounds[k+1]), rep(lmfit$coefficients[k], 2), type = "l", lty = 1, col = "deepskyblue", lwd = 4)
# piecewise linear
mybasis = cbind("x_1" = (birthrates$Year < myknots[1]),
"x_2" = (birthrates$Year >= myknots[1])*(birthrates$Year < myknots[2]),
"x_3" = (birthrates$Year >= myknots[2])*(birthrates$Year < myknots[3]),
"x_4" = (birthrates$Year >= myknots[3]),
"x_11" = birthrates$Year*(birthrates$Year < myknots[1]),
"x_21" = birthrates$Year*(birthrates$Year >= myknots[1])*(birthrates$Year < myknots[2]),
"x_31" = birthrates$Year*(birthrates$Year >= myknots[2])*(birthrates$Year < myknots[3]),
"x_41" = birthrates$Year*(birthrates$Year >= myknots[3]))
lmfit <- lm(birthrates$Birthrate ~ .-1, data = data.frame(mybasis))
par(mar = c(2,3,2,0))
plot(birthrates, pch = 19, col = "darkorange")
abline(v = myknots, lty = 2)
title("Piecewise linear")
for (k in 1:4)
points(c(bounds[k], bounds[k+1]), lmfit$coefficients[k] + c(bounds[k], bounds[k+1])*lmfit$coefficients[k+4],
type = "l", lty = 1, col = "deepskyblue", lwd = 4)
However, these functions are not continuous. Hence we use a trick to construct continuous basis:
pos <- function(x) x*(x>0)
mybasis = cbind("int" = 1, "x_1" = birthrates$Year,
"x_2" = pos(birthrates$Year - myknots[1]),
"x_3" = pos(birthrates$Year - myknots[2]),
"x_4" = pos(birthrates$Year - myknots[3]))
par(mar = c(2,2,2,0))
matplot(birthrates$Year, mybasis[, -1], type = "l", lty = 1,
yaxt = 'n', ylim = c(0, 50), lwd = 2)
title("Spline Basis Functions")
With this definition, any fitted model will be
The resulted model is called a spline.
lmfit <- lm(birthrates$Birthrate ~ .-1, data = data.frame(mybasis))
par(mar = c(2,3,2,0))
plot(birthrates, pch = 19, col = "darkorange")
lines(birthrates$Year, lmfit$fitted.values, lty = 1, col = "deepskyblue", lwd = 4)
abline(v = myknots, lty = 2)
title("Linear Spline")
Of course, writing this out explicitly is very tedious, hence we have the bs
function in the splines
package to help us.
par(mar = c(2,2,2,0))
lmfit <- lm(Birthrate ~ splines::bs(Year, degree = 1, knots = myknots), data = birthrates)
plot(birthrates, pch = 19, col = "darkorange")
lines(birthrates$Year, lmfit$fitted.values, lty = 1, col = "deepskyblue", lwd = 4)
title("Linear spline with the bs() function")
The next step is to increase the degree to account for more complicated functions. A few things we need to consider here:
For example, we consider this setting
par(mar = c(2,2,2,0))
lmfit <- lm(Birthrate ~ splines::bs(Year, degree = 3, knots = myknots), data = birthrates)
plot(birthrates, pch = 19, col = "darkorange")
lines(birthrates$Year, lmfit$fitted.values, lty = 1, col = "deepskyblue", lwd = 4)
title("Cubic spline with 3 knots")
All of them affects the performance. In particular, the number of knots and the number of degrees in each region will determine the total number of degrees of freedom. For simplicity, we can control that using the df
parameter. We use a total of 6 parameters, chosen by the function automatically. However, this does not seems to perform better than the knots we implemented. The choice of knots can be crucial.
par(mar = c(2,2,2,0))
lmfit <- lm(Birthrate ~ splines::bs(Year, df = 5), data = birthrates)
plot(birthrates, pch = 19, col = "darkorange")
lines(birthrates$Year, lmfit$fitted.values, lty = 1, col = "deepskyblue", lwd = 4)
title("Linear spline with 6 degrees of parameters")
There are different ways to construct spline basis. We used two techniques previously, the regression spline and basis spline (B-spline). The B-spline has slight more advantages computationally. Here is a comparision of B-spline with different degrees.
par(mfrow = c(4, 1), mar = c(0, 0, 2, 0))
for (d in 0:3)
{
bs_d = splines2::bSpline(1:100, degree = d, knots = seq(10, 90, 10), intercept = TRUE)
matplot(1:100, bs_d, type = ifelse(d == 0, "s", "l"), lty = 1, ylab = "spline",
xaxt = 'n', yaxt = 'n', ylim = c(-0.05, 1.05), lwd = 2)
title(paste("degree =", d))
}
Extrapolations are generally dangerous because the functions could be extream outside the range of the observed data. In linear models fit by bs()
, extrapolations outside the boundaries will trigger a warning.
library(splines)
fit.bs = lm(Birthrate ~ bs(Year, df=6), data=birthrates)
plot(birthrates$Year, birthrates$Birthrate, ylim=c(0,280), pch = 19,
xlim = c(1900, 2020), xlab = "year", ylab = "rate", col = "darkorange")
lines(seq(1900, 2020), predict(fit.bs, data.frame("Year"= seq(1900, 2020))), col="deepskyblue", lty=1, lwd = 3)
## Warning in bs(Year, degree = 3L, knots = c(`25%` = 1938.5, `50%` = 1960, : some 'x' values beyond boundary knots may cause ill-conditioned bases
fit.ns = lm(Birthrate ~ ns(Year, df=6), data=birthrates)
lines(seq(1900, 2020), predict(fit.ns, data.frame("Year"= seq(1900, 2020))), col="darkgreen", lty=1, lwd = 3)
legend("topright", c("Cubic B-Spline", "Natural Cubic Spline"), col = c("deepskyblue", "darkgreen"), lty = 1, lwd = 3, cex = 1.2)
title("Birth rate extrapolation")
Hence this motivates us to consider setting additional constrains that forces the extrapolations to be come more regular. This is done by forcing the second and thrid derivatives to be 0 if beyond the two extreme knots.
par(mar = c(0, 2, 2, 0))
ncs = ns(1:100, df = 6, intercept = TRUE)
matplot(1:100, ncs, type = "l", lty = 1, ylab = "spline",
xaxt = 'n', yaxt = 'n', lwd = 3)
title("Natural Cubic Spline")
The smoothing spline can be done by using the smooth.spline
package. However, since the birthrate data has little variation in adjacent years, over-fitting is quite severe. The function will automatically use GCV to tune the parameter.
# smoothing spline
fit = smooth.spline(birthrates$Year, birthrates$Birthrate)
plot(birthrates$Year, birthrates$Birthrate, pch = 19,
xlab = "Year", ylab = "BirthRates", col = "darkorange")
lines(seq(1917, 2003), predict(fit, seq(1917, 2003))$y, col="deepskyblue", lty=1, lwd = 3)
# the degrees of freedom is very large
fit$df
## [1] 60.7691
Let’s look at another simulation example, where this method performs resonabaly well.
set.seed(1)
n = 100
x = seq(0, 1, length.out = n)
y = sin(12*(x+0.2))/(x+0.2) + rnorm(n)
# fit smoothing spline
fit = smooth.spline(x, y)
# the degrees of freedom
fit$df
## [1] 9.96443
# fitted model
plot(x, y, pch = 19, xlim = c(0, 1), xlab = "x", ylab = "y", col = "darkorange")
lines(x, sin(12*(x+0.2))/(x+0.2), col="red", lty=1, lwd = 3)
lines(x, predict(fit, x)$y, col="deepskyblue", lty=1, lwd = 3)
legend("bottomright", c("Truth", "Smoothing Splines"), col = c("red", "deepskyblue"), lty = 1, lwd = 3, cex = 1.2)
Since all spline approaches can be transformed into some kind of linear model, if we postulate an additive structure, we can fit a multivariate model with
\[f(x) = \sum_j h_j(x_j) = \sum_j \sum_k N_{jk}(x_j) \beta_{jk}\] where \(h_j(x_j)\) is a univariate function for \(x_j\) that can be approximated by splines basis \(N_{jk}(\cdot), k = 1, \ldots, K\). This works for both linear regression and generalized linear regressions. For the South Africa Heart Disease data, we use the gam()
function in the gam
(generalized additive models) package. We compute a logistic regression model using natural splines (note famhist is included as a factor).
library(ElemStatLearn)
library(gam)
form = formula("chd ~ ns(sbp,df=4) + ns(tobacco,df=4) +
ns(ldl,df=4) + famhist + ns(obesity,df=4) +
ns(alcohol,df=4) + ns(age,df=4)")
# note that we can also do
# m = glm(form, data=SAheart, family=binomial)
# print(summary(m), digits=3)
# however, the gam function provides more information
m = gam(form, data=SAheart, family=binomial)
## Warning in model.matrix.default(mt, mf, contrasts): non-list contrasts argument ignored
summary(m)
##
## Call: gam(formula = form, family = binomial, data = SAheart)
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -1.7245 -0.8265 -0.3884 0.8870 2.9589
##
## (Dispersion Parameter for binomial family taken to be 1)
##
## Null Deviance: 596.1084 on 461 degrees of freedom
## Residual Deviance: 457.6318 on 436 degrees of freedom
## AIC: 509.6318
##
## Number of Local Scoring Iterations: 6
##
## Anova for Parametric Effects
## Df Sum Sq Mean Sq F value Pr(>F)
## ns(sbp, df = 4) 4 6.31 1.5783 1.4242 0.224956
## ns(tobacco, df = 4) 4 18.09 4.5218 4.0802 0.002941 **
## ns(ldl, df = 4) 4 12.05 3.0137 2.7194 0.029290 *
## famhist 1 19.70 19.7029 17.7788 3.019e-05 ***
## ns(obesity, df = 4) 4 3.66 0.9161 0.8266 0.508701
## ns(alcohol, df = 4) 4 1.28 0.3200 0.2887 0.885278
## ns(age, df = 4) 4 17.64 4.4100 3.9794 0.003496 **
## Residuals 436 483.19 1.1082
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
par(mfrow = c(3, 3), mar = c(5, 5, 2, 0))
plot(m, se = TRUE, residuals = TRUE)