import numpy as np
import matplotlib.pyplot as plt
import os
import scipy.interpolate as interpolate
import sys

################################################################################################
#
# Initial code written by Wolfgang Szwillus, University of Kiel
# Small modifications by Bart Root, Delft University of Technology
#
# No rights can be derived from this code or using it.
# 
################################################################################################ 

# You can change this to get higher resolution
number_of_layers_crust = 20
number_of_layers_mantle = 10
bound_mesh = 80
radius_reference_model = 6371 # Choosen because we do this in the Benchmark study, maybe wise to change for you own model

print('number_of_layers_crust',number_of_layers_crust)
print('number_of_layers_mantle',number_of_layers_mantle)

################################## NEED TO ADD YOUR OWN FOLDER ##################################
#folder = '<LOCATION OF WINTERC-G>'
folder = '/home/thieulot/ASPECT/2021_root_paper/aspect/build/winterc-g/WINTERC_ver5.4_test_Bart/'

#################################################################################################
# NO CHANGES SHOULD BE NEEDED AFTER THIS LINE
#################################################################################################

def read_txt(fname):
    data = np.loadtxt(fname)
    x = np.unique(data[:,0])
    y = np.unique(data[:,1])
    return x,y,np.flipud(data[:,2].reshape((len(y),len(x))))

def parse(x,data,shape=(360,720)):
    """Turn element from layer tuple into grid
    """
    if type(x) == float or type(x) == int:
        return np.ones(shape)*x
    else:
        return data[x][2]

def layers_to_grids(layers,data,shape=(360,720)):
    """Turn information about each layer into a density grid.
    The top and bottom density are averaged in that process.
    """
    grids = []
    for i,layer in enumerate(layers):
        parsed_layer = tuple(map(lambda x:parse(x,data,shape=shape),layer))
        if len(layer)==3:
            top,bottom,rho = parsed_layer
        else:
            top,bottom,rho1,rho2 = parsed_layer
            rho = 0.5*(rho1+rho2)
        grids.append((top,bottom,rho))
    return grids

fnames = ['ETOPO2_km_continental.xyz',
 'ETOPO2_km_depth_Bed.xyz',
 'ETOPO2_km_depth_Ice.xyz',
 'Global_Moho_CRUST1.0.xyz',
 'rho_20km_out.xyz',
 'rho_36km_out.xyz',
 'rho_56km_out.xyz',
 'rho_80km_out.xyz', 
 'rho_c_out.xyz',
 'rho_submoho_out.xyz',
 'rho_110km_out.xyz',
 'rho_150km_out.xyz',
 'rho_200km_out.xyz',
 'rho_260km_out.xyz',
 'rho_330km_out.xyz',
 'rho_400km_out.xyz',
 'z_20km',
 'z_36km',
 'z_56km',
 'z_80km']
names = ['continental',
        'bed',
        'ice',
        'moho',
        'rho_20',
        'rho_36',
        'rho_56',
        'rho_80',
        'rho_c',
        'rho_submoho',
        'rho_110',
        'rho_150',
        'rho_200',
        'rho_260',
        'rho_330',
        'rho_400',
        'z_20',
        'z_36',
        'z_56',
        'z_80']

# Construct data dictionary
data = dict()
for i in range(len(names)):
    data[names[i]] = read_txt(os.path.join(folder,fnames[i]))

# Remove ice thickness everywhere, where it is less than 10 m
# This affects mostly the coastal regions of Greenland and Antarctica
# Plus it removes the weird structures in the ice model

incorrect_ice =  (data['ice'][2] - data['bed'][2]) > -0.01
new_ice = data['ice'][2].copy()
new_ice[incorrect_ice] = data['bed'][2][incorrect_ice]
data['ice'] = (data['ice'][0],data['ice'][1],new_ice)

# Each layer is described by a 3-tuple
# Order is top - bottom - density (top) - density (bottom)
# OR
# top - bottom - density (if the density is constant)
# every element of the tuple is a string or a float, if
# it is a string it refers to keys of data, otherwise
# it is a constant

