mirror of https://github.com/llvm/torch-mlir
Generate backward graph via functorch-aot module
Example to demonstrate the extraction of forward as well as backward graph via Functorch's AOT module is added.pull/786/head snapshot-20220422.405
parent
28bf9cc1fc
commit
e9c785b04b
|
@ -0,0 +1,81 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from transformers import AutoModelForMaskedLM, BertConfig
|
||||
import transformers
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
from functorch_utils import AOTModule, get_input_annotations
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
def getAnnotatedModule(ts_module, inputs):
|
||||
export(ts_module.forward)
|
||||
annotate_args_decorator = annotate_args(
|
||||
get_input_annotations(inputs, dynamic=False))
|
||||
annotate_args_decorator(ts_module.forward)
|
||||
return ts_module
|
||||
|
||||
pytree._register_pytree_node(
|
||||
transformers.modeling_outputs.MaskedLMOutput,
|
||||
lambda x: ([x.loss, x.logits], None),
|
||||
lambda values, _: transformers.modeling_outputs.MaskedLMOutput(
|
||||
loss=values[1], logits=values[1]
|
||||
),
|
||||
)
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
config = BertConfig(hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0)
|
||||
model_type = AutoModelForMaskedLM
|
||||
input_size = (1, 2)
|
||||
device = "cpu"
|
||||
dtype = torch.float
|
||||
|
||||
model = model_type.from_config(config).to(device, dtype=dtype)
|
||||
input_ids = torch.randint(0, config.vocab_size, input_size).to(device)
|
||||
decoder_ids = torch.randint(0, config.vocab_size, input_size).to(device)
|
||||
train_inputs = {"input_ids": input_ids, "labels": decoder_ids}
|
||||
|
||||
def inference_fn(model, input, labels):
|
||||
return model(**input).loss.sum().backward()
|
||||
|
||||
|
||||
aot_module = AOTModule(model,
|
||||
train_inputs,
|
||||
labels=None,
|
||||
training_fn=inference_fn)
|
||||
aot_module.generate_graphs()
|
||||
|
||||
# ==============================================================================
|
||||
# Forward test.
|
||||
forw_inputs = []
|
||||
for inp in aot_module.forward_inputs:
|
||||
forw_inputs.append(inp.detach())
|
||||
|
||||
ts_module_forw = getAnnotatedModule(aot_module.forward_graph,
|
||||
aot_module.forward_inputs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ts_module_forw)
|
||||
def BERT_forward_basic(module, tu: TestUtils):
|
||||
module.forward(*forw_inputs)
|
||||
|
||||
# ==============================================================================
|
||||
# Backward test.
|
||||
back_inputs = []
|
||||
for inp in aot_module.backward_inputs:
|
||||
back_inputs.append(inp.detach())
|
||||
|
||||
ts_module_back = getAnnotatedModule(aot_module.backward_graph,
|
||||
aot_module.backward_inputs)
|
||||
|
||||
@register_test_case(module_factory=lambda: ts_module_back)
|
||||
def BERT_backward_basic(module, tu: TestUtils):
|
||||
module.forward(*back_inputs)
|
|
@ -0,0 +1,83 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
from functorch_utils import AOTModule, get_input_annotations
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
def getAnnotatedModule(ts_module, inputs):
|
||||
export(ts_module.forward)
|
||||
annotate_args_decorator = annotate_args(
|
||||
get_input_annotations(inputs, dynamic=True))
|
||||
annotate_args_decorator(ts_module.forward)
|
||||
return ts_module
|
||||
|
||||
class NeuralNet(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(NeuralNet, self).__init__()
|
||||
self.l1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.l2 = torch.nn.Linear(16, 2)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.l1(x)
|
||||
out = self.relu(out)
|
||||
out = self.l2(out)
|
||||
return out
|
||||
|
||||
|
||||
input = torch.randn(1, 10)
|
||||
labels = torch.randn(1, 2)
|
||||
|
||||
|
||||
def run_model(module, input, labels):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(module.parameters(), lr=0.01)
|
||||
iters = 1
|
||||
for _ in range(iters):
|
||||
optimizer.zero_grad()
|
||||
output = module(input)
|
||||
loss = criterion(output, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
aot_module = AOTModule(NeuralNet(), input, labels, run_model)
|
||||
aot_module.generate_graphs()
|
||||
|
||||
# ==============================================================================
|
||||
# Forward test.
|
||||
forw_inputs = []
|
||||
for inp in aot_module.forward_inputs:
|
||||
forw_inputs.append(inp.detach())
|
||||
|
||||
ts_module_forw = getAnnotatedModule(aot_module.forward_graph,
|
||||
aot_module.forward_inputs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ts_module_forw)
|
||||
def NeuralNet_forward_basic(module, tu: TestUtils):
|
||||
module.forward(*forw_inputs)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Backward test.
|
||||
back_inputs = []
|
||||
for inp in aot_module.backward_inputs:
|
||||
back_inputs.append(inp.detach())
|
||||
|
||||
ts_module_back = getAnnotatedModule(aot_module.backward_graph,
|
||||
aot_module.backward_inputs)
|
||||
|
||||
@register_test_case(module_factory=lambda: ts_module_back)
|
||||
def NeuralNet_backward_basic(module, tu: TestUtils):
|
||||
module.forward(*back_inputs)
|
|
@ -0,0 +1,88 @@
|
|||
import torch
|
||||
from functorch.compile import memory_efficient_fusion, get_decompositions, default_partition
|
||||
from torch import fx
|
||||
import copy
|
||||
|
||||
|
||||
def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
|
||||
"""Generates the annotation i.e., shape and dtype for the given inputs, required by torch-mlir module."""
|
||||
|
||||
annotations_list = [None]
|
||||
for i in inputs:
|
||||
temp_list = []
|
||||
if dynamic:
|
||||
temp_list.append([-1 for i in range(len(i.shape))])
|
||||
else:
|
||||
temp_list.append(list(i.shape))
|
||||
temp_list.append(i.dtype)
|
||||
temp_list.append(True)
|
||||
annotations_list.append(tuple(temp_list))
|
||||
return annotations_list
|
||||
|
||||
|
||||
class AOTModule:
|
||||
|
||||
def __init__(self, model, inputs, labels, training_fn):
|
||||
self.model = model
|
||||
self.inputs = inputs
|
||||
self.labels = labels
|
||||
self.training_fn = training_fn
|
||||
self.forward_graph = None
|
||||
self.backward_graph = None
|
||||
self.forward_inputs = None
|
||||
self.backward_inputs = None
|
||||
|
||||
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
# output nodes always have one argument
|
||||
node_arg = node.args[0]
|
||||
out_nodes = []
|
||||
if isinstance(node_arg, list):
|
||||
# Don't return NoneType elements.
|
||||
for out_node in node_arg:
|
||||
if not isinstance(out_node, type(None)):
|
||||
out_nodes.append(out_node)
|
||||
# If there is a single tensor/element to be returned don't
|
||||
# a tuple for it.
|
||||
if len(out_nodes) == 1:
|
||||
node.args = out_nodes
|
||||
else:
|
||||
node.args = (tuple(out_nodes), )
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
def get_forward_graph(self, fx_g: fx.GraphModule, inps):
|
||||
return_fx = copy.deepcopy(fx_g)
|
||||
f = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
f = torch.jit.script(fx_g)
|
||||
f = torch.jit.freeze(f.eval())
|
||||
torch.jit.save(f, "forw.pt")
|
||||
f = torch.jit.load("forw.pt")
|
||||
self.forward_graph = f
|
||||
self.forward_inputs = copy.deepcopy(inps)
|
||||
return return_fx
|
||||
|
||||
def get_backward_graph(self, fx_g: fx.GraphModule, inps):
|
||||
return_fx = copy.deepcopy(fx_g)
|
||||
f = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
f = torch.jit.script(fx_g)
|
||||
f = torch.jit.freeze(f.eval())
|
||||
torch.jit.save(f, "back.pt")
|
||||
f = torch.jit.load("back.pt")
|
||||
self.backward_graph = f
|
||||
self.backward_inputs = copy.deepcopy(inps)
|
||||
return return_fx
|
||||
|
||||
def generate_graphs(self):
|
||||
aot_model = memory_efficient_fusion(
|
||||
self.model,
|
||||
fw_compiler=self.get_forward_graph,
|
||||
bw_compiler=self.get_backward_graph,
|
||||
partition_fn=default_partition,
|
||||
decompositions=get_decompositions([
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward
|
||||
]))
|
||||
self.training_fn(aot_model, self.inputs, self.labels)
|
|
@ -22,8 +22,13 @@ python3 -m venv $venv_dir
|
|||
source $venv_dir/bin/activate
|
||||
# For minilm_seq_classification.py
|
||||
python3 -m pip install 'transformers[torch]'
|
||||
# For functorch dependent models
|
||||
python3 -m pip install ninja "git+https://github.com/pytorch/functorch.git"
|
||||
python3 -m pip install networkx
|
||||
|
||||
cd "$torch_mlir_src_root"
|
||||
export PYTHONPATH=${PYTHONPATH-}
|
||||
source "$torch_mlir_src_root/.env"
|
||||
# For the functorch utils.py
|
||||
export PYTHONPATH="$torch_mlir_src_root/build_tools/torchscript_e2e_heavydep_tests:$PYTHONPATH"
|
||||
python3 -m build_tools.torchscript_e2e_heavydep_tests.main --output_dir=$serialized_test_dir
|
||||
|
|
|
@ -8,6 +8,8 @@ import argparse
|
|||
from torch_mlir_e2e_test.torchscript.serialization import serialize_all_tests_to
|
||||
|
||||
from . import hf_sequence_classification
|
||||
from . import fully_connected_backward
|
||||
from . import bert_functorch
|
||||
|
||||
|
||||
def _get_argparse():
|
||||
|
|
Loading…
Reference in New Issue