"""
Edward Dale
2006-2-1
Computer Graphics 2
Ray Tracer

Copyright (c) 2006, Edward Dale
All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of Edward Dale nor the names of its contributors
  may be used to endorse or promote products derived from this software
  without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

import ConfigParser
from math import pi
from random import *
import timing
import sys

from cgtypes import vec3, mat4 # http://cgkit.sf.net
from PIL import Image
from PIL import ImageDraw

from primitives import *
from shaders import *

class Output:
    """
    The output of the ray tracer.
    """
    def __init__(self, config):
        """
        Initialize things.
        """
        section="Output"
        self.filename=config.get(section, "filename")
        self.realheight=config.getint(section, "height")
        self.realwidth=config.getint(section, "width")
        self.supersample=config.getint(section, "supersample")
        self.height=self.realheight * self.supersample
        self.width=self.realwidth * self.supersample
        self.bgcolor=eval(config.get(section, "bgcolor"))
        self.open()

    def open(self):
        """
        Opens the output file.
        """
        self.im = Image.new( "RGB", (self.width, self.height), self.bgcolor.pilstr() )
        self.draw = ImageDraw.Draw( self.im )
    
    def close(self):
        """
        Does the finishing touches on the output file.
        """
        del self.draw
        output = open(self.filename, "w")
        if self.supersample>1:
            self.im.resize((self.realwidth, self.realheight), Image.ANTIALIAS).save(output)
        else:
            self.im.save(output)
        output.close()
    
    def drawPixel(self, x, y, color):
        """
        Draws a pixel on the output.
        """
        self.draw.point((x,y), color.pilstr())

class Scene:
    section="Scene"
    """
    Represents the scene and all the parameters of it including the
    objects therein and the camera parameters.
    """
    def __init__(self, config):
        """
        Initializes everything.
        """
        self.defaultshader = eval(config.get(Scene.section, "shader"))
        self.jitter = config.getfloat(Scene.section, "jitter")
        self.refractindex=config.getfloat(Scene.section, "refractindex")
        self.depth = config.getint(Scene.section, "depth")
        self.reflectrays = config.getint(Scene.section, "reflectrays")
        self.refractrays = config.getint(Scene.section, "refractrays")
        
        self.objects = self.buildObjects(config)
        self.lights = self.buildLights(config)
        self.ambient = Lighting()

        self.d =config.getint(Scene.section, "eyedistance")
        
        self.eye = vec3(config.get(Scene.section, "eye"))
        self.up = vec3(config.get(Scene.section, "up"))
        self.lookat = vec3(config.get(Scene.section, "lookat"))
        self.dir = (self.eye-self.lookat).normalize()
        
        trans = mat4().lookAt(self.eye,self.lookat, self.up).inverse()
        self.rayorigin = vec3(trans*self.eye)
        
        viewnorm = -self.dir
        pointonplane = vec3((self.dir*self.d) + self.rayorigin)
        
        self.ray = Ray( self.rayorigin, pointonplane, "Ray" )
        self.viewPlane = ViewPlane( viewnorm, pointonplane )
        
        self.planeCenter = self.viewPlane.intersect( self.ray )[0][0]

        # Translate the objects&lights into camera coordinates        
        for i in self.objects+self.lights:
            i.translate( trans )

    def fireRay(self, ray, object=None):
        """
        Return the closest intersecting object in the scene.
        Ignores 'object' to prevent intersections with ray originating
        at 'object'.
        """
        isects = []
        for i in self.objects:
            if i != object:
                isects += i.intersect(ray)

        if isects:
            deco = [ (abs(isect-ray.origin), isect, normal, object) for isect, normal, object in isects ]
            deco.sort()
            return deco[0]
        else:
            return None

    def buildLights(self, config):
        """
        Creates and returns the lights for the scene.
        """
        lights=[]
        lightsConfig=config.get(Scene.section, "lights").split()
        for light in lightsConfig:
            pos = vec3(config.get(light, "pos"))
            try:
                lighting=eval(config.get(object, "lighting"))
            except:
                lighting=Lighting()
            lights.append( PointLight(pos=pos, lighting=lighting) )
        return lights

    def buildObjects(self, config):
        """
        Creates all of the objects for the scene.
        """
        objects=[]
        objectsConfig=config.get(Scene.section, "objects").split()
        for object in objectsConfig:
            type=config.get(object, "type")
            newprim = eval(type).create(config, section=object)
            try:
                shader=eval(config.get(object, "shader"))
                newprim.shader=shader
            except: pass
            try:
                material=eval(config.get(object, "lighting"))
                newprim.lighting=material
            except:
                newprim.lighting=Lighting()
            
            objects.append( newprim )
        return objects

def illuminate(ray, depth=0, source=None, inside=False):
    """
    This is the global illumination step.  It will shoot the ray into the scene
    and spawn additional rays at each intersection.  Depth is how many rays 
    have been spawned on this path so far.  Source is used to prevent rays
    from intersecting their source object.
    """
    global scene, output
    i=scene.fireRay(ray, source)
    if not i:
        return output.bgcolor
    else:
        isectDist, isect, isectNormal, isectObject = i
        shader = scene.defaultshader
        if isectObject.shader is not None: shader = isectObject.shader 
        color = shader.getColor(isectObject, isectNormal, isect)
        if depth<scene.depth:
            # If there's reflection, do reflection
            if isectObject.kr > 0:
                illums = []
                for x in xrange(scene.reflectrays):
                    reflect = ray.reflect( origin=isect, normal=isectNormal )
                    jitter=scene.jitter
                    if jitter:
                        reflect.direction.x += uniform(-jitter,jitter)
                        reflect.direction.y += uniform(-jitter,jitter)
                        reflect.direction.z += uniform(-jitter,jitter)
                        reflect.direction = reflect.direction.normalize()
                    illums.append(illuminate(reflect, depth+1, isectObject))
                avg = Color.average(illums)
                color += isectObject.kr * avg

            # If there's transmission, do transmission
            if isectObject.kt > 0:
                illums = []
                for x in xrange(scene.refractrays):
                    # Determine the normals and relative index of refraction
                    normal=isectNormal
                    eta=scene.refractindex/isectObject.refractindex
                    if inside:
                        eta = 1.0/eta
                        normal=-isectNormal

                    refract = ray.refract( origin=isect, 
                                           normal=normal, 
                                           eta=eta )
                    # If there's total internal reflectance, then make the
                    # refract ray a reflect ray.
                    # http://www.cs.stevens.edu/~quynh/courses/cs537-notes/lesson8-RayTracing1.ppt
                    if not refract:
                        refract = ray.reflect( origin=isect, normal=normal)
                        
                    if scene.jitter:
                        jitter=scene.jitter
                        refract.direction.x += uniform(-jitter,jitter)
                        refract.direction.y += uniform(-jitter,jitter)
                        refract.direction.z += uniform(-jitter,jitter)
                        refract.direction = refract.direction.normalize()
                    illums.append(illuminate(refract, depth+1, 
                                             isectObject, not inside))
                avg = Color.average(illums)
                color += isectObject.kt * avg
        return color

def raytrace(config):
    """
    Runs the ray tracing algorithm using the parameters stored in the ConfigParser
    object passed in.
    """
    global scene, output
    output = Output(config)
    scene = Scene(config)

    try:
        timing.start()
        eyeray = Ray(origin=scene.rayorigin)
        # TODO: Use a generator for ray generation
        for y in range(-output.height/2, output.height/2):
            print (float(y)/output.height)+0.5
            for x in range(-output.width/2, output.width/2):
                target = scene.planeCenter + vec3( x*float(scene.d)/output.width, y*float(scene.d)/output.height , 0)
                # Only jitter the ray if there is supersampling and a jitter amount
                if scene.jitter and output.supersample>1:
                    target.x += uniform(-scene.jitter,scene.jitter)
                    target.y += uniform(-scene.jitter,scene.jitter)
                    target.z += uniform(-scene.jitter,scene.jitter)
                
                # Renormalize the ray after jittering
                eyeray.direction=target.normalize()
                # Everything is illuminated
                color = illuminate( eyeray )
                output.drawPixel(output.width/2-x-1,output.height/2-y-1, color )
    
    finally:
        timing.finish()
        print "seconds:%d" % timing.seconds()
        if output:
            output.close()

if __name__ == "__main__":
    # The raytracer is being run on the command line, so get the filename and 
    # start ray tracing.
    try:
        config = ConfigParser.ConfigParser()
        config.readfp(open(sys.argv[1]))
        raytrace(config)
    except IndexError, IOError:
        print "Error loading configuration."
