#=============================================================================
#
#  Copyright (c) Kitware, Inc.
#  All rights reserved.
#  See LICENSE.txt for details.
#
#  This software is distributed WITHOUT ANY WARRANTY; without even
#  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
#  PURPOSE.  See the above copyright notice for more information.
#
#=============================================================================
import os
print 'loading', os.path.basename(__file__)

import smtk
if 'pybind11' == smtk.wrappingProtocol():
  #print 'Using pybind11 bindings'
  import smtk.attribute

from cardformat import CardFormat
from conditionset import ConditionSet

# ---------------------------------------------------------------------
class Writer2D:
  '''Top level writer class for IBAMR input files (2D)
  '''

# ---------------------------------------------------------------------
  def __init__(self, export_spec):
    '''
    '''
    # Member data
    self.component_sequence = None
    self.export_spec_att = None
    self.format_table = None
    self.sim_atts = None

    # Initialize
    self.sim_atts = export_spec.getSimulationAttributes()
    print 'sim_atts', self.sim_atts

    # Get export attribute
    export_atts = export_spec.getExportAttributes()
    if not export_atts:
      raise Exception('ERROR: Missing Export Attribute Manager')

    att_list = export_atts.findAttributes('ExportSpec')
    if not att_list:
      raise Exception('ERROR: Missing ExportSpec attribute')

    self.export_spec_att = att_list[0]

# ---------------------------------------------------------------------
  def write(self, component_sequence, format_table):
    '''
    '''
    ConditionSet.clear()
    self.component_sequence = component_sequence
    self.format_table = format_table

    # Get output filename/path
    filename_item = self.export_spec_att.findFile('OutputFile')
    if not filename_item:
      raise Exception('ERROR: ExportSpec attribute missing OutputFile item')

    output_filename = filename_item.value(0)
    if not output_filename:
      output_filename = 'output.ibamr'
      msg = 'No output file specified; using \"%s\"' % output_filename
      print 'WARNING:', msg
      logger.addWarning(msg)
    print 'output filename', output_filename

    completed = False
    with open(output_filename, 'w') as out:
      out.write('// Generated by CMB\n')
      self.out = out

      for component in self.component_sequence:
        # Get format list
        format_list = self.format_table.get(component.format_list_name)
        if format_list is None:
          print 'WARNING: Missing format list for component', component.name,
          if component.format_list_name != component.name:
            print ' (format list name:', component.format_list_name, ')',
          print
          continue

        # Set component condition (if any)
        if component.set_condition:
          ConditionSet.set_condition(component.set_condition)

        # Components can assign custom method
        if component.custom_component_method is not None:
          if not hasattr(self, component.custom_component_method):
            print 'ERROR: For component', component.name, \
              ', custom_method', component.custom_component_method, \
              'not found'
          else:
            method = getattr(self, component.custom_component_method)
            method(out, component, format_list)
          continue

        # Else use the default component writer
        else:
          self.write_component(out, component, format_list)

        # Unset any component condition
        if component.set_condition:
          ConditionSet.unset_condition(component.set_condition)

      completed = True
      print 'Wrote output file %s' % output_filename
    return completed

# ---------------------------------------------------------------------
  def write_component(self, out, component, format_list):
    '''
    '''
    print 'Writing component', component.name

    # If namelist specifies attribute, process each one
    if component.att_name is not None:
      att = self.sim_atts.findAttribute(component.att_name)
      if att:
        self.write_att(out, att, component, format_list)
    elif component.att_type is not None:
      att_list = self.sim_atts.findAttributes(component.att_type)
      #print 'att_type', component.att_type, 'att_list', att_list
      att_list.sort(key=lambda att: att.name())
      for att in att_list:
        self.write_att(out, att, component, format_list)
      return

    # Otherwise write single component
    else:
      self.begin_component(out, component)
      for card in format_list:
        att_list = self.sim_atts.findAttributes(card.att_type)
        for att in att_list:
          card.write(out, att, tab=component.tab)
      self.end_component(out)

# ---------------------------------------------------------------------
  def write_att(self, out, att, component, format_list):
    '''Writes component for 1 attribute
    '''
    self.begin_component(out, component)
    for card in format_list:
      self.write_card(out, att, component, card)
    self.end_component(out)

# ---------------------------------------------------------------------
  def write_card(self, out, att, component, card):
    '''
    '''
    tab = component.tab
    base_path = component.base_item_path
    if card.att_type is None:
      card.write(out, att, base_item_path=base_path, tab=tab)
    else:
      card_att_list = self.sim_atts.findAttributes(card.att_type)
      for card_att in card_att_list:
        card.write(out, card_att, base_item_path=base_path, tab=tab)

