-
Notifications
You must be signed in to change notification settings - Fork 344
Expand file tree
/
Copy pathrun_dpo.py
More file actions
104 lines (82 loc) · 2.92 KB
/
run_dpo.py
File metadata and controls
104 lines (82 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import pprint
from omegaconf import OmegaConf
from nemo_rl.algorithms.dpo import MasterConfig, dpo_train, setup
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.data.utils import setup_preference_data
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Run DPO training with configuration")
parser.add_argument(
"--config", type=str, default=None, help="Path to YAML config file"
)
# Parse known args for the script
args, overrides = parser.parse_known_args()
return args, overrides
def main():
"""Main entry point."""
args, overrides = parse_args()
if not args.config:
args.config = os.path.join(os.path.dirname(__file__), "configs", "dpo.yaml")
config = load_config(args.config)
print(f"Loaded configuration from: {args.config}")
if overrides:
print(f"Overrides: {overrides}")
config = parse_hydra_overrides(config, overrides)
config: MasterConfig = OmegaConf.to_container(config, resolve=True)
print("Applied CLI overrides")
# Print config
print("Final config:")
pprint.pprint(config)
config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"])
print(f"📊 Using log directory: {config['logger']['log_dir']}")
if config["checkpointing"]["enabled"]:
print(
f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}"
)
init_ray()
# setup tokenizer
tokenizer = get_tokenizer(config["policy"]["tokenizer"])
# setup data
dataset, val_dataset = setup_preference_data(tokenizer, config["data"])
(
policy,
cluster,
train_dataloader,
val_dataloader,
loss_fn,
logger,
checkpointer,
dpo_save_state,
master_config,
) = setup(config, tokenizer, dataset, val_dataset)
dpo_train(
policy,
train_dataloader,
val_dataloader,
tokenizer,
loss_fn,
master_config,
logger,
checkpointer,
dpo_save_state,
)
if __name__ == "__main__":
main()