# -*- coding: utf-8 -*-
"""
Abaqus Extraction Script
1) Strict Component Validation
Raises a ValueError if a field contains elements with different numbers 
of components within the same part (e.g., mixed Truss/Shell for Stress 'S').
2) Automatic Section Point Detection
The script automatically identifies all available section points for a field 
(e.g., 'Top', 'Bottom', 'Point 1') and writes them as individual fields.
"""

import subprocess
import pickle
import os
import sys

# ================= USER CONFIG =================
ODB_PATH = "blocks_rigid_gcont.odb"#"selfcontact_gask.odb"#"stent-C3D8R-2step.odb"#"tennis_surfcav.odb"#"snapbuckling_b32h_deep.odb"#
OUTPUT_HT5 = "blocks_results.ht5"

STEPS_SELECTION = {
    #"StaticStep": "all",
    #"Recoil": [-1],
    "Step-1": "all",
}

FIELDS = ["U", "U_MAGNITUDE", "S_VON_MISES", "S_MAX_PRINC", "LE_MAX_PRINC"]
PARTS = "all"
# ==============================================


# ================= CONFIG =================
TEMP_PICKLE = "temp_data.pkl"

derived_fields_map = {
    "LE_MAX_PRINC": ("LE", lambda val: val.maxPrincipal),
    "LE_MID_PRINC": ("LE", lambda val: val.midPrincipal),
    "LE_MIN_PRINC": ("LE", lambda val: val.minPrincipal),
    "S_MAX_PRINC": ("S", lambda val: val.maxPrincipal),
    "S_MID_PRINC": ("S", lambda val: val.midPrincipal),
    "S_MIN_PRINC": ("S", lambda val: val.minPrincipal),
    "S_VON_MISES": ("S", lambda val: val.mises),
    "U_MAGNITUDE": ("U", lambda val: val.magnitude),
}
# ==========================================


# ============================================================
# ================= ABAQUS EXTRACTION PART ===================
# ============================================================

