1
背景
近年來,量化感知訓練是一個較為熱點的問題,可以大大優(yōu)化量化后訓練造成精度損失的問題,使得訓練過程更加高效。
Torch.fx在這一問題上走在了前列,使用純Python語言實現(xiàn)了對于Torch.nn.Module的解析和向IR的轉(zhuǎn)換,也可以提供變換后的IR對應的Python代碼,在外部則是提供了簡潔易用的API,大大方便了量化感知訓練過程的搭建。此外,Torch.fx也有助于消除動態(tài)圖和靜態(tài)圖之間的Gap,可以比較方便地對圖進行操作以及進行算子融合。
OneFlow緊隨其后添加了針對OneFlow的fx,即One-fx,在安裝One-fx之后,用戶可以直接調(diào)用oneflow.fx,也可以直接通過import onefx as fx進行使用。
One-fx實現(xiàn)代碼中絕大部分是對于Torch.fx的fork,但根據(jù)OneFlow和PyTorch之間存在的差別進行了一些適配或優(yōu)化。本文將圍繞One-fx適配方式以及在OneFlow中的應用展開。
2
FX主要模塊
Symbolioc Trace
Graph Module
Interpreter
Proxy
Passes
其中,前4個模塊共同實現(xiàn)了fx的基本功能,Graph Module和Proxy又是Symbolic Trace的基礎,Passes則是在此基礎上的擴充。

