#!/usr/bin/env python
# -*- coding: UTF-8 no BOM -*-

# This script is used for the post processing of the results achieved by the spectral method.
# As it reads in the data coming from "materialpoint_results", it can be adopted to the data
# computed using the FEM solvers. Until now, its capable to handle elements with one IP in a regular order
# written by M. Diehl, m.diehl@mpie.de

import os,sys,threading,re,numpy,time, postprocessingMath
from optparse import OptionParser, OptionGroup, Option, SUPPRESS_HELP

# -----------------------------
class extendedOption(Option):
# -----------------------------
# used for definition of new option parser action 'extend', which enables to take multiple option arguments
# taken from online tutorial http://docs.python.org/library/optparse.html
    
    ACTIONS = Option.ACTIONS + ("extend",)
    STORE_ACTIONS = Option.STORE_ACTIONS + ("extend",)
    TYPED_ACTIONS = Option.TYPED_ACTIONS + ("extend",)
    ALWAYS_TYPED_ACTIONS = Option.ALWAYS_TYPED_ACTIONS + ("extend",)

    def take_action(self, action, dest, opt, value, values, parser):
        if action == "extend":
            lvalue = value.split(",")
            values.ensure_value(dest, []).extend(lvalue)
        else:
            Option.take_action(self, action, dest, opt, value, values, parser)

            
# -----------------------------
class backgroundMessage(threading.Thread):
# -----------------------------
    
    def __init__(self):
        threading.Thread.__init__(self)
        self.message = ''
        self.new_message = ''
        self.counter = 0
        self.symbols = ['- ', '\ ', '| ', '/ ']
        self.waittime = 0.5
    
    def __quit__(self):
        length = len(self.message) + len(self.symbols[self.counter])
        sys.stderr.write(chr(8)*length + ' '*length + chr(8)*length)
        sys.stderr.write('')
    
    def run(self):
        while not threading.enumerate()[0]._Thread__stopped:
            time.sleep(self.waittime)
            self.update_message()
        self.__quit__()

    def set_message(self, new_message):
        self.new_message = new_message
        self.print_message()
    
    def print_message(self):
        length = len(self.message) + len(self.symbols[self.counter])
        sys.stderr.write(chr(8)*length + ' '*length + chr(8)*length)                                # delete former message
        sys.stderr.write(self.symbols[self.counter] + self.new_message)                             # print new message
        self.message = self.new_message
        
    def update_message(self):
        self.counter = (self.counter + 1)%len(self.symbols)
        self.print_message()



def outStdout(cmd,locals):
  if cmd[0:3] == '(!)':
    exec(cmd[3:])
  elif cmd[0:3] == '(?)':
    cmd = eval(cmd[3:])
    print cmd
  else:
    print cmd
  return

def outFile(cmd,locals):
  if cmd[0:3] == '(!)':
    exec(cmd[3:])
  elif cmd[0:3] == '(?)':
    cmd = eval(cmd[3:])
    locals['filepointer'].write(cmd+'\n')
  else:
    locals['filepointer'].write(cmd+'\n')
  return


def output(cmds,locals,dest):
  for cmd in cmds:
    if isinstance(cmd,list):
      output(cmd,locals,dest)
    else:
      {\
      'File': outFile,\
      'Stdout': outStdout,\
      }[dest](str(cmd),locals)
  return


def transliterateToFloat(x):
  try:
    return float(x)
  except:
    return 0.0

# ++++++++++++++++++++++++++++++++++++++++++++++++++++
def vtk_writeASCII_mesh(mesh,data,res):
# ++++++++++++++++++++++++++++++++++++++++++++++++++++
  """ function writes data array defined on a hexahedral mesh (geometry) """
  N1 = (res[0]+1)*(res[1]+1)*(res[2]+1)
  N  = res[0]*res[1]*res[2]
  
  cmds = [\
          '# vtk DataFile Version 3.1',
          'powered by 3Dvisualize',
          'ASCII',
          'DATASET UNSTRUCTURED_GRID',
          'POINTS %i float'%N1,
          [[['\t'.join(map(str,mesh[i,j,k])) for i in range(res[0]+1)] for j in range(res[1]+1)] for k in range(res[2]+1)],
          'CELLS %i %i'%(N,N*9),
          ]

# cells
  for i in range (res[2]):
    for j in range (res[1]):
      for k in range (res[0]):
        base = i*(res[1]+1)*(res[2]+1)+j*(res[1]+1)+k
        cmds.append('8 '+'\t'.join(map(str,[ \
                                            base,
                                            base+1,
                                            base+res[1]+2,
                                            base+res[1]+1,
                                            base+(res[1]+1)*(res[2]+1),
                                            base+(res[1]+1)*(res[2]+1)+1,
                                            base+(res[1]+1)*(res[2]+1)+res[1]+2,
                                            base+(res[1]+1)*(res[2]+1)+res[1]+1,
                                          ])))
  cmds += [\
           'CELL_TYPES %i'%N,
           ['12']*N,
           'CELL_DATA %i'%N,
          ]
  
  for type in data:
    for item in data[type]:
      cmds += [\
               '%s %s float'%(type.upper(),item),
               'LOOKUP_TABLE default',
               [[['\t'.join(map(str,data[type][item][:,j,k]))] for j in range(res[1])] for k in range(res[2])],
              ]