# ---------------------------------------------------------------------
  def get_value(self, card, indx=0):
    '''
    Returns value from card item by extracting attribute type and path.

    Required argument:
        card: (object) CardFormat object with att_type != None

    Optional argument:
        indx: (int) Index of value in item if item has multiple values
        Default is 0.
    '''

    att_list = self.sim_atts.findAttributes(card.att_type)
    if not att_list:
      print 'ERROR: Missing attribute type', card.att_type
      return None

    att = att_list[0]

    item = att.itemAtPath(card.item_path, '/')
    if not item:
      print 'ERROR: no value found for %s/%s' % (card.att_type, card.item_path)
      return None

    concrete_item = smtk.attribute.to_concrete(item)
    return concrete_item.value(indx)

# ---------------------------------------------------------------------
  def get_att(self, component):
    '''Get attribute for input component

    '''
    att_list = self.sim_atts.findAttributes(component.att_type)
    if not att_list:
      print 'ERROR: Missing', component.att_type, 'attribute'
      return

    return att_list[0]

# ---------------------------------------------------------------------
  def write_main(self, out, component, format_list):
    '''Custom method for writing Main component
    '''
    print 'Writing component', component.name
    att = self.get_att(component)
    tab = component.tab
    self.begin_component(out, component)

    for card in format_list:
      if 'viz_writer' == card.keyword:
        # See which, if any, writers are enabled
        enabled_list = list()
        item = att.itemAtPath(card.item_path, '/')
        viz_item = smtk.attribute.to_concrete(item)
        apps = {'visit': 'VisIt', 'exodus': 'ExodusII', 'silo': 'Silo'}
        for name,label in apps.items():
          item = viz_item.find(name)
          #print 'item', name, item.isEnabled()
          if item.isEnabled():
            enabled_list.append(label)
        #print 'enabled_list:', enabled_list
        if not enabled_list:
          enabled_list.append('')  # so that there is *something*

        # (else)
        if card.comment:
          card.write_comment(out, card.comment)
        if card.set_condition:
          ConditionSet.set_condition(card.set_condition)
        string_list = ['\"%s\"'%x for x in enabled_list]
        string_value = ','.join(string_list)
        card.write_value(out, string_value, quote_string=False)
      else:
        self.write_card(out, att, component, card)

    self.end_component(out)

# ---------------------------------------------------------------------
  def write_bc_coefs(self, out, component, format_list):
    '''Custom method for writing velocity BC coefficients
    '''
    if not component.att_name:
      print 'ERROR: Missing att_name for component', component.name
      return

    att = self.sim_atts.findAttribute(component.att_name)
    if not att:
      print 'ERROR: Missing attribute with name', component.att_name
      return

    # Check the enabled group item
    enable_item = att.findGroup('enable')
    if not enable_item.isEnabled():
      return

    print 'Writing component', component.name

    # Initialize CardFormat for temp use
    card = CardFormat('temp')
    tab = component.tab

    self.begin_component(out, component)
    for item_name in ['a', 'b', 'g']:
      item = enable_item.find(item_name)
      coef_item = smtk.attribute.to_concrete(item)
      if item_name != 'a':
        out.write('\n')
      coef_num = coef_item.numberOfValues()
      for i in range(coef_num):
        keyword = '%scoef_function_%d' % (item_name, i)
        value = coef_item.value(i)
        value_string = '\"%s\"' % value
        card.write_value(
          out, value_string, keyword=keyword, quote_string=False, tab=tab)

    self.end_component(out)

# ---------------------------------------------------------------------
  def write_geometry(self, out, component, format_list):
    '''Custom method for writing CartesianGeometry
    '''
    print 'Writing component', component.name

    att = self.get_att(component)
    tab = component.tab
    self.begin_component(out, component)

    for card in format_list:
      if 'domain_boxes' == card.keyword:
        # Get the grid attribute & base-grid-size item
        upper_x = self.get_value(card, indx=0) - 1
        upper_y = self.get_value(card, indx=1) - 1
        value = '[ (0,0), (%d,%d) ]' % (upper_x, upper_y)
        card.write_value(out, value, quote_string=False, tab=tab)

      elif card.is_custom:
        print 'TODO', card.keyword
      else:
        card.write(out, att, tab=tab)

    self.end_component(out)