Symbolic Trace的基本概念如上圖所示,最基本的模型運行過程就是從模型定義到模型執(zhí)行這樣一個流程。
fx則是進行了非侵入式的解析,將模型執(zhí)行過程轉(zhuǎn)成一張圖,這張圖中包含了很多個Node,每一個Node都包含了模型中的子模塊或者函數(shù)調(diào)用信息,然后用戶可以很方便地獲取到所有的Node,并對其進行一些變換操作,最后通過GraphModule重新生成一個模型定義,并對其執(zhí)行。
其中,在進行模型解析的時候,節(jié)點之間變量傳遞也均使用代理后的變量,如y = oneflow.relu(x),實際上x和y是Proxy(x)和Proxy(y)。
3
One-fx實現(xiàn)方式
這里給出一個Fx最簡單的用例,以方便后續(xù)對于實現(xiàn)方式的介紹。
import oneflow
class MyModule(oneflow.nn.Module):
def __init__(self):
super().__init__()
self.linear = oneflow.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
y = oneflow.ones([2, 3])
x = oneflow.relu(x)
return y
m = MyModule()
traced = oneflow.fx.symbolic_trace(m)
print(traced.code)
"""
def forward(self, x):
linear = self.linear(x); x = None
relu = oneflow.relu(linear); linear = None
_tensor_constant0 = self._tensor_constant0
return _tensor_constant0
"""
?
函數(shù)代理
代理,即fx中的Proxy模塊,目的是在每次進行函數(shù)或模塊調(diào)用的時候添加一些額外操作,使得對模型的解析和重建得以進行,而包裝則是適配代理的一種方式。
torch.fx中,對于nn.Module的包裝比較易于理解,每當待解析Module中出現(xiàn)了繼承自nn.Module的對象,那么就將其__call__函數(shù)替換成包裝過的函數(shù)。然而,對于pytorch的函數(shù)的代理的實現(xiàn)要更“繞”一些,是借助了__torch_function__這一機制
限于篇幅原因這里不專門對其進行介紹。比較關鍵的點是,OneFlow中沒有這一機制,如果需要添加,那么會是規(guī)模很大的、侵入性的,于是One-fx的實現(xiàn)就需要找其它路徑。
我們使用的解決方式是搜索oneflow,oneflow.nn.functional,oneflow._C等模塊中的Callable,并去除其中屬于類的部分,然后對其余函數(shù)進行包裝,在每次解析模型之前,會將這些模塊的__dict__中對應項替換成包裝后的函數(shù),并且在解析模型之后重新將這些項進行還原。對于constructor類型的函數(shù),如ones,randn等則不進行代理,直接運行,在最終構(gòu)建圖的時候作為constant來處理。
對于函數(shù)的包裝部分源碼實現(xiàn)如下,每次運行代理后的函數(shù),會先判斷該函數(shù)的入?yún)⒅杏袥]有Proxy變量,如果有,那么將會創(chuàng)建一個call_function類型的節(jié)點并返回Proxy包裝后的節(jié)點,否則直接調(diào)用原函數(shù)并返回結(jié)果。
def _create_wrapped_func(orig_fn):
@functools.wraps(orig_fn)
def wrapped(*args, **kwargs):
# 判斷參數(shù)中是否存在proxy變量
proxy = _find_proxy(args, kwargs)
if proxy is not None:
# 如果參數(shù)中有Proxy變量,創(chuàng)建節(jié)點并返回Proxy包裝后的節(jié)點
return_proxy = proxy.tracer.create_proxy(
"call_function", orig_fn, args, kwargs
)
return_proxy.node.meta["is_wrapped"] = True
return return_proxy
# 如果沒有Proxy變量,直接調(diào)用原函數(shù)
return orig_fn(*args, **kwargs)
return wrapped
其中,return_proxy = proxy.tracer.create_proxy("call_function", orig_fn, args, kwargs)這行代碼指定了使用與入?yún)⑾嗤腡racer來創(chuàng)建節(jié)點并返回結(jié)果,create_proxy函數(shù)定義的主要部分如下,創(chuàng)建節(jié)點并在Proxy包裝后返回。
def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
name: Optional[str] = None, type_expr : Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
args_ = self.create_arg(args)
kwargs_ = self.create_arg(kwargs)
assert isinstance(args_, tuple)
assert isinstance(kwargs_, dict)
# 創(chuàng)建節(jié)點
node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
if not proxy_factory_fn:
proxy = self.proxy(node)
else:
proxy = proxy_factory_fn(node)
return proxy
而其中的create_node方法,實際上是調(diào)用了Tracer.graph.create_node,在圖中創(chuàng)建節(jié)點,主要部分代碼如下,其中op就是fx IR中的op,代表了節(jié)點類型,而target則是節(jié)點的操作主體,在上面的例子中就是orig_func。
因此,當我們自定義的Module中的forward函數(shù)中的所有調(diào)用都被包裝之后,實際上再運行forward的時候,就會依次在Tracer.graph中創(chuàng)建節(jié)點,這也正是symbolic_trace的基本思路。
def create_node(self, op: str, target: 'Target',
args: Optional[Tuple['Argument', ...]] = None,
kwargs: Optional[Dict[str, 'Argument']] = None,
name: Optional[str] = None,
type_expr: Optional[Any] = None) -> Node:
# 此處有一些assert
# 創(chuàng)建一個節(jié)點名稱,避免重復
candidate = name if name is not None else self._target_to_str(target)
name = self._graph_namespace.create_name(candidate, None)
# 創(chuàng)建節(jié)點
n = Node(self, name, op, target, args, kwargs, type_expr)
# 建立名稱與節(jié)點的映射關系
self._graph_namespace.associate_name_with_obj(name, n)
return n
而對于symbolic_trace過程,其核心就是Tracer.trace。這個方法可以分為兩部分,一個是預處理部分,一個是主干部分。其中預處理過程大致定義如下,主要任務是初始化Graph、確立模型以及forward函數(shù)和創(chuàng)建包裝后的參數(shù)。
如前面所提及的,symbolic trace的基本思路是借助Proxy變量以及包裝后的函數(shù),在每次調(diào)用的時候都創(chuàng)建一個節(jié)點,因此,forward函數(shù)的輸入也需要用Proxy進行包裝,這一步定義在Tracer.create_args_for_root中。
?
def trace(
self,
root: Union[oneflow.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
) -> Graph:
# 確定模塊主體以及forward函數(shù),其中fn即forward函數(shù)
if isinstance(root, oneflow.nn.Module):
self.root = root
assert hasattr(
type(root), self.traced_func_name
), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
fn = getattr(type(root), self.traced_func_name)
self.submodule_paths = {mod: name for name, mod in root.named_modules()}
else:
self.root = oneflow.nn.Module()
fn = root
tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None)
# 在Tracer中初始化一張圖
self.graph = Graph(tracer_cls=tracer_cls)
self.tensor_attrs: Dict[oneflow.Tensor, str] = {}
# 這個子函數(shù)用于收集模型中所有Tensor類型的變量
def collect_tensor_attrs(m: oneflow.nn.Module, prefix_atoms: List[str]):
for k, v in m.__dict__.items():
if isinstance(v, oneflow.Tensor):
self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
for k, v in m.named_children():
collect_tensor_attrs(v, prefix_atoms + [k])
collect_tensor_attrs(self.root, [])
assert isinstance(fn, FunctionType)
# 獲取fn所在模塊的所有可讀變量
fn_globals = fn.__globals__
# 創(chuàng)建包裝后的參數(shù)
fn, args = self.create_args_for_root(
fn, isinstance(root, oneflow.nn.Module), concrete_args
)
隨后則是trace的主干部分,這一部分大致代碼如下,主要任務是對函數(shù)、方法、模塊進行必要的包裝,然后在Graph中創(chuàng)建節(jié)點,完成整個圖的信息。
其中,我們會創(chuàng)建一個Patcher環(huán)境并在其中進行這些過程,這是因為對于函數(shù)和方法的包裝會直接改變掉某些包中對應函數(shù)或方法的行為,為了不讓這種行為的改變溢出到trace的范圍之外,在每次進行包裝的時候會在Patcher中記錄本次操作,然后在_Patcher.__exit__中根據(jù)記錄的操作一一還原現(xiàn)場。
# 下面代碼仍然是`trace`函數(shù)的一部分
# 定義對于`nn.Module`的getattr方法的包裝
@functools.wraps(_orig_module_getattr)
def module_getattr_wrapper(mod, attr):
attr_val = _orig_module_getattr(mod, attr)
return self.getattr(attr, attr_val, parameter_proxy_cache)
# 定義對于`nn.Module`的forward方法的包裝
@functools.wraps(_orig_module_call)
def module_call_wrapper(mod, *args, **kwargs):
def forward(*args, **kwargs):
return _orig_module_call(mod, *args, **kwargs)
_autowrap_check(
patcher,
getattr(getattr(mod, "forward", mod), "__globals__", {}),
self._autowrap_function_ids,
)
return self.call_module(mod, forward, args, kwargs)
# 這里Patcher的作用是在退出這一環(huán)境的時候恢復現(xiàn)場,避免包裝函數(shù)、方法的影響溢出到`trace`之外。
with _Patcher() as patcher:
# 對`__getattr__`和`nn.Module.__call__`這兩個方法默認進行包裝
patcher.patch_method(
oneflow.nn.Module,
"__getattr__",
module_getattr_wrapper,
deduplicate=False,
)
patcher.patch_method(
oneflow.nn.Module, "__call__", module_call_wrapper, deduplicate=False
)
# 對預定好需要進行包裝的函數(shù)進行包裝
_patch_wrapped_functions(patcher)
_autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
# 遍歷所有需要對其中函數(shù)進行自動包裝的package
for module in self._autowrap_search:
if module is oneflow:
dict = {}
# 當package為oneflow時,對此進行特殊處理,單獨分出一個字典存放原本`oneflow.__dict__`中的內(nèi)容
for name, value in module.__dict__.items():
if not isinstance(value, oneflow.nn.Module) and not value in _oneflow_no_wrapped_functions:
dict[name] = value
_autowrap_check_oneflow(
patcher, dict, module.__dict__, self._autowrap_function_ids
)
else:
_autowrap_check(
patcher, module.__dict__, self._autowrap_function_ids
)
# 創(chuàng)建節(jié)點,這里的`create_node`調(diào)用實際上只是創(chuàng)建了最后一個節(jié)點,即輸出節(jié)點。
# 但是這里`fn`就是forward函數(shù),在運行這一函數(shù)的時候,就會如前面所說依次創(chuàng)建節(jié)點。
self.create_node(
"output",
"output",
(self.create_arg(fn(*args)),),
{},
type_expr=fn.__annotations__.get("return", None),
)
?
其中,_patch_wrapped_functions的實現(xiàn)如下:
def _patch_wrapped_functions(patcher: _Patcher):
# `_wrapped_fns_to_patch`中包含了所有需要自動包裝的函數(shù)
for frame_dict, name in _wrapped_fns_to_patch:
if name not in frame_dict:
if hasattr(builtins, name):
# 對于built-in函數(shù),不存在于frame_dict中,單獨進行處理來根據(jù)名稱獲取函數(shù)本身
orig_fn = getattr(builtins, name)
else:
# 如果是oneflow中指定需要包裝的函數(shù),那么就進行獲取,否則拋出名稱無法識別的異常
is_oneflow_wrapped_function, func = is_oneflow_wrapped_function_and_try_get(name)
if is_oneflow_wrapped_function:
orig_fn = func
else:
raise NameError("Cannot deal with the function %s."%name)
else:
# 如果函數(shù)名稱已經(jīng)存在于frame_dict中,直接通過字典查詢來獲得函數(shù)
orig_fn = frame_dict[name]
# 創(chuàng)建包裝后的函數(shù)并進行`patch`,即定義當trace過程結(jié)束的時候,如何還原現(xiàn)場
patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
# 對于類中的方法,直接包裝并patch。
for cls, name in _wrapped_methods_to_patch:
patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
?
全局包裝
在模型的forward函數(shù)中,我們有時不僅會用到框架自帶的模塊或者函數(shù),有點時候還需要用到自定義的函數(shù)或者built-in函數(shù),對于這種情況如果不進行處理,那么自然無法接受Proxy(x)的入?yún)?。fx中提供了fx.wrap這一API,當用戶需要調(diào)用這部分函數(shù)的時候,可以實現(xiàn)使用fx.wrap(func)使其被包裝。
例如:
import oneflow
oneflow.fx.wrap(len)
class MyModule(oneflow.nn.Module):
def __init__(self):
super().__init__()
self.linear = oneflow.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x) + len(x.shape)
return x
traced = oneflow.fx.symbolic_trace(MyModule())
print(traced.code)
"""
def forward(self, x):
linear = self.linear(x)
getattr_1 = x.shape; x = None
len_1 = len(getattr_1); getattr_1 = None
add = linear + len_1; linear = len_1 = None
return add
"""
?
但是其局限性在于,如果Module的源代碼是來自其它庫,那么在調(diào)用的地方使用fx.wrap是不起作用的,在oneflow和torch中都會有這一問題。然而flowvision中有多處使用了built-in function,因此我們添加了一個API,即global_wrap,原理比較簡單,就是直接對某個函數(shù)所在的包的__dict__進行修改,用法如下:
# MyModule來自其它包
with oneflow.fx.global_wrap(len):
m = MyModule()
traced = oneflow.fx.symbolic_trace(m)
print(traced.code)
"""
def forward(self, x):
linear = self.linear(x); x = None
getattr_1 = linear.shape
len_1 = len(getattr_1); getattr_1 = None
relu = oneflow.relu(linear); linear = None
add = relu + len_1; relu = len_1 = None
return add
"""
?
使用with關鍵字的原因是這種實現(xiàn)方式是直接修改了某個包的__dict__,對于其它地方的調(diào)用也會產(chǎn)生影響,因此需要將其限制在一定范圍內(nèi)。此外,包裝后的函數(shù)包含了對類型的判定等一系列操作,也會極大影響built-in函數(shù)的性能。
其它適配
其它地方的處理都比較簡單,不需要對實現(xiàn)方式做修改,只需要將細節(jié)部分對齊即可,這也體現(xiàn)出oneflow和pytorch在前端部分的高度兼容性。
4
IR設計
fx的IR設計遵循以下幾個原則:
避免支持長尾分布,復雜的樣例。主要關注經(jīng)典模型的程序捕獲和變換。
使用機器學習從業(yè)者已經(jīng)熟悉的工具和概念,例如Python的數(shù)據(jù)結(jié)構(gòu)和 PyTorch 中公開記錄的算子 。
使程序捕獲過程具有高度可配置性,以便用戶可以為長尾需求實現(xiàn)自己的解決方案。
fx的IR主要由幾個部分組成;
opcode:即當前操作的類型,可以是placeholder, get_attr, call_function, call_method, call_module, output
name:即給當前操作的命名。
target:當前操作的實體,例如對于call_function類型的操作,可能這一屬性會是
args和kwargs:指定當前操作的參數(shù)。
通過print_tabular這一API可以很方便美觀地打印出fx中的IR,例如對于以下的MyModule模型,我們可以打印出其IR:
import oneflow
class MyModule(oneflow.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = oneflow.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
y = oneflow.ones([2, 3])
x = oneflow.topk(x, 10)
return x.relu() + y
traced = oneflow.fx.symbolic_trace(MyModule())
traced.graph.print_tabular()
"""
opcode name target args kwargs
------------- ----------------- ------------------------ ------------------------- --------
placeholder x x () {}
call_module linear linear (x,) {}
call_function topk (linear, 10) {}
call_method relu relu (topk,) {}
get_attr _tensor_constant0 _tensor_constant0 () {}
call_function add (relu, _tensor_constant0) {}
output output output (add,) {}
"""
盡管fx的IR不算強大(例如不能處理動態(tài)控制流),但是定義非常簡潔,實現(xiàn)簡單,對于用戶來講上手門檻相對低很多。
5
One-fx應用舉例
OP替換
下面的例子展示了如何將add操作全部替換成mul操作。
import oneflow
from oneflow.fx import symbolic_trace
import operator
class M(oneflow.nn.Module):
def forward(self, x, y):
return x + y, oneflow.add(x, y), x.add(y)
if __name__ == '__main__':
traced = symbolic_trace(M())
patterns = set([operator.add, oneflow.add, "add"])
for n in traced.graph.nodes:
if any(n.target == pattern for pattern in patterns):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(oneflow.mul, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
traced.recompile()
traced.graph.print_tabular()
print(traced.code)
?
性能分析
以下代碼展示如何使用fx進行模型的性能分析,將原本的模型通過symbolic_trace解析成各個節(jié)點,再在其中插入測試性能的操作。
import oneflow
import flowvision.models as models
import statistics, tabulate, time
from typing import Any, Dict, List
class ProfilingInterpreter(oneflow.fx.Interpreter):
def __init__(self, mod : oneflow.nn.Module):
gm = oneflow.fx.symbolic_trace(mod)
super().__init__(gm)
# 記錄總運行時間
self.total_runtime_sec : List[float] = []
# 記錄各個節(jié)點運行時間
self.runtimes_sec : Dict[oneflow.fx.Node, List[float]] = {}
# 重寫`run`方法,本質(zhì)上是對基類`run`方法的簡單封裝,在運行前后記錄時間點。
# 這一方法是Graph整體運行的入口。
def run(self, *args) -> Any:
t_start = time.time()
return_val = super().run(*args)
t_end = time.time()
self.total_runtime_sec.append(t_end - t_start)
return return_val
# 同上,重寫`run_node`方法,不需要自己寫細節(jié)實現(xiàn),只需要在對基類的`run_node`調(diào)用前后記錄時間點即可
# 這一方法是Graph中運行每個Node的入口。
def run_node(self, n : oneflow.fx.Node) -> Any:
t_start = time.time()
return_val = super().run_node(n)
t_end = time.time()
self.runtimes_sec.setdefault(n, [])
self.runtimes_sec[n].append(t_end - t_start)
return return_val
# 定義如何打印性能測試結(jié)果
def summary(self, should_sort : bool = False) -> str:
# 存儲每個節(jié)點的打印信息
node_summaries : List[List[Any]] = []
# 由于模塊會被調(diào)用多次,所以這里計算一下平均的運行總時長
mean_total_runtime = statistics.mean(self.total_runtime_sec)
for node, runtimes in self.runtimes_sec.items():
mean_runtime = statistics.mean(runtimes)
# 計算節(jié)點運行時間占總時間的比例
pct_total = mean_runtime / mean_total_runtime * 100
# 記錄節(jié)點信息、節(jié)點平均運行時長和節(jié)點運行時間占總時間的比例
node_summaries.append(
[node.op, str(node), mean_runtime, pct_total])
# 如果需要,安按照運行時間進行排序
if should_sort:
node_summaries.sort(key=lambda s: s[2], reverse=True)
# 以下是借助tabulate庫進行格式化來美化顯示效果
headers : List[str] = [
'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
]
return tabulate.tabulate(node_summaries, headers=headers)
if __name__ == '__main__':
rn18 = models.resnet18()
rn18.eval()
input = oneflow.randn(5, 3, 224, 224)
output = rn18(input)
interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
?
效果如下:

算子融合
以下代碼演示如何借助fx將模型中的卷積層和BN層進行融合,對于這種組合,并不需要引入新的算子,只需要對原本conv的權重進行操作即可。
import sys
import oneflow
import oneflow.nn as nn
import numpy as np
import copy
from typing import Dict, Any, Tuple
# 通過直接對權重進行運算的方式進行Conv和BN的融合
def fuse_conv_bn_eval(conv, bn):
assert(not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
fused_conv.weight, fused_conv.bias =
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
return fused_conv
# 權重融合方式
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
if conv_b is None:
conv_b = oneflow.zeros_like(bn_rm)
if bn_w is None:
bn_w = oneflow.ones_like(bn_rm)
if bn_b is None:
bn_b = oneflow.zeros_like(bn_rm)
bn_var_rsqrt = oneflow.rsqrt(bn_rv + bn_eps)
conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
return oneflow.nn.Parameter(conv_w), oneflow.nn.Parameter(conv_b)
# 根據(jù)字符串對名稱進行分割,比如`foo.bar.baz` -> (`foo.bar`, `baz`)
def _parent_name(target : str) -> Tuple[str, str]:
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
def replace_node_module(node: oneflow.fx.Node, modules: Dict[str, Any], new_module: oneflow.nn.Module):
assert(isinstance(node.target, str))
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, new_module)
# 定義對模型進行融合操作的過程
def fuse(model: oneflow.nn.Module) -> oneflow.nn.Module:
model = copy.deepcopy(model)
# 先通過fx.symbolic_trace獲取一個GraphModule
fx_model: oneflow.fx.GraphModule = oneflow.fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())
# 遍歷GraphModule中的所有節(jié)點,分別進行操作
for node in fx_model.graph.nodes:
# 跳過所有不是module的節(jié)點
if node.op != 'call_module':
continue
# 檢測到conv+bn的結(jié)構(gòu)后進行融合操作
if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:
# conv的輸出同時被其它節(jié)點使用,即conv后連接兩個節(jié)點時無法融合
if len(node.args[0].users) > 1:
continue
conv = modules[node.args[0].target]
bn = modules[node.target]
fused_conv = fuse_conv_bn_eval(conv, bn)
replace_node_module(node.args[0], modules, fused_conv)
# 對圖中的邊進行置換,對于用到bn輸出的節(jié)點,要更改它們的輸入
node.replace_all_uses_with(node.args[0])
# 移除舊的節(jié)點
fx_model.graph.erase_node(node)
fx_model.graph.lint()
# 重新建圖(構(gòu)造模型)
fx_model.recompile()
return fx_model
if __name__ == '__main__':
# 以下引入flowvision中的resnet 18模型,并進行融合前后的benchmark比較
import flowvision.models as models
import time
rn18 = models.resnet18().cuda()
rn18.eval()
inp = oneflow.randn(10, 3, 224, 224).cuda()
output = rn18(inp)
def benchmark(model, iters=20):
for _ in range(10):
model(inp)
oneflow.cuda.synchronize()
begin = time.time()
for _ in range(iters):
model(inp)
return str(time.time()-begin)
fused_rn18 = fuse(rn18)
unfused_time = benchmark(rn18)
fused_time = benchmark(fused_rn18)
print("Unfused time: ", benchmark(rn18))
print("Fused time: ", benchmark(fused_rn18))
assert unfused_time > fused_time
?
6
未來計劃
基于fx進行8bit量化感知訓練和部署
基于fx進行算子融合
eager模式下基于fx獲得模型更精確的FLOPs和MACs結(jié)果
審核編輯:劉清
-
python
+關注
關注
57文章
4877瀏覽量
90071 -
pytorch
+關注
關注
2文章
813瀏覽量
14856 -
OneFlow
+關注
關注
0文章
9瀏覽量
9048
原文標題:適配PyTorch FX,OneFlow讓量化感知訓練更簡單
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
Pytorch模型訓練實用PDF教程【中文】
YOLOv6中的用Channel-wise Distillation進行的量化感知訓練
9個用Pytorch訓練快速神經(jīng)網(wǎng)絡的技巧
如何讓PyTorch模型訓練變得飛快?
Pytorch量化感知訓練的詳解
基于PyTorch的深度學習入門教程之PyTorch簡單知識
PyTorch教程15.10之預訓練BERT
PyTorch教程21.7之序列感知推薦系統(tǒng)
適配PyTorch FX讓量化感知訓練更簡單
評論