;;; This code was written by Patrick Stein <pat@nklein.com>
;;; Feel free to do with it as you like.

;; prepare dynamic variable for tracking the order of the polynomial
(defvar +order+)
(declaim (special +order+))

;;;-------------------------------------------------------------
;;; The overall scheme of this code is to create a system of
;;; linear equations for a given set of conditions with the
;;; variables in the equations being the coefficients of a
;;; polynomial.  For example, if the conditions were:
;;;   f(1) = 0
;;;   f'(2) = 1
;;;   f''(3) = -1
;;;   f'''(4) = 5
;;;
;;; Then we can find A, B, C, and D in f(x) = A + Bx + Cx^2 + Dx^3
;;; by solving for A, B, C, and D in the following equations:
;;;    A +  B +    C +    D =  0    because    f(x) = A + Bx +  Cx^2 +  Dx^3
;;;        2B + 2*2C + 4*3D =  1    because   f'(x) =     B  + 2Cx   + 3Dx^2
;;;               2C + 3*6D = -1    because  f''(x) =          2C    + 6Dx
;;;                      6D =  5    because f'''(x) =                  6D
;;;
;;; so, the function NTH-DERIVATIVE is responsible for coming up
;;; with the appropriate augmented matrix row.


(defun nth-derivative ( nth &key at equals )
  "Determines the right coefficients for the augmented matrix row
   for the polynomial of degree +ORDER+ of the NTH derivative where
   x is AT and the NTH derivative at x is EQUALS."
  ;;; For exercise, rather than nested looping, I implemented the
  ;;; outer loop using tail recursion.
  (labels ((times-coeffs (nth initial)
	      (mapcar #'*
		      initial
		      (loop for nn from 0 to +order+
			    collecting (max (- (1+ nn) nth) 0))))
	   (rec ( nth current )
		(if (plusp nth)
		    (times-coeffs nth (rec (1- nth) current))
		  current)))
    (list (rec nth (loop for nn from 0 to +order+
			 collecting (if (> nn nth)
					(expt at nn)
				      1)))
	  equals)))
    

(defun value ( &key at equals )
  "VALUE is a special case of NTH-DERIVATIVE where NTH is 0."
  (nth-derivative 0 :at at :equals equals))

(defun derivative ( &key at equals )
  "DERIVATIVE is a special case of NTH-DERIVATIVE where NTH is 1."
  (nth-derivative 1 :at at :equals equals))

(defun back-substitution ( augmented-matrix )
  "Given an augmented matrix that has already been Gaussian-reduced,
   find the values of the coefficients by back substitution.  Note:
   if the matrix is overspecified, this function does not return any
   information about the span of the solutions, it simply returns one
   possible solution."
  (labels ((calculate-coeff ( column row coefficients )
	     (let ((numerator (reduce #'-
				      (mapcar #'* (rest row) coefficients)
				      :initial-value (first row)))
		   (denominator (nth column row)))
	       (if (zerop denominator)
		   0
		   (/ numerator denominator))))
           ;;; Again with the tail recursion stuff here, eh?
	   (back-substitute (rows column coefficients)
	     (cond
	       ((null rows) coefficients)
	       (t (back-substitute (rest rows)
				   (1+ column)
				   (loop with row = (reverse (first rows))
				      for cc from 1 below (length row)
				      collecting (cond
						   ((< cc column)
						        (nth (1- cc)
							     coefficients))
						   ((= cc column)
						        (calculate-coeff
							    column
							    row
							    coefficients))
						   (t 0))))))))
    (nreverse (back-substitute (reverse augmented-matrix)
			       1
			       (mapcar #'(lambda (x)
					   (declare (ignore x))
					   0)
				       (rest (first augmented-matrix)))))))

