#=============================================================================
#
#  Copyright (c) Kitware, Inc.
#  All rights reserved.
#  See LICENSE.txt for details.
#
#  This software is distributed WITHOUT ANY WARRANTY; without even
#  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
#  PURPOSE.  See the above copyright notice for more information.
#
#=============================================================================

"""Top level writer for PyARC son files
"""

import datetime
import string

import smtk
import smtk.model

import convert_to_pyarc as pyarcHelper

class PyARCWriter:
    '''Top level writer for PyARC (.son) files
    '''
    def __init__(self):
        self.logger = None
        self.scope = None

    def __call__(self, filename, model_entity, mcc3_attribute = None, dif3d_attribute = None):
        self.tab = '    '

        complete = False
        with open(filename, 'w') as out:
            self.out = out

            self.out.write('% Generated by CMB {}\n'.format(
                datetime.datetime.now().strftime('%d-%b-%Y  %H:%M')))

            out.write('=arc\n')
            indent = self.tab

            # Write geometry block
            self._begin_block('geometry')
            self.out.write('{}% Description of the core model\n'.format(indent))

            self.write_materials(indent, model_entity)

            # Not until core is processed then would we know the surfaces info.
            # So we combine them here.
            self.write_surfaces_and_regions_reactor(indent, model_entity)

            self._end_block()

            # Write calculations block
            self._begin_block('calculations')
            self.out.write('{}% The list of calculations\n'.format(indent))
            if mcc3_attribute:
                self.write_mcc3(indent, model_entity, mcc3_attribute)
            if dif3d_attribute:
                self.write_dif3d(indent, dif3d_attribute)
            self._end_block()

            out.write('end\n')
            complete = True

        valid = self.check_if_valid(filename)

        return complete and valid


    def write(self, scope):
        ''''''
        self.logger = scope.logger
        self.scope = scope
        self.tab = '    '

        complete = False
        with open(scope.output_path, 'w') as out:
            self.out = out

            self.out.write('% Generated by CMB {}\n'.format(
                datetime.datetime.now().strftime('%d-%b-%Y  %H:%M')))

            out.write('=arc\n')
            indent = self.tab

            # Write geometry block
            self._begin_block('geometry')
            self.out.write('{}% Description of the core model\n'.format(indent))

            self.write_materials(indent, scope.model_entity)

            # Not until core is processed then would we know the surfaces info.
            # So we combine them here.
            self.write_surfaces_and_regions_reactor(indent, scope.model_entity)

            self._end_block()

            # Write calculations block
            self._begin_block('calculations')
            self.out.write('{}% The list of calculations\n'.format(indent))
            if 'mcc3' in self.scope.solver_list:
                mcc3_list = self.scope.sim_atts.findAttributes('mcc3')
                if len(mcc3_list) > 1:
                    msg = 'More than one mcc3 instance -- ignoring all'
                    print 'WARNING:', msg
                    self.logger.addWarning(msg)
                elif len(mcc3_list) != 0:
                    self.write_mcc3(indent, scope.model_entity, mcc3_list[0])
            if 'dif3d' in self.scope.solver_list:
                dif3d_list = self.scope.sim_atts.findAttributes('dif3d')
                if len(dif3d_list) > 1:
                    msg = 'More than one dif3d instance -- ignoring all'
                    print 'WARNING:', msg
                    self.logger.addWarning(msg)
                elif len(dif3d_list) != 0:
                    self.write_dif3d(indent, dif3d_list[0])
            self._end_block()

            out.write('end\n')
            complete = True


        valid = self.check_if_valid(scope.output_path)

        return complete and valid


    def _begin_block(self, name, indent=''):
        self.out.write('{}{}{{\n'.format(indent, name))

    def _end_block(self, indent=''):
        self.out.write('{}}}\n'.format(indent))

    def _write_item(self, attribute, name, indent='', label = '', quoted = False, forceList = False):
        item = attribute.find(name)
        if not item or not item.isEnabled():
            return
        if not label:
            label = name
        value = ''
        if item.type() == smtk.attribute.Item.VoidType:
            value = 'true'
        elif quoted:
            value = '"' + item.value(0) + '"'
        elif item.numberOfValues() > 1 or forceList:
            value = '[' + ' '.join([str(item.value(i)) for i in xrange(item.numberOfValues())]) + ']'
        else:
            value = item.value(0)
        self.out.write('{}{} = {}\n'.format(indent, label, value))

    def write_materials(self, indent, model_entity):
        self._begin_block('materials', indent)

        # Material descriptions are son-formated strings on the model
        material_descriptions = smtk.model.Model(model_entity).stringProperty(
            smtk.session.rgg.Material.label)

        # For each material description...
        for material_description in material_descriptions:

            # ...we tokenize the description by line...
            material_description = material_description.replace("\\n","\n")
            tokenized = material_description.split('\n')

            # ... and do some string chicanery to get our indents right.
            for token in tokenized:
                self.out.write('{}{}\n'.format(indent, token))
                if token[-1] == '{':
                    indent = indent + self.tab
                elif token[-1] == '}':
                    indent = indent[:-len(self.tab)]

        self._end_block(indent)

    def write_surfaces(self, tabNum, surfaces):
        """Helper function for write_surfaces_and_regions_reactor"""
        surfacesString = pyarcHelper.surfacesToString(tabNum, surfaces)
        self.out.write(surfacesString)


    def write_surfaces_and_regions_reactor(self, indent, model_entity):
        tabNum = len(indent) / len(self.tab)
        model = smtk.model.Model(model_entity)

        self.surfaces = dict()
        # (name, orientation, normal, pitch)
        self.surfaces["hexagon"] = set()
        # (name, axis, radius)
        self.surfaces["cylinder"] = set()
        # (name, z)
        self.surfaces["plane"] = set()

        coreString = ""
        for group in model.groups():
            rggType = group.stringProperty('rggType')[0]
            if (rggType == '_rgg_core'):
                core = pyarcHelper.Core(group)
                coreString = core.exportCore(tabNum, self.surfaces)
                break
        self.write_surfaces(tabNum, self.surfaces)
        self.out.write(coreString)

    def write_mcc3(self, indent, model_entity, mcc3):
        self._begin_block('mcc3', indent)
        indent = indent + self.tab

        self._write_item(mcc3, 'force_mixture_calc', indent)
        self._write_item(mcc3, 'xslib', indent, '', True)
        self._write_item(mcc3, 'egroupname', indent)
        self._write_item(mcc3, 'egroupvals', indent)
        self._write_item(mcc3, 'scattering_order', indent)
        self._write_item(mcc3, 'inelastic_treatment', indent)
        self._write_item(mcc3, 'lumped_element_text_file', indent, '', True)
        # FIXME: Write all cells for now
        model = smtk.model.Model(model_entity)
        for group in model.groups():
            rggType = group.stringProperty('rggType')[0]
            if (rggType == '_rgg_core'):
                core = pyarcHelper.Core(group)
                subAssyNames = core.getAllSubAssyNames()
                cellAssyNames = list(string.ascii_lowercase)
                for i in xrange(len(subAssyNames)):
                    self.out.write('{}cell( {} ){{\n    {}associated_sub_assembly ='
                                   ' {} \n}}\n'.format(indent, cellAssyNames[i%26],\
                                                       indent, subAssyNames[i]))

        rzmflx_code_options = mcc3.findGroup('rzmflx_code_options')
        # TODO: cylinders should not represent pin geometry, but rather
        #       cylindrical boundaries in the reactor. They are not used to define
        #       the reactor geometry; they are used in conjunction with planes to
        #       define a 2-dimensional lattice that is in turn used to compute
        #       neutron cross-sections.
        #
        #       From "ARC Integration into the NEAMS Workbench" (ANL/NE-17/31
        #       9/30/2017) pp. 28:
        #       How to get an equivalent RZ geometry? calculate the number of
        #       assemblies represented by each mcc3_id. Calculate the area for each
        #       region: the area of an hexagon is equal to
        #       $A=\frac{\sqrt{3}}{2}p^{2}$ with $p$ the pitch of the assembly. Find
        #       the "equivalent radius" by solving for r in $A=\pi r^{2}$.
        TODO_fix_cylinders = False
        if rzmflx_code_options and TODO_fix_cylinders:
            self._begin_block('rzmflx_code_options', indent)
            indent = indent + self.tab

            self._write_item(rzmflx_code_options, 'code', indent)
            self._write_item(rzmflx_code_options, 'finegroup_egroupname', indent, 'egroupname')

            cylinders = sorted(list(self.surfaces["cylinder"]), key = lambda item: item[1])
            value = '[' + ' '.join([cylinder[0] for cylinder in cylinders]) + ']'
            self.out.write('{}R_boundaries = {}\n'.format(indent, value))

            self._write_item(rzmflx_code_options, 'R_nodes_distance', indent)

            planes = sorted(list(self.surfaces["plane"]), key = lambda item: item[1])
            value = '[' + ' '.join([plane[0] for plane in planes]) + ']'
            self.out.write('{}Z_boundaries = {}\n'.format(indent, value))

            self._write_item(rzmflx_code_options, 'Z_nodes_distance', indent)
            self._write_item(rzmflx_code_options, 'SN_angular_order', indent)
            self._write_item(rzmflx_code_options, 'core_2d_geometry', indent, '', False, True)

            indent = indent[:-len(self.tab)]
            self._end_block(indent)

        indent = indent[:-len(self.tab)]
        self._end_block(indent)

    def write_dif3d(self, indent, dif3d):
        self._begin_block('dif3d', indent)
        indent = indent + self.tab

        self._write_item(dif3d, 'power', indent)
        self._write_item(dif3d, 'geometry_type', indent)
        if dif3d.find('Cross Sections').value(0) == 'previous':
            self._write_item(dif3d, 'Cross Sections', indent, 'isotxs')
        else:
            self._write_item(dif3d, 'isotxs', indent)
        self._write_item(dif3d, 'run_dif3d', indent)
        self._write_item(dif3d, 'max_axial_mesh_size', indent)
        if dif3d.find('dif_options').value(0) == 'variant_options':
            self._begin_block('variant_options', indent)
            indent = indent + self.tab

            self._write_item(dif3d, 'polynomial_approx_source', indent)
            self._write_item(dif3d, 'polynomial_approx_fluxes', indent)
            self._write_item(dif3d, 'polynomial_approx_leakages', indent)
            self._write_item(dif3d, 'angular_approx', indent)
            self._write_item(dif3d, 'anisotropic_scattering_approx', indent)
            self._write_item(dif3d, 'omega_acceleration', indent)

            indent = indent[:-len(self.tab)]
            self._end_block(indent)
        elif dif3d.find('dif_options').value(0) == 'dif_fd_options':
            self._begin_block('dif_fd_options', indent)
            indent = indent + self.tab

            self._write_item(dif3d, 'hex_triangular_subdivision', indent)

            indent = indent[:-len(self.tab)]
            self._end_block(indent)
        elif dif3d.find('dif_options').value(0) == 'dif_nod_options':
            self._begin_block('dif_nod_options', indent)
            indent = indent + self.tab

            self._write_item(dif3d, 'course_mesh_rebalance', indent)

            indent = indent[:-len(self.tab)]
            self._end_block(indent)

        rebus = dif3d.findGroup('rebus')
        self._begin_block('rebus', indent)
        indent = indent + self.tab

        self._write_item(rebus, 'cycle_length', indent)
        self._write_item(rebus, 'shutdown_time_between_cycle', indent)
        self._write_item(rebus, 'num_cycles', indent)
        self._write_item(rebus, 'num_subintervals', indent)

        decay_chain = rebus.find('decay_chain')

        self._begin_block('decay_chain', indent)
        indent = indent + self.tab

        self._write_item(decay_chain, 'list_isotopes', indent, '', False, True)
        self._write_item(decay_chain, 'list_lumped_elements', indent, '', False, True)
