#################################################################
#                                                               #
# Template code to simulate contact interaction between         #
# a finite element truss structure and a rigid plane            #
# using node-to-segment (NTS) contact elements.                 #
#                                                               #
# "Contact Mechanics and Elements of Tribology" (CMET) course   #
#                                                               #
# V.A. Yastrebov, CNRS                                          #
# MINES Paris, PSL University, Centre des materiaux             #
# Feb 2021-Feb 2023                                             #
#                                                               #
# Licence: CC0                                                  #
#                                                               #
#################################################################

import numpy as np
import matplotlib.pyplot as plt

#################################
##                             ##
##    Classes & Functions      ##
##                             ##
#################################


class ELEMENT:
    def __init__(self,e_id,n1_id,n2_id):
      self.e_id = e_id
      self.n1_id = n1_id
      self.n2_id = n2_id
    def construct_matrix(self):
      return 0

class NODE:
   def __init__(self,n_id, coord, displ):
     self.n_id = n_id
     if coord.shape == displ.shape:
       self.coord = coord
       self.displ = displ
     else:
       print("Error in node constructor")
       exit(1)

class MESH:
   def __init__(self, nodes, elements):
     self.nodes = nodes
     self.num_nodes = len(nodes)
     self.elements = elements
     self.num_el = len(elements)
     self.dim = 2
     self.dirichlet_bc = np.array([[]])

   def compute_right_part(self):
     self.R = np.zeros(self.dim*self.num_nodes)
     for el in self.elements:
       i = el.n1_id*dim
       j = el.n1_id*dim+1
       l = el.n2_id*dim
       n = el.n2_id*dim+1
     # Residual vector
       Rel = k * np.array([nodes[el.n1_id].displ[0] - nodes[el.n2_id].displ[0],nodes[el.n1_id].displ[1] - nodes[el.n2_id].displ[1],-(nodes[el.n1_id].displ[0] - nodes[el.n2_id].displ[0]),-(nodes[el.n1_id].displ[1] - nodes[el.n2_id].displ[1])])
       self.R[i] += Rel[0]
       self.R[j] += Rel[1]
       self.R[l] += Rel[2]
       self.R[n] += Rel[3]
      
   def matrix_assembly(self):
     self.K = np.zeros((self.dim*self.num_nodes,self.dim*self.num_nodes))
     for el in self.elements:
       i = el.n1_id*dim
       j = el.n1_id*dim+1
       l = el.n2_id*dim
       n = el.n2_id*dim+1
     # Elemental stiffness matrix is 4x4 since every element has 2 nodes per element and 2 dofs per node
       Kel = np.array([[k,0,-k,0],[0,k,0,-k],[-k,0,k,0],[0,-k,0,k]])

       self.K[i,i] += Kel[0,0]
       self.K[i,j] += Kel[0,1]
       self.K[i,l] += Kel[0,2]
       self.K[i,n] += Kel[0,3]

       self.K[j,i] += Kel[1,0]
       self.K[j,j] += Kel[1,1]
       self.K[j,l] += Kel[1,2]
       self.K[j,n] += Kel[1,3]

       self.K[l,i] += Kel[2,0]
       self.K[l,j] += Kel[2,1]
       self.K[l,l] += Kel[2,2]
       self.K[l,n] += Kel[2,3]

       self.K[n,i] += Kel[3,0]
       self.K[n,j] += Kel[3,1]
       self.K[n,l] += Kel[3,2]
       self.K[n,n] += Kel[3,3]

   def add_Dirichlet_bc(self, n_id, dof, value):
     self.K[n_id*self.dim + dof,n_id*self.dim + dof] += bc_penalty

   def add_Dirichlet_bc_in_residual(self, n_id, dof, value):
     self.R[n_id*self.dim + dof] += bc_penalty * (self.nodes[n_id].displ[dof] - value)


#################################
##                             ##
##          Core code          ##
##                             ##
#################################

# Material property
# Stiffness of truss  F = k*dx
k = 10. 

# Boundary conditions
# Penalty factor to impose Dirichlet BC
bc_penalty = k*1000. 
# Will apply Dirichlet BC to these nodes and all their dofs
Dirichlet_bc_nodes = [0,1]
Dirichlet_bc_dofs = [0,1]

# Will apply Neumann BC to these nodes
Neumann_bc_nodes = [5]
Neumann_bc_dofs = [1]
Force_max = -15.

# Construct mesh
dim = 2
num_nodes = 11 # Use odd integer number to make it compatible with hardcoded mesh construction
length = 1.

