Fix insertion point bug #102

The current code was inserting all build_list ops
after the last constant op since it was assuming that all
elements being passed in were constants.

This patch replaces that patch with a new function that
inserts the build_list ops before the terminator.

Also modifies test_export_conv2d_fwd.py since its output
no longer matches.

TEST: Added test_export_cat.py which is the code in #102
pull/108/head
Harsh Menon 2020-11-02 15:30:21 -08:00 committed by Stella Laurenzo
parent 0c73c535d6
commit c2d3820e48
5 changed files with 70 additions and 25 deletions

View File

@ -500,7 +500,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
for (IValue element : list) { for (IValue element : list) {
elements.push_back(mapIValueToMlirValue(loc, element)); elements.push_back(mapIValueToMlirValue(loc, element));
} }
return funcBuilder->buildConstantList(loc, elements); return funcBuilder->buildList(loc, elements);
} }
if (ival.isNone()) { if (ival.isNone()) {
return funcBuilder->getNoneConstant(loc); return funcBuilder->getNoneConstant(loc);

View File

@ -190,12 +190,13 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
} }
MlirValue MlirValue
FuncBuilder::buildConstantList(MlirLocation loc, FuncBuilder::buildList(MlirLocation loc,
llvm::SmallVectorImpl<MlirValue> &elements) { llvm::SmallVectorImpl<MlirValue> &elements) {
MlirType resultType = npcompListTypeGet(context); MlirType resultType = npcompListTypeGet(context);
OperationStateHolder state{"basicpy.build_list", loc}; OperationStateHolder state{"basicpy.build_list", loc};
mlirOperationStateAddResults(state, 1, &resultType); mlirOperationStateAddResults(state, 1, &resultType);
mlirOperationStateAddOperands(state, elements.size(), elements.data()); mlirOperationStateAddOperands(state, elements.size(), elements.data());
MlirOperation op = state.createOperation(); MlirOperation op = state.createOperation();
return insertConstantOp(op); entryBlock.insertBeforeTerminator(op);
return mlirOperationGetResult(op, 0);
} }

View File

