#=============================================================================
#
#  Copyright (c) Kitware, Inc.
#  All rights reserved.
#  See LICENSE.txt for details.
#
#  This software is distributed WITHOUT ANY WARRANTY; without even
#  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
#  PURPOSE.  See the above copyright notice for more information.
#
#=============================================================================

"""Generate ".nemesh" from meshes representing RGG models
"""

from collections import OrderedDict
import datetime
import os
import string

import smtk
import smtk.mesh
import smtk.model

class nemesh_writer:
    def __init__(self):
        self.indexed_input = 0 # 0....Indexed Input (numbered elements and nodes)
                               # 1....Not Indexed
        self.debug_printing = 0 # Integer...0-10 (0=no printing, 10=full debug printing)

    def _validate_parameters(self):
        if not (0 <= self.indexed_input <= 1):
            return False
        if not (0 <= self.debug_printing <= 10):
            return False
        return True

    def write(self, filename, mesh_resource, model=None):

        complete = False

        if not self._validate_parameters():
            return complete

#        with open(filename, 'w') as out:
        if True:
            out = open(filename, 'w')

            # "nemesh" files must have 9 header lines. We can put whatever we
            # want there, but we must put something there (or leave it blank).

            # header lines 1 - 2
            out.write('! Generated by CMB {}\n!\n'.format(
                datetime.datetime.now().strftime('%d-%b-%Y  %H:%M')))

            # header lines 3 - 4
            out.write('! ANL FINITE ELEMENT INPUT - NEMESH FILE DESCRIPTION\n!\n')

            # header line 5
            if mesh_resource.location():
                out.write('! mesh resource: %s\n' % mesh_resource.location())
            else:
                out.write('! mesh resource: (unsaved)\n')

            # header line 6
            if model and model.resource().location():
                out.write('! model resource: %s\n' % model_resource().location())
            elif model:
                out.write('! model resource: (unsaved)\n')
            else:
                out.write('! model resource: (none)\n')

            # header lines 7 - 9
            out.write('!\n' * 3)

            # control info
            out.write('%d %d\t\t! CARD TYPE 1\t control info\n' % (self.indexed_input,
                                                                   self.debug_printing))

            # mesh info

            # we are only interested in meshes that have associated domains
            domains = mesh_resource.domains()
            meshes_by_domain = OrderedDict()
            for domain in domains:
                meshes_by_domain[domain] = mesh_resource.meshes(domain)

            neumanns = mesh_resource.neumanns()
            meshes_by_neumann = OrderedDict()
            n_boundary_surfaces = 0
            for neumann in neumanns:
                meshes_by_neumann[neumann] = mesh_resource.meshes(neumann)
                n_boundary_surfaces = \
                  n_boundary_surfaces + meshes_by_neumann[neumann].cells().size()

            meshes_with_domains = smtk.mesh.MeshSet()
            for mesh in meshes_by_domain.values():
                meshes_with_domains.append(mesh)

            tessellation = smtk.mesh.Tessellation(True, False)
            tessellation.extract(meshes_with_domains)

            constants = smtk.mesh.MeshConstants()
            constants.extractDomain(meshes_with_domains)

            cell_types = tessellation.cellTypes()
            cell_connectivity = tessellation.connectivity()
            cell_domains = constants.cellData()
            points = tessellation.points()

            out.write('%d %d 0 %d\t\t! CARD TYPE 2\t mesh info\n' %
                          (len(cell_types), len(points)/3, n_boundary_surfaces))

            cell_type = {
                        int(smtk.mesh.Line): 1,
                        int(smtk.mesh.Triangle): 5,
                        int(smtk.mesh.Quad): 10,
                        int(smtk.mesh.Tetrahedron): 15,
                        int(smtk.mesh.Wedge): 20,
                        int(smtk.mesh.Hexahedron): 25
                    }

            # element info
            for cell_index in range(len(cell_types)):
                if cell_index == 0:
                    if self.indexed_input == 0:
                        out.write('%d\t%d\t%d\t\t! CARD TYPE 3\t element info\n' %
                                   ((cell_index + 1), cell_type.get(cell_types[cell_index]),
                                    cell_domains[cell_index]))
                    else:
                        out.write('%d\t%d\t\t! CARD TYPE 3\t element info\n' %
                                   (cell_type.get(cell_types[cell_index]),
                                    cell_domains[cell_index]))
                else:
                    if self.indexed_input == 0:
                        out.write('%d\t%d\t%d\n' %
                                   (cell_index + 1, cell_type.get(cell_types[cell_index]),
                                    cell_domains[cell_index]))
                    else:
                        out.write('%d\t%d\n' %
                                   (cell_type.get(cell_types[cell_index]),
                                    cell_domains[cell_index]))

            # connectivity info
            connectivity_index = 0
            for cell_index in range(len(cell_types)):
                if self.indexed_input == 0:
                    out.write('%d\t' % (cell_index + 1))
                for conn_index in range(cell_connectivity[connectivity_index]):
                    if conn_index != 0:
                        out.write('\t')
                    out.write('%d' % (cell_connectivity[connectivity_index + conn_index + 1] + 1))

                if cell_index == 0:
                    out.write('\t! CARD TYPE 4\t connectivity\n')
                else:
                    out.write('\n')

                connectivity_index = connectivity_index + cell_connectivity[connectivity_index] + 1

            # node info
            coordinate_index = 0
            for point_index in range(len(points)/3):
                if self.indexed_input == 0:
                    out.write('%d\t' % (point_index + 1))
                out.write('%f\t%f\t%f' % (points[coordinate_index + 0],
                                           points[coordinate_index + 1],
                                           points[coordinate_index + 2]))
                if point_index == 0:
                    out.write('\t\t! CARD TYPE 5\t node info\n')
                else:
                    out.write('\n')
                coordinate_index = coordinate_index + 3

            # boundary data
            boundary_index = 0
            for neumann, mesh in meshes_by_neumann.items():

                indices = smtk.mesh.CanonicalIndices()
                indices.extract(mesh, meshes_with_domains)
                reference_cells = indices.referenceCellIndices()
                canonical_indices = indices.canonicalIndices()

                for reference_cell, canonical_index in zip(reference_cells, canonical_indices):
                    if self.indexed_input == 0:
                        out.write('%d\t%d\t%d\t%d\n' % ((boundary_index + 1), (reference_cell + 1),
                                                        (canonical_index + 1), 0))
                    else:
                        out.write('%d\t%d\t%d\n' % ((reference_cell + 1), (canonical_index + 1), 0))
                    boundary_index = boundary_index + 1

            out.write('\t\t! NECESSARY BLANK LINE AFTER DATA\n')

            complete = True

        return complete
