from tc_python import *
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


"""
In this example a batch of steady state calculations are run using the Top-hat heat source model.
The heat source parameters for the Top-hat heat source as well as the processing parameters [power and speed]
are taken from Sow et al. [2020Sow]. The beam radius of the top-hat beam is adjusted to match the laser
distribution of the large multimode laser spot (Fig. 2 in [2020Sow]). In the end, the predicted melt pool 
width and depth are compared to the experimental data.

[2020Sow] Additive Manufacturing  36, 101532 (2020). 
DOI: https://doi.org/10.1016/j.addma.2020.101532
"""

EXP_FILE = "exp_data_Sow_IN625.csv"
MATERIAL = "IN625"
TOPHAT_BEAM_RADIUS = 450E-6
BASE_PLATE_HEIGHT = 3E-3
LAYER_THICKNESS = 50E-6
NR_OF_CORES = 4

def rms_error(all_experimental, all_simulated):
    differences = all_simulated - all_experimental
    squared_differences = differences ** 2
    mean_square_error = np.mean(squared_differences)
    rms_error = np.sqrt(mean_square_error)
    return rms_error


with TCPython() as start:
    start.set_cache_folder("cache")
    mp = LibraryMaterialProperties(MATERIAL)
    hs = (HeatSource.tophat_with_calculated_absorptivity()
                    .with_keyhole_model(KeyholeModel())
                    .set_beam_radius(TOPHAT_BEAM_RADIUS))

    calc = (start.with_additive_manufacturing()
            .with_steady_state_calculation()
            .with_numerical_options(NumericalOptions().set_number_of_cores(NR_OF_CORES))
            .with_material_properties(mp)
            .disable_fluid_flow_marangoni()
            .with_heat_source(hs)
            .set_height(BASE_PLATE_HEIGHT)
            .set_layer_thickness(LAYER_THICKNESS)
            .with_mesh(CoarseMesh()))

    results = {}

    ## Experimental data
    batch_data = pd.read_csv(EXP_FILE, skipinitialspace=True)
    power = batch_data['power[W]']
    speed = batch_data['speed[m/s]'] * 1E3
    experimental_width = batch_data['width[m]'] * 1E6
    experimental_depth = batch_data['depth[m]'] * 1E6
    error_width = batch_data['error_width[m]'] * 1E6
    error_depth = batch_data['error_depth[m]'] * 1E6

    for _, row in batch_data.iterrows():
        P = row['power[W]']
        V = row['speed[m/s]']
        print(f"Power: {P}, Speed: {V}")

        hs.set_power(P).set_scanning_speed(V)
        PV = "[{0:3.0f},".format(P) + "{0:4.0f}]".format(V * 1e3)
        results[PV] = calc.calculate()

    # Extracting predicted data for melt pool width and depth
    simulated_width = [result.get_meltpool_width()*1E6 for result in results.values()]
    simulated_depth = [result.get_meltpool_depth()*1E6 for result in results.values()]

    ## 3D plots
    for PV, result in results.items():
        print("{}; Has keyhole: {}. Melt pool width:{:.4E}, depth:{:.4E}, length:{:.4E};"
              .format(PV,
                      result.has_keyhole(),
                      result.get_meltpool_width(),
                      result.get_meltpool_depth(),
                      result.get_meltpool_length()))

        plotter, mesh = result.get_pyvista_plotter(shape=(2, 1))
        plotter.add_text("Temperature", font_size=14)
        plotter.add_mesh(mesh)

        plotter.subplot(1, 0)
        plotter.add_text("Iso-surface solid & liquid", font_size=14)
        contour_mesh = mesh.contour(
            isosurfaces=np.array([mp.get_solidification_temperature(), mp.get_liquidus_temperature()]))
        plotter.add_mesh(mesh.outline(), color="k")
        plotter.add_mesh(contour_mesh)
        plotter.link_views()
        plotter.camera_position = 'iso'
        plotter.show()
        del plotter

    # RMS error
    all_experimental = np.concatenate((experimental_width.values, experimental_depth.values))
    all_simulated = np.array(simulated_width + simulated_depth)
    min_val = min(min(all_experimental), min(all_simulated))
    max_val = max(max(all_experimental), max(all_simulated))
    rms_error = rms_error(all_experimental, all_simulated)

    ## Parity plot
    plt.figure(figsize=(8, 8))
    plt.scatter(experimental_width, simulated_width, color='blue', marker='s', label='Melt pool width')
    plt.scatter(experimental_depth, simulated_depth, color='red', marker='o', label='Melt pool depth')
    plt.plot([min_val, max_val], [min_val, max_val], 'r--')

    plt.xlabel("Experimental Values (Width/Depth) [μm]", fontsize=12)
    plt.ylabel("Calculated Values (Width/Depth) [μm]", fontsize=12)
    plt.title("Parity Plot for Melt Pool Width and Depth"
              + f"\nRMS Error: {rms_error:.2f} μm", fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()

    ## Bar plots for melt pool width and depth
    bar_width = 0.35
    x = np.arange(len(experimental_width))
    fig, axes = plt.subplots(2, 1, figsize=(10, 10))

    # Bars for melt pool widths
    axes[0].bar(x - bar_width / 2, experimental_width, bar_width, yerr=error_width,
                  capsize=5, label='Experimental Width', color='b')
    axes[0].bar(x + bar_width / 2, simulated_width, bar_width,
                  label='Calculated Width', color='g')

    axes[0].set_xlabel('Experiments', fontsize=12)
    axes[0].set_ylabel('Melt pool Width', fontsize=12)
    axes[0].set_title('Comparison of Experimental and Calculated Widths', fontsize=14)
    axes[0].set_xticks(x)
    axes[0].set_xticklabels([f"P={int(p)}W\nV={int(s)}mm/s" for p, s in zip(power, speed)],
                            fontsize=10, rotation=45)
    axes[0].legend(fontsize=12)
    axes[0].grid(axis='y', linestyle='--', alpha=0.6)

    # Bars for melt pool depths
    axes[1].bar(x - bar_width / 2, experimental_depth, bar_width, yerr=error_depth,
                  capsize=5, label='Experimental Depth', color='b')
    axes[1].bar(x + bar_width / 2, simulated_depth, bar_width,
                  label='Calculated Depth', color='g')

    axes[1].set_xlabel('Experiments', fontsize=12)
    axes[1].set_ylabel('Melt pool Depth', fontsize=12)
    axes[1].set_title('Comparison of Experimental and Calculated Depths', fontsize=14)
    axes[1].set_xticks(x)
    axes[1].set_xticklabels([f"P={int(p)}W\nV={int(s)}mm/s" for p, s in zip(power, speed)],
                            fontsize=10, rotation=45)
    axes[1].legend(fontsize=12)
    axes[1].grid(axis='y', linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.show()