layers = [('continental','ice',1030.0),
         ('ice','bed',910.0),
          ('bed','moho','rho_c'),
          ('moho','z_20','rho_submoho','rho_20'),
          ('z_20','z_36','rho_20','rho_36'),
          ('z_36','z_56','rho_36','rho_56'),
          ('z_56','z_80','rho_56','rho_80'),
          ('z_80',110.0,'rho_80','rho_110'),
          (110.0,150.0,'rho_110','rho_150'),
          (150.0,200.0,'rho_150','rho_200'),
          (200.0,260.0,'rho_200','rho_260'),
          (260.0,330.0,'rho_260','rho_330'),
          (330.0,400.0,'rho_330','rho_400')
         ]
       
grids = layers_to_grids(layers,data)
grids = np.array(grids)
lon = data['continental'][0]
lat = data['continental'][1]
loni,lati = np.meshgrid(lon,lat,)

###############################################################################
##### Crust
###############################################################################

# Depth refers to the center of the cell (vertically)
z = np.linspace(grids[0][1].min(),bound_mesh,number_of_layers_crust+1)
nx,ny,nz = len(lon),len(lat),len(z)
cube = np.zeros((nz,ny,nx))
lon_cube = np.zeros(cube.shape)
lat_cube = np.zeros(cube.shape)
z_cube = np.zeros(cube.shape)
lon_cube[:,:,:],lat_cube[:,:,:],z_cube[:,:,:] = loni,lati,z[:,None,None]

# Construct grid cube file (version 3)
# Equi-thickness layers

dz = z[1]-z[0]

for j in range(len(z)):
    if j  == len(z)-1:
       z1 = z[j] - 0.5 * dz
       z2 = z[j]
    else:
       z1 = z[j] - 0.5 * dz
       z2 = z[j] + 0.5 * dz 
    U = np.maximum(z1,np.minimum(z2,grids[:,1,:,:]))
    L = np.minimum(z2,np.maximum(z1,grids[:,0,:,:]))
    contribution = (radius_reference_model-U)**3 - (radius_reference_model-L)**3
    cube[j,:,:] = (contribution * grids[:,2,:,:]).sum(0)/((radius_reference_model-z2)**3-(radius_reference_model-z1)**3)

# FOR TESTING TOTAL MASS OF THE MODEL
total_mass = np.sum((grids[:,1,:,:] - grids[:,0,:,:]) * grids[:,2,:,:],0)
total_mass_cube = (z[1]-z[0]) * cube.sum(0)

# Save data cube to ASCII file that can be imported to ASPECT
fid_save = open('Cube_%d_crust_%d_mantle_%d_bound.xyz'%(number_of_layers_crust,number_of_layers_mantle,bound_mesh),'w')

#np.savetxt(fid_save,np.vstack((lon_cube.flatten(),lat_cube.flatten(),z_cube.flatten(),cube.flatten())).T,fmt='%.1f %.1f %.3f %.3f')

############################
#conversion to ASPECT format
############################

nzzz=number_of_layers_crust+number_of_layers_mantle+1

myfile=open("bench4c.ascii", "w")
myfile.write('# POINTS: %d %d %d \n' %(nzzz,nx,ny))

ass_z_cube   = np.zeros((nzzz,ny,nx))
ass_lat_cube = np.zeros((nzzz,ny,nx))
ass_lon_cube = np.zeros((nzzz,ny,nx))
ass_cube     = np.zeros((nzzz,ny,nx))

ass_z_cube[0:nz,:,:]  =z_cube[0:nz,:,:]
ass_lon_cube[0:nz,:,:]=lon_cube[0:nz,:,:]
ass_lat_cube[0:nz,:,:]=lat_cube[0:nz,:,:]
ass_cube[0:nz,:,:]    =cube[0:nz,:,:]

###############################################################################
#### Mantle
###############################################################################