#   vtk = open(filename, 'w')
#   output(cmd,{'filepointer':vtk},'File')
#   vtk.close()

  return cmds
  
# ++++++++++++++++++++++++++++++++++++++++++++++++++++
def gmsh_writeASCII_mesh(mesh,data,res):
# ++++++++++++++++++++++++++++++++++++++++++++++++++++
  """ function writes data array defined on a hexahedral mesh (geometry) """
  N1 = (res[0]+1)*(res[1]+1)*(res[2]+1)
  N  = res[0]*res[1]*res[2]
  
  cmds = [\
          '$MeshFormat',
          '2.1 0 8',
          '$EndMeshFormat',
          '$Nodes',
          '%i float'%N1,
          [[['\t'.join(map(str,l,mesh[i,j,k])) for l in range(1,N1+1) for i in range(res[0]+1)] for j in range(res[1]+1)] for k in range(res[2]+1)],
          '$EndNodes',
          '$Elements',
          '%i'%N,
          ]

  n_elem = 0
  for i in range (res[2]):
    for j in range (res[1]):
      for k in range (res[0]):
        base = i*(res[1]+1)*(res[2]+1)+j*(res[1]+1)+k
        n_elem +=1
        cmds.append('\t'.join(map(str,[ \
                                            n_elem,
                                            '5',
                                            base,
                                            base+1,
                                            base+res[1]+2,
                                            base+res[1]+1,
                                            base+(res[1]+1)*(res[2]+1),
                                            base+(res[1]+1)*(res[2]+1)+1,
                                            base+(res[1]+1)*(res[2]+1)+res[1]+2,
                                            base+(res[1]+1)*(res[2]+1)+res[1]+1,
                                          ])))
  cmds += [\
           'ElementData',
           '1',
           '%s'%item,     # name of the view
           '0.0',         # thats the time value
           '3', 
           '0',           # time step
           '1',
           '%i'%N
          ]
  
  for type in data:
    for item in data[type]:
      cmds += [\
               '%s %s float'%(type.upper(),item),
               'LOOKUP_TABLE default',
               [[['\t'.join(map(str,data[type][item][:,j,k]))] for j in range(res[1])] for k in range(res[2])],
              ]

#   vtk = open(filename, 'w')
#   output(cmd,{'filepointer':vtk},'File')
#   vtk.close()

  return cmds
 
# +++++++++++++++++++++++++++++++++++++++++++++++++++
def vtk_writeASCII_points(coordinates,data,res):
# +++++++++++++++++++++++++++++++++++++++++++++++++++
  """ function writes data array defined on a point field """
  N  = res[0]*res[1]*res[2]
  
  cmds = [\
          '# vtk DataFile Version 3.1',
          'powered by 3Dvisualize',
          'ASCII',
          'DATASET UNSTRUCTURED_GRID',
          'POINTS %i float'%N,
          [[['\t'.join(map(str,coordinates[i,j,k])) for i in range(res[0])] for j in range(res[1])] for k in range(res[2])],
          'CELLS %i %i'%(N,N*2),
          ['1\t%i'%i for i in range(N)],
          'CELL_TYPES %i'%N,
          ['1']*N,
          'POINT_DATA %i'%N,
         ]
  
  for type in data:
    for item in data[type]:
      cmds += [\
               '%s %s float'%(type.upper(),item),
               'LOOKUP_TABLE default',
               [[['\t'.join(map(str,data[type][item][:,j,k]))] for j in range(res[1])] for k in range(res[2])]
              ]

  return cmds
  
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
def vtk_writeASCII_box(diag,defgrad):
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
  """ corner box for the average defgrad """
  points = numpy.array([\
                         [0.0,0.0,0.0,],\
                         [diag[0],0.0,0.0,],\
                         [diag[0],diag[1],0.0,],\
                         [0.0,diag[1],0.0,],\
                         [0.0,0.0,diag[2],],\
                         [diag[0],0.0,diag[2],],\
                         [diag[0],diag[1],diag[2],],\
                         [0.0,diag[1],diag[2],],\
                       ])

  cmds = [\
    '# vtk DataFile Version 3.1',
    'powered by 3Dvisualize',
    'ASCII',
    'DATASET UNSTRUCTURED_GRID',
    'POINTS 8 float',
    ['\t'.join(map(str,numpy.dot(defgrad_av,points[p]))) for p in range(8)],
    'CELLS 8 16',
    ['1\t%i'%i for i in range(8)],
    'CELL_TYPES 8',
    ['1']*8,
  ]
  
  return cmds