# Rigid wall
# y = y_wall + tangent[1]/tangent[0] * (x-x_wall)
x_wall = 5
y_wall = -1
normal = np.array([-1,5])
normal = normal/np.linalg.norm(normal)

normal[0],normal[1]

tangent = normal.copy()
tangent[0] = normal[1]
tangent[1] = -normal[0]


coord = np.zeros(dim)
coord[0] = 0
coord[1] = 0 
node = NODE(0,coord.copy(),np.zeros(dim))
nodes = []
nodes.append(node)

# Hard coded mesh
# Hard coded nodes
xmin = 0
ymin = 0
ymax = 0
xmax = 0
for i in range(1,num_nodes):
   coord[0] = (i-1)*length
   coord[1] = length - float(((i-1) % 2)*length)
   if coord[0] > xmax:
        xmax = coord[0]
   if coord[1] > ymax:
        ymax = coord[1]
   node = NODE(i,coord.copy(),np.zeros(dim))
   nodes.append(node)

# Hard coded elements (truss elements)
elements = []
el = ELEMENT(0,0,1)
elements.append(el)
eid = 1
for j in range(0,int((num_nodes-3)/2.)):
  i = 2*j+1
  el = ELEMENT(eid,i-1,i+1)
  eid += 1
  elements.append(el)
  el = ELEMENT(eid,i+1,i)
  eid += 1
  elements.append(el)
  el = ELEMENT(eid,i,i+2)
  eid += 1
  elements.append(el)
  el = ELEMENT(eid,i+2,i+1)
  eid += 1
  elements.append(el)
# add 2 more elements
el = ELEMENT(eid,num_nodes-3,num_nodes-1)
elements.append(el)
eid += 1
el = ELEMENT(eid,num_nodes-2,num_nodes-1)
elements.append(el)

# Make a FE mesh out of it
mesh = MESH(nodes,elements)
mesh.matrix_assembly()

# Apply Dirichlet BC
for nbc in Dirichlet_bc_nodes:
  for dof in Dirichlet_bc_dofs:
     mesh.add_Dirichlet_bc(nbc, dof, 0)

# Apply Neumann BC

num_inc = 2
max_iter = 10
penalty = 1000*k
vscale=.01
convergence_tolerance = 0.01
TIME = np.linspace(0,1,num_inc)
Lmax = max(ymax,xmax)
xx = np.linspace(0,1.5*Lmax,2) # simply to plot the rigid wall

