import MPSPyLib as mps
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
import os

st = datetime.now()
# Build Pauli operators for spin-1/2
Operators = mps.BuildSpinOperators(0.5)
Operators['sigmaz'] = 2* Operators['sz']
Operators['sigmax'] = (Operators['splus'] + Operators['sminus'])

# Define the Hamiltonian
H = mps.MPO(Operators)
#H.AddMPOTerm("bond", ['sigmaz', 'sigmaz'], hparam='J', weight= -1.0)#responsible for hopping terms
H.AddMPOTerm('site', 'sigmax', hparam='g', weight= -1.0)

system_size = 12
def add_bond_term(hamiltonian, pair, size, param_name):
    left, right = pair
    bond_term = ['I' for _ in range(size)]
    bond_term[left] = 'sigmaz'
    bond_term[right] = 'sigmaz'
    hamiltonian.AddMPOTerm('MBString', bond_term, hparam=param_name, weight=-1.0)

# ntrials = 1000
# Jlists = []
# for _ in range(ntrials):
#     Jlists.append(Jlist = np.random.choice([-1,1],(system_size)))
# #Jlist = [0,1,1,1,1,1,1,1,1,1,1,1]
#Jlist = np.random.choice([-1,1],(system_size)) #Anderson Localization


pair = np.array([0, 1])
for i in range(system_size-1):
    print(f"[{i}, {i+1}]")
    add_bond_term(H, pair, system_size, f"J[{i}, {i+1}]") 
    pair = pair + 1

# pair_next = np.array([0, 2])
# J_next = np.ones(system_size) 
# for i in range(system_size-2):
#     add_bond_term(H, pair_nest, system_size, f"J_next[{i}, {i+1}]") 
#     pair = pair + 1


# Observables and convergence parameters
myObservables = mps.Observables(Operators)
myObservables.AddObservable('DensityMatrix_i', [])
myObservables.AddObservable('DensityMatrix_ij', []) 
myObservables.AddObservable('corr', ['sigmaz', 'sigmaz'], 'zz')
myObservables.AddObservable('site', 'sigmax', 'sx')
myObservables.AddObservable('MI', True)
# Set maximum bond dimension
#myConv = mps.MPSConvParam(max_bond_dimension= 500, max_num_sweeps=10, variance_tol=1e-16)
myConv = mps.MPSConvParam(max_bond_dimension= 30)
# Specify Hailtonian's parameter lists and system size
glist = np.linspace(0.01, 2, 21)
#glist = [0.25]
parameters = []
ntrials = 10

# for _ in range(ntrials):
Jlist =  np.random.choice([-1,1],(system_size))
for g in glist:
    params = {
        'simtype'                   : 'Finite',
        'job_ID'                    : 'dis_Ising',
        'unique_ID'                 : f'{np.random.random()}',
        'Write_Directory'           : 'TMP/',
        'Output_Directory'          : 'OUT/',
        'MPSObservables'            : myObservables,
        'MPSConvergenceParameters'  : myConv,
        # System size and Hamiltonian parameters
        'L'                         : system_size,
        'g'                         : g,
        }
    for i, J in enumerate(Jlist):
        print(f"J[{i}, {i+1}] = {J}")
        params[f"J[{i}, {i+1}]"] = J
    parameters.append(params)



# Delete existing temporary and output files
os.system('rm -f -r OUT')
os.system('rm -f -r TMP')
# Write Fortran-readable input files
MainFiles = mps.WriteFiles(parameters, Operators, H)
# Run MPS calculations
mps.runMPS(MainFiles)
Outputs = mps.ReadStaticObservables(parameters)
'''
rho_i_data = []
rho_ij_data = []
for output in Outputs: # a specific glist 41
    tmp_rho_i = []
    tmp_rho_ij = []
    for i in range(system_size):  # for each particle in system 12
        tmp_rho_i.append(output[f"rho_{i+1}"])
        tmp_rho_ij_row = []
        for j in range(system_size):# 12
            if j < i:
                rho = output[f"rho_{j+1}_{i+1}"]
            elif i < j:
                rho = output[f"rho_{i+1}_{j+1}"]
                rho = rho.conjugate().transpose()
            elif j == i:
                rho = np.zeros([4, 4], dtype=complex)
            tmp_rho_ij_row.append(rho)
        tmp_rho_ij.append(tmp_rho_ij_row)

    rho_i_data.append(tmp_rho_i)
    rho_ij_data.append(tmp_rho_ij)
    
np.save(f"Box_{system_size}_rho_i.npy", rho_i_data)
np.save(f"Box_{system_size}_rho_ij.npy", rho_ij_data)
'''

OPlist = []
for Output in Outputs:
    # Get the ferromagnetic order parameter from spin correlation zz at largest separation
    #OPlist = (Output['zz'][0][range(1,10,2)])
    OPlist.append(sum(Output['sx']))

# Make a plot
#plt.plot(range(0,10,2), OPlist, '-o')
plt.plot(glist, OPlist, '-o')

plt.xlabel("g")
plt.ylabel("magnetization")
plt.title("magnetization for Ising with random couplings")
plt.savefig('random_coupling_Ising_mag')

fi = datetime.now()
print('Duration: {}'.format(fi - st))