"""
Edward Dale
2006-1-27
Computer Graphics 2
Ray Tracer

Copyright (c) 2005, Edward Dale
All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of Edward Dale nor the names of its contributors
  may be used to endorse or promote products derived from this software
  without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

from cgtypes import vec3, mat4
from math import sqrt,atan,acos,pi,modf

class Color2:
    def __init__(self, r=0, g=0, b=0):
        self.r=int(r)
        self.g=int(g)
        self.b=int(b)
    def average(illums):
        """
        Accepts a list of Color objects and returns the average.
        """
        racc=0
        gacc=0
        bacc=0
        for illum in illums:
            racc += illum.r
            gacc += illum.g
            bacc += illum.b
        avg = Color2(racc/len(illums), gacc/len(illums), bacc/len(illums))
        return avg
        
    average=staticmethod(average)
    
    def __mul__(self,other):
        """
        Lets two colors be multiplied together.
        A color can also be multiplied by a float or int.
        
        >>> a=Color(0.2, 0.2, 0.2)
        >>> print 2*a
        rgb( 40%, 40%, 40%)
        >>> print a*a
        rgb(  4%,  4%,  4%)
        >>> print 0.3*a
        rgb(  6%,  6%,  6%)
        """
        if isinstance(other,float) or isinstance(other,int):
            return Color2(self.r*other, self.g*other, self.b*other)
        
        return Color2(self.r*other.r, self.g*other.g, self.b*other.b)
        
    __rmul__=__mul__
    
    def __add__(self, other):
        """
        Lets two colors be added together.
        
        >>> a=Color(0.2, 0.2, 0.2)
        >>> b=Color(0.2, 0.3, 0.4)
        >>> c=a+b
        >>> print c
        rgb( 40%, 50%, 60%)
        """
        return Color2(self.r+other.r, self.g+other.g, self.b+other.b)
    
    __radd__=__add__
    
    def pilstr(self):
        """
        Returns a string that can be sent to PIL as a pixel color.
        """
        return "rgb(%3d%%,%3d%%,%3d%%)" % (max(0,self.r)/255.0,
                                           max(0,self.g)/255.0,
                                           max(0,self.b)/255.0)
    
    def __str__(self):
        """
        Pretty-prints the color.
        """
        return self.pilstr()

class Color:
    """
    Represents an RGB color using values in (0.0,1.0).
    """
    
    def __init__(self, r=0.0, g=0.0, b=0.0):
        """
        Creates a color using the given values.

        >>> a=Color()
        >>> print a
        rgb(  0%,  0%,  0%)
        >>> a=Color(0.5, 0.5, 0.5)
        >>> print a
        rgb( 50%, 50%, 50%)
        """
        self.r=r
        self.g=g
        self.b=b
    
    def average(illums):
        """
        Accepts a list of Color objects and returns the average.
        """
        racc=0.0
        gacc=0.0
        bacc=0.0
        for illum in illums:
            racc += illum.r
            gacc += illum.g
            bacc += illum.b
        avg = Color(racc/len(illums), gacc/len(illums), bacc/len(illums))
        return avg
        
    average=staticmethod(average)
    
    def __mul__(self,other):
        """
        Lets two colors be multiplied together.
        A color can also be multiplied by a float or int.
        
        >>> a=Color(0.2, 0.2, 0.2)
        >>> print 2*a
        rgb( 40%, 40%, 40%)
        >>> print a*a
        rgb(  4%,  4%,  4%)
        >>> print 0.3*a
        rgb(  6%,  6%,  6%)
        """
        if isinstance(other,float) or isinstance(other,int):
            return Color(self.r*other, self.g*other, self.b*other)
        
        return Color(self.r*other.r, self.g*other.g, self.b*other.b)
        
    __rmul__=__mul__
    
    def __add__(self, other):
        """
        Lets two colors be added together.
        
        >>> a=Color(0.2, 0.2, 0.2)
        >>> b=Color(0.2, 0.3, 0.4)
        >>> c=a+b
        >>> print c
        rgb( 40%, 50%, 60%)
        """
        return Color(self.r+other.r, self.g+other.g, self.b+other.b)
    
    __radd__=__add__
    
    def pilstr(self):
        """
        Returns a string that can be sent to PIL as a pixel color.
        """
        return "rgb(%3d%%,%3d%%,%3d%%)" % (max(0,self.r)*100,
                                           max(0,self.g)*100,
                                           max(0,self.b)*100)
    
    def __str__(self):
        """
        Pretty-prints the color.
        """
        return self.pilstr()

class PointLight:
    """
    Represents a PointLight with no bounds, only a position and color.
    """
    
    def __init__(self, pos, lighting):
        """
        Creates a light at a fixed location with the specified color 
        information.
        """
        self.pos=pos
        self.lighting=lighting
    
    def translate(self, matrix):
        """
        Translates the position of the light.
        """
        self.pos = vec3(matrix*self.pos)

class Lighting:
    """
    Stores the material properties.
    """
    def __init__(self, ambient=Color(1,1,1), diffuse=None, specular=Color(1,1,1), ambientK=0.2, diffuseK=0.5, specularK=0.2, exponent=1.0):
        self.ambient=ambient
        if diffuse == None:
            self.diffuse=ambient
        else:
            self.diffuse=diffuse
        self.specular=specular
        self.ambientK=ambientK
        self.diffuseK=diffuseK
        self.specularK=specularK
        self.exponent=exponent

class Ray(object):
    """
    Represents a ray using an origin and direction.
    """
    
    def __init__(self, origin, direction=vec3(0,0,1), name="Ray"):
        """
        Creates a ray using an origin and direction and optionally a name.
        """
        self.name = name
        self.origin = vec3(origin)
        self.__direction = vec3(direction).normalize()
    
    def getdirection(self):
        """
        Retrieves the direction of the ray.
        """
        return self.__direction

    def setdirection(self, direction):
        """
        Stores the direction of the ray, normalizing in the process.
        """
        self.__direction = vec3(direction).normalize()

    direction = property(fget=getdirection, fset=setdirection)

    def reflect(self, normal, origin):
        """
        Returns a new Ray reflected at origin about normal.
        """
        return Ray(origin, self.direction.reflect(normal))

    def refract(self, normal, origin, eta):
        """
        Returns a new Ray reflected at origin about normal.
        Returns None if total internal reflection.
        """
        refracted = self.direction.refract(normal, eta)
        if abs(refracted):
            return Ray(origin, refracted)
        else:
            return None

    def __str__(self):
        """
        Pretty-prints this ray.
        """
        return "%s (origin:%s, direction:%s)" % (self.name, self.origin, self.__direction)

class Primitive(object):
    """
    The superclass of all the primitives.  Maintains some of the common
    data members.
    """
    def __init__(self, shader, lighting, name, kr, kt, refractindex):
        """
        Store the data members.
        """
        self.shader = shader
        self.lighting = lighting
        self.name = name
        self.kr = kr
        self.kt = kt
        self.refractindex = refractindex

class Sphere(Primitive):
    """
    Represents a sphere using a position and radius.
    """

    def create(config, section):
        """
        Factory method to create a Sphere from the specified ConfigParser
        object.  The Sphere data is located in section.
        """
        pos=vec3(config.get(section, "pos"))
        radius=config.getfloat(section, "radius")
        kr=config.getfloat(section, "kr")
        kt=config.getfloat(section, "kt")
        refractindex=config.getfloat(section, "refractindex")
        return Sphere(pos=pos, radius=radius, 
            name=section, kr=kr, kt=kt,
            refractindex=refractindex)
    create=staticmethod(create)
    
    def __init__(self, pos, radius, lighting=Lighting(), name="Sphere",
                shader=None, kr=0.0, kt=0.0, refractindex=1.0):
        """
        Creates a sphere.
        """
        Primitive.__init__(self, shader, lighting, name, kr, kt, refractindex)
        self.pos = vec3(pos)
        self.radius = radius
        self.name = name
    
    def project(self,isect):
        """
        Determines where the intersection occurs on the sphere.
        Returns u,v between 0 and 1.
        """
        proj=self.pos-isect
        az=atan((proj.y/proj.x))%(2*pi)
        zen=acos(proj.z/self.radius)%pi
        u=az/(2*pi)
        v=zen/pi
        return u,v
    
    def intersect(self, ray):
        """
        Tests for intersection with a ray.  Returns a tuple containing the
        intersection point and normal.  These will both be None if there was
        no intersection.
        http://www.siggraph.org/education/materials/HyperGraph/raytrace/rtinter1.htm
        
        >>> origin=vec3(0,0,0)
        >>> dir=vec3(0,0,1)
        >>> ray=Ray(origin, dir) 
        >>> pos1=vec3(0,0,5)
        >>> pos2=vec3(0,1,5)
        >>> radius=1
        >>> sphere1=Sphere(pos1, radius)
        >>> sphere1.intersect(ray)
        ((0, 0, 4), (0, 0, -1))
        >>> sphere2=Sphere(pos2, radius)
        >>> sphere2.intersect(ray)
        ((0, 0, 5), (0, -1, 0))
        """
        diff = ray.origin-self.pos
        b = 2*(ray.direction.x*(diff.x)+
               ray.direction.y*(diff.y)+
               ray.direction.z*(diff.z))
        c = (diff.x)**2 + (diff.y)**2 + (diff.z)**2 - self.radius**2
        root = b**2 - 4*c
        if( root < 0 ):
            # No roots, no intersection
            return []
        elif( root == 0 ):
            # One root, ray intersects tangentially
            omega = -b/2
        else:
            # Two roots
            omegas = [(-b + sqrt(root)) / 2]
            omegas.append( (-b - sqrt(root)) / 2 )
            omegas = filter( lambda x: x >=0 , omegas)
            try:
                omega = min(omegas)
            except:
                # Two negative roots?
                return []
        
        isect = vec3( ray.origin+ray.direction*omega )
        isectnorm = vec3((isect-self.pos) / self.radius)
        return [(isect, isectnorm, self)]
    
    def translate( self, matrix ):
        """
        Translates the center of the sphere using the given matrix.

        >>> s=Sphere(vec3(0,0,0), radius=1)
        >>> trans = mat4().translation(vec3(2,3,4))
        >>> print s    
        Sphere (pos=(0, 0, 0), radius=1.000000)
        >>> s.translate(trans)
        >>> print s
        Sphere (pos=(2, 3, 4), radius=1.000000)
        """
        self.pos = vec3(matrix*self.pos)

    def __str__(self):
        """
        Pretty-prints this sphere.
        """
        return "%s (pos=%s, radius=%f)" % (self.name, self.pos, self.radius)

class InfinitePlane(Primitive):
    """
    Models an infinite plane using a point and a normal.
    """
    
    def create(config, section):
        """
        Factory method to create a InfinitePlane from the specified ConfigParser
        object.  The InfinitePlane data is located in section.
        """
        normal=vec3(config.get(section, "normal"))
        point=vec3(config.get(section, "point"))
        kr=config.getfloat(section, "kr")
        kt=config.getfloat(section, "kt")
        refractindex=config.getfloat(section, "refractindex")
        return InfinitePlane(normal, point, section, kr, kt)
    create=staticmethod(create)
    
    def __init__(self, normal, point, name="Plane", lighting=Lighting(),
                shader=None, kr=0.0, kt=0.0, refractindex=1.0):
        """
        Creates an InfinitePlane.
        """
        Primitive.__init__(self, shader, lighting, name, kr, kt, refractindex)
        self.normal = normal.normalize()
        self.point = point
        self.name = name
    
    def project(self, isect):
        """
        Determines where an intersection exists on the InfinitePlane.
        The exact orientation is undefined.
        Returns u,v between 0 and 1.
        """
        d=self.normal.ortho()
        u=modf(d*(isect-self.point))[0]
        v=modf(d.cross(self.normal)*(isect-self.point))[0]
        u = (u/2.0)+0.5
        v = (v/2.0)+0.5
        return u, v

    def intersect(self, ray):
        """
        Tests for intersection with a ray.  Returns a tuple containing the intersection point
        and normal.  These will both be None if there was no intersection.
        http://www.siggraph.org/education/materials/HyperGraph/raytrace/rayplane_intersection.htm

        >>> origin = vec3(0,5,5)
        >>> otherpoint = vec3(0,5,0)
        >>> viewnorm = vec3(0, 1, 0)
        >>> pointonplane = vec3(0,0,0)
        >>> ray = Ray( origin, vec3(pointonplane-origin), "Ray" )
        >>> ray2 = Ray( otherpoint, vec3(pointonplane-otherpoint), "Ray2")
        >>> parallel = Ray( origin, vec3(1,5,1), "Parallel")
        >>> p = InfinitePlane( vec3(0,1,0), pointonplane )
        >>> p.intersect(ray)
        ((0, 0, 0), (0, 1, 0))
        >>> p.intersect(ray2)
        ((0, 0, 0), (0, 1, 0))
        >>> p.intersect(parallel)
        (None, None)

        >>> plane=InfinitePlane(normal=vec3(0,-1,0),point=vec3(1,1,1))
        >>> ray=Ray(origin=vec3(0,0,0),direction=vec3(0,1,0))
        >>> plane.intersect(ray)
        ((0, 1, 0), (0, -1, 0))
        """
        vd = self.normal * ray.direction
        if( vd < 0 ):
            v0 = -(self.normal * ray.origin + abs(self.normal*self.point))
            t = v0/vd
            if( t < 0 ):
                # Ray intersects behind origin
                return []
            else:
                i = ray.origin+ray.direction*t
                # Ray intersects at i
                return [(i, self.normal, self)]
        elif( vd == 0):
            # Ray is parallel to plane
            return []
        else:
            # Ray intersects behind plane
            return []
    
    def translate( self, matrix ):
        """
        Translates the plane.
        
        >>> p = InfinitePlane(normal=vec3(0,1,0), point=vec3(0,0,0))
        >>> print p
        Plane (normal:(0, 1, 0), point:(0, 0, 0))
        >>> trans = mat4().translation(vec3(2,3,4))
        >>> p.translate(trans)
        >>> print p
        Plane (normal:(0, 1, 0), point:(2, 3, 4))
        """
        self.point = vec3(matrix*self.point)
        
    def __str__(self):
        return "%s (normal:%s, point:%s)" % (self.name, self.normal, self.point)

class ViewPlane(InfinitePlane):
    def __init__(self, normal, point, name="ViewPlane", lighting=Lighting()):
        InfinitePlane.__init__(self, normal, point, name, lighting)
        

class Circle(InfinitePlane):
    """
    A Circle is just a Circle on an InfinitePlane.
    """

    def create(config, section):
        """
        Factory method to create a Sphere from the specified ConfigParser
        object.  The Sphere data is located in section.
        """
        normal=vec3(config.get(section, "normal"))
        center=vec3(config.get(section, "center"))
        radius=config.getfloat(section, "radius")
        kr=config.getfloat(section, "kr")
        kt=config.getfloat(section, "kt")
        refractindex=config.getfloat(section, "refractindex")
        return Circle(normal, center, radius, section, kr, kt, refractindex)
    create=staticmethod(create)

    def __init__(self, normal, center, radius, name="Circle",
                 lighting=Lighting(), shader=None, kr=0.0, kt=0.0,
                 refractindex=1.0):
        """
        Build the Circle.
        """
        InfinitePlane.__init__(self, normal, center, name, lighting, 
                               shader, kr, kt, refractindex)
        self.center = center
        self.radius = radius
    
    def translate(self, matrix):
        """
        Translates the center of the Circle.
        """
        InfinitePlane.translate(self, matrix)
        self.center = vec3(matrix*self.center)
        
    def intersect(self, ray):
        """
        Tests for intersection with a ray.  Returns a tuple containing the intersection point
        and normal.  These will both be None if there was no intersection.
        Checks to see if the ray intersects the InfinitePlane and then checks to see if that
        intersection is within the Circle.

        >>> circ=Circle(vec3(0,-1,0),vec3(0,1,1),radius=1)
        >>> ray=Ray(vec3(0,0,0),vec3(1,1,1))
        >>> circ.intersect(ray)
        ((1, 1, 1), (0, -1, 0))
        
        >>> circ=Circle(vec3(0,-1,0),vec3(0,1,0),radius=1)
        >>> ray=Ray(vec3(0,0,0),vec3(1,1,1))
        >>> circ.intersect(ray)
        (None, None)
        """
        isect = InfinitePlane.intersect(self,ray)
        if len(isect) > 0:
            if abs(isect[0][0]-self.center) <= self.radius:
                return isect
            else:
                # Intersects the plane, but not the circle
                return []
        else:
            # Doesn't intersect the plane of the circle
            return []
    
class Rectangle(InfinitePlane):
    """
    A Rectangle is an InfinitePlane with rectangular bounds.
    """
    
    def create(config, section):
        """
        Factory method to create a Sphere from the specified ConfigParser
        object.  The Sphere data is located in section.
        """
        normal=vec3(config.get(section, "normal"))
        dir=vec3(config.get(section, "dir"))
        point=vec3(config.get(section, "point"))
        width=config.getfloat(section, "width")
        height=config.getfloat(section, "height")
        kr=config.getfloat(section, "kr")
        kt=config.getfloat(section, "kt")
        refractindex=config.getfloat(section, "refractindex")
        return Rectangle(normal, dir, point, name=section, width=width,
                         height=height, kr=kr, kt=kt, refractindex=refractindex)
    create=staticmethod(create)

    def __init__(self, normal, dir, point, name="Plane", lighting=Lighting(),
                 width=1, height=1, shader=None, kr=0.0, kt=0.0,
                 refractindex=1.0):
        """
        Creates a Rectangle.
        """
        InfinitePlane.__init__( self, normal, point, name, lighting, shader, kr, kt )
        self.width=width
        self.height=height
        self.dir=dir
        self.kr=kr
        self.kt=kt
    
    def intersect(self, ray):
        """
        Tests for intersection with a ray.  Returns a tuple containing the intersection point
        and normal.  These will both be None if there was no intersection.
        """
        isects = InfinitePlane.intersect( self, ray )
        if len(isects) > 0:
            isect = isects[0]
            # InfinitePlane was intersected, now check for rectangular bounds.
            if abs(self.dir*(isect[0]-self.point)) < self.height/2.0 and abs(self.dir.cross(self.normal)*(isect[0]-self.point)) < self.width/2.0:
                return isects
        
        return []
    
    def project(self, isect):
        """
        Determines where an intersection exists on the Rectangle.
        Returns u,v between 0 and 1.
        """
        u=(self.dir*(isect-self.point))/self.height
        v=(self.dir.cross(self.normal)*(isect-self.point))/self.width
        return u+0.5, v+0.5

def _test():
    """
    Run the unit tests for each method using the doctest module.
    """
    import doctest
    doctest.testmod()

# If the module's being run directly, then run the unit tests.
if __name__ == "__main__":
    _test()
