from math import sqrt
import sets

epsilon = 0.000001

def distanceFunction( c1, c2 ):
	xdiff = c2[0] - c1[0]
	ydiff = c2[1] - c1[1]
	ret = sqrt((xdiff ** 2) + (ydiff ** 2) )
	return ret

def nDimDistance( c1, c2 ):
	ret = 0
	if len(c1) == len(c2):
		for a,b in zip( c1, c2):
			ret += abs( a - b ) ** 2

	return sqrt( ret )

def buildBalls( S, d ):
	i = 1
	ret = {}
	
	for source in S.iteritems():
		for target in S.iteritems():
			if ( target[1] == source[1] ) and (target[0] != source[0] ):
				inputVector = source[0]
				rho = d( source[0], target[0] )
				ret[i] = {"center":source[0], "distance":rho, "value":source[1]}
				i+=1
	return ret

def contains( ball, point, d ):
	ret = False
	"""
	print ball
	print point
	print d( ball["center"], point )
	print ( ball["distance"] + epsilon )
	"""
	if d( ball["center"], point ) < ( ball["distance"] + epsilon ):
		ret = True
	
	return ret

def scm( S, t, pen, s, H, d ):
	"""
	S = dictionary of training examples, coord => 1/0
	t = 'c' or 'd'
	p = penalty
	s = stopping point
	H = dictionary of balls, i => tuple(coord, radius, 1/0)
	"""
	ret = [] # List of balls
	P = sets.Set([])   # Set of positive/negative training examples
	N = sets.Set([])   # Set of positive/negative training examples
	Q = {}   # dictionary, i => Set of training examples
	R = {}   # dictionary, i => Set of training examples
	
	# Step 1
	# Split the data set into P and N
	if t == 'c':
		for ex,val in S.items():
			if val == True:
				P.add( ex )
			elif val == False:
				N.add( ex )
	elif t == 'd':
		for ex,val in S.items():
			if val == True:
				N.add( ex )
			elif val == False:
				P.add( ex )

	# Step 2
	# Split into Q and R
	for k,h in H.items():
		for ex in N:
			if contains( h, ex, d ):
				if k in Q:
					Q[k].add( ex )
				else:
					Q[k] = sets.Set([ex])

		for ex in P:
			if contains( h, ex, d ):
				if k in R:
					R[k].add( ex )
				else:
					R[k] = sets.Set([ex])

	while ( len(N) > 0 ) and ( len(ret) < s ):
	
		"""
		print( "P" )
		printBalls( P )
		print( "N" )
		printBalls( N )
		print( "Q" )
		printBalls( Q.items() )
		print( "R" )
		printBalls( R.items() )	
		"""

		# Step 3
		# Get best feature
		largestValue = 0
		largest = 0

		for k,fe in Q.items():
			if k in R:
				penalty = pen * len(R[k])
			else:
				penalty = 0
			usefullness = len(fe) - penalty
			# print usefullness, len(fe), penalty, largestValue, largest, k
			if usefullness > largestValue:
				largestValue = usefullness
				largest = k
		# print "best is", largest

		# Step 4
		# Get best and trim rest
		ret.append( H[largest] )
		N = N - Q[largest]

		if largest in R:
			P = P - R[largest]

		# Step 5
		# More trimming
		newQ, newR = {}, {}
		for k,q in Q.items():
			q = q - Q[largest]
			newQ[k] = q
		Q = newQ
						
		for k,r in R.items():
			if largest in R:
				r = r - R[largest]
				newR[k] = r
			else:
				newR[k] = r
		R = newR
	return ret

def printBalls( balls ):
	for ball in balls:
		print ball

def scmFunction( bestBalls, machineType, point, d ):
	ret = ( machineType == 'c' )
	for ball in bestBalls:
		if machineType == 'c':
			ret = ret and ( ball["value"] == contains( ball, point, d ) )
		else:
			ret = ret or ( ball["value"] == contains( ball, point, d ) )
	return ret
	
#d = lambda c1, c2: sqrt(((c2[0] - c1[0])**2) + ((c2[1] - c1[1])**2))
d = nDimDistance

set = {(3,3):False, (3.5,2.5):False, (4.5,3.5):False, (3,10):True, (4,9):True, (6,5):True, (7.5,6.5):True, (9.2,8.2):True, (10,5.5):True, (11,6):True, (11,8):True, (10,9):False, (8,11):False}
#set = {(1,1):True, (1,2):True, (1,3):True, (2,2):False, (3,3):False}
machineType = 'd'

H = buildBalls( set, d )
"""
print "The balls"
printBalls( H.items() )
print
"""
bestBalls = scm( set, machineType, 1000, 5, H, d )
print "The solution"
printBalls( bestBalls )
f = lambda point: scmFunction( bestBalls, machineType, point, d )
print f( (1.5,3.5) )