# This is an Abaqus Python script to extract data from an ODB file
# Run with: abaqus python odb_to_json_v2.py

import json

from abaqusConstants import ELEMENT_NODAL, NODAL
from odbAccess import openOdb

# ------------------------
# USER INPUT
# ------------------------
odb_path = "simulation.odb"
output_json = "simulation.json"
field_requests = ["U", "U_MAGNITUDE", "LE", "LE_MAX_PRINC", "S_VON_MISES"]

# Extrapolation mode for element-based fields (only one option for now)
extrapolation_mode = "element_nodal"

# Step + frame selection
# Use 'all' or a list per step, e.g ([5, -1]). -1 will include last frame.
steps_selection = {
    "Step-1": "all",
}

# ------------------------

derived_fields_map = {
    "LE_MAX_PRINC": ("LE", lambda val: val.maxPrincipal),
    "LE_MID_PRINC": ("LE", lambda val: val.midPrincipal),
    "LE_MIN_PRINC": ("LE", lambda val: val.minPrincipal),
    "S_MAX_PRINC": ("S", lambda val: val.maxPrincipal),
    "S_MID_PRINC": ("S", lambda val: val.midPrincipal),
    "S_MIN_PRINC": ("S", lambda val: val.minPrincipal),
    "S_VON_MISES": ("S", lambda val: val.mises),
    "U_MAGNITUDE": ("U", lambda val: val.magnitude),
}

# ------------------------
# HELPER FUNCTIONS
# ------------------------


def get_field_subset(field, inst, mode):
    native_pos = field.locations[0].position
    if native_pos == NODAL:
        return field.getSubset(region=inst)
    if mode == "element_nodal":
        return field.getSubset(region=inst, position=ELEMENT_NODAL)
    raise ValueError("Unknown extrapolation_mode: %s" % mode)


def build_field_layout(field, inst, mode):
    native_pos = field.locations[0].position
    sub = get_field_subset(field, inst, mode)

    if not sub.values:
        return None

    if native_pos == NODAL:
        unique_labels = sorted(list(set([v.nodeLabel for v in sub.values])))
        return {
            "location": "NODAL",
            "node_labels": unique_labels,
        }

    # ELEMENT-based fields
    location = "ELEMENT_NODAL"  # only option for now
    by_type = {}
    for v in sub.values:
        elem = inst.getElementFromLabel(v.elementLabel)
        etype = elem.type
        if etype not in by_type:
            by_type[etype] = {"element_labels": set()}
        by_type[etype]["element_labels"].add(v.elementLabel)

    for etype in by_type:
        by_type[etype]["element_labels"] = sorted(
            list(by_type[etype]["element_labels"])
        )

    return {
        "location": location,
        "by_element_type": by_type,
    }


# ------------------------
# MAIN LOGIC
# ------------------------

odb = openOdb(odb_path)

data = {
    "meta": {
        "job": odb_path,
        "extrapolation_mode": extrapolation_mode,
        "fields_requested": field_requests,
    },
    "mesh": {},
    "steps": {},
}

# 1) Mesh Extraction
for inst_name, inst in odb.rootAssembly.instances.items():
    data["mesh"][inst_name] = {
        "node_labels": [n.label for n in inst.nodes],
        "undeformed_coords": [float(c) for n in inst.nodes for c in n.coordinates],
        "elements": {},
    }
    for elem in inst.elements:
        etype = elem.type
        if etype not in data["mesh"][inst_name]["elements"]:
            data["mesh"][inst_name]["elements"][etype] = {
                "connectivity": [],
                "nplex": len(elem.connectivity),
            }
        data["mesh"][inst_name]["elements"][etype]["connectivity"].extend(
            [nid for nid in elem.connectivity]
        )

