Generate backward graph via functorch-aot module

Example to demonstrate the extraction of forward as well as
backward graph via Functorch's AOT module is added.
pull/786/head snapshot-20220422.405
Prashant Kumar 2022-04-11 14:08:27 +00:00
parent 28bf9cc1fc
commit e9c785b04b
5 changed files with 259 additions and 0 deletions

View File

@ -0,0 +1,81 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from transformers import AutoModelForMaskedLM, BertConfig
import transformers
import torch
import torch.utils._pytree as pytree
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
from functorch_utils import AOTModule, get_input_annotations
torch.manual_seed(0)
def getAnnotatedModule(ts_module, inputs):
export(ts_module.forward)
annotate_args_decorator = annotate_args(
get_input_annotations(inputs, dynamic=False))
annotate_args_decorator(ts_module.forward)
return ts_module
pytree._register_pytree_node(
transformers.modeling_outputs.MaskedLMOutput,
lambda x: ([x.loss, x.logits], None),
lambda values, _: transformers.modeling_outputs.MaskedLMOutput(
loss=values[1], logits=values[1]
),
)
torch.manual_seed(42)
config = BertConfig(hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0)
model_type = AutoModelForMaskedLM
input_size = (1, 2)
device = "cpu"
dtype = torch.float
model = model_type.from_config(config).to(device, dtype=dtype)
input_ids = torch.randint(0, config.vocab_size, input_size).to(device)
decoder_ids = torch.randint(0, config.vocab_size, input_size).to(device)
train_inputs = {"input_ids": input_ids, "labels": decoder_ids}
def inference_fn(model, input, labels):
return model(**input).loss.sum().backward()
aot_module = AOTModule(model,
train_inputs,
labels=None,
training_fn=inference_fn)
aot_module.generate_graphs()
# ==============================================================================
# Forward test.
forw_inputs = []
for inp in aot_module.forward_inputs:
forw_inputs.append(inp.detach())
ts_module_forw = getAnnotatedModule(aot_module.forward_graph,
aot_module.forward_inputs)
@register_test_case(module_factory=lambda: ts_module_forw)
def BERT_forward_basic(module, tu: TestUtils):
module.forward(*forw_inputs)
# ==============================================================================
# Backward test.
back_inputs = []
for inp in aot_module.backward_inputs:
back_inputs.append(inp.detach())
ts_module_back = getAnnotatedModule(aot_module.backward_graph,
aot_module.backward_inputs)
@register_test_case(module_factory=lambda: ts_module_back)
def BERT_backward_basic(module, tu: TestUtils):
module.forward(*back_inputs)

View File

