#Computes thermodynamic quantitites from the density of states..........
import string
import numpy as np
import sys

from Plot import plotToFile

#filesToRead = ['outputs/spinGlass/SPINGLASS_D3N4A10', 'outputs/spinGlass/SPINGLASS_D3N4A8', 'outputs/spinGlass/SPINGLASS_D3N4A8_2']
##Default parameters
filesToRead = ['outputs/I_N10','outputs/I_N20','outputs/I_N16', 'outputs/I_N40_Normalized']
N = [10, 20, 16, 40]
dim = [2, 2, 2, 2]

#filesToRead = ['outputs/I_N20']
#N = [20]
#dim = [2]

plotRange = [0.1, 10, 100]        #Temperature Plot range (passed to file output module, Plot)

##input arguments from command line, if any
for i in range(0,len(sys.argv)):
    if(sys.argv[i] == '-i'):
        filesToRead=[sys.argv[i+1]]
    if(sys.argv[i] == '-r'):
        plotRange[0] = float(sys.argv[i+1])
        plotRange[1] = float(sys.argv[i+2])
        plotRange[2] = int(sys.argv[i+3])

def calcThermalAverage(energyObservableFunction, temperature):    #Calculate thermal average of argument function
    z = 0.0
    p = 0.0
    global numStates
    logStates = np.log(numStates)
    for i in range(0,len(logG)):
        largeNum = np.exp(logG[i] - energies[i]*numSites/temperature)
        z += largeNum
        p += largeNum*energyObservableFunction(energies[i])
    return p/z

def normalizeG():                                               #force proper number of ground states (2)
    baseline = min(logG) - np.log(2.0)
    for i in range(0,len(logG)):
        logG[i]-=baseline
def totalStates():                                #Sum over g(E) to obtain total number of states
    sum = 0.0
    for i in range(0,len(logG)):
        sum+=np.exp(logG[i])
    return sum

##Observables to compute
def energyFunc(energy):
    return energy
def energySqr(energy):
    return energy**2
    
##file output
def writeDOS():
    file = open(fileIn + "_NormG",'w')
    logNorm = np.log(totalStates())
    for i in range(0,len(logG)):
        file.write(str(energies[i]) + '\t' + str(logG[i] - logNorm) + "\n")
    file.close()

def plotFunction(x):        #function to pass to file output module
    return [calcThermalAverage(energyFunc,x), calcThermalAverage(energySqr,x)]
    
###main execution
global numStates
for i in range(0,len(filesToRead)):
    fileIn = filesToRead[i]
    fileOut = fileIn + '_ThermalAverage'
    file = open(fileIn,'r')
    energies, logG = [], []
    for line in file:           #Read log[g(E)] from file generated by wang-landau algorithm...
        if(line[0] != '#'):
            vals = str.split(line,'\t')
            energies.append(float(vals[0]))
            logG.append(float(vals[1]))
    file.close()
    normalizeG()
    numSites = len(logG) + 1
    numStates = totalStates()
    plotToFile(plotFunction, plotRange, fileOut)   #output thermal averages
    writeDOS()                  #output normalized DOS