#        self._write_item(decay_chain, 'list_dummy_elements', indent)
        self._write_item(decay_chain, 'decay_chain_text_file', indent, 'text_file', True)

        indent = indent[:-len(self.tab)]
        self._end_block(indent)

        indent = indent[:-len(self.tab)]
        self._end_block(indent)

        indent = indent[:-len(self.tab)]
        self._end_block(indent)

    def check_if_valid(self, output_path):
        try:
            import PyArc

            try:
                import PyArc.PyARCModel
                import os
                import subprocess
                abs_path = os.path.abspath(os.path.dirname(PyArc.PyARCModel.__file__))
                sonvalidxml = abs_path+"/bin/sonvalidxml"
                schema = abs_path+"/schema/arc.sch"
                input_file = output_path
                cmd = sonvalidxml + ' ' + schema + ' ' + input_file
                try:
                    xmlresult = subprocess.check_output(cmd, shell=True)
                except subprocess.CalledProcessError as e:
                    if self.logger:
                        self.logger.addError(str(e))
                        self.logger.addError('\n' + e.output)
                    return False

                from PyArc.wasppy import xml2obj
                document = xml2obj.xml2obj(xmlresult).arc

                if self.logger:
                    msg = 'Output file passed PyARC validity check'
                    smtk.InfoMessage(self.logger, msg)
                return True
            except Exception as e:
                if self.logger:
                    msg = 'Output file failed PyARC validity check'
                    smtk.ErrorMessage(self.logger, msg)
                    smtk.ErrorMessage(self.logger, str(e))
                return False
        except ImportError as e:
            if self.logger:
                msg = 'Cannot check against PyARC schema for validity'
                smtk.WarningMessage(self.logger, msg)
                smtk.WarningMessage(self.logger, str(e))
            return True