@ -0,0 +1,83 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
from functorch_utils import AOTModule, get_input_annotations
torch.manual_seed(0)
def getAnnotatedModule(ts_module, inputs):
export(ts_module.forward)
annotate_args_decorator = annotate_args(
get_input_annotations(inputs, dynamic=True))
annotate_args_decorator(ts_module.forward)
return ts_module
class NeuralNet(torch.nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.l1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.l2 = torch.nn.Linear(16, 2)
def forward(self, x):
out = self.l1(x)
out = self.relu(out)
out = self.l2(out)
return out
input = torch.randn(1, 10)
labels = torch.randn(1, 2)
def run_model(module, input, labels):
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(module.parameters(), lr=0.01)
iters = 1
for _ in range(iters):
optimizer.zero_grad()
output = module(input)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
aot_module = AOTModule(NeuralNet(), input, labels, run_model)
aot_module.generate_graphs()
# ==============================================================================
# Forward test.
forw_inputs = []
for inp in aot_module.forward_inputs:
forw_inputs.append(inp.detach())
ts_module_forw = getAnnotatedModule(aot_module.forward_graph,
aot_module.forward_inputs)
@register_test_case(module_factory=lambda: ts_module_forw)
def NeuralNet_forward_basic(module, tu: TestUtils):
module.forward(*forw_inputs)
# ==============================================================================
# Backward test.
back_inputs = []
for inp in aot_module.backward_inputs:
back_inputs.append(inp.detach())
ts_module_back = getAnnotatedModule(aot_module.backward_graph,
aot_module.backward_inputs)
@register_test_case(module_factory=lambda: ts_module_back)
def NeuralNet_backward_basic(module, tu: TestUtils):
module.forward(*back_inputs)

View File

@ -0,0 +1,88 @@
import torch
from functorch.compile import memory_efficient_fusion, get_decompositions, default_partition
from torch import fx
import copy
def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
"""Generates the annotation i.e., shape and dtype for the given inputs, required by torch-mlir module."""
annotations_list = [None]
for i in inputs:
temp_list = []
if dynamic:
temp_list.append([-1 for i in range(len(i.shape))])
else:
temp_list.append(list(i.shape))
temp_list.append(i.dtype)
temp_list.append(True)
annotations_list.append(tuple(temp_list))
return annotations_list
class AOTModule:
def __init__(self, model, inputs, labels, training_fn):
self.model = model
self.inputs = inputs
self.labels = labels
self.training_fn = training_fn
self.forward_graph = None
self.backward_graph = None
self.forward_inputs = None
self.backward_inputs = None
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
for node in fx_g.graph.nodes:
if node.op == "output":
# output nodes always have one argument
node_arg = node.args[0]
out_nodes = []
if isinstance(node_arg, list):
# Don't return NoneType elements.
for out_node in node_arg:
if not isinstance(out_node, type(None)):
out_nodes.append(out_node)
# If there is a single tensor/element to be returned don't
# a tuple for it.
if len(out_nodes) == 1:
node.args = out_nodes
else:
node.args = (tuple(out_nodes), )
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def get_forward_graph(self, fx_g: fx.GraphModule, inps):
return_fx = copy.deepcopy(fx_g)
f = self.change_fx_graph_return_to_tuple(fx_g)
f = torch.jit.script(fx_g)
f = torch.jit.freeze(f.eval())
torch.jit.save(f, "forw.pt")
f = torch.jit.load("forw.pt")
self.forward_graph = f
self.forward_inputs = copy.deepcopy(inps)
return return_fx
def get_backward_graph(self, fx_g: fx.GraphModule, inps):
return_fx = copy.deepcopy(fx_g)
f = self.change_fx_graph_return_to_tuple(fx_g)
f = torch.jit.script(fx_g)
f = torch.jit.freeze(f.eval())
torch.jit.save(f, "back.pt")
f = torch.jit.load("back.pt")
self.backward_graph = f
self.backward_inputs = copy.deepcopy(inps)
return return_fx
def generate_graphs(self):
aot_model = memory_efficient_fusion(
self.model,
fw_compiler=self.get_forward_graph,
bw_compiler=self.get_backward_graph,
partition_fn=default_partition,
decompositions=get_decompositions([
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward
]))
self.training_fn(aot_model, self.inputs, self.labels)

View File

@ -22,8 +22,13 @@ python3 -m venv $venv_dir
source $venv_dir/bin/activate
# For minilm_seq_classification.py
python3 -m pip install 'transformers[torch]'
# For functorch dependent models
python3 -m pip install ninja "git+https://github.com/pytorch/functorch.git"
python3 -m pip install networkx
cd "$torch_mlir_src_root"
export PYTHONPATH=${PYTHONPATH-}
source "$torch_mlir_src_root/.env"
# For the functorch utils.py
export PYTHONPATH="$torch_mlir_src_root/build_tools/torchscript_e2e_heavydep_tests:$PYTHONPATH"
python3 -m build_tools.torchscript_e2e_heavydep_tests.main --output_dir=$serialized_test_dir

View File

@ -8,6 +8,8 @@ import argparse
from torch_mlir_e2e_test.torchscript.serialization import serialize_all_tests_to
from . import hf_sequence_classification
from . import fully_connected_backward
from . import bert_functorch
def _get_argparse():