"""
Edward Dale
2005-12-19
Computer Graphics 2
Ray Tracer
Copyright (c) 2005, Edward DaleAll rights reserved.Redistribution and use in source and binary forms, with or without modification,are permitted provided that the following conditions are met:* Redistributions of source code must retain the above copyright notice, this  list of conditions and the following disclaimer.* Redistributions in binary form must reproduce the above copyright notice,  this list of conditions and the following disclaimer in the documentation  and/or other materials provided with the distribution.* Neither the name of Edward Dale nor the names of its contributors  may be used to endorse or promote products derived from this software  without specific prior written permission.THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ANDCONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OFMERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE AREDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORSBE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, ORCONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENTOF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; ORBUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OFLIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDINGNEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THISSOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."""

from cgtypes import vec3, mat4
from math import sqrt,atan,acos,pi,modf

class Color:
    """
    Represents an RGB color using values in (0.0,1.0).
    """
    
    def __init__(self, r=0.0, g=0.0, b=0.0):
        """
        Creates a color using the given values.

        >>> a=Color()
        >>> print a
        rgb(  0%,  0%,  0%)
        >>> a=Color(0.5, 0.5, 0.5)
        >>> print a
        rgb( 50%, 50%, 50%)
        """
        self.r=self.clampNumber(r)
        self.g=self.clampNumber(g)
        self.b=self.clampNumber(b)
    
    def __mul__(self,other):
        """
        Lets two colors be multiplied together.
        A color can also be multiplied by a float or int.
        
        >>> a=Color(0.2, 0.2, 0.2)
        >>> print 2*a
        rgb( 40%, 40%, 40%)
        >>> print a*a
        rgb(  4%,  4%,  4%)
        >>> print 0.3*a
        rgb(  6%,  6%,  6%)
        """
        if isinstance(other,float) or isinstance(other,int):
            return Color(self.r*other, self.g*other, self.b*other)
        
        return Color(self.r*other.r, self.g*other.g, self.b*other.b)
        
    __rmul__=__mul__
    
    def __add__(self, other):
        """
        Lets two colors be added together.
        
        >>> a=Color(0.2, 0.2, 0.2)
        >>> b=Color(0.2, 0.3, 0.4)
        >>> c=a+b
        >>> print c
        rgb( 40%, 50%, 60%)
        """
        sum = Color(self.r+other.r, self.g+other.g, self.b+other.b)
        return sum
    
    __radd__=__add__
    
    def __neg__(self):
        """
        Negates a Color by subtracting all it's values from 1.
        
        >>> a=Color()
        >>> print a
        rgb(  0%,  0%,  0%)
        >>> print -a
        rgb(100%,100%,100%)
        >>> print --a
        rgb(  0%,  0%,  0%)
        """
        return Color(1-self.r, 1-self.g, 1-self.g)
    
    def clamp(self):
        """
        Returns a new Color that has all its values clamped to (0.0,1.0).
        
        >>> a = Color()
        >>> a.r = 5.0
        >>> b=a.clamp()
        >>> print a
        rgb(500%,  0%,  0%)
        >>> print b
        rgb(100%,  0%,  0%)
        """
        return Color(self.clampNumber(self.r), self.clampNumber(self.g), self.clampNumber(self.b))

    def clampNumber(self, number):
        """
        Clamps a number between 0.0 and 1.0.
        
        >>> a=Color()
        >>> a.clampNumber(5.0)
        1.0
        >>> a.clampNumber(-3.0)
        0.0
        >>> a.clampNumber(0.7)
        0.69999999999999996
        """
        return max(0.0, min(1.0, number))
    
    def pilstr(self):
        """
        Returns a string that can be sent to PIL as a pixel color.
        """
        return "rgb(%3d%%,%3d%%,%3d%%)" % (self.r*100,self.g*100,self.b*100)
    
    def __str__(self):
        """
        Pretty-prints the color.
        """
        return self.pilstr()

class PointLight:
    """
    Represents a PointLight with no bounds, only a position and color.
    """
    
    def __init__(self, pos, lighting):
        """
        Creates a light at a fixed location with the specified color 
        information.
        """
        self.pos=pos
        self.lighting=lighting
    
    def translate(self, matrix):
        """
        Translates the position of the light.
        """
        self.pos = vec3(matrix*self.pos)

class Lighting:
    """
    Stores the material properties.
    """
    def __init__(self, ambient=Color(1,1,1), diffuse=None, specular=Color(1,1,1), ambientK=0.2, diffuseK=0.5, specularK=0.2, exponent=1.0):
        self.ambient=ambient
        if diffuse == None:
            self.diffuse=ambient
        else:
            self.diffuse=diffuse
        self.specular=specular
        self.ambientK=ambientK
        self.diffuseK=diffuseK
        self.specularK=specularK
        self.exponent=exponent