# ---------------------------------------------------------------------
  def write_toplevel(self, out, component, format_list):
    '''Custom writer for Top Level component
    '''
    print 'Writing component', component.name

    out.write('\n')
    out.write('// %s' % component.name)
    out.write('\n')

    att = self.get_att(component)
    tab = component.tab
    order_dict = {1:'FIRST', 2:'SECOND', 3:'THIRD', 4:'FOURTH', 5:'FIFTH', 6:'SIXTH', 7:'SEVENTH'}
    for card in format_list:
      if card.keyword == 'N':
        N = self.get_value(card)
        card.write_value(out, N, tab=tab)

      elif card.keyword == 'L':
        L = self.get_value(card)
        card.write_value(out, L, tab=tab)

      elif card.keyword == 'MAX_LEVELS':
        max_levels = self.get_value(card)
        card.write_value(out, max_levels, tab=tab)

      elif card.keyword == 'REF_RATIO':
        card_att = self.get_att(card)
        type_item = card_att.findString(card.item_path)
        data_type = type_item.value(0)
        item_path = '%s/%s' % (type_item.name(), data_type)
        item = card_att.itemAtPath(item_path, '/')
        data_item = smtk.attribute.to_concrete(item)
        if 'fixed' == data_type:
          # data_item is single double value
          ref_ratio = data_item.value(0)

        elif 'table' == data_type:
          # data_item is Group with rows of double[2] items
          num_rows = data_item.numberOfGroups()
          vals = []
          for i in range(num_rows):
            item = data_item.find(i, 'row')
            row_item = smtk.attribute.to_concrete(item)
            vals.append(row_item.value(0))
            vals.append(row_item.value(1))
          ref_ratio = max(vals)

        card.write_value(out, ref_ratio, tab=tab)

      elif card.keyword == 'DX0':
        value = L/N
        card.write_value(out, value, tab=tab)

      elif card.keyword == 'NFINEST':
        nfinest =(ref_ratio**(max_levels - 1))*N
        card.write_value(out, nfinest, tab=tab)

      elif card.keyword == 'DX':
        dx = L/nfinest
        card.write_value(out, dx, tab=tab)

      elif 'PK1' in card.keyword.split('_'):
        value = order_dict[self.get_value(card)]
        card.write_value(out, value, tab=tab)

      else:
        self.write_card(out, att, component, card)

# ---------------------------------------------------------------------
  def write_grid(self, out, component, format_list):
    '''Custom writer for GriddingAlgorithm component
    '''
    print 'Writing component', component.name
    att = self.get_att(component)
    tab = component.tab
    self.begin_component(out, component)

    for card in format_list:
      if card.is_custom:
        # All custom grids cards are for "table" subcomponents
        self.write_table(out, card, att)
      else:
        card.write(out, att, tab=tab)

    self.end_component(out)

# ---------------------------------------------------------------------
  def write_table(self, out, card, att):
    '''Writes subcomponent with table of value pairs by level

    This presumes a specific attribute format, used by several items:
    refinement-ratio, largest-patch-size, smallest-patch-size
    '''

    # Begin subcomponent
    indent = '  '
    out.write('%s%s {\n' % (indent, card.keyword))

    level = 1 if card.keyword == 'ratio_to_coarser' else 0
    type_item = att.findString(card.item_path)
    if not type_item:
      print 'ERROR: Missing string item of type', card.item_path
      return
    data_type = type_item.value(0)
    # Next line fails to_concrete() call for unknown reason
    #item = type_item.findChild(data_type, smtk.attribute.ACTIVE_CHILDREN)
    # So instead construct itemAtPath
    item_path = '%s/%s' % (type_item.name(), data_type)
    item = att.itemAtPath(item_path, '/')
    data_item = smtk.attribute.to_concrete(item)
    if 'fixed' == data_type:
      # data_item is single double value
      fixed_val = data_item.value(0)
      prefix = '%s%slevel_%d =' % (indent, indent, level)
      out.write('%s %s,%s\n' % (prefix, fixed_val, fixed_val))
    elif 'table' == data_type:
      # data_item is Group with rows of double[2] items
      num_rows = data_item.numberOfGroups()
      for i in range(num_rows):
        item = data_item.find(i, 'row')
        row_item = smtk.attribute.to_concrete(item)
        val0 = row_item.value(0)
        val1 = row_item.value(1)
        prefix = '%s%slevel_%d =' % (indent, indent, level)
        out.write('%s %s,%s\n' % (prefix, val0, val1))
        level += 1

    # End subcomponent
    out.write('%s}\n' % indent)

# ---------------------------------------------------------------------
  def begin_component(self, out, component):
    out.write('\n')
    out.write('%s {' % component.name)
    out.write('\n')

# ---------------------------------------------------------------------
  def end_component(self, out, indent=''):

    out.write('%s}\n' % indent)