for time in TIME:
        # Reconstruct the matrix and apply again Dirichlet BC 
        mesh.matrix_assembly()
        for nbc in Dirichlet_bc_nodes:
          for dof in Dirichlet_bc_dofs:
            mesh.add_Dirichlet_bc(nbc, dof, 0)
        contact = False
        F = np.zeros(mesh.num_nodes*dim)
        for nbc in range(len(Neumann_bc_nodes)):
            F[Neumann_bc_nodes[nbc]*dim + Neumann_bc_dofs[nbc]] = Force_max*time


        U = np.linalg.solve(mesh.K,F)



        for nod in mesh.nodes:
          nod.displ = np.array([U[nod.n_id*dim],U[nod.n_id*dim+1]])
        mesh.compute_right_part()
        # Apply BC
        if True:
                for nbc in Dirichlet_bc_nodes:
                  for dof in Dirichlet_bc_dofs:
                        mesh.add_Dirichlet_bc_in_residual(nbc, dof, 0)
   
        #print("Res = ", mesh.R - F) 

        # Newton method
        for iterations in range(max_iter):
                fig,ax = plt.subplots()
                plt.xlim([-0,1.5*Lmax])
                plt.ylim([-1.5*Lmax/2.,1.5*Lmax/2.])
                plt.xlabel("x")
                plt.ylabel("y")
                plt.fill_between(xx,y_wall + tangent[1]/tangent[0]*(xx-x_wall),-10*Lmax/2.,facecolor="b") 
        
                # Plot elements to check 
                for el in mesh.elements:
                  n1 = mesh.nodes[el.n1_id] 
                  n2 = mesh.nodes[el.n2_id]
                  X = np.array([n1.coord[0],n2.coord[0]]) 
                  Y = np.array([n1.coord[1],n2.coord[1]])
                  x = np.array([n1.coord[0]+U[n1.n_id*dim],n2.coord[0]+U[n2.n_id*dim]]) 
                  y = np.array([n1.coord[1]+U[n1.n_id*dim+1],n2.coord[1]+U[n2.n_id*dim+1]]) 
                  plt.plot(X,Y,"--",color="b")
                  plt.plot(x,y,"o-",color="k")


                F = np.zeros(mesh.num_nodes*dim)
                for nbc in range(len(Neumann_bc_nodes)):
                    F[Neumann_bc_nodes[nbc]*dim + Neumann_bc_dofs[nbc]] = Force_max*time
                    node_id_F = Neumann_bc_nodes[nbc]
                    current_coord = mesh.nodes[node_id_F].coord + mesh.nodes[node_id_F].displ                                         
                    plt.quiver(current_coord[0], current_coord[1], 0,Force_max,color="g")
                mesh.matrix_assembly()
                for nbc in Dirichlet_bc_nodes:
                  for dof in Dirichlet_bc_dofs:
                    mesh.add_Dirichlet_bc(nbc, dof, 0)
                # Compute right hand part
                F = np.zeros(mesh.num_nodes*dim)
                for nbc in range(len(Neumann_bc_nodes)):
                    F[Neumann_bc_nodes[nbc]*dim + Neumann_bc_dofs[nbc]] = Force_max*time
                for nod in mesh.nodes:
                  nod.displ = np.array([U[nod.n_id*dim],U[nod.n_id*dim+1]])
                mesh.compute_right_part()
                # Apply BC
                if True:
                        for nbc in Dirichlet_bc_nodes:
                          for dof in Dirichlet_bc_dofs:
                                mesh.add_Dirichlet_bc_in_residual(nbc, dof, 0)

                # Contact detection
                contact = False
                for node in mesh.nodes:
                  # FIXME
                  # Do the contact detection
                  # For all nodes, you define the gap

                  # updated coordinates of the node
                  xi = node.coord[0] + node.displ[0]
                  yi = node.coord[1] + node.displ[1]

                  # projection of the node on the rigid wall
                  xii = (xi - x_wall)*tangent[0] + (yi - y_wall)*tangent[1]
                  proj_x = x_wall + tangent[0]*xii 
                  proj_y = y_wall + tangent[1]*xii 
                  plt.plot(proj_x,proj_y,"v",color="r")
                  # define the gap = (xi-proj).normal
                  gap = (xi - proj_x)*normal[0] + (yi - proj_y)*normal[1] 
                  if gap < 0:
                    # FIXME
                    # If contact is detected, adjust residual vector and tangent matrix
                    ii = node.n_id * mesh.dim      # UX of n_id-th node
                    jj = node.n_id * mesh.dim + 1  # UY of n_id-th node
                    # Residual vector contact element
                    # contact reaction force = penalty*|gap|*normal
                    F0x = F[ii]
                    F0y = F[jj]
                    F[ii] += abs(gap) * normal[0] * penalty
                    F[jj] += abs(gap) * normal[1] * penalty 
                    plt.quiver(xi,yi, -vscale*(F[ii]-F0x),-vscale*(F[jj]-F0y),color="r",scale=1)
                    
                    # Contact element tangent matrix
                    # U_i = U^x_i + U^y_i
                    # R^x_i = dWel/dU^x_i
                    # R^y_i = dWel/dU^y_i
                    # K^xx_i = d^2Wel/dU^x_i dU^x_i
                    # K^xy_i = d^2Wel/dU^x_i dU^y_i
                    # K^yy_i = d^2Wel/dU^y_i dU^y_i
                    # K^yx_i = d^2Wel/dU^y_i dU^x_i
                    mesh.K[ii][ii]     += penalty * normal[0] **2
                    mesh.K[ii][ii+1]   += penalty * normal[0] * normal[1]
                    mesh.K[ii+1][ii]   += penalty * normal[0] * normal[1]
                    mesh.K[ii+1][ii+1] += penalty * normal[1] **2 
                    # ========
                    contact = True
                # Plot forces
                plt.show()
                fig.savefig("Convergence.png")

                # Show the residual
                Residual = mesh.R - F
                print ("Iter: {0}, Residual norm: {1:5.3e}".format(iterations,np.linalg.norm(Residual)))
                if np.linalg.norm(mesh.R - F)/np.linalg.norm(F) < convergence_tolerance or np.linalg.norm(mesh.R - F) == 0:
                    print("Convergence reached.")
                    break
                elif contact:
                  Residual = mesh.R - F
                  dU = np.linalg.solve(mesh.K,-Residual)
                  U += dU
                else:
                  Residual = mesh.R - F
                  dU = np.linalg.solve(mesh.K,-Residual)
                  U += dU