class Ray:
    """
    Represents a ray using an origin and direction.
    """
    
    def __init__(self, origin, direction, name="Ray"):
        """
        Creates a ray using an origin and direction and optionally a name.
        """
        self.name = name
        self.origin = vec3(origin)
        self.direction = vec3(direction).normalize()
    
    def __str__(self):
        """
        Pretty-prints this ray.
        """
        return "%s (origin:%s, direction:%s)" % (self.name, self.origin, self.direction)

class Sphere:
    """
    Represents a sphere using a position and radius.
    """
    
    def __init__(self, pos, radius, lighting=Lighting(), name="Sphere", shader=None):
        """
        Creates a sphere.
        """
        self.pos = vec3(pos)
        self.radius = radius
        self.lighting = lighting
        self.name = name
        self.shader = shader
    
    def project(self,isect):
        """
        Determines where the intersection occurs on the sphere.
        Returns u,v between 0 and 1.
        """
        proj=self.pos-isect
        az=atan((proj.y/proj.x))%(2*pi)
        zen=acos(proj.z/self.radius)%pi
        u=az/(2*pi)
        v=zen/pi
        return u,v
    
    def intersect(self, ray):
        """
        Tests for intersection with a ray.  Returns a tuple containing the intersection point
        and normal.  These will both be None if there was no intersection.
        http://www.siggraph.org/education/materials/HyperGraph/raytrace/rtinter1.htm
        
        >>> origin=vec3(0,0,0)
        >>> dir=vec3(0,0,1)
        >>> ray=Ray(origin, dir) 
        >>> pos1=vec3(0,0,5)
        >>> pos2=vec3(0,1,5)
        >>> radius=1
        >>> sphere1=Sphere(pos1, radius)
        >>> sphere1.intersect(ray)
        ((0, 0, 4), (0, 0, -1))
        >>> sphere2=Sphere(pos2, radius)
        >>> sphere2.intersect(ray)
        ((0, 0, 5), (0, -1, 0))
        """
        diff = ray.origin-self.pos
        b = 2*(ray.direction.x*(diff.x)+ray.direction.y*(diff.y)+ray.direction.z*(diff.z))
        c = (diff.x)**2 + (diff.y)**2 + (diff.z)**2 - self.radius**2
        root = b**2 - 4*c
        if( root < 0 ):
            # No roots, no intersection
            return []
        elif( root == 0 ):
            # One root, ray intersects tangentially
            omega = -b/2
        else:
            # Two roots
            omegas = [(-b + sqrt(root)) / 2]
            omegas.append( (-b - sqrt(root)) / 2 )
            omegas = filter( lambda x: x >=0 , omegas)
            try:
                omega = min(omegas)
            except:
                # Two negative roots?
                return []
        
        isect = vec3( ray.origin+ray.direction*omega )
        isectnorm = vec3((isect-self.pos) / self.radius)
        return [(isect, isectnorm, self)]
    
    def translate( self, matrix ):
        """
        Translates the center of the sphere using the given matrix.

        >>> s=Sphere(vec3(0,0,0), radius=1)
        >>> trans = mat4().translation(vec3(2,3,4))
        >>> print s    
        Sphere (pos=(0, 0, 0), radius=1.000000)
        >>> s.translate(trans)
        >>> print s
        Sphere (pos=(2, 3, 4), radius=1.000000)
        """
        self.pos = vec3(matrix*self.pos)

    def __str__(self):
        """
        Pretty-prints this sphere.
        """
        return "%s (pos=%s, radius=%f)" % (self.name, self.pos, self.radius)