def abaqus_main():

    from odbAccess import openOdb
    from abaqusConstants import NODAL, ELEMENT_NODAL, ELEMENT

    def build_global_frames(odb, steps_selection):
        global_frames = []
        times = {"values": [], "names": []}
        total_time_accumulator = 0.0

        for step_name, frame_selection in steps_selection.items():

            if step_name not in odb.steps:
                raise ValueError("Step '{}' not found in ODB.".format(step_name))

            step = odb.steps[step_name]
            num_frames = len(step.frames)

            if num_frames == 0:
                raise ValueError("Step '{}' contains no frames.".format(step_name))

            if frame_selection != "all" and not isinstance(frame_selection, (list, tuple)):
                raise TypeError(
                    "Frame selection for step '{}' must be 'all' or a list/tuple of indices, got {}."
                    .format(step_name, type(frame_selection).__name__)
                )

            # Frame selection
            if frame_selection == "all":
                frames_to_process = list(range(num_frames))
            else:
                frames_to_process = [
                    num_frames - 1 if idx == -1 else idx for idx in frame_selection
                ]

            for frame_idx in frames_to_process:

                if frame_idx < 0 or frame_idx >= num_frames:
                    raise IndexError(
                        "Frame index {} out of range for step '{}' (available: 0 to {})."
                        .format(frame_idx, step_name, num_frames - 1)
                    )

                frame = step.frames[frame_idx]
                global_time = total_time_accumulator + float(frame.frameValue)

                global_frames.append((step_name, frame))
                times["values"].append(global_time)
                times["names"].append("{}-frame{}".format(step_name, frame_idx))

            total_time_accumulator += step.timePeriod

        return global_frames, times

    def main():
        odb = openOdb(ODB_PATH)
        global_frames, times = build_global_frames(odb, STEPS_SELECTION)

        data_tree = {
            "metadata": {"solver": "Abaqus", "ht5_version": "1.0", "times": times},
            "parts": {}
        }

        instances = odb.rootAssembly.instances.values()
        if PARTS != "all":
            instances = [i for i in instances if i.name in PARTS]

        for inst in instances:
            p_name = inst.name
            part_dict = {"nodes": {"points": []}, "connectivity": {}, "fields": {}}
            node_map = {node.label: i for i, node in enumerate(inst.nodes)}
            part_dict["nodes"]["points"] = [node.coordinates for node in inst.nodes]

            for elem in inst.elements:
                e_type = elem.type.lower()
                if e_type not in part_dict["connectivity"]:
                    part_dict["connectivity"][e_type] = {"connectivity": []}
                part_dict["connectivity"][e_type]["connectivity"].append([node_map[lbl] for lbl in elem.connectivity])

            num_nodes = len(inst.nodes)
            NAN_VAL = float('nan')

            for f_name in FIELDS:
                base_name, func = derived_fields_map.get(f_name, (f_name, None))
                has_section_points = base_name in ["S", "LE"]

                # --- AUTOMATIC SECTION POINT DETECTION ---
                sp_list = [None]
                if has_section_points and global_frames:
                    detected_sps = set()
                    for _, test_frame in global_frames:
                        if base_name in test_frame.fieldOutputs:
                            test_output = test_frame.fieldOutputs[base_name].getSubset(region=inst)
                            for v in test_output.values:
                                if getattr(v, "sectionPoint", None) is not None:
                                    detected_sps.add(v.sectionPoint.description)
                        if detected_sps: break
                    if detected_sps:
                        sp_list = sorted(list(detected_sps))

                for sp in sp_list:
                    out_field_name = "{}_{}".format(f_name, sp) if sp else f_name
                    field_storage = {"values": [], "location": "node", "components": []}
                    data_found = False

                    for step_name, frame in global_frames:
                        if base_name not in frame.fieldOutputs:
                            field_storage["values"].append([[NAN_VAL]] * num_nodes)
                            continue
                        
                        f_output = frame.fieldOutputs[base_name]
                        
                        if not field_storage["components"]:
                            if func:
                                field_storage["components"] = [out_field_name]
                            else:
                                field_storage["components"] = list(f_output.componentLabels) or [out_field_name]
                        
                        num_comp = len(field_storage["components"])
                        sum_array = [[0.0] * num_comp for _ in range(num_nodes)]
                        count_array = [0] * num_nodes

                        inst_f_output = None
                        try:
                            pos = f_output.locations[0].position
                            if pos == NODAL:
                                inst_f_output = f_output.getSubset(region=inst)
                            else:
                                inst_f_output = f_output.getSubset(region=inst, position=ELEMENT_NODAL)
                        except: pass

                        if not inst_f_output or not inst_f_output.values:
                            try:
                                inst_f_output = f_output.getSubset(region=inst, position=ELEMENT)
                            except: inst_f_output = None

                        if not inst_f_output or not inst_f_output.values:
                            field_storage["values"].append([[NAN_VAL] * num_comp for _ in range(num_nodes)])
                            continue

                        for val in inst_f_output.values:
                            val_sp = getattr(val, "sectionPoint", None)
                            
                            if sp:
                                if val_sp:
                                    # Normal case: filter by section point
                                    if val_sp.description != sp:
                                        continue
                                else:
                                    # NEW: fallback → no section point available
                                    # → accept value (will be duplicated across SPs)
                                    pass
                            
                            v_data = func(val) if func else val.data
                            if not hasattr(v_data, '__len__'): v_data = [v_data]

                            current_len = len(v_data)
                            if current_len != num_comp:
                                msg = ("Component mismatch in Part '{0}', Field '{1}'. "
                                       "Extraction cannot continue for mixed-component tensors.")
                                raise ValueError(msg.format(p_name, out_field_name))

                            node_lbl = getattr(val, "nodeLabel", None)
                            nodes_to_process = []
                            if node_lbl:
                                n_idx = node_map.get(node_lbl)
                                if n_idx is not None: nodes_to_process.append(n_idx)
                            else:
                                elem = inst.elements.getFromLabel(val.elementLabel)
                                for n_lbl in elem.connectivity:
                                    n_idx = node_map.get(n_lbl)
                                    if n_idx is not None: nodes_to_process.append(n_idx)

                            for n_idx in nodes_to_process:
                                for c in range(num_comp):
                                    sum_array[n_idx][c] += v_data[c]
                                count_array[n_idx] += 1

                        if any(count_array):
                            data_found = True
                            field_storage["values"].append([
                                [(sum_array[i][c]/count_array[i]) if count_array[i]>0 else NAN_VAL for c in range(num_comp)]
                                for i in range(num_nodes)
                            ])
                        else:
                            field_storage["values"].append([[NAN_VAL] * num_comp for _ in range(num_nodes)])

                    if data_found:
                        part_dict["fields"][out_field_name] = field_storage

            data_tree["parts"][p_name] = part_dict

        with open(TEMP_PICKLE, 'wb') as f:
            pickle.dump(data_tree, f, protocol=2)
        odb.close()

    main()