@ -131,11 +131,9 @@ public:
/// attribute. /// attribute.
MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value); MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value);
/// Builds a list with the given elements (derived from constants). /// Builds a list with the given elements
/// The resulting list is inserted into the "constant section" of the MlirValue buildList(MlirLocation loc,
/// function. llvm::SmallVectorImpl<MlirValue> &elements);
MlirValue buildConstantList(MlirLocation loc,
llvm::SmallVectorImpl<MlirValue> &elements);
private: private:
FuncBuilder(MlirContext context, MlirOperation funcOp, FuncBuilder(MlirContext context, MlirOperation funcOp,

View File

@ -0,0 +1,47 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_mlir
torch_mlir.debug_trace_to_stderr()
N = 3
Cin = 16
Cout = 4
w = 10
h = 10
class Net(nn.Module):
def __init__(self, Cin, Cout):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(Cin, Cout, (3,3))
def forward(self, x):
x0 = self.conv1(x)
x1 = self.conv1(x)
z = torch.cat([x0, x1])
output = F.log_softmax(z, dim=1)
return output
model = Net(Cin, Cout)
inputs = torch.ones((N,Cin,h,w))
loss = torch.nn.NLLLoss()
target = torch.empty(2*N, 8, 8, dtype=torch.long).random_(0, Cout)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("conv_cat", [inputs, target]) as f:
result = loss(model(inputs), target)
f.returns([result])
# CHECK: "aten::convolution"
# CHECK: "aten::convolution"
# CHECK: basicpy.build_list
# CHECK: "aten::_cat"
# CHECK: "aten::_log_softmax"
# CHECK: "aten::nll_loss2d_forward"
mb.module.operation.print(large_elements_limit=2)

View File

@ -30,29 +30,28 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
result = model(tensor) result = model(tensor)
f.returns([result]) f.returns([result])
# Generated with mlir/utils/generate-test-checks.py # NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
# This is very deterministic and a change test is appropriate.
# CHECK-LABEL: func @conv2d_fwd( # CHECK-LABEL: func @conv2d_fwd(
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> { # CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
# CHECK: %[[VAL_1:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4x16x3x3xf32> # CHECK: %[[VAL_1:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4x16x3x3xf32>
# CHECK: %[[VAL_2:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4xf32> # CHECK: %[[VAL_2:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4xf32>
# CHECK: %[[VAL_3:.*]] = constant 1 : i64 # CHECK: %[[VAL_3:.*]] = constant 1 : i64
# CHECK: %[[VAL_4:.*]] = constant 1 : i64 # CHECK: %[[VAL_4:.*]] = constant 1 : i64
# CHECK: %[[VAL_5:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType # CHECK: %[[VAL_5:.*]] = constant 0 : i64
# CHECK: %[[VAL_6:.*]] = constant 0 : i64 # CHECK: %[[VAL_6:.*]] = constant 0 : i64
# CHECK: %[[VAL_7:.*]] = constant 0 : i64 # CHECK: %[[VAL_7:.*]] = constant 1 : i64
# CHECK: %[[VAL_8:.*]] = basicpy.build_list %[[VAL_6]], %[[VAL_7]] : (i64, i64) -> !basicpy.ListType # CHECK: %[[VAL_8:.*]] = constant 1 : i64
# CHECK: %[[VAL_9:.*]] = constant 1 : i64 # CHECK: %[[VAL_9:.*]] = constant false
# CHECK: %[[VAL_10:.*]] = constant 1 : i64 # CHECK: %[[VAL_10:.*]] = constant 0 : i64
# CHECK: %[[VAL_11:.*]] = basicpy.build_list %[[VAL_9]], %[[VAL_10]] : (i64, i64) -> !basicpy.ListType # CHECK: %[[VAL_11:.*]] = constant 0 : i64
# CHECK: %[[VAL_12:.*]] = constant false # CHECK: %[[VAL_12:.*]] = constant 1 : i64
# CHECK: %[[VAL_13:.*]] = constant 0 : i64 # CHECK: %[[VAL_13:.*]] = numpy.create_array_from_tensor %[[VAL_1]] : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
# CHECK: %[[VAL_14:.*]] = constant 0 : i64 # CHECK: %[[VAL_14:.*]] = numpy.create_array_from_tensor %[[VAL_2]] : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_13]], %[[VAL_14]] : (i64, i64) -> !basicpy.ListType # CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_16:.*]] = constant 1 : i64 # CHECK: %[[VAL_16:.*]] = basicpy.build_list %[[VAL_5]], %[[VAL_6]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_17:.*]] = numpy.create_array_from_tensor %[[VAL_1]] : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32> # CHECK: %[[VAL_17:.*]] = basicpy.build_list %[[VAL_7]], %[[VAL_8]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_18:.*]] = numpy.create_array_from_tensor %[[VAL_2]] : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32> # CHECK: %[[VAL_18:.*]] = basicpy.build_list %[[VAL_10]], %[[VAL_11]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_19:.*]] = torch.kernel_call "aten::convolution" %[[VAL_0]], %[[VAL_17]], %[[VAL_18]], %[[VAL_5]], %[[VAL_8]], %[[VAL_11]], %[[VAL_12]], %[[VAL_15]], %[[VAL_16]] : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32> {sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} # CHECK: %[[VAL_19:.*]] = torch.kernel_call "aten::convolution" %[[VAL_0]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]], %[[VAL_17]], %[[VAL_9]], %[[VAL_18]], %[[VAL_12]] : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32> {sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
# CHECK: return %[[VAL_19]] : !numpy.ndarray<[3,4,8,8]:f32> # CHECK: return %[[VAL_19]] : !numpy.ndarray<[3,4,8,8]:f32>
# CHECK: } # CHECK: }
mb.module.operation.print(large_elements_limit=2) mb.module.operation.print(large_elements_limit=2)