"""
Mixin class for legacy materials
"""

FLUID_PHASE = 'fluid-phase'
VOID_MATERIAL = 'void'

from cardformat import CardFormat

class LegacyMaterialWriter:
# ---------------------------------------------------------------------
  def _write_legacy_materials(self, namelist, format_list):
    '''Writes interleaved MATERIAL, PHASE, MATERIAL_SET namelists

    '''
    print('Writing legacy MATERIAL/PHASE/MATERIAL_SYSTEM namelists')

    # First sort materials by single- vs multiple phase. Create 2 sets:
    #  single_phase_list: [phase attribute not appearing in any transition]
    #  first_transition_list: [lowest transition in multiphase material]

    # To sort materials/phases, first create these intermediate objects:
    #  lower_dict: <phase (id), transition> for phases in lower side of transition
    #  upper_dict: <phase (id), transition> for phases in upper side of transition
    # Note that we use ids for keys because python wrappers might create
    # multiple python objects for same c++ object

    def check_phase(phase_id, transition_dict, label):
      """Checks that phase not already used in transition dictionary"""
      if phase_id in transition_dict:
        phase_att = self.sim_atts.findAttribute(phase_id)
        tpl = 'ERROR: Phase attribute \"{}\" twice used as {} part of phase transitions'
        msg = tpl.format(phase_att.name(), label)
        print(msg)
        raise RuntimeError(msg)

    lower_dict = dict()  # <phrase attribute, transition attribute>
    upper_dict = dict()  # <phrase attribute, transition attribute>
    trans_atts = self.sim_atts.findAttributes('phase-transition')
    for trans_att in trans_atts:
      if not trans_att.isValid():
        tpl = 'ERROR: invalid Phase Transition attribute \"{}\"'
        msg = tpl.format(trans_att.name())
        print(msg)
        raise RuntimeError(msg)

      # Get lower & upper components and their ids
      lower_att = trans_att.findComponent('lower').value()
      upper_att = trans_att.findComponent('upper').value()

      lower_id = lower_att.id()
      upper_id = upper_att.id()

      # Check for consistency
      if lower_id == upper_id:
        tpl = 'ERROR: same phase -- \"{}\" -- used for both sides of transition \"{}\"'
        msg = tpl.format(lower.name(), att.name())
        print(msg)
        raise RuntimeError(msg)

      lower_temp = trans_att.findDouble('lower-transition-temperature').value()
      upper_temp = trans_att.findDouble('upper-transition-temperature').value()
      if upper_temp <= lower_temp:
        tpl = 'ERROR: phase transition \"{}\" upper temp <= lower temp ({} <= {})'
        msg = tpl.format(trans_att.name(), upper_temp, lower_temp)
        print(msg)
        raise RuntimeError(msg)

      check_phase(lower_id, lower_dict, 'lower')
      check_phase(upper_id, upper_dict, 'upper')

      lower_dict[lower_id] = trans_att
      upper_dict[upper_id] = trans_att

    # Create list of "materials", each item either:
    #  a phase attribute representing a single-phase material
    #  a phase-transition attribute representing the "first" transition of multi-phase material
    material_list = list()
    phase_atts = self.sim_atts.findAttributes('phase')
    for phase_att in phase_atts:
      phase_id = phase_att.id()
      in_lower = phase_id in lower_dict
      in_upper = phase_id in upper_dict
      if not in_lower and not in_upper:
        material_list.append(phase_att)  # single-phase material
      elif in_lower and not in_upper:
        material_list.append(lower_dict[phase_id])  # multi-phase material (transition)
    # print('material_list: {}'.format(material_list))

    def set_conditions(att=None):
      """Sets CardFormat conditions for given phase attribute"""
      CardFormat.Conditions.discard(VOID_MATERIAL)
      CardFormat.Conditions.discard(FLUID_PHASE)
      if att is None:
        return
      if att.type() == 'phase.void':
        CardFormat.Conditions.add(VOID_MATERIAL)
      elif att.type() == 'phase.material':
        fluid_item = att.findGroup('fluid')
        if fluid_item and fluid_item.isEnabled():
          CardFormat.Conditions.add(FLUID_PHASE)

    # Traverse material list
    # Also keep track of transitions, to make sure *all* get used
    trans_id_list = [t.id() for t in trans_atts]
    trans_id_set = set(trans_id_list)
    for att in material_list:
      if att.type() == 'phase.void':
        if not self.skip_void_material:
          set_conditions(att)
          self._write_legacy_material_namelist(att)
        continue

      if att.type() in ['phase.material', 'phase.void']:
        # Write single phase material
        set_conditions(att)
        self._write_legacy_phase_namelist(att)
        self._write_legacy_material_namelist(att)
        self._write_legacy_material_system_namelist(att)
      elif att.type() == 'phase-transition':
        trans_att = att
        # Write multiple phase material, starting with lowest-temp phase
        lower_att = trans_att.findComponent('lower').value()
        set_conditions(lower_att)
        self._write_legacy_phase_namelist(lower_att)
        self._write_legacy_material_namelist(lower_att)
        while trans_att is not None:
          trans_id_set.discard(trans_att.id())
          upper_att = trans_att.findComponent('upper').value()
          set_conditions(upper_att)
          self._write_legacy_phase_namelist(upper_att)
          self._write_legacy_material_namelist(upper_att)

          # Get next transition
          trans_att = lower_dict.get(upper_att.id())
        self._write_legacy_material_system_namelist(att, lower_dict)

        set_conditions(None)  # resets material conditions
      else:
        raise RuntimeError('ERROR - unexpected attribute type ', att.type())

    # Any transaction atts not removed above indicate error(s)
    if trans_id_set:
      unused_atts = [self.sim_atts.findAttribute(id) for id in trans_id_set]
      unused_att_names = [att.name() for att in unused_atts]
      tpl = 'ERROR - inconsistent transitions, probably circular: {}'
      msg = tpl.format(unused_att_names)
      raise RuntimeError(msg)