# ============================================================
# ================= HT5 CONVERSION PART ======================
# ============================================================

def generate_ht5():
    import h5py
    import numpy as np

    if not os.path.exists(TEMP_PICKLE):
        print("!!! Error: Pickle file not found.")
        return

    with open(TEMP_PICKLE, 'rb') as f:
        data = pickle.load(f, encoding='latin1')

    with h5py.File(OUTPUT_HT5, 'w') as h5:
        meta = h5.create_group("metadata")
        meta.create_dataset("solver", data=data['metadata']['solver'])
        meta.create_dataset("ht5_version", data=data['metadata']['ht5_version'])

        times_grp = meta.create_group("times")
        times_grp.create_dataset("values", data=np.array(data['metadata']['times']['values'], dtype='f4'))
        times_grp.create_dataset("names", data=np.array(data['metadata']['times']['names'], dtype=h5py.string_dtype(encoding='utf-8')))

        parts_root = h5.create_group("parts")
        
        for p_name, p_content in data['parts'].items():
            part_grp = parts_root.create_group(p_name)
            
            nodes_grp = part_grp.create_group("nodes")
            nodes_grp.create_dataset("points", data=np.array(p_content['nodes']['points'], dtype='f4'))

            conn_root = part_grp.create_group("connectivity")
            for e_type, e_data in p_content['connectivity'].items():
                etype_grp = conn_root.create_group(e_type)
                etype_grp.create_dataset("connectivity", data=np.array(e_data['connectivity'], dtype='i4'))

            fields_grp = part_grp.create_group("fields")
            for f_name, f_data in p_content['fields'].items():
                fg = fields_grp.create_group(f_name)
                fg.attrs["location"] = f_data['location']
                
                val_array = np.array(f_data['values'], dtype='f4')
                fg.create_dataset("values", data=val_array, compression="gzip", compression_opts=4)
                fg.create_dataset("components", data=np.array(f_data['components'], dtype='S'))

    print(f">>> Success! {OUTPUT_HT5} is ready.")


# ============================================================
# ================= CONTROLLER (ONLY NEW PART) ===============
# ============================================================

def run_abaqus_extraction():
    print(">>> Launching Abaqus Python extraction...")
    try:
        subprocess.check_call(['abaqus', 'python', __file__, '--extract'], shell=True)
        print(">>> Extraction successful.")
    except subprocess.CalledProcessError as e:
        print(f"!!! Error: Abaqus failed to run. {e}")
        return False
    except FileNotFoundError:
        print("!!! Error: 'abaqus' command not found. Is it in your System PATH?")
        return False
    return True


# ============================================================
# ================= ENTRY POINT ==============================
# ============================================================

if __name__ == "__main__":
    if "--extract" in sys.argv:
        abaqus_main()
    else:
        if run_abaqus_extraction():
            generate_ht5()