Figure 2 From Paper

import symd
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pickle
import pandas as pd
import skunk
import svglib
import seaborn as sns
import gzip
import urllib
base_colors = [
    "f94144",
    "f3722c",
    "f8961e",
    "f9844a",
    "f9c74f",
    "90be6d",
    "43aa8b",
    "4d908e",
    "577590",
    "277da1",
]
colors = ["#" + c for c in base_colors]
sns.set_style("white")
sns.set_style("ticks")
sns.set(
    rc={
        "axes.facecolor": "#f5f4e9",
        "grid.color": "#AAAAAA",
        "axes.edgecolor": "#333333",
        "figure.facecolor": "#FFFFFF",
        "axes.grid": False,
        "axes.prop_cycle": plt.cycler("color", plt.cm.Dark2.colors),
        "font.family": "monospace",
    }
)
titles = [
    "p1",
    "p2",
    "pm",
    "pg",
    "cm",
    "pmm",
    "pmg",
    "pgg",
    "cmm",
    "p4",
    "p4m",
    "p4g",
    "p3",
    "p3m1",
    "p31m",
    "p6",
    "p6m",
]
print(symd.__version__)
1.1.3

Make Atlas

def scale(x, ib):
    s = np.apply_along_axis(lambda xi: ib @ xi, 1, x)
    return np.fmod(s, 1.0)


def compute_symm(positions, gnum, cell, ndim, n):
    group = symd.groups.load_group(gnum, ndim)
    cell = np.array(cell).reshape(ndim, ndim)
    ib = np.linalg.inv(cell)
    s = scale(positions[:, :ndim], ib)
    members = [symd.groups.str2mat(e) for e in group["genpos"]]
    folded_positions = np.zeros_like(s)
    for i in range(n):
        folded_positions[i, :] = s[i, :]
        for j in range(1, len(members)):
            k = j * n + i
            im = np.linalg.inv(members[j])
            w = im[:ndim, :ndim]
            folded_positions[i, :] += w @ s[k] + im[ndim, :ndim]
        folded_positions[i, :] /= len(members)
        for j in range(1, len(members)):
            k = j * n + i
            w = members[j][:ndim, :ndim]
            folded_positions[k] = w @ folded_positions[i] + members[j][ndim, :ndim]
    return np.mean((s[:k] - folded_positions[:k]) ** 2)


def rmsd(p1, p2):
    return np.mean((p1 - p2) ** 2, axis=(1, 2))
