Thursday, 2 June 2011

Photon Mapping in Haskell

For a while I have been working on extending my Haskell raytracer to include a global illumination model. I chose to implement photon mapping, as it seemed to offer both a number of challenges, and it quite suited a functional programming approach.

Photon Tracing

One of the main challenges of implementing an algorithm like photon mapping is that it relies on a small amount of mutable state due to its monte-carlo integration. As each photon is traced through the scene, a random number is required to determine the fate of a photon for a given photon-to-material interaction. A random number generator typically requires state to iterate from one number to the next. Such state is easy to incorporate into imperative code, but it presents a challenge to the Haskell programmer.

There are two main challenges with state: management and propagation of state, and selecting what code should share a given state context. For example, in the photon mapper I could have introduced a single, global state for the random number generator. Or, I could have given each light source its own state. Or each photon. The greater scope that the state has, the greater the coupling in the code and the greater the difficulty in parallelisation.

I chose to give each photon traced its own random number generator state - and each photon had a different initial seed.

-- Realistic Image Synthesis Using Photon Mapping p60
tracePhoton :: [Photon] -> Photon -> SceneGraph -> StdGen -> (Int, Int) -> [Photon]
tracePhoton !currentPhotons (Photon !photonPower !photonPosDir) sceneGraph !rndState !(bounce, maxBounces) = 
    -- See if the photon intersects any surfaces
    case findNearestIntersection sceneGraph ray of
      Nothing -> currentPhotons
      Just (obj, t, subId) -> case photonFate of
                                -- Diffuse reflection. Here, we store the photon that got reflected, and trace a new photon - but only if it's bright enough to be worthwhile
                                DiffuseReflect -> if Colour.magnitude newPhotonPower > brightnessEpsilon && (bounce + 1) <= maxBounces
                                                  then tracePhoton (storedPhoton : currentPhotons) reflectedPhoton sceneGraph rndState'' (bounce + 1, maxBounces)
                                                  else storedPhoton : currentPhotons
                                    where
                                      !reflectedPhoton = Photon newPhotonPower (surfacePos, reflectedDir)
                                      !(reflectedDir, rndState'') = diffuseReflectionDirection rndState' tanSpace


                                -- Specular reflection. Here, we reflect the photon in the fashion that the surface would reflect towards the viewer and
                                -- aim to absorb it somewhere else in the photon map
                                SpecularReflect -> if Colour.magnitude newPhotonPower > brightnessEpsilon && (bounce + 1) <= maxBounces
                                                   then tracePhoton currentPhotons reflectedPhoton sceneGraph rndState' (bounce + 1, maxBounces)
                                                   else currentPhotons
                                    where
                                      !reflectedPhoton = Photon newPhotonPower (surfacePos, reflectedDir)
                                      !reflectedDir = Vector.negate (snd photonPosDir) `reflect` normal


                                -- Absorb. The photon simply gets absorbed into the map
                                Absorb -> storedPhoton : currentPhotons
          where
            !(photonFate, rndState') = runState (choosePhotonFate coefficients) rndState
            !coefficients = russianRouletteCoefficients (material obj)
            !newPhotonPower = computeNewPhotonPower photonFate coefficients photonPower (material obj)
            !tanSpace = primitiveTangentSpace (primitive obj) subId hitPosition obj
            !normal = thr tanSpace
            !hitPosition = pointAlongRay ray t
            !surfacePos = hitPosition + (normal Vector.<*> surfaceEpsilon)
            !brightnessEpsilon = 0.1
            !storedPhoton = Photon photonPower (surfacePos, snd photonPosDir)
    where
      !ray = rayWithPosDir photonPosDir 10000

KD-Tree Construction

The resulting list of photons is inserted into a KD tree in order to efficiently locate the set of photons near to a point in space. This is a simple recursive operation:

buildKDTree :: [Photon] -> PhotonMapTree
buildKDTree photons
    | length photons == 1 = PhotonMapLeaf (head photons)
    | length photons >= 2 = let (boxMin, boxMax) = photonsBoundingBox photons
                                !axis = largestAxis (boxMax - boxMin)
                                !photonsMedian = foldr ((+) . fst . posDir) zeroVector photons Vector. fromIntegral (length photons)
                                !value = component photonsMedian axis
                                photonsGT = Prelude.filter (\x -> component ((fst . posDir) x) axis > value) photons
                                photonsLE = Prelude.filter (\x -> component ((fst . posDir) x) axis <= value) photons
                            in if length photonsGT > 0 && length photonsLE > 0
                               then PhotonMapNode axis value (buildKDTree photonsGT) (buildKDTree photonsLE)
                               else let (photons0', photons1') = trace "Using degenerate case" $ degenerateSplitList photons in PhotonMapNode axis value (buildKDTree photons0') (buildKDTree photons1')
    | otherwise = error ("Invalid case, length of array is " ++ show (length photons) ++ "\n")

KD-Tree Query

The photon mapping algorithm uses an interesting combination of a kd tree and max-heap to locate only the closest N photons to a point. The max-heap is sorted on the squared distance to a photon. This means the most distant photons are stored towards the top of the max-heap, making them easy to discard.

Traversal of the tree commences with a point of interest and a radius. Any photons found within that radius are added to the max-heap. If the max-heap exceeds its specified size then excess, distant photons are dropped from the top of the heap. Since the max-heap stores the most distant photon at the top of the heap, it is easy to monitor the current-furthest photon and tighten our search radius during traversal.


-- Use a max heap to make it easy to eliminate distant photons
data GatheredPhoton = GatheredPhoton Float Photon deriving (Show)
type PhotonHeap = MaxHeap GatheredPhoton


instance Ord GatheredPhoton where
    compare (GatheredPhoton dist1 _) (GatheredPhoton dist2 _)
        | dist1 == dist2 = EQ
        | dist1 <= dist2 = LT
        | otherwise = GT


instance Eq GatheredPhoton where
    (GatheredPhoton dist1 _) == (GatheredPhoton dist2 _) = dist1 == dist2


minimalSearchRadius !rSq !photonHeap = case viewHead photonHeap of
                                         Nothing -> rSq
                                         Just (GatheredPhoton dSq _) -> Prelude.min rSq dSq


-- Gather photons for irradiance computations
-- Algorithm adapted from Realistic Image Synthesis Using Photon Mapping p73
gatherPhotons :: PhotonMapTree -> Position -> Float -> PhotonHeap -> Int -> PhotonHeap
gatherPhotons (PhotonMapNode !axis !value gtChild leChild) !pos !rSq !photonHeap !maxPhotons
    -- In this case, the split plane bisects the search sphere - search both halves of tree
    | (value - posComponent) ** 2 <= rSq = let heap1 = gatherPhotons gtChild pos rSq' photonHeap maxPhotons
                                               rSq'' = minimalSearchRadius rSq' heap1
                                               heap2 = gatherPhotons leChild pos rSq'' photonHeap maxPhotons
                                               newHeap = union heap1 heap2
                                           in Data.Heap.drop (size newHeap - maxPhotons) newHeap


    -- One side of the tree...
    | posComponent > value = gatherPhotons gtChild pos rSq' photonHeap maxPhotons


    -- ... or the other
    | posComponent <= value = gatherPhotons leChild pos rSq' photonHeap maxPhotons


    -- Prolapse
    | otherwise = error "gatherPhotons: unexplained/unexpected case here"
    where
      !posComponent = component pos axis
      !rSq' = minimalSearchRadius rSq photonHeap -- Refine search radius as we go down tree to search no further than closest allowed photon
gatherPhotons (PhotonMapLeaf !p) !pos !rSq !photonHeap !maxPhotons
    | distSq < rSq = let newHeap = insert (GatheredPhoton distSq p) photonHeap
                     in Data.Heap.drop (size newHeap - maxPhotons) newHeap -- Discard any excess photons - we get rid of the furthest ones
    | otherwise = photonHeap
    where !distSq = pos `distanceSq` (fst . posDir) p


Parallelisation

The photon tracing stage is trivially easy to parallelise, provided that you have carefully treated the shared state required by the random number generator. Since I have provided each traced photon with it's own state, which is uniquely initialised, all photons are data independent and therefore easy to trace in parallel.

I have not yet parallelised the photon gathering stage, due to some issues with the current implementation's efficiency. Two options present themselves for parallelisation. The first is to parallelise the traversal of high-level sub-branches of the KD tree. This naive approach would yield some speedup, but the work of each thread is likely to be highly unbalanced.

Speculatively, a work-stealing queue type approach may help here. The high-level nodes to be traversed could be initially inserted into a queue. When a worker thread pulls data, it could compare how many nodes remain to be traversed against the total number of workers. If the ratio of nodes:workers becomes very low, nodes could be repeatedly removed from the queue and replaced with their child nodes, until the ratio becomes favourable. This would populate the queue with a fairly heterogenous mixture of large and small amounts of work to be done, which would help fill up and balance processor time across the threads.

Results

Here is the current output image, with all options cranked up:



The code requires refinement and tuning to eliminate the bias, improve the colour bleeding and fix a few minor artefacts.

The photon mapper is currently very slow. It can take hours on my MacBook Pro to trace a scene containing 100,000 photons, and that is a very small number of photons.

Most of the time is spent in the gathering phase. The current implementation is very space intensive and requires further serious attention. The photon emission and tracing step, however, is extremely quick, requiring only a few seconds to trace 100,000 photons.

Code

Full source code is available on github:

https://github.com/TomHammersley/HaskellRenderer

Future Work

I am now working to improve the efficiency of the gathering phase which will improve efficiency and development speed. I am also implementing an irradiance cache to further optimise execution speed. Clearly various minor bugs also need addressing.

References

Further implementation details can be found primarily in Henrik Wann Jensen's photon mapping book, supplemented by various SIGGRAPH course materials.