from tc_python import *
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


"""
In this example a batch of steady state calculations are run using the Core-Ring heat source model.
The heat source parameters for the Core-Ring heat source as well as the processing parameters [power and speed]
are taken from Holla et al. [2024Holla]. In the end, the predicted melt pool width and depth are compared to the 
experimental data.

[2024Holla] Integrating Materials and Manufacturing Innovation  13, 969–985 (2024). 
DOI: https://doi.org/10.1007/s40192-024-00382-2
"""

EXP_FILE = "exp_data_Holla_SS316L.csv"
MATERIAL = "SS316L"
ABS_PREFACTOR = 1.2
BEAM_RADIUS_CORE = 47.6E-6
BEAM_RADIUS_RING = 25.29E-6
RING_RADIUS = 68.4E-6
RING_POWER_PERCENT = 80.0
NR_OF_CORES = 4

def rms_error(all_experimental, all_simulated):
    differences = all_simulated - all_experimental
    squared_differences = differences ** 2
    mean_square_error = np.mean(squared_differences)
    rms_error = np.sqrt(mean_square_error)
    return rms_error


with TCPython() as start:
    start.set_cache_folder("cache")

    mp = LibraryMaterialProperties(MATERIAL)

    hs = (HeatSource.core_ring_with_calculated_absorptivity()
                    .set_absorptivity_pre_factor(ABS_PREFACTOR)
                    .with_keyhole_model(KeyholeModel())
                    .set_beam_radius_core(BEAM_RADIUS_CORE)
                    .set_beam_radius_ring(BEAM_RADIUS_RING)
                    .set_ring_radius(RING_RADIUS)
                    .set_ring_power_percent(RING_POWER_PERCENT))

    calc = (start.with_additive_manufacturing()
            .with_steady_state_calculation()
            .with_numerical_options(NumericalOptions().set_number_of_cores(NR_OF_CORES))
            .with_material_properties(mp)
            .enable_fluid_flow_marangoni()
            .with_heat_source(hs)
            .with_mesh(CoarseMesh()))

    results = {}

    # Experimental data
    batch_data = pd.read_csv(EXP_FILE, skipinitialspace=True)
    power = batch_data['power[W]']
    speed = batch_data['speed[m/s]'] * 1E3
    experimental_width = batch_data['width[m]'] * 1E6
    experimental_depth = batch_data['depth[m]'] * 1E6

    for _, row in batch_data.iterrows():
        P = row['power[W]']
        V = row['speed[m/s]']
        print(f"Power: {P}, Speed: {V}")

        hs.set_power(P).set_scanning_speed(V)
        PV = "[{0:3.0f},".format(P) + "{0:4.0f}]".format(V * 1e3)
        results[PV] = calc.calculate()

    # Extracting predicted data for melt pool width and depth
    simulated_width = [result.get_meltpool_width()*1E6 for result in results.values()]
    simulated_depth = [result.get_meltpool_depth()*1E6 for result in results.values()]

    ## 3D plots
    for PV, result in results.items():
        print("{}; Has keyhole: {}. Melt pool width:{:.4E}, depth:{:.4E}, length:{:.4E};"
              .format(PV,
                      result.has_keyhole(),
                      result.get_meltpool_width(),
                      result.get_meltpool_depth(),
                      result.get_meltpool_length()))

        plotter, mesh = result.get_pyvista_plotter(shape=(2, 1))
        plotter.add_text("Temperature", font_size=14)
        plotter.add_mesh(mesh)

        plotter.subplot(1, 0)
        plotter.add_text("Iso-surface solid & liquid", font_size=14)
        contour_mesh = mesh.contour(
            isosurfaces=np.array([mp.get_solidification_temperature(), mp.get_liquidus_temperature()]))
        plotter.add_mesh(mesh.outline(), color="k")
        plotter.add_mesh(contour_mesh)
        plotter.link_views()
        plotter.camera_position = 'iso'
        plotter.show()
        del plotter

    # RMS error
    all_experimental = np.concatenate((experimental_width.values, experimental_depth.values))
    all_simulated = np.array(simulated_width + simulated_depth)
    min_val = min(min(all_experimental), min(all_simulated))
    max_val = max(max(all_experimental), max(all_simulated))
    rms_error = rms_error(all_experimental, all_simulated)

    ## Parity plot
    plt.figure(figsize=(8, 8))
    plt.scatter(experimental_width, simulated_width, color='blue', marker='s', label='Melt pool width')
    plt.scatter(experimental_depth, simulated_depth, color='red', marker='o', label='Melt pool depth')
    plt.plot([min_val, max_val], [min_val, max_val], 'r--')

    plt.xlabel("Experimental Values (Width/Depth) [μm]", fontsize=12)
    plt.ylabel("Calculated Values (Width/Depth) [μm]", fontsize=12)
    plt.title("Parity Plot for Melt Pool Width and Depth"
              + f"\nRMS Error: {rms_error:.2f} μm", fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()