def crystal(
    n,
    group,
    w=None,
    retries=5,
    steps=10**6,
    steps2=5 * 10**3,
    ndims=2,
    starting_density=0.2,
    positions=False,
):
    # trying to have n be number in UNIT CELLL
    # so have to adjust for group size
    m = len(symd.groups.load_group(group, ndims).genpos)
    n = max(2, n // m)
    if w is not None:
        n += sum(w)
        name = f"{group}-{n}-{sum(w)}"
    else:
        name = f"{group}-{n}"
    print("Simulating", n, "particles:", name)
    # break out the try/except because we will accept failed NPT (because it jams so hard)
    for i in range(retries):
        np.random.seed(i)
        cell = symd.groups.get_cell(starting_density, group, 2, n, w)
        # NPT
        md = symd.Symd(
            nparticles=n,
            cell=cell,
            ndims=ndims,
            images=2,
            force="lj",
            wyckoffs=w,
            group=group,
            steps=steps,
            exeDir=f"crystal-{name}",
            pressure=0.25,
            temperature=0.1,
            start_temperature=0.5,
        )
        try:
            md.remove_overlap()
        except RuntimeError as e:
            continue
        if positions:
            md.log_positions(period=250)
        else:
            md.log_positions()
        try:
            md.run()
        except RuntimeError as e:
            d = md.number_density()
            if d < 0.5:
                print("Not dense enough, retrying", d)
                continue

        # NVT
        md.runParams["start_temperature"] = 0.05
        md.runParams["temperature"] = 1e-4
        md.runParams["box_update_period"] = 0
        md.runParams["langevin_gamma"] = 0.5
        md.runParams["steps"] = steps // 4
        md.log_positions(filename="equil.xyz")
        try:
            md.run()
        except RuntimeError as e:
            continue
        config = md.positions[-1]

        # Stability
        fp = np.loadtxt(md.runParams["final_positions"])
        # changing group, so need to read projected cell
        cell = md.read_cell(bravais=True)
        m = fp.shape[0]
        md2 = symd.Symd(
            nparticles=m,
            cell=cell,
            ndims=2,
            images=2,
            force="lj",
            wyckoffs=None,
            group=1,
            steps=steps2,
            exeDir=f"melt-{name}",
            temperature=None,
            start_temperature=0.0,
        )
        # run once to get melting traj
        # then again for longer with longer period
        md2.log_positions(period=10)
        md2.runParams["start_positions"] = md.runParams["final_positions"]
        try:
            md2.run()
        except RuntimeError as e:
            continue
        traj = md2.positions
        csm = rmsd(md2.positions[:, :m], md2.positions[0, :m])
        # csm = []
        # for i in range(md2.positions.shape[0]):
        #    csm.append(compute_symm(md2.positions[i], group, md2.read_cell(), ndims, n))
        if positions:
            return np.concatenate((md.positions, md2.positions))
        return (
            config,
            md2.positions[-1],
            md2.number_density(),
            csm,
            traj,
            np.arange(0, steps2, 10) * md2.runParams["time_step"],
        )
    return None
config, config2, nd, csm, traj, time = crystal(10, 6)
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[5], line 1
----> 1 config, config2, nd, csm, traj, time = crystal(10, 6)

Cell In[4], line 14, in crystal(n, group, w, retries, steps, steps2, ndims, starting_density, positions)
      1 def crystal(
      2     n,
      3     group,
   (...)
     12     # trying to have n be number in UNIT CELLL
     13     # so have to adjust for group size
---> 14     m = len(symd.groups.load_group(group, ndims).genpos)
     15     n = max(2, n // m)
     16     if w is not None:

File /opt/hostedtoolcache/Python/3.12.4/x64/lib/python3.12/site-packages/symd/groups.py:517, in load_group(gnum, dim)
    507 """
    508 Load one of the 2D planar groups or 3D space groups that tile space. The :obj:`Group`
    509 contains the name of the Bravais lattice, the general positions,
   (...)
    514 :return: The :obj:`Group`
    515 """
    516 gnum = str(gnum)
--> 517 from importlib_resources import files
    518 import symd.data
    520 fp = files(symd.data).joinpath(f"{dim}dgroups.json")

ModuleNotFoundError: No module named 'importlib_resources'
plt.figure(figsize=(8, 8))
plt.title(f"{nd}")
plt.plot(config[:, 0], config[:, 1], ".")
plt.plot(config2[:, 0], config2[:, 1], ".")
plt.plot(time, csm)
# Turned off
if False:
    from multiprocessing import Pool

    cdf = None
    results = []
    trajs = {}

    with Pool() as pool:
        for N in [8, 16, 32, 128]:
            for i, t in enumerate(titles):
                W = len(symd.groups.load_group(i + 1, 2)["specpos"])
                for j in range(1 + W):
                    wycks = None if j == 0 else [1] * j
                    name = f"{t}-w{j}-n{N}"
                    job = pool.apply_async(crystal, (N, i + 1, wycks))
                    # job = crystal(N, i+1, wycks)
                    results.append((t, name, N, j, job))

        for r in results:
            t, name, N, j, ar = r
            print("Getting result for ", name)
            res = ar.get()
            # res = ar
            if res is None:
                continue
            config, config2, nd, csm, traj, time = res
            T = len(csm)

            df2 = pd.DataFrame(
                {
                    "Group": [t] * T,
                    "Traj": [name] * T,
                    "rho": [nd] * T,
                    "N": [N] * T,
                    "Wyckoffs": [str(j)] * T,
                    "RMSD": csm,
                    "Time": time,
                }
            )
            if cdf is None:
                cdf = df2
            else:
                cdf = pd.concat((cdf, df2))
            trajs[name] = traj

    cdf.reset_index(inplace=True)
    cdf.to_pickle("atlas2d.pkl.gz")
    with open("atlas2d.traj.pkl", "wb") as f:
        pickle.dump(trajs, f, pickle.HIGHEST_PROTOCOL)

Plot Figure

urllib.request.urlretrieve(
    "https://www.dropbox.com/s/m6gi1ecv66ylm06/atlas2d.traj.pkl.gz?dl=1",
    "atlas2d.traj.pkl.gz",
)
urllib.request.urlretrieve(
    "https://www.dropbox.com/s/smn8tvljxhlp75f/atlas2d.pkl.gz?dl=1", "atlas2d.pkl.gz"
)
cdf = pd.read_pickle("atlas2d.pkl.gz")
with gzip.open("atlas2d.traj.pkl.gz", "rb") as f:
    trajs = pickle.load(f)
cdf.query("rho > 0.5").Traj.unique().shape
def plot_config(pos, figsize=(1.5, 1.5), color="#333333"):
    N, D = pos.shape
    fig, ax = plt.subplots(figsize=figsize)
    ax.plot(
        pos[:, 0],
        pos[:, 1],
        color=color,
        marker="o",
        markersize=8,
        markeredgewidth=0.0,
        linestyle="None",
    )
    ax.set_facecolor("#f5f4e9")
    fig.patch.set_facecolor("#f5f4e9")
    ax.axis("off")
    # want about 64 points
    # q = max(0.05, min(0.5, 64 / N))
    # xlim = np.quantile(pos[:,0], [0.5 - q, 0.5 + q])
    # ylim = np.quantile(pos[:,1], [0.5 - q, 0.5 + q])
    xlim = (-3, 3)
    ylim = (-3, 3)
    ax.set_aspect("equal")
    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)
    plt.tight_layout()
    return skunk.pltsvg(fig=fig)


skunk.display(plot_config(trajs["pg-w0-n16"][0]))
top = cdf[cdf.Time == cdf.Time.max()].sort_values(by=["RMSD"])[:25]
for r, t in zip(top.RMSD.values, top.Traj.values):
    plot_config(trajs[t][0])
    plt.title(f"{t} = {r}")
replaces = []


def annotate_config(rmsd, traj, time, wyckoffs, color):
    late_rmsd = rmsd.where(time > 15)
    idx = late_rmsd.argmin()
    label = traj.iloc[idx]
    w = wyckoffs.iloc[idx]
    x = time.iloc[idx]
    y = rmsd.iloc[idx]
    if y > 1:
        x = 0
        y = 0
    replaces.append((label, f"C{w}"))
    ax = plt.gca()
    box = skunk.Box(25, 25, label)
    ab = mpl.offsetbox.AnnotationBbox(
        box,
        (x, y),
        pad=0,
        bboxprops=dict(edgecolor="#333", linewidth=1),
        xybox=(0.7 if x == 0 else 0.2, 0.7),
        xycoords="data",
        boxcoords="axes fraction",
        arrowprops=dict(arrowstyle="->", color="#333"),
    )

    ax.add_artist(ab)


g = sns.relplot(
    data=cdf.query("rho > 0.5"),
    x="Time",
    y="RMSD",
    kind="line",
    hue="Wyckoffs",
    col="Group",
    col_wrap=6,
    aspect=1,
    linewidth=1,
    style="N",
    height=1.75,
    palette="Dark2",
    hue_order=[str(i) for i in range(9)],
)
plt.ylim(-0.1, 1)
# sns.move_legend(g, "lower right", bbox_to_anchor=(0.9,0.1), ncol=3)
g.map(annotate_config, "RMSD", "Traj", "Time", "Wyckoffs").set_axis_labels(
    r"Time [$\tau$]", "RMSD"
)

main_svg = skunk.pltsvg()
svg = skunk.insert(
    {l: plot_config(trajs[l][0], color=c) for l, c in replaces}, svg=main_svg
)
skunk.display(svg)
with open("atlas.svg", "w") as f:
    f.write(svg)
from cairosvg import svg2png

svg2png(bytestring=svg, write_to="atlas_raster.png", dpi=300)

TOC

print(titles)
name = "p3m1-w0-n128"
group_i = titles.index(name.split("-")[0]) + 1
n_i = int(name.split("-n")[1])
group = symd.groups.load_group(group_i, 2)
gp = len(group.genpos)
pos = trajs[name][0]

N, D = pos.shape

# build out colors
M = max(2, n_i // gp)
print(n_i, gp, N, M, N // M)
c = [colors[i % len(colors)] for i in range(M)]
for i in range(1, N // M):
    c.extend([colors[i % len(colors)] for i in range(M)])
#    # which particle is being duplicated
#
#    c.extend([colors[((i - 1) % M) % len(colors)]] * M)
#    #c.extend(['#333'] * M)
print(len(c))

figsize = (3.33, 1.88)
aspect = figsize[0] / figsize[1]
fig, ax = plt.subplots(figsize=figsize, dpi=300)
ax.scatter(pos[:, 0], pos[:, 1], marker="o", s=2, c=c)
# want about 64 points
q = 0.4
xlim = np.quantile(pos[:, 0], [0.5 - q, 0.5 + q])
ylim = np.quantile(pos[:, 1], [0.5 - q, 0.5 + q])
yd = (ylim[1] - ylim[0]) / aspect
yc = (ylim[1] + ylim[0]) / 2
ylim = (yc - yd, yc + yd)
# ax.set_xlim(*xlim)
# ax.set_ylim(*ylim)
x = 20
ax.set_xlim(-x, x)
ax.set_ylim(-x / aspect, x / aspect)
ax.set_aspect("equal")
ax.set_facecolor("#f5f4e9")
fig.patch.set_facecolor("#f5f4e9")
ax.axis("off")
plt.tight_layout()
plt.savefig("toc.tiff")
plt.show()