# --------------------------------------------------------------------
  def _write_legacy_material_namelist(self, phase_att):
    '''Write MATERIAL namelist for phase attribute

    This method also assigns material number
    '''
    title = 'MATERIAL'
    print('Writing namelist {} for attribute {}'.format(title, phase_att.name()))
    self._start_namelist(title)
    format_list = self.format_table.get(title)

    # Assign next material number
    number = len(self.material_number_dict)
    self.material_number_dict[phase_att.id()] = number

    CardFormat.write_value(self.out, 'material_name', phase_att.name())
    CardFormat.write_value(self.out, 'material_number', number)

    if phase_att.type() == 'phase.material':
      immobile = FLUID_PHASE not in CardFormat.Conditions
      CardFormat.write_value(self.out, 'immobile', immobile, as_boolean=True)

    for card in format_list:
      card.write(self.out, phase_att)

    # Write material density (1 for material, 0 for void)
    density = 0.0 if VOID_MATERIAL in CardFormat.Conditions else 1.0
    CardFormat.write_value(self.out, 'density', density)

    # If background material not assigned, use this material
    if self.background_material_id is None:
      self.background_material_id = phase_att.id()

    # Check if this material is the background material
    if phase_att.id() == self.background_material_id:
      CardFormat.write_value(self.out, 'material_feature','background')

    self._finish_namelist()

# ---------------------------------------------------------------------
  def _write_legacy_phase_namelist(self, phase_att):
    '''Write PHASE namelist for phase attribute
    '''
    title = 'PHASE'
    print('Writing namelist {} for attribute {}'.format(title, phase_att.name()))
    self._start_namelist(title)

    name = phase_att.name()
    CardFormat.write_value(self.out, 'name', name[:31], tab=4)

    if not VOID_MATERIAL in CardFormat.Conditions:
      format_list = self.format_table.get(title)

      CardFormat.PropertyIndex = 0
      for card in format_list:
        card.write(self.out, phase_att)
    self._finish_namelist()

# ---------------------------------------------------------------------
  def _write_legacy_material_system_namelist(self, att, lower_transition_dict={}):
    '''

    att: either phase att for single-phase, or first transition att for multi-phase
    lower_transition_dict: <phase att id, transition att>
    '''
    title = 'MATERIAL_SYSTEM'
    print('Writing', title)
    self._start_namelist(title)

    # Simple case - att is single "phase" attribute
    if att.type() in ['phase.material', 'phase.void']:
      number = self.material_number_dict.get(att.id())
      name = 'material {}'.format(number)
      CardFormat.write_value(self.out, 'name', name, tab=6)
      CardFormat.write_value(self.out, 'phases', att.name(), tab=6)
      self._finish_namelist()
      return

    elif att.type() != 'phase-transition':
      raise RuntimeError('Unexpected attribute type \"{}\"'.format(att.type()))

    # Multiple phase case
    # Traverse transitions to build data lists
    phase_list = list()
    lower_temps = list()
    upper_temps = list()
    latent_heats = list()
    smoothing_radius = None
    numbers = list()  # material numbers

    trans_att = att
    lower_att = trans_att.findComponent('lower').value()
    lower_name = lower_att.name()
    number = self.material_number_dict.get(lower_att.id())
    numbers.append(number)
    phase_list.append('\"{}\"'.format(lower_name[:31]))
    while trans_att is not None:
      upper_att = trans_att.findComponent('upper').value()
      upper_name = upper_att.name()
      phase_list.append('\"{}\"'.format(upper_name[:31]))
      number = self.material_number_dict.get(upper_att.id())
      numbers.append(number)

      lower_temp_item = trans_att.findDouble('lower-transition-temperature')
      lower_temps.append(lower_temp_item.value())
      upper_temp_item = trans_att.findDouble('upper-transition-temperature')
      upper_temps.append(upper_temp_item.value())
      latent_heat = trans_att.findDouble('latent-heat')
      latent_heats.append(latent_heat.value())

      # Smoothing radius currently part of transition, but only a single value
      # applies to the material system. This code takes the smoothing radius
      # from the *first* transition that has its item enabled.
      if smoothing_radius is None:
        smoothing_radius_item = trans_att.findDouble('smoothing-radius')
        if smoothing_radius_item and smoothing_radius_item.isEnabled():
          smoothing_radius = smoothing_radius_item.value()

      # Get next transition
      trans_att = lower_transition_dict.get(upper_att.id())

    # Check that transition temps are monotonically increasing
    temp_lists = (lower_temps, upper_temps)
    temps = [val for pair in zip(*temp_lists) for val in pair]
    #print('temps', temps)
    prev_temp = temps[0]
    for i in range(1, len(temps)):
      temp = temps[i]
      if temp < prev_temp:
        tpl = 'ERROR: temperatures for material {} not monotonically increasing: {}'
        msg = tpl.format(lower_name, temps)
        print(msg)
        raise RuntimeError(msg)
      prev_temp = temp

    numbers.sort()
    number_string = '+'.join([str(n) for n in numbers])
    name = 'material {}'.format(number_string)
    CardFormat.write_value(self.out, 'name', name)
    CardFormat.write_value(self.out, 'phases', phase_list)
    CardFormat.write_value(self.out, 'transition_temps_low', lower_temps)
    CardFormat.write_value(self.out, 'transition_temps_high', upper_temps)
    CardFormat.write_value(self.out, 'latent_heat', latent_heats)
    if smoothing_radius is not None:
      CardFormat.write_value(self.out, 'smoothing_radius', smoothing_radius)

    self._finish_namelist()