# ----------------------- MAIN -------------------------------

parser = OptionParser(option_class=extendedOption, usage='%prog [options] datafile', description = """
Produce VTK file from data field.

$Id$
""")
parser.add_option('-s', '--scalar', action='extend', dest='scalar', type='string', \
                  help='list of scalars to visualize')
parser.add_option('-d', '--deformation', dest='defgrad', type='string', \
                  help='heading of deformation gradient columns [%default]')
parser.add_option('-g', '--grain', dest='grain', type='int', \
                  help='grain of interest [%default]')

parser.set_defaults(defgrad = 'f')
parser.set_defaults(grain = 1)
parser.set_defaults(scalar = [])
parser.set_defaults(vector = [])
parser.set_defaults(tensor = [])

(options, args) = parser.parse_args()

for filename in args:
  if not os.path.exists(filename):
    continue
  file = open(filename)
  content = file.readlines()
  file.close()
  m = re.search('(\d+)\shead',content[0],re.I)
  if m == None:
    continue
  print filename

  headrow = int(m.group(1))
  headings = content[headrow].split()
  column = {}
  maxcol = 0
  
  for col,head in enumerate(headings):
    if head == 'ip.x':
      ipcol = col
      maxcol = max(maxcol,col+3)
      break

  if ipcol < 0:
    print 'missing ip coordinates..!'
    continue
    
  column['tensor'] = {}
  for label in [options.defgrad] + options.tensor:
    column['tensor'][label] = -1
    for col,head in enumerate(headings):
      if head == label or head == '%i_1_%s'%(options.grain,label):
        column['tensor'][label] = col
        maxcol = max(maxcol,col+9)
        break
      
  if column['tensor'][options.defgrad] < 0:
    print 'missing deformation gradient..!'
    continue

  column['vector'] = {}
  for label in options.vector:
    column['vector'][label] = -1
    for col,head in enumerate(headings):
      if head == label or head == '%i_1_%s'%(options.grain,label):
        column['vector'][label] = col
        maxcol = max(maxcol,col+3)
        break

  column['scalar'] = {}
  for label in options.scalar:
    column['scalar'][label] = -1
    for col,head in enumerate(headings):
      if head == label or head == '%i_%s'%(options.grain,label):
        column['scalar'][label] = col
        maxcol = max(maxcol,col+1)
        break

  
  values = numpy.array([map(transliterateToFloat,line.split()[:maxcol]) for line in content[headrow+1:]],'d')
  N = len(values)
  grid = [{},{},{}]
  for i in range(N):
    grid[0][str(values[i,ipcol+0])] = True
    grid[1][str(values[i,ipcol+1])] = True
    grid[2][str(values[i,ipcol+2])] = True

  res = numpy.array([len(grid[0]),\
                     len(grid[1]),\
                     len(grid[2]),],'i')
  dim = numpy.array([max(map(float,grid[0].keys()))-min(map(float,grid[0].keys())),\
                     max(map(float,grid[1].keys()))-min(map(float,grid[1].keys())),\
                     max(map(float,grid[1].keys()))-min(map(float,grid[1].keys())),]*res/(res-numpy.ones(3)), 'd')

  print 'resolution',res
  print 'dimension',dim
  defgrad_av = postprocessingMath.tensor_avg(res[0],res[1],res[2],\
                              numpy.reshape(values[:,column['tensor'][options.defgrad]:
                              column['tensor'][options.defgrad]+9],
                                                   (res[0],res[1],res[2],3,3)))
  centroids = postprocessingMath.deformed_fft(res[0],res[1],res[2],dim,\
                                     numpy.reshape(values[:,column['tensor'][options.defgrad]:
                                                            column['tensor'][options.defgrad]+9],
                                                   (res[0],res[1],res[2],3,3)),defgrad_av,1.0)
  ms = postprocessingMath.mesh(res[0],res[1],res[2],dim,defgrad_av,centroids)

  fields = {\
             'tensors': {},\
             'vectors': {},\
             'scalars': {},\
           }
  for me in options.tensor:
    fields['tensors'][me] = numpy.reshape(values[:,column['tensor'][me]:column['tensor'][me]+9],(res[0],res[1],res[2],3,3))
  for me in options.vector:
    fields['vectors'][me] = numpy.reshape(values[:,column['vector'][me]:column['vector'][me]+3],(res[0],res[1],res[2],3))
  for me in options.scalar:
    fields['scalars'][me] = numpy.reshape(values[:,column['scalar'][me]],(res[0],res[1],res[2]))

  out = {}
  out['mesh']   = vtk_writeASCII_mesh(ms,fields,res)
  out['points'] = vtk_writeASCII_points(centroids,fields,res)
  out['box']    = vtk_writeASCII_box(dim,defgrad_av)
  
  for what in out.keys():
    vtk = open(os.path.splitext(filename)[0]+'_%s.vtk'%what, 'w')
    output(out[what],{'filepointer':vtk},'File')
    vtk.close()