Inverse functions with fixed-points July 18th, 2013
Patrick Stein

Introduction

SICP has a few sections devoted to using a general, damped fixed-point iteration to solve square roots and then nth-roots. The Functional Programming In Scala course that I did on Coursera did the same exercise (at least as far as square roots go).

The idea goes like this. Say that I want to find the square root of five. I am looking then for some number s so that s^2 = 5. This means that I’m looking for some number s so that s = \frac{5}{s}. So, if I had a function f(x) = \frac{5}{x} and I could find some point s where s = f(s), I’d be done. Such a point is called a fixed point of f.

There is a general method by which one can find a fixed point of an arbitrary function. If you type some random number into a calculator and hit the “COS” button over and over, your calculator is eventually going to get stuck at 0.739085…. What happens is that you are doing a recurrence where x_{n+1} = \cos(x_n). Eventually, you end up at a point where x_{n+1} = x_{n} (to the limits of your calculator’s precision/display). After that, your stuck. You’ve found a fixed point. No matter how much you iterate, you’re going to be stuck in the same spot.

Now, there are some situations where you might end up in an oscillation where x_{n+1} \ne x_n, but x_{n+1} = x_{n-k} for some k \ge 1. To avoid that, one usually does the iteration x_{n+1} = \mathsf{avg}(x_n,f(x_n)) for some averaging function \mathsf{avg}. This “damps” the oscillation.

The Fixed Point higher-order function

In languages with first-class functions, it is easy to write a higher-order function called fixed-point that takes a function and iterates (with damping) to find a fixed point. In SICP and the Scala course mentioned above, the fixed-point function was written recursively.

(defun fixed-point (fn &optional (initial-guess 1) (tolerance 1e-8))
  (labels ((close-enough? (v1 v2)
             (<= (abs (- v1 v2)) tolerance))
           (average (v1 v2)
             (/ (+ v1 v2) 2))
           (try (guess)
             (let ((next (funcall fn guess)))
               (cond
                ((close-enough? guess next) next)
                (t (try (average guess next)))))))
    (try (* initial-guess 1d0))))

It is easy to express the recursion there iteratively instead if that’s easier for you to see/think about.

(defun fixed-point (fn &optional (initial-guess 1) (tolerance 1e-8))
  (flet ((close-enough? (v1 v2)
            (<= (abs (- v1 v2)) tolerance))
         (average (v1 v2)
            (/ (+ v1 v2) 2)))
    (loop :for guess = (* initial-guess 1d0) :then (average guess next)
          :for next = (funcall fn guess)
          :until (close-enough? guess next)
          :finally (return next))))

Using the Fixed Point function to find k-th roots

Above, we showed that the square root of n is a fixed point of the function f(x) = \frac{n}{x}. Now, we can use that to write our own square root function:

(defun my-sqrt (n)
  (fixed-point (lambda (x) (/ n x)))

By the same argument we used with the square root, we can find the k-th root of 5 by finding the fixed point of f(x) = \frac{5}{x^{k-1}}. We can make a function that returns a function that does k-th roots:

(defun kth-roots (k)
  (lambda (n)
    (fixed-point (lambda (x) (/ n (expt x (1- k)))))))

(setf (symbol-function 'cbrt) (kth-root 3))

Inverting functions

I found myself wanting to find inverses of various complicated functions. All that I knew about the functions was that if you restricted their domain to the unit interval, they were one-to-one and their domain was also the unit interval. What I needed was the inverse of the function.

For some functions (like f(x) = x^2), the inverse is easy enough to calculate. For other functions (like f(x) = 6x^5 - 15x^4 + 10x^3), the inverse seems possible but incredibly tedious to calculate.

Could I use fixed points to find inverses of general functions? We’ve already used them to find inverses for f(x) = x^k. Can we extend it further?

After flailing around Google for quite some time, I found this article by Chen, Lu, Chen, Ruchala, and Olivera about using fixed-point iteration to find inverses for deformation fields.

There, the approach to inverting f(x) was to formulate u(x) = f(x) - x and let v(x) = f^{-1}(x) - x. Then, because

x = f(f^{-1}(x)) =  f^{-1}(x) + u(f^{-1}(x)) = x + v(x) + u(x + v(x))

That leaves the relationship that v(x) = -u(x + v(x)). The goal then is to find a fixed point of v(x).

I messed this up a few times by conflating f and u so I abandoned it in favor of the tinkering that follows in the next section. Here though, is a debugged version based on the cited paper:

(defun pseudo-inverse (fn &optional (tolerance 1d-10))
  (lambda (x)
    (let ((iterant (lambda (v)
                     (flet ((u (x)
                               (- (funcall fn x) x)))
                       (- (u (+ x v)))))))
      (+ x (fixed-point iterant 0d0 tolerance)))))

Now, I can easily check the average variance over some points in the unit interval:

(defun check-pseudo-inverse (fn &optional (steps 100))
  (flet ((sqr (x) (* x x)))
    (/ (loop :with dx = (/ (1- steps))
             :with inverse = (pseudo-inverse fn)
             :repeat steps
             :for x :from 0 :by dx
             :summing (sqr (- (funcall fn (funcall inverse x)) x)))
       steps)))

(check-pseudo-inverse #'identity) => 0.0d0
(check-pseudo-inverse #'sin)      => 2.8820112095939962D-12
(check-pseudo-inverse #'sqrt)     => 2.7957469632748447D-19                                                          
(check-pseudo-inverse (lambda (x) (* x x x (+ (* x (- (* x 6) 15)) 10))))
                                  => 1.3296561385041381D-21

A tinkering attempt when I couldn’t get the previous to work

When I had abandoned the above, I spent some time tinkering on paper. To find f^{-1}(x), I need to find y so that f(y) = x. Multiplying both sides by y and dividing by f(y), I get y = \frac{xy}{f(y)}. So, to find f^{-1}(x), I need to find a y that is a fixed point for \frac{xy}{f(y)}:

(defun pseudo-inverse (fn &optional (tolerance 1d-10))
  (lambda (x)
    (let ((iterant (lambda (y)
                     (/ (* x y) (funcall fn y)))))
      (fixed-point iterant 1 tolerance))))

This version, however, has the disadvantage of using division. Division is more expensive and has obvious problems if you bump into zero on your way to your goal. Getting rid of the division also allows the above algorithms to be generalized for inverting endomorphisms of vector spaces (the \mathsf{avg} function being the only slightly tricky part).

Denouement

I finally found a use of the fixed-point function that goes beyond k-th roots. Wahoo!

Calculating the mean and variance with one pass February 15th, 2011
Patrick Stein

A friend showed me this about 15 years ago. I use it every time I need to calculate the variance of some data set. I always forget the exact details and have to derive it again. But, it’s easy enough to derive that it’s never a problem.

I had to derive it again on Friday and thought, I should make sure more people get this tool into their utility belts.

First, a quick refresher on what we’re talking about here. The mean \mu of a data set { x_1, x_2, \ldots, x_n } is defined to be \frac{1}{n} \sum_{i=1}^n x_i. The variance \sigma^2 is defined to be \frac{1}{n} \sum_{i=1}^n (x_i - \mu)^2.

A naïve approach to calculating the variance then goes something like this:

(defun mean-variance (data)
  (flet ((square (x) (* x x)))
    (let* ((n (length data))
           (sum (reduce #'+ data :initial-value 0))
           (mu (/ sum n))
           (vv (reduce #'(lambda (accum xi)
                           (+ accum (square (- xi mu))))
                       data :initial-value 0)))
      (values mu (/ vv n)))))

This code runs through the data list once to count the items, once to calculate the mean, and once to calculate the variance. It is easy to see how we could count the items at the same time we are summing them. It is not as obvious how we can calculate the sum of squared terms involving the mean until we’ve calculated the mean.

If we expand the squared term and pull the constant \mu outside of the summations it ends up in, we find that:

\frac{\sum (x_i - \mu)^2}{n} = \frac{\sum x_i^2}{n} - 2 \mu \frac{\sum x_i}{n} + \mu^2 \frac{\sum 1}{n}

When we recognize that \frac{\sum x_i}{n} = \mu and \sum_{i=1}^n 1 = n, we get:

\sigma^2 = \frac{\sum x_i^2}{n} - \mu^2 = \frac{\sum x_i^2}{n} - \left( \frac{\sum x_i}{n} \right)^2
.

This leads to the following code:

(defun mean-variance (data)
  (flet ((square (x) (* x x)))
    (destructuring-bind (n xs x2s)
        (reduce #'(lambda (accum xi)
                    (list (1+ (first accum))
                          (+ (second accum) xi)
                          (+ (third accum) (square xi))))
                data :initial-value '(0 0 0))
      (let ((mu (/ xs n)))
        (values mu (- (/ x2s n) (square mu)))))))

The code is not as simple, but you gain a great deal of flexibility. You can easily convert the above concept to continuously track the mean and variance as you iterate through an input stream. You do not have to keep data around to iterate through later. You can deal with things one sample at a time.

The same concept extends to higher-order moments, too.

Happy counting.

Edit: As many have pointed out, this isn’t the most numerically stable way to do this calculation. For my part, I was doing it with Lisp integers, so I’ve got all of the stability I could ever want. 🙂 But, yes…. if you are intending to use these numbers for big-time decision making, you probably want to look up a really stable algorithm.

Updates In Email

Email:

l