# 2) Step/Field Extraction
for step_name, frame_selection in steps_selection.items():
    if step_name not in odb.steps:
        print("Warning: Step '%s' not found." % step_name)
        continue

    step = odb.steps[step_name]
    total_frames = len(step.frames)

    if frame_selection == "all":
        frames_to_export = range(total_frames)
    else:
        frames_to_export = [f if f >= 0 else total_frames + f for f in frame_selection]

    step_data = {
        "frames_selected": list(frames_to_export),
        "field_layouts": {},
        "frames": {},
    }
    ref_frame = step.frames[frames_to_export[0]]

    # BUILD LAYOUTS
    for f_key in field_requests:
        base_name = (
            derived_fields_map[f_key][0] if f_key in derived_fields_map else f_key
        )
        if base_name not in ref_frame.fieldOutputs:
            continue

        field = ref_frame.fieldOutputs[base_name]
        step_data["field_layouts"][f_key] = {}

        for inst_name, inst in odb.rootAssembly.instances.items():
            try:
                layout = build_field_layout(field, inst, extrapolation_mode)
                if layout:
                    step_data["field_layouts"][f_key][inst_name] = layout
            except Exception as e:
                print(
                    "Note: Field '%s' skipped for instance '%s': %s"
                    % (f_key, inst_name, str(e))
                )

    # FRAME DATA EXTRACTION
    for idx in frames_to_export:
        frame = step.frames[idx]
        frame_data = {"time": float(frame.frameValue), "fields": {}}
        print("Processing Step: %s, Frame: %d" % (step_name, idx))

        for f_key, inst_layouts in step_data["field_layouts"].items():
            base_name, func = derived_fields_map.get(f_key, (f_key, None))
            field = frame.fieldOutputs[base_name]
            frame_data["fields"][f_key] = {}

            for inst_name, layout in inst_layouts.items():
                inst = odb.rootAssembly.instances[inst_name]
                sub_field = get_field_subset(field, inst, extrapolation_mode)

                val_dict = {}  # Key: elementLabel, Value: {nodeLabel: [values]}
                for val in sub_field.values:
                    e_lbl = val.elementLabel
                    n_lbl = val.nodeLabel
                    v_out = func(val) if func else val.data

                    if layout["location"] == "NODAL":
                        val_dict[n_lbl] = v_out
                    else:
                        # ELEMENT_NODAL logic
                        if e_lbl not in val_dict:
                            val_dict[e_lbl] = {}
                        if n_lbl not in val_dict[e_lbl]:
                            val_dict[e_lbl][n_lbl] = []

                        if isinstance(v_out, (float, int)):
                            val_dict[e_lbl][n_lbl].append(float(v_out))
                        else:
                            val_dict[e_lbl][n_lbl].extend([float(c) for c in v_out])

                # Build final structure
                v_list = []
                if layout["location"] == "NODAL":
                    for nl in layout["node_labels"]:
                        d = val_dict.get(nl, 0.0)
                        if isinstance(d, (float, int)):
                            v_list.append(float(d))
                        else:
                            v_list.extend([float(c) for c in d])
                else:
                    field_by_type = {}
                    for etype in sorted(layout["by_element_type"].keys()):
                        type_values = []
                        for el in layout["by_element_type"][etype]["element_labels"]:
                            # Get the connectivity for this specific element to maintain order
                            elem_obj = inst.getElementFromLabel(el)
                            for node_id in elem_obj.connectivity:
                                # This ensures Node1(SNEG, SPOS), Node2(SNEG, SPOS)...
                                node_vals = val_dict.get(el, {}).get(node_id, [])
                                type_values.extend(node_vals)
                        field_by_type[etype] = type_values
                    v_list = field_by_type

                frame_data["fields"][f_key][inst_name] = v_list

        step_data["frames"][str(idx)] = frame_data
    data["steps"][step_name] = step_data

with open(output_json, "w") as f:
    json.dump(data, f, separators=(",", ":"))

print("\nDone. Results saved to: %s" % output_json)
odb.close()