# Depth refers to the center of the cell (vertically)
z = np.linspace(bound_mesh,grids[-1][1].max(),number_of_layers_mantle+1)
z = np.delete(z, 0)
nx,ny,nz = len(lon),len(lat),len(z)
cube = np.zeros((nz,ny,nx))
lon_cube = np.zeros(cube.shape)
lat_cube = np.zeros(cube.shape)
z_cube = np.zeros(cube.shape)
lon_cube[:,:,:],lat_cube[:,:,:],z_cube[:,:,:] = loni,lati,z[:,None,None]

# Construct grid cube file (version 3)
# Equi-thickness layers

dz = z[1]-z[0]

for j in range(len(z)):
    if j  == len(z)-1:
       z1 = z[j] - 0.5 * dz
       z2 = z[j]
    else:
       z1 = z[j] - 0.5 * dz
       z2 = z[j] + 0.5 * dz
    U = np.maximum(z1,np.minimum(z2,grids[:,1,:,:]))
    L = np.minimum(z2,np.maximum(z1,grids[:,0,:,:]))
    contribution = (radius_reference_model-U)**3 - (radius_reference_model-L)**3
    cube[j,:,:] = (contribution * grids[:,2,:,:]).sum(0)/((radius_reference_model-z2)**3-(radius_reference_model-z1)**3)

# FOR TESTING TOTAL MASS OF THE MODEL
total_mass = np.sum((grids[:,1,:,:] - grids[:,0,:,:]) * grids[:,2,:,:],0)
total_mass_cube = (z[1]-z[0]) * cube.sum(0)

# Save data cube to ASCII file that can be imported to ASPECT
#fid_save.write("\n")
#np.savetxt(fid_save,np.vstack((lon_cube.flatten(),lat_cube.flatten(),z_cube.flatten(),cube.flatten())).T,fmt='%.1f %.1f %.3f %.3f')
fid_save.close()

############################

ass_z_cube  [number_of_layers_crust+1+0:number_of_layers_crust+1+nz,:,:]  =z_cube[0:nz,:,:]
ass_lon_cube[number_of_layers_crust+1+0:number_of_layers_crust+1+nz,:,:]=lon_cube[0:nz,:,:]
ass_lat_cube[number_of_layers_crust+1+0:number_of_layers_crust+1+nz,:,:]=lat_cube[0:nz,:,:]
ass_cube    [number_of_layers_crust+1+0:number_of_layers_crust+1+nz,:,:]    =cube[0:nz,:,:]

ass_z_cube*=1000
ass_z_cube[:]=6371000-ass_z_cube[:]
ass_lon_cube/=(360/2/np.pi)
ass_lat_cube+=90
ass_lat_cube/=(180/np.pi)
ass_cube/=3300
ass_cube[:]=1-ass_cube[:]
ass_cube/=3e-5

for j in range(0,ny):
    for i in range(0,nx):
        for k in range(nzzz-1,-1,-1):
            myfile.write("%.3f %.5f %.5f %.6f \n"  %( ass_z_cube[k,j,i],ass_lon_cube[k,j,i],ass_lat_cube[k,j,i],ass_cube[k,j,i] ))

########################### print radii file ##################################################

crust_z = np.linspace(grids[0][1].min(),bound_mesh,number_of_layers_crust+1)
mantle_z = np.linspace(bound_mesh,grids[-1][1].max(),number_of_layers_mantle+1)
mantle_z = np.delete(mantle_z, 0)

radii = (radius_reference_model - np.concatenate([crust_z, mantle_z]))*1e3

fid_depth = open('Radii_%d_crust_%d_mantle_%d_bound.xyz'%(number_of_layers_crust,number_of_layers_mantle,bound_mesh),'w')
#np.savetxt(fid_depth,radii[::-1],fmt='%.f')
#aspect does not need the first nor last radius!! 
for i in range(radii.size-2,0,-1):
    fid_depth.write("%.2f , " % radii[i])

fid_depth.close()

