import sys
import os
import string
import math
import IMP
import IMP.core
import IMP.atom
import random
import IMP.algebra
import sys
import unittest
from StringIO import StringIO
import math
from random import random

#---------------------------------------------------------
chain_seq = {"chain1":3000000, "chain2":10000, "chain3":3000, "chain4":400,\
           "chain5":1000, "chain6":2000, "chain7":1000, "chain8":1000}
sep = 30      
chain_bead = {}    # number of beads for each chain
nbead = 0
bead_start = {}  # bead label starts of a chain
for i in chain_seq.keys():
    n = chain_seq[i]/sep + 1
    chain_bead[i] = n
    nbead = nbead + n
    bead_start[i] = nbead - n
#---------------------------------------------------------
def mdstep(t,step):
    o = IMP.atom.MolecularDynamics()
    o.set_model(m)
    md = IMP.atom.VelocityScalingOptimizerState(xyzr,t,10)  
    o.add_optimizer_state(md)
    print 'optimizing with temperature',t,'and',step,'steps'
    s=o.optimize(step)
    o.remove_optimizer_state(md)
    print 'MD',step,'steps done @',datetime.datetime.now()
    return s

def cgstep(step):
    o = IMP.core.ConjugateGradients()
    o.set_model(m)
    f=o.optimize(step)
    print 'CG',step,'steps done @',datetime.datetime.now()
    return f
#___________________________ IMP starts _____________________________________
m = IMP.Model()
r = 1.0
lb = 2.0 # length of bond
kbend=0.2 

xyzr = IMP.core.create_xyzr_particles(m,nbead,r)
chain = IMP.container.ListSingletonContainer(xyzr)
# First beads
corner1=IMP.algebra.Vector3D(-1000,-1000,-1000)
corner2=IMP.algebra.Vector3D(1000,1000,1000)
box=IMP.algebra.BoundingBox3D(corner1,corner2)
rdummy=int(random()*10000)
for i in range(rdummy):
    ranvec = IMP.algebra.get_random_vector_in(box)
#----------------------------------------------------------  
print 'total bead:',nbead
for i in range(nbead):
    p0 = chain.get_particle(i)
    IMP.atom.Mass.setup_particle(p0,1)
    p = IMP.core.XYZR(p0)
    coor = IMP.algebra.get_random_vector_in(box)
    p.set_coordinates(coor)
#---------------------------------------------------------------------------------
bonds = IMP.container.ListSingletonContainer()
for id in chain_seq.keys():
    istart = bead_start[id]
    iend = istart + chain_bead[id]
    IMP.atom.Bonded.setup_particle(chain.get_particle(istart))
    for i in range(istart + 1,iend):
        bp = IMP.atom.Bonded.decorate_particle(chain.get_particle(i-1))
        bpr = IMP.atom.Bonded.setup_particle(chain.get_particle(i))
        b = IMP.atom.create_custom_bond(bp, bpr, lb, 2)
        bonds.add_particle(b.get_particle())

# Restraint for bonds
bss = IMP.atom.BondSingletonScore(IMP.core.Harmonic(0,1))
br = IMP.container.SingletonsRestraint(bss, bonds)
m.add_restraint(br) #0

# Set up the nonbonded list
nbl = IMP.container.ClosePairContainer(chain, 0.0, 3.0) #singletonlist, dist-touching,slack
# Exclude bonds from closest pairs
nbl.add_pair_filter(IMP.atom.BondedPairFilter())

# Set up excluded volume
ps = IMP.core.SphereDistancePairScore(IMP.core.HarmonicLowerBound(0,1))
evr = IMP.container.PairsRestraint(ps, nbl)
m.add_restraint(evr) #1


#IMP.set_check_level(IMP.NONE)
#IMP.set_log_level(IMP.SILENT)
o = IMP.core.ConjugateGradients()
o.set_model(m)
m.show()
o.optimize(1000)
print 'cg done'
# Followed with MD
mdstep(500000,5000)
mdstep(300000,3000)
score=mdstep(5000,1000)
print 'before',score
# Angle Restraint
angle = math.pi
for i in range(nbead-1):
    ieval = i+1
    if ieval in bead_start.values():
        continue
    elif i in bead_start.values():
        continue
    else:
        d1 = chain.get_particle(i-1)
        d2 = chain.get_particle(i)
        d3 = chain.get_particle(i+1)
        pot = IMP.core.Harmonic(angle,kbend)
        ar = IMP.core.AngleRestraint(pot,d1,d2,d3)
        m.add_restraint(ar)

score=mdstep(300,1000)
print 'angle in ',score

score=cgstep(1000)
print 'cg score',score
mdstep(2000,20000)
score=cgstep(50000)

print 'Final score:',score
#IMP.set_log_level(IMP.TERSE)
pym2 = IMP.display.PymolWriter("last.pym")
g3 = IMP.display.XYZRsGeometry(chain,IMP.core.XYZR.get_default_radius_key())
g3.set_name("beads") # BUG here: doesn't print out the name, '' instead.
g3.set_color(IMP.display.Color(1,1,1))
pym2.add_geometry(g3)
