| | import os |
| | import pandas as pd |
| | from datasets import Dataset, DatasetDict, Features, Value, Image, Sequence |
| | from PIL import Image as PILImage |
| |
|
| | def load_crysmtm_dataset(data_dir, split="train"): |
| | """Load CrysMTM dataset for a specific split.""" |
| | |
| | |
| | metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv") |
| | df = pd.read_csv(metadata_path) |
| | |
| | def load_example(row): |
| | """Load a single example with all modalities.""" |
| | example = { |
| | "phase": row["phase"], |
| | "temperature": row["temperature"], |
| | "rotation": row["rotation"], |
| | "split": row["split"] |
| | } |
| | |
| | |
| | if pd.notna(row["image_path"]): |
| | image_path = os.path.join(data_dir, row["image_path"]) |
| | if os.path.exists(image_path): |
| | example["image"] = PILImage.open(image_path).convert("RGB") |
| | |
| | |
| | if pd.notna(row["xyz_path"]): |
| | xyz_path = os.path.join(data_dir, row["xyz_path"]) |
| | if os.path.exists(xyz_path): |
| | with open(xyz_path, 'r') as f: |
| | lines = f.readlines()[2:] |
| | coords = [] |
| | elements = [] |
| | for line in lines: |
| | parts = line.strip().split() |
| | if len(parts) >= 4: |
| | elements.append(parts[0]) |
| | coords.append([float(x) for x in parts[1:4]]) |
| | example["xyz_coordinates"] = coords |
| | example["elements"] = elements |
| | |
| | |
| | if pd.notna(row["text_path"]): |
| | text_path = os.path.join(data_dir, row["text_path"]) |
| | if os.path.exists(text_path): |
| | with open(text_path, 'r') as f: |
| | example["text"] = f.read() |
| | |
| | |
| | regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"] |
| | example["regression_labels"] = [row[prop] for prop in regression_properties] |
| | |
| | |
| | example["classification_label"] = row["label"] |
| | |
| | return example |
| | |
| | |
| | dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()]) |
| | |
| | return dataset |
| |
|
| | def load_dataset(data_dir): |
| | """Load the complete CrysMTM dataset.""" |
| | |
| | splits = ["train", "test_id", "test_ood"] |
| | dataset_dict = {} |
| | |
| | for split in splits: |
| | try: |
| | dataset_dict[split] = load_crysmtm_dataset(data_dir, split) |
| | except FileNotFoundError: |
| | print(f"Warning: {split} split not found") |
| | |
| | return DatasetDict(dataset_dict) |
| |
|
| | |
| | def load_crysmtm(): |
| | """Main function to load CrysMTM dataset.""" |
| | return load_dataset(".") |