-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathdistributed_tensor.cpp
More file actions
75 lines (64 loc) · 2.04 KB
/
distributed_tensor.cpp
File metadata and controls
75 lines (64 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include "distributed_tensor.h"
#include "base.h"
#include "exceptions.h"
#include "ir/interface_nodes.h"
#include "type.h"
namespace nvfuser {
void Sharding::setAxisIsShardedOn(
const int64_t axis,
const ParallelType parallel_type) {
NVF_CHECK(isParallelTypeDeviceDim(parallel_type));
NVF_CHECK(mesh_.size() > 0, "Cannot shard a non-distributed tensor.");
const auto i = axis_sharded_on_.find(parallel_type);
NVF_CHECK(
i == axis_sharded_on_.end(),
"Parallel type ",
parallel_type,
" was already used to shard axis ",
i->second);
axis_sharded_on_[parallel_type] = axis;
}
int64_t Sharding::axisShardedOn(const ParallelType parallel_type) const {
return getOrDefault(axis_sharded_on_, parallel_type, -1L);
}
std::vector<Sharding> getOutputShardings(Fusion* fusion) {
std::vector<TensorView*> all_tvs = fusion->allTvs();
if (std::none_of(
all_tvs.begin(),
all_tvs.end(),
std::mem_fn(&TensorView::hasDeviceMesh))) {
return {};
}
std::vector<Sharding> output_shardings;
output_shardings.reserve(fusion->outputs().size());
for (Val* out_val : fusion->outputs()) {
auto* out_tv = dynamic_cast<TensorView*>(out_val);
if (out_tv == nullptr) {
output_shardings.emplace_back();
continue;
}
if (fusion->getOutputAlias(out_tv).visibility ==
OutputVisibility::kHidden) {
continue;
}
const DeviceMesh& mesh = out_tv->getDeviceMesh();
Sharding& output_sharding = output_shardings.emplace_back(mesh);
if (mesh.size() > 0) {
for (const ParallelType parallel_type : kParallelTypeDIDs) {
if (const auto axis = getShardedLogicalAxis(out_tv, parallel_type);
axis != -1) {
output_sharding.setAxisIsShardedOn(axis, parallel_type);
}
}
}
}
return output_shardings;
}
} // namespace nvfuser