照片由 DeepMind on Unsplash
矩阵乘法是许多系统中使用的基本运算,从神经网络到科学计算例程。 为矩阵乘法寻找高效且可证明正确的算法可以对加快计算速度和提高效率产生巨大影响,但这是一项非常具有挑战性的任务。 可能算法的空间是巨大的,而用于发现算法的传统方法,如人工设计的启发式或组合搜索,往往不是最优的。
DeepMind最近提出的基于人工智能的自动搜索解决方案远远超出了人类的直觉。该解决方案由一个名为 AlphaTensor 的深度强化学习代理组成,它构建在 零度. 该代理经过训练可以玩单人游戏 TensorGame,其目标是发现计算效率高的矩阵乘法算法。
AlphaTensor 特别擅长通过将大矩阵乘法分解为更小的乘法来处理大矩阵。 此外,一旦在特定硬件设备上进行微调,AlphaTensor 可用于实现最先进的矩阵乘法性能。
AlphaTensor 具有加速深度学习计算的巨大潜力。 在深度学习中,许多耗时的操作可以映射到矩阵乘法。 通过使用 AlphaTensor 优化这些操作,可以显着提高深度学习模型的整体性能。
最近,OpenAlphaTensor, AlphaTensor 的第一个开源实现, 已发布,这可能会彻底改变深度学习模型的计算能力。
矩阵乘法张量
对于矩阵乘法优化方面的非专家来说,理解矩阵乘法等运算如何映射到三维张量可能并不简单。 我将尝试用简单的文字和示例来解释它。
让我们考虑乘积 C = A*B,其中为简单起见,A 和 B 都是大小为 N 的方阵。乘法运算可以映射到形状为 (N^3, N^2, N^2) 的 2D 张量中。 第一个张量维度表示展平矩阵 A,第二个维度表示展平矩阵 B,第三个维度表示展平矩阵 C。
对于每个条目,张量只有二进制值(1 或 0)。 请注意,张量表示乘法运算,因此它与矩阵 A 和 B 的值无关。
张量的每个条目都对应于运算的系数。 例如,要计算 C[1,1],需要将 A[1,1] 和 B[1,1] 相乘。 因此,对应于 A[0,0,0]、B[1,1] 和 C[1,1] 的张量项 [1,1] 的值为 1。相反,要计算 C[1,1 ,2,1],不需要 A[1]。 因此,张量行 T[N+0, :, XNUMX] 将仅包含零。
下图显示了 N=2 的张量示例。
图片来自 DeepMind 纸 发表于 自然
如上图(b)和(c)所示,可以使用3D张量的分解来实现计算乘积的算法。 更具体地说,下面的算法可用于将张量分解(矩阵 U、V、W)转换为矩阵乘法算法。
DeepMind 中引入的用于计算矩阵乘积 C=AB 的参数化元算法 纸
张量游戏
寻找有效的矩阵乘法算法的问题极具挑战性,因为要考虑的可能算法的数量远大于宇宙中的原子数量,即使对于矩阵乘法的小实例也是如此。
DeepMind 将这个问题转化为单人游戏,并称之为 TensorGame。 在这个游戏中,玩家选择如何组合不同的矩阵条目以将它们相乘。 根据获得正确乘法结果所需的运算次数分配分数。 当达到零张量或已进行最大移动次数时,游戏结束。 最终的因式分解是基于对残差等级的估计和某些优化标准(例如渐近时间复杂度或实际运行时间)进行评估的。
TensorGame 中的初始位置对应于在某种随机基础上表示的矩阵乘法张量。
在游戏的每个步骤 t 中,玩家写下三个向量
,它指定 rank-1 张量 . 游戏状态通过减去玩家选择的向量来更新:
哪里
是矩阵乘法张量。如果游戏以 p 步结束,这意味着 Matrix Multiplication Tensor
可以分解为 p rank-1 张量 ,即它至少有秩 p。然后可以将 TensorGame 解释为秩分解算法,而 AlphaTensor 可以看作是估计张量秩的算法。
AlphaTensor 架构
到目前为止,我们已经了解了 TensorGame 并阐明了如何将其解决方案视为矩阵乘法算法。 现在让我们探讨 AlphaTensor 的主要概念,该算法用于游戏。
AlphaTensor 架构基本上是一种编码器-解码器 Transformer 架构,其中:
- 编码器将游戏状态作为输入 ,模型之前采取的n个动作(通常n=7),以及当前动作的时间索引t。 信息以张量的形式堆叠在一起,形状为 (n+1, N^2, N^2, N^2)。 然后将该张量重新整形并转换(使用三个线性层)为形状为 (N^2, N^2, c) 的张量,其中 c 是模型的内部维度。
- 解码器以自回归的方式从编码器给出的嵌入向量生成 n_steps 动作。 每个动作对应三元组的一个token 代表分解游戏张量的三元组之一(即降低其等级)
该模型通过交替反向传播和模型作用进行训练。 模型表演用于生成数据,然后用于训练模型。 在实践中,模型是用综合生成的数据和模型在表演过程中生成的数据的混合物来训练的。 执行步骤是通过获取对应于矩阵运算的 3D 张量并在其上玩 n_actors 游戏来完成的。 每个参与者在标准基础上或在替代基础上玩游戏(基础的变化以给定的概率应用)。 然后收集结果,并可用于合成数据的训练步骤。
动作步骤基于 AlphaZero 的蒙特卡罗树搜索 (MCTS),经过修改以支持大型动作空间。简而言之,在选择操作之前,将从模型输出中探索 n_sims 路径,最多未来探索 5 个步骤。然后,考虑生成的路径来调整模型生成的概率。然后选择最有希望的未来路径的行动来继续游戏。
在训练模型时,奖励实际上是负奖励(惩罚)。 它的绝对值随着解决游戏所需的每个额外步骤而增加。 如果模型需要 m 个步骤来解决 TensorGame,则与游戏相关的奖励为 r=-m。 如果模型无法在 max_rank 步骤中解决 TensorGame,则通过估计剩余张量的等级来计算奖励。 秩估计为构成张量的矩阵的秩之和。 该估计值是张量真实秩的上限。
在微调模型时,最终状态的惩罚奖励还应该考虑模型产生的算法的延迟。奖励公式变为rt'=rt+λbt,其中rt是前面描述的奖励方案,bt是基准奖励(仅在最终状态下非零),并且 λ 是用户指定的系数。
为 GPU 和 TPU 量身定制的 AlphaTensor 发现算法的加速 (%),摘自 DeepMind 的论文。 加速是相对于相同硬件上的标准(例如 GPU 的 cuBLAS)矩阵乘法测量的,并与 施特拉森平方算法。 资源: DeepMind.
我最近发布 OpenAlpha张量,AlphaTensor 的第一个开源实现。 在本节中,我将介绍实施过程。 正如我们之前讨论的那样,AlphaTensor 架构非常简单,它基于具有编码器-解码器架构的标准转换器。 AlphaTensor 最有趣的组件是编码器部分的第一层和动作的采样方式。
让我们从第一个编码层开始。
# x.size = (N, T, S, S, S)
# scalars.size = (N, s)
batch_size = x.shape[0]
S = x.shape[-1]
T = x.shape[1]
x1 = x.permute(0, 2, 3, 4, 1).reshape(batch_size, S, S, S * T)
x2 = x.permute(0, 4, 2, 3, 1).reshape(batch_size, S, S, S * T)
x3 = x.permute(0, 3, 4, 2, 1).reshape(batch_size, S, S, S * T)
input_list = [x1, x2, x3]
for i in range(3): temp = self.linears_1[i](scalars).reshape(batch_size, S, S, 1) input_list[i] = torch.cat([input_list[i], temp], dim=-1) input_list[i] = self.linears_2[i](input_list[i])
x1, x2, x3 = input_list
在上面的代码片段中,我们展示了如何将输入张量分解为三个张量,然后将其用作转换层的查询、键和值输入。
- 在表示展平矩阵(A、B、C)的三个张量维度上,输入张量沿着每个维度与表示先前动作的维度一起展平。 这样,在输入张量的每个扁平化副本中,对于所选维度的所有 S 值,所选维度是最后 T-1 个值和实际值的聚合,其中 S=N^2。 从哲学上讲,就好像对于每个维度,我们都专注于该维度中先前行动中发生的事情。
- 标量被映射到维度 S^2 的三个不同空间中,然后重新整形以与在前一点获得的张量连接。 从概念上讲,标量被映射到维度为 S^2 的嵌入空间,然后嵌入信息被分块为 S 向量并堆叠在一起,类似于标记化文本时发生的情况。
- 标量标记与重组后的输入张量连接,然后作为线性层的输入,用于在模型的内部维度中映射标量+通道历史焦点信息。
这三个步骤可以解释为一种向模型提供有关标量的信息(如在 TensorGame 时间步长中)和关注每个通道的先前操作的方式。
关于动作的产生方式,有趣的是 AlphaTensor 生成三元组 u、v、w 作为输出,其目的是降低张量等级。 这三个向量的大小为 S,由于它们是串联的,因此模型必须生成一个大小为 3*S 的向量。 AlphaTensor 是用 RL 算法训练的,因此所有可能的动作都必须用枚举空间中的概率来表示,即模型产生不同动作的概率。 这意味着 3S 空间中的每个向量都应映射到不同的动作。 这导致大小为 |F|^(3S) 的动作空间,其中 |F| 是u,v,w的元素可以取的不同值的个数。 通常,值被限制为 (-2, -1, 0, 1, 2),导致 5 个元素的基数。
这是一个主要的挑战:要为大小为 5 的矩阵的矩阵乘积生成动作概率,我们需要 5^75 * 4 字节的内存,这意味着大约 10^44 GB 的内存。 显然,我们无法管理如此大的行动空间。
我们如何解决这个问题? 为了减少动作概率的内存占用,我们可以将三元组分成更小的块,对它们进行“标记化”,并将这些块视为变换器体系结构中生成的标记,即标记作为自回归解码器的输入方式。 在上面的示例中,我们可以将三元组拆分为 15 个块,从而将内存消耗减少到 15 * 5^(75/15) * 4,即 187.5 KB。
def _eval_forward(self, e: torch.Tensor): bs = e.shape[0] future_g = ( torch.zeros((bs, self.n_samples, self.n_steps)).long().to(e.device) ) ps = torch.ones((bs, self.n_samples)).to(e.device) e = e.unsqueeze(1).repeat(1, self.n_samples, 1, 1) future_g = future_g.view(-1, self.n_steps) ps = ps.view(-1) e = e.view(-1, e.shape[-2], e.shape[-1]) for i in range(self.n_steps): o_s, z_s = self.core(future_g[:, : i + 1], e) future_g[:, i], p_i = sample_from_logits(o_s[:, i]) ps *= p_i future_g = future_g.view(bs, self.n_samples, self.n_steps) ps = ps.view(bs, self.n_samples) return ( future_g, ps, z_s[:, 0].view(bs, self.n_samples, *z_s.shape[2:]).mean(1), )
上面我们展示了生成完整动作的代码片段。 在代码中,self.core 包含解码器层,张量 e 表示编码器层的输出。 零可以被认为是NLP 模型中的 token 和表示 n_steps 块的 n_steps 动作是以渐进的方式生成的。
该模型返回三个数量:
- 生成的动作
- 与完整动作相关的概率
- 为生成将用于计算模型值的第一个动作(第一个块)而生成的逻辑。
值得在 n_samples 参数上多说几句。 该参数用于动作步骤,它允许模型生成不同版本的三元组,然后用于探索动作过程中使用的蒙特卡洛树搜索算法中的动作空间。 根据模型生成的策略对 n_samples 个不同的动作进行采样。
作用步骤
整个算法中最棘手的部分可能是用于解决 TensorGame 的 Acting 步骤。 该算法在 AlphaTensor 论文中没有深入解释,因为它基于几篇 DeepMind 之前的论文,这些论文只是被引用并作为已知给出。 在这里,我将重建所有缺失的部分并逐步解释我们的实现。
我们可以将动作步骤组织成三个不同的部分:
- 蒙特卡洛树搜索
- 游戏模拟
- 改进的策略计算
让我们一一分析。
蒙特卡洛树搜索 (MCTS)
蒙特卡洛树搜索 (MCTS) 是一种广泛用于玩游戏的人工智能技术,尤其是在棋盘游戏和视频游戏中。 该算法创建了一个游戏树来模拟潜在的动作和结果,并使用随机抽样来评估每个动作的预期奖励。 然后,该算法迭代地选择具有最高预期奖励的移动并模拟结果,直到它达到最终状态或指定的停止条件。 模拟用于估计每一步获胜的概率并指导决策过程。 MCTS 已被证明在可能的移动和结果数量很大的复杂游戏中是有效的,并且它已被用于成功的游戏人工智能系统,例如 AlphaGo。
在 AlphaTensor 中,使用了原始 MCTS 的修改版本。 特别是,不是从整个动作空间中随机选择动作,而是在模型直接生成的子集中选择动作(通过前面介绍的 n_samples)。 然后在改进的策略计算步骤中应用对策略升级的更正。
在我们的实现中,我们决定将关于蒙特卡洛树的所有信息保存在一个字典中,该字典以 TensorGame 状态的哈希版本作为键,以与状态本身相关的信息作为值。 每个蒙特卡洛步骤都从一个节点开始,模拟 n_sim 迷你游戏,以 5 步的视野探索未来。 如果该节点已经在之前的模拟中被探索过,则 n_sim 会根据之前的探索次数进行调整。 对于每个节点,访问次数存储在 N_s_a 张量中,因为该张量包含每个节点子操作的访问次数(在模型采样的次数中)。
def monte_carlo_tree_search( model: torch.nn.Module, state: torch.Tensor, n_sim: int, t_time: int, n_steps: int, game_tree: Dict, state_dict: Dict,
): """Runs the monte carlo tree search algorithm. Args: model (torch.nn.Module): The model to use for the simulation. state (torch.Tensor): The initial state. n_sim (int): The number of simulations to run. t_time (int): The current time step. n_steps (int): The maximum number of steps to simulate. game_tree (Dict): The game tree. state_dict (Dict): The dictionary containing the states. """ state_hash = to_hash(extract_present_state(state)) if state_hash in state_dict: with torch.no_grad(): N_s_a = state_dict[state_hash][3] n_sim -= int(N_s_a.sum()) n_sim = max(n_sim, 0) for _ in range(n_sim): simulate_game(model, state, t_time, n_steps, game_tree, state_dict) # return next state possible_states_dict, _, repetitions, N_s_a, q_values, _ = state_dict[ state_hash ] possible_states = _recompose_possible_states(possible_states_dict) next_state_idx = select_future_state( possible_states, q_values, N_s_a, repetitions, return_idx=True ) next_state = possible_states[next_state_idx] return next_state
上面的代码显示了我们对算法的实现。 为了代码简单,策略修正在 simulate_game 函数中执行。
游戏模拟
simulate_game 函数负责探索由代表 TensorGame 特定状态的节点组成的树。 它还会在遇到叶节点时运行模型,并将所有节点信息存储在 state_dict 字典中。 让我们深入了解一下它的实现:
@torch.no_grad()
def simulate_game( model, state: torch.Tensor, t_time: int, max_steps: int, game_tree: Dict, states_dict: Dict, horizon: int = 5,
): """Simulates a game from a given state. Args: model: The model to use for the simulation. state (torch.Tensor): The initial state. t_time (int): The current time step. max_steps (int): The maximum number of steps to simulate. game_tree (Dict): The game tree. states_dict (Dict): The states dictionary. horizon (int): The horizon to use for the simulation. """ idx = t_time max_steps = min(max_steps, t_time + horizon) state_hash = to_hash(extract_present_state(state)) trajectory = [] # selection while state_hash in game_tree: ( possible_states_dict, old_idx_to_new_idx, repetition_map, N_s_a, q_values, actions, ) = states_dict[state_hash] possible_states = _recompose_possible_states(possible_states_dict) state_idx = select_future_state( possible_states, q_values, N_s_a, repetition_map, return_idx=True ) trajectory.append((state_hash, state_idx)) # state_hash, action_idx future_state = extract_present_state(possible_states[state_idx]) state = possible_states[state_idx] state_hash = to_hash(future_state) idx += 1 # expansion if idx = max_steps: trajectory.append((state_hash, None)) if not game_is_finished(extract_present_state(state)): state = state.to(model.device) scalars = get_scalars(state, idx).to(state.device) actions, probs, q_values = model(state, scalars) ( possible_states, cloned_idx_to_idx, repetitions, not_dupl_indexes, ) = extract_children_states_from_actions( state, actions, ) not_dupl_actions = actions[:, not_dupl_indexes].to("cpu") not_dupl_q_values = torch.zeros(not_dupl_actions.shape[:-1]).to( "cpu" ) N_s_a = torch.zeros_like(not_dupl_q_values).to("cpu") present_state = extract_present_state(state) states_dict[to_hash(present_state)] = ( _reduce_memory_consumption_before_storing(possible_states), cloned_idx_to_idx, repetitions, N_s_a, not_dupl_q_values, not_dupl_actions, ) game_tree[to_hash(present_state)] = [ to_hash(extract_present_state(fut_state)) for fut_state in possible_states ] leaf_q_value = q_values else: leaf_q_value = -int(torch.linalg.matrix_rank(state).sum()) # backup backward_pass(trajectory, states_dict, leaf_q_value=leaf_q_value)
每个模拟分为三个部分:
- 选择
- 扩展
- 备份工具
在选择部分,模拟在已经生成的树节点上运行,并使用以下函数选择以下节点:
def select_future_state( possible_states: List[torch.Tensor], q_values: torch.Tensor, N_s_a: torch.Tensor, repetitions: Dict[int, list], c_1: float = 1.25, c_2: float = 19652, return_idx: bool = False,
) -> torch.Tensor: """Select the future state maximizing the upper confidence bound."""
# q_values (1, K, 1) pi = torch.tensor( [ len(repetitions[i]) for i in range(len(possible_states)) if i in repetitions ] ).to(q_values.device) ucb = q_values.reshape(-1) + pi * torch.sqrt( torch.sum(N_s_a) / (1 + N_s_a) ) * (c_1 + torch.log((torch.sum(N_s_a) + c_2 + 1) / c_2)) if return_idx: return ucb.argmax() return possible_states[ucb.argmax()]
在实践中,最大化 ucb 函数的动作:
对于给定的状态被选中。 这里 Q 表示模型生成的 Q 值,π 表示使用模型策略采样的动作的随机分布。 N(s, a) 表示节点从节点 s 到动作 a 的访问次数。
一旦选择阶段到达叶节点,如果模拟没有达到终止条件(根据最大探索,即未来地平线或游戏结束),则该模型将用于选择 n_samples 个替代节点(它们将是叶连续迭代中的节点)。 这称为扩展阶段,因为新节点被添加到树中。 然后,在当前模拟中不再探索其他节点,但叶 q_value 被发送到以下模拟步骤:备份。
备份是每个模拟的最后阶段。 在备份期间,如果叶节点是终端状态,则计算最终奖励; 否则叶 q 值被用作估计的奖励。 然后奖励在模拟轨迹上反向传播更新状态 q_values 和更新访问计数器 N(s, a)。 在下面的代码片段中,我们展示了奖励反向传播的代码。
def backward_pass(trajectory, states_dict, leaf_q_value: torch.Tensor): """Backward pass of the montecarlo algorithm"""
reward = 0 for idx, (state, action_idx) in enumerate(reversed(trajectory)): if action_idx is None: # leaf node reward += leaf_q_value else: ( _, old_idx_to_new_idx, _, N_s_a, q_values, _, ) = states_dict[state] if isinstance(reward, torch.Tensor): reward = reward.to(q_values.device) action_idx = int(action_idx) if action_idx in old_idx_to_new_idx: not_dupl_index = old_idx_to_new_idx[int(action_idx)] else: not_dupl_index = action_idx reward -= 1 q_values[:, not_dupl_index] = ( N_s_a[:, not_dupl_index] * q_values[:, not_dupl_index] + reward ) / (N_s_a[:, not_dupl_index] + 1) N_s_a[:, not_dupl_index] += 1
改进的策略计算
一旦运行了所有模拟并且 MCTS 提供了近期的有趣快照,就该更新与预测节点关联的策略并返回它们,以便它们可以在训练期间使用。 改进后的策略,遵循中描述的方法 休伯特等人, 用于管理大型动作空间。 事实上,对于小的搜索空间,在 MCTS 期间可以从动作空间中随机采样一个动作并评估其影响。 在更大的动作空间中采用类似的方法会导致所有轨迹在不同路径上发散,并且需要无限数量的轨迹才能获得有意义的统计数据,然后更新策略。 由于这里我们使用 sample-MCTS 来避免分散,即根据模型策略对 n_samples 个动作进行采样,然后 MCTS 在探索树时仅选择一个采样动作,因此我们需要在计算时考虑样本校正训练模型时将使用的最终更新策略。
在实践中,改进的策略被计算为
哪里
def compute_improved_policy( state_dict: Dict, states: List[str], model_n_steps: int, model_n_logits: int, N_bar: int,
): """Compute the improved policy given the state_dict, the list of states. The improved policy is computed as (N_s_a / N_s_a.sum())^(1/tau) where tau is (log(N_s_a.sum()) / log(N_bar)) if N_s_a.sum() > N_bar else 1. """ policies = torch.zeros(len(states), model_n_steps, model_n_logits) N_bar = torch.tensor(N_bar) for idx, state in enumerate(states): N_s_a = state_dict[state][3] actions = state_dict[state][5] if N_s_a.sum() > N_bar: tau = (torch.log(N_s_a.sum()) / torch.log(N_bar)).item() else: tau = 1 N_s_a = N_s_a ** (1 / tau) improved_policy = N_s_a / N_s_a.sum() for sample_id in range(actions.shape[1]): action_ids = actions[0, sample_id] for step_id, action_id in enumerate(action_ids): policies[idx, step_id, action_id] += improved_policy[ 0, sample_id ] return policies
请注意,在我们的实现中,在从 N_s_a 张量计算出策略后,我们必须将其映射回原始动作张量。 事实上,N_s_a 只考虑模型采样的动作,而最终的策略还必须包含未探索动作的概率。
ChatGPT 训练算法的差异
AlphaTensor 是 DeepMind 人工智能方法 AlphaGo/AlphaZero 家族的最新成员。这些方法基于蒙特卡罗树搜索 (MCTS) 算法,该算法经过 DeepMind 的改进和增强,可以处理日益复杂的任务。另一个人工智能系统 OpenAI 的 ChatGPT 因其卓越的性能而引起了广泛关注,它采用了一种不同的方法进行训练,称为人类反馈强化学习(RLHF)。
RLHF 是一种微调技术,用于调整语言模型以遵循一组书面说明。 它使用人类偏好作为奖励信号来微调模型,从而使语言模型的行为与特定人群的既定偏好保持一致,而不是更广泛的“人类价值观”概念。
相比之下,MCTS 是一种基于树的搜索算法,用于确定游戏中的最佳动作。 它模拟潜在的移动并根据结果更新每个移动的值,指导最佳移动的选择。
RLHF 从人工编写的演示和人工智能模型之间人工标记的比较中收集数据,并训练奖励模型来预测给定人群的偏好。 然后使用奖励模型微调 AI 模型。 另一方面,MCTS 使用模拟和评估来确定最佳决策。
尽管它们是不同的方法,但 RLHF 和 MCTS 也有相似之处。 两种人工智能技术都使用决策和解决问题的方法,并且都使用试错法来探索不同的选项并根据可用信息做出决策。 两者都是迭代过程,随着时间的推移会随着收集到更多信息和经验而改进。
RLHF 和 MCTS 之间的选择取决于手头的任务。 当没有明确的指标来评估模型性能时,RLHF 是理想的选择,而 MCTS 已被证明在类似游戏的任务中有效,在这些任务中,对未来的知识和探索使模型具有显着优势。
AlphaTensor 训练的代码优化
实施 AlphaTensor 训练算法需要在训练速度和内存消耗之间找到完美的折衷。 如模型部分所示,简单地考虑动作标记化可以节省大量内存,但过度激进的动作空间减少会导致准确性下降和性能下降。 后者的发生是因为所有标记都是由模型解码器以自回归方式顺序生成的。 因此,一旦动作空间上的 softmax 不再是瓶颈,推理时间就会随着每个动作的标记数量线性增长。
在设置 AlphaTensor 训练时,发现主要困难在于处理 acting 过程。 如果张量没有以正确的格式存储,MCTS 很容易导致不受控制的内存使用量增长。 另一方面,如果每次模拟期间存储的张量数量减少太多,MCTS 可能会花费无限多的时间重新计算所需的状态。
让我们以游戏模拟步骤为例,通过查看未来可能的场景来探索游戏。对于每个状态,如果我们不保存模型生成的动作,而决定仅保存用于对策略中的动作进行采样的随机种子,那么每次探索树节点时,我们都必须重新计算策略并然后对动作进行采样。显然,我们决定存储采样的操作,以节省时间,并避免在 MCTS 探索并行化的情况下管理不同进程之间的模型共享。然而,仅仅保存动作并不足以获得足够有效的动作步骤。事实上,将 n_steps 个动作转换为 (u, v, w) 三元组、减少游戏张量状态以及从 n_samples 个动作创建新的 3D 张量的时间很容易成为整个训练的瓶颈。其次,我们不想存储每个采样动作的所有可能的未来状态,因为这会对算法使用的内存产生巨大影响。假设我们设置 n_samples=32、n=7 和 N=5,并且记住 N 是我们想要减少的方阵乘积的大小,n 是模型记住的先前动作的数量。在这种情况下,每个状态张量的形式为 (8, 25, 25, 25),乘以 32 将得到 328252525图中每个节点 4 个字节。 现在,考虑到扩展阶段的每个模拟都会生成一个新节点(并且 n_sim=200),我们最终的内存消耗为 200328252525*4 = 3.2GB 仅第一个 MCTS 节点。 在最坏的情况下,在探索活动的 max_rank 节点(其中 max_rank=150)时,这将导致 RAM 内存(或 GPU 内存,如果所有张量都存储在 GPU 上)中的总内存消耗为 150 * 3.2GB = 480GB . 我们在配备 128 GB RAM 和 48 GB GPU 内存的工作站上运行训练,因此我们必须减少内存消耗。
由于我们不想增加执行时间,因此我们采用了一种利用所生成的状态张量中的冗余的优化。事实上,这些张量有 n-1 个共同的先前动作,因此可以存储一次,并且不会为每个存储的张量重复操作。这导致内存减少 2/7~28%,这意味着在最坏的情况下可以存储 137GB。此时,通过简单地修剪树中未使用的部分(例如未选择的轨迹)并将张量存储在 CPU 内存中,我们就能够避免训练期间出现任何内存错误。
随着 OpenAlphaTensor 现已开源,为进一步开发开辟了几个令人兴奋的途径。
一个自然的过程是在目标硬件设备上对 OpenAlphaTensor 进行微调。 这有望带来极具竞争力的计算性能。 我将发布更多关于 OpenAlphaTensor 在各种硬件上的性能 GitHub上. 在撰写本文时,OpenAlphaTensor 正在接受训练。
另一个重要的进步是支持远程编译,允许用户构建针对边缘设备优化的算法。 这可以通过将 OpenAlphaTensor 模型存储在服务器上来实现,而矩阵乘法算法在不同的硬件上进行评估。
扩展对不同编译器的支持以计算基于延迟的奖励校正也很重要。 不同的编译器可以在给定的硬件上产生不同的优化算法。 例如,DeepMind 论文展示了在 TPU 和 Nvidia GPU 上使用 JAX 和 XLA 编译器的可喜成果。 在 Nvidia 上使用 NCCL 或在 CPU 上使用 LLVM 对此进行评估会很有趣。
最后,扩展模型和训练算法以支持更大的矩阵大小仍然是一个主要的开放挑战。 目前,OpenAlphaTensor 支持的最大矩阵大小为 5,但可以通过将较大的矩阵乘法拆分为大小小于 5 的微小 MM 组来应用。这种方法不是最优的,直接对对应于完整的 MM 理论上可以带来更好的结果。
迭戈菲奥里 是 Nebuly AI 的首席技术官,该公司致力于让 AI 优化成为每个开发者工具包的一部分。
- SEO 支持的内容和 PR 分发。 今天得到放大。
- 柏拉图区块链。 Web3 元宇宙智能。 知识放大。 访问这里。
- Sumber: https://www.kdnuggets.com/2023/03/first-open-source-implementation-deepmind-alphatensor.html?utm_source=rss&utm_medium=rss&utm_campaign=first-open-source-implementation-of-deepminds-alphatensor
- :是
- ][p
- $UP
- 1
- 3d
- 8
- a
- Able
- 关于我们
- 以上
- 绝对
- 加速
- 根据
- 因此
- 账号管理
- 实现
- 操作
- 行动
- 通
- 添加
- 额外
- 调整
- 采用
- 推进
- 优点
- 后
- 经纪人
- 聚合
- 侵略性
- AI
- 人工智能系统
- 目标
- 算法
- 算法
- 所有类型
- 允许
- 允许
- 单
- 已经
- 替代
- 其中
- 量
- 分析
- 和
- 另一个
- 应用的
- 的途径
- 方法
- 架构
- 保健
- 刊文
- 人造的
- 人工智能
- AS
- 分配
- 相关
- At
- 自动化
- 可使用
- 避免
- 背部
- 备份工具
- 基于
- 基本上
- 基础
- BE
- 因为
- 成为
- before
- 作为
- 如下。
- 基准
- 最佳
- 更好
- 之间
- 超越
- 板
- 棋盘游戏
- 界
- 更广泛
- BT
- 建立
- 建
- by
- 被称为
- CAN
- 不能
- 案件
- 原因
- 造成
- 一定
- 挑战
- 挑战
- 更改
- 渠道
- ChatGPT
- 孩子
- 选择
- 选择
- 选择
- 引
- 清除
- 明确地
- 码
- 收集
- 结合
- 提交
- 相当常见
- 公司
- 相比
- 竞争的
- 复杂
- 复杂
- 组件
- 由
- 妥协
- 计算
- 计算能力
- 计算
- 计算
- 概念
- 概念
- 流程条件
- 信心
- 考虑
- 考虑
- 考虑
- 考虑
- 消费
- 包含
- 继续
- 对比
- 转换
- 核心
- 相应
- 对应
- 可以
- Counter
- 中央处理器
- 创建
- 创造
- 标准
- 首席技术官
- 电流
- 目前
- data
- 处理
- 决定
- 决定
- 决定
- 决策
- 决定
- 深
- 深入学习
- DeepMind
- 依靠
- 描述
- 确定
- 开发商
- 研发支持
- 设备
- 设备
- 信息通信技术部
- 不同
- 困难
- 尺寸
- 尺寸
- 直接
- 通过各种方式找到
- 发现
- 讨论
- 分配
- 分
- 向下
- 下降
- ,我们将参加
- e
- 每
- 此前
- 容易
- 边缘
- 有效
- 高效
- 或
- element
- 分子
- 嵌入式
- 结束
- 增强
- 巨大
- 更多
- 条目
- 错误
- 评估
- 估计
- 醚(ETH)
- 评估
- 评估
- 评估
- 评价
- 甚至
- 所有的
- 例子
- 例子
- 令人兴奋的
- 执行
- 扩张
- 预期
- 体验
- 说明
- 解释
- 功勋
- 勘探
- 探索
- 探讨
- 探索
- 表示
- 延长
- 延长
- 非常
- 相当
- 家庭
- 快
- 反馈
- 少数
- 数字
- 最后
- 寻找
- 姓氏:
- 浮动
- 专注焦点
- 遵循
- 以下
- Footprint
- 针对
- 申请
- 格式
- 公式
- 发现
- 止
- ,
- 功能
- 根本
- 进一步
- 进一步的发展
- 未来
- 游戏
- Games
- 生成
- 产生
- 产生
- 发电
- 得到
- 越来越
- 给
- 特定
- 给予
- 目标
- GOES
- 非常好
- GPU
- 图形处理器
- 图形
- 大
- 团队
- 组的
- 成长
- 事业发展
- 指南
- 手
- 处理
- 发生
- 发生
- 硬件
- 硬件设备
- 硬件设备
- 有
- 有
- 点击此处
- 最高
- 地平线
- 创新中心
- How To
- 但是
- HTTPS
- 巨大
- 人
- i
- 生病
- 理想
- IDX
- 图片
- 影响力故事
- 实施
- 履行
- 重要
- 改善
- 改善
- in
- 增加
- 增加
- 日益
- 独立
- 指数
- 信息
- 初始
- 输入
- 代替
- 说明
- 房源搜索
- 有趣
- 内部
- 介绍
- 直觉
- IT
- 迭代
- 它的
- 本身
- JPG
- 掘金队
- 保持
- 键
- 知识
- 已知
- 语言
- 大
- 大
- 名:
- 潜伏
- 最新
- 层
- 层
- 铅
- 知道
- 学习
- 清单
- 看
- 寻找
- 占地
- 制成
- 主要
- 主要
- 使
- 制作
- 管理
- 管理的
- 许多
- 地图
- 制图
- 矩阵
- 问题
- 最多
- 意
- 有意义的
- 手段
- 会员
- 内存
- 方法
- 方法
- 公
- 失踪
- 混合物
- 模型
- 模型
- 改性
- 模块
- 更多
- 更高效
- 此外
- 最先进的
- 移动
- 移动
- 乘以
- 自然
- 自然
- 近
- 必要
- 需求
- 打印车票
- 负
- 网络
- 神经
- 神经网络
- 全新
- 下页
- NLP
- 节点
- 节点
- 非专家
- 概念
- 数
- Nvidia公司
- 获得
- of
- 优惠精选
- on
- 一
- 打开
- 开放源码
- OpenAI
- 操作
- 运营
- 最佳
- 优化
- 优化
- 优化
- 附加选项
- 原版的
- 其他名称
- 除此以外
- 产量
- 最划算
- 纸类
- 文件
- 参数
- 部分
- 特别
- 尤其
- 部分
- 员工
- 性能
- 执行
- 相
- 件
- 柏拉图
- 柏拉图数据智能
- 柏拉图数据
- 播放
- 播放机
- 播放
- 点
- 政策
- 政策
- 位置
- 可能
- 潜力
- 功率
- 实用
- 在练习上
- 预测
- 都曾预测
- 喜好
- 呈现
- 以前
- 可能性
- 大概
- 市场问题
- 过程
- 过程
- 生产
- 生成
- 产品
- 级数
- 进步
- 有希望
- 建议
- 可证明的
- 成熟
- 发布
- 出版
- 内存
- 随机
- 行列
- 宁
- 达到
- 上游
- 最近
- 减少
- 减少
- 减少
- 精
- 强化学习
- 发布
- 其余
- 遗迹
- 卓越
- 纪念
- 远程
- 重复
- 代表
- 代表
- 必须
- 需要
- 提供品牌战略规划
- 受限
- 导致
- 导致
- 成果
- 回报
- 回报
- 革命化
- 积分
- 行
- rt
- 运行
- s
- 同
- 保存
- 保存
- 脚本
- 情景
- 方案
- 搜索
- 其次
- 部分
- 种子
- 选
- 选择
- 选择
- 自
- 集
- 设置
- 几个
- 形状
- 共享
- 短
- 应该
- 显示
- 如图
- 作品
- 信号
- 显著
- 显著
- 类似
- 相似之处
- 简易
- 简单
- 只是
- 模拟
- 自
- 情况
- 尺寸
- 尺寸
- 小
- 小
- 快照
- So
- 方案,
- 解决
- 解决
- 一些
- 来源
- 太空
- 剩余名额
- 具体的
- 特别是
- 指定
- 速度
- 花
- 花费
- 分裂
- 广场
- 堆叠
- 阶段
- 标准
- 开始
- 启动
- 州/领地
- 国家的最先进的
- 说
- 州
- 统计
- 步
- 步骤
- 停车
- 商店
- 存储
- 商店
- 简单的
- 成功
- 这样
- SUPPORT
- 支持
- 合成的
- 综合数据
- 综合地
- 系统
- 产品
- 量身定制
- 采取
- 需要
- 服用
- 目标
- 任务
- 任务
- 技术
- 终端
- 条款
- 这
- 未来
- 图
- 信息
- 国家
- 其
- 他们
- 从而
- 因此
- 博曼
- 第三
- 三
- 三维
- 通过
- 次
- 耗时的
- 至
- 一起
- 象征
- 符号化
- 符号化
- 令牌
- 也有
- 工具箱
- 最佳
- 火炬
- 合计
- 传统
- 培训
- 熟练
- 产品培训
- 火车
- 轨道
- 转化
- 治疗
- true
- 理解
- 宇宙
- 未使用
- 更新
- 更新
- 最新动态
- 更新
- 升级
- us
- 用法
- 使用
- 用户
- 平时
- 折扣值
- 价值观
- 各个
- 版本
- 视频
- 视频游戏
- 参观
- 访客
- W
- 方法..
- 什么是
- 这
- 而
- 广泛
- 维基百科上的数据
- 将
- 胜利
- 话
- 工作站
- 价值
- 将
- 写作
- 书面
- X
- 和风网
- 零