(defun gaussian-elimination ( rows )
  "Performs Gaussian elimination on a (possibly augmented) matrix
   specified as a list of rows."
  (labels ((find-pivot-row ( rows column &optional pivot remaining )
	     (cond
	       ((null rows)  (values pivot remaining))
	       ((null pivot) (find-pivot-row (rest rows)
					     column
					     (first rows)
					     remaining))
	       ((> (abs (nth column (first rows)))
		   (abs (nth column pivot)))
		             (find-pivot-row (rest rows)
					     column
					     (first rows)
					     (cons pivot remaining)))
	       (t            (find-pivot-row (rest rows)
					     column
					     pivot
					     (cons (first rows) remaining)))))
	   (eliminate-row (row pivot pp column)
	     (let ((numerator (nth column row))
		   (denominator pp))
	       (cond
		 ((zerop denominator) row)
		 ((zerop numerator) row)
		 (t (let ((ratio (/ numerator denominator)))
		      (mapcar #'(lambda (a b)
				  (- a (* b ratio)))
			      row
			      pivot))))))
	   (scale-pivot ( pivot column )
	     (let ((denominator (nth column pivot)))
	       (cond
		 ((zerop denominator) pivot)
		 (t (mapcar #'(lambda (v)
				(/ v denominator))
			    pivot)))))
	   (eliminate ( rows column &optional complete )
	     (cond
	       ((null rows) complete)
	       (t (multiple-value-bind (pivot remaining)
		      (find-pivot-row rows column)
		    (eliminate (loop with pp = (nth column pivot)
				  for row in remaining
				  collecting (eliminate-row row
							    pivot
							    pp
							    column))
			       (1+ column)
			       (cons (scale-pivot pivot column) complete)))))))
    (nreverse (eliminate rows 0))))

(defun linear-solve (matrix equals)
  "Do Gaussian elimination with back-substitution to solve a linear system."
  (back-substitution (gaussian-elimination (mapcar #'(lambda (row value)
						       (append row
							       (list value)))
						   matrix equals))))

(defmacro calculate-polynomial-subject-to (&body conditions)
  "The conditions here should be calls to VALUE, DERIVATIVE, or
   NTH-DERIVATIVE.  This macro prepares the +ORDER+ variable so
   those functions know how much is expected of them.  Then, it
   solves the linear system specified by those conditions."   
  (let ((augmented-rows (gensym "AUGMENTED-ROWS")))
    `(let ((+order+ (1- ,(length conditions))))
       (let ((,augmented-rows (list ,@conditions)))
	 (linear-solve (mapcar #'first ,augmented-rows)
		       (mapcar #'second ,augmented-rows))))))

(defun polynomial-to-string ( coefficients )
  "A pretty poor pretty-printer for polynomials given the coefficients."
  (apply #'concatenate 'string
	 (remove nil (loop for nn from 0
			for cc in coefficients
			collecting (unless (zerop cc)
				     (format nil "+ ~A*x^~A " cc nn))))))

;;;;;;;;;;;;;; some tests...
(polynomial-to-string
  (calculate-polynomial-subject-to
    (value :at 0 :equals 0)
    (derivative :at 0 :equals 0)
    (nth-derivative 2 :at 0 :equals 0)
    (value :at 1 :equals 1)
    (derivative :at 1 :equals 0)
    (nth-derivative 2 :at 1 :equals 0)
    (value :at 1/2 :equals 3/4)))

(polynomial-to-string
  (calculate-polynomial-subject-to
    (value :at 0 :equals 0)
    (derivative :at 0 :equals 0)
    (nth-derivative 2 :at 0 :equals -1)
    (derivative :at 1/2 :equals 0)
    (value :at 1/2 :equals 5/4)
    (derivative :at 7/8 :equals 0)
    (value :at 7/8 :equals 15/16)
    (value :at 1 :equals 1)))

(polynomial-to-string
  (calculate-polynomial-subject-to
    (value :at 0 :equals 0)
    (derivative :at 0 :equals 0)
    (nth-derivative 2 :at 0 :equals 1)
    (nth-derivative 2 :at 1 :equals -5)
    (value :at 1 :equals 1)
    (derivative :at 1 :equals 0)))

(polynomial-to-string
  (calculate-polynomial-subject-to
    (value :at 0 :equals 0)
    (derivative :at 0 :equals 0)
    (nth-derivative 2 :at 0 :equals -2)
    (value :at 1 :equals 1)
    (value :at .77 :equals 1.1)
    (derivative :at .77 :equals 0)
    (derivative :at .88 :equals 0)
    (derivative :at .93 :equals 0)
    (derivative :at .97 :equals 0)))