class InfinitePlane:
    """
    Models an infinite plane using a point and a normal.
    """
    
    def __init__(self, normal, point, name="Plane", lighting=Lighting(), shader=None):
        """
        Creates an InfinitePlane.
        """
        self.normal = normal.normalize()
        self.point = point
        self.name = name
        self.lighting = lighting
        self.shader = shader
    
    def project(self, isect):
        """
        Determines where an intersection exists on the InfinitePlane.
        The exact orientation is undefined.
        Returns u,v between 0 and 1.
        """
        d=self.normal.ortho()
        u=modf(d*(isect-self.point))[0]
        v=modf(d.cross(self.normal)*(isect-self.point))[0]
        u = (u/2.0)+0.5
        v = (v/2.0)+0.5
        return u, v

    def intersect(self, ray):
        """
        Tests for intersection with a ray.  Returns a tuple containing the intersection point
        and normal.  These will both be None if there was no intersection.
        http://www.siggraph.org/education/materials/HyperGraph/raytrace/rayplane_intersection.htm

        >>> origin = vec3(0,5,5)
        >>> otherpoint = vec3(0,5,0)
        >>> viewnorm = vec3(0, 1, 0)
        >>> pointonplane = vec3(0,0,0)
        >>> ray = Ray( origin, vec3(pointonplane-origin), "Ray" )
        >>> ray2 = Ray( otherpoint, vec3(pointonplane-otherpoint), "Ray2")
        >>> parallel = Ray( origin, vec3(1,5,1), "Parallel")
        >>> p = InfinitePlane( vec3(0,1,0), pointonplane )
        >>> p.intersect(ray)
        ((0, 0, 0), (0, 1, 0))
        >>> p.intersect(ray2)
        ((0, 0, 0), (0, 1, 0))
        >>> p.intersect(parallel)
        (None, None)

        >>> plane=InfinitePlane(normal=vec3(0,-1,0),point=vec3(1,1,1))
        >>> ray=Ray(origin=vec3(0,0,0),direction=vec3(0,1,0))
        >>> plane.intersect(ray)
        ((0, 1, 0), (0, -1, 0))
        """
        vd = self.normal * ray.direction
        if( vd < 0 ):
            v0 = -(self.normal * ray.origin + abs(self.normal*self.point))
            t = v0/vd
            if( t < 0 ):
                # Ray intersects behind origin
                return []
            else:
                i = ray.origin+ray.direction*t
                # Ray intersects at i
                return [(i, self.normal, self)]
        elif( vd == 0):
            # Ray is parallel to plane
            return []
        else:
            # Ray intersects behind plane
            return []
    
    def translate( self, matrix ):
        """
        Translates the plane.
        
        >>> p = InfinitePlane(normal=vec3(0,1,0), point=vec3(0,0,0))
        >>> print p
        Plane (normal:(0, 1, 0), point:(0, 0, 0))
        >>> trans = mat4().translation(vec3(2,3,4))
        >>> p.translate(trans)
        >>> print p
        Plane (normal:(0, 1, 0), point:(2, 3, 4))
        """
        self.point = vec3(matrix*self.point)
        
    def __str__(self):
        return "%s (normal:%s, point:%s)" % (self.name, self.normal, self.point)

class ViewPlane(InfinitePlane):
    def __init__(self, normal, point, name="ViewPlane", lighting=Lighting()):
        InfinitePlane.__init__(self, normal, point, name, lighting)
        

class Circle(InfinitePlane):
    """
    A Circle is just an Circle on an InfinitePlane.
    """
    def __init__(self, normal, center, radius, name="Circle", lighting=Lighting(), shader=None):
        """
        Build the Circle.
        """
        InfinitePlane.__init__(self, normal, center, name, lighting, shader)
        self.center = center
        self.radius = radius
    
    def translate(self, matrix):
        """
        Translates the center of the Circle.
        """
        InfinitePlane.translate(self, matrix)
        self.center = vec3(matrix*self.center)
        
    def intersect(self, ray):
        """
        Tests for intersection with a ray.  Returns a tuple containing the intersection point
        and normal.  These will both be None if there was no intersection.
        Checks to see if the ray intersects the InfinitePlane and then checks to see if that
        intersection is within the Circle.

        >>> circ=Circle(vec3(0,-1,0),vec3(0,1,1),radius=1)
        >>> ray=Ray(vec3(0,0,0),vec3(1,1,1))
        >>> circ.intersect(ray)
        ((1, 1, 1), (0, -1, 0))
        
        >>> circ=Circle(vec3(0,-1,0),vec3(0,1,0),radius=1)
        >>> ray=Ray(vec3(0,0,0),vec3(1,1,1))
        >>> circ.intersect(ray)
        (None, None)
        """
        isect = InfinitePlane.intersect(self,ray)
        if len(isect) > 0:
            if abs(isect[0][0]-self.center) <= self.radius:
                return isect
            else:
                # Intersects the plane, but not the circle
                return []
        else:
            # Doesn't intersect the plane of the circle
            return []
    
class Rectangle(InfinitePlane):
    """
    A Rectangle is an InfinitePlane with rectangular bounds.
    Doesn't really work for Planes with a normal other than
    """
    
    def __init__(self, normal, dir, point, name="Plane", lighting=Lighting(), width=1, height=1, shader=None):
        """
        Creates a Rectangle.
        """
        InfinitePlane.__init__( self, normal, point, name, lighting, shader )
        self.width=width
        self.height=height
        self.dir=dir
    
    def intersect(self, ray):
        """
        Tests for intersection with a ray.  Returns a tuple containing the intersection point
        and normal.  These will both be None if there was no intersection.
        """
        isects = InfinitePlane.intersect( self, ray )
        if len(isects) > 0:
            isect = isects[0]
            # InfinitePlane was intersected, now check for rectangular bounds.
            if abs(self.dir*(isect[0]-self.point)) < self.height/2.0 and abs(self.dir.cross(self.normal)*(isect[0]-self.point)) < self.width/2.0:
                return isects
        
        return []
    
    def project(self, isect):
        """
        Determines where an intersection exists on the Rectangle.
        Returns u,v between 0 and 1.
        """
        u=(self.dir*(isect-self.point))/self.height
        v=(self.dir.cross(self.normal)*(isect-self.point))/self.width
        return u+0.5, v+0.5

def _test():
    """
    Run the unit tests for each method using the doctest module.
    """
    import doctest
    doctest.testmod()

# If the module's being run directly, then run the unit tests.
if __name__ == "__main__":
    _test()