mirror of https://github.com/llvm/torch-mlir
Fix insertion point bug #102
The current code was inserting all build_list ops after the last constant op since it was assuming that all elements being passed in were constants. This patch replaces that patch with a new function that inserts the build_list ops before the terminator. Also modifies test_export_conv2d_fwd.py since its output no longer matches. TEST: Added test_export_cat.py which is the code in #102pull/108/head
parent
0c73c535d6
commit
c2d3820e48
|
@ -500,7 +500,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
|||
for (IValue element : list) {
|
||||
elements.push_back(mapIValueToMlirValue(loc, element));
|
||||
}
|
||||
return funcBuilder->buildConstantList(loc, elements);
|
||||
return funcBuilder->buildList(loc, elements);
|
||||
}
|
||||
if (ival.isNone()) {
|
||||
return funcBuilder->getNoneConstant(loc);
|
||||
|
|
|
@ -190,12 +190,13 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
|||
}
|
||||
|
||||
MlirValue
|
||||
FuncBuilder::buildConstantList(MlirLocation loc,
|
||||
FuncBuilder::buildList(MlirLocation loc,
|
||||
llvm::SmallVectorImpl<MlirValue> &elements) {
|
||||
MlirType resultType = npcompListTypeGet(context);
|
||||
OperationStateHolder state{"basicpy.build_list", loc};
|
||||
mlirOperationStateAddResults(state, 1, &resultType);
|
||||
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
||||
MlirOperation op = state.createOperation();
|
||||
return insertConstantOp(op);
|
||||
entryBlock.insertBeforeTerminator(op);
|
||||
return mlirOperationGetResult(op, 0);
|
||||
}
|
||||
|
|
|
@ -131,10 +131,8 @@ public:
|
|||
/// attribute.
|
||||
MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value);
|
||||
|
||||
/// Builds a list with the given elements (derived from constants).
|
||||
/// The resulting list is inserted into the "constant section" of the
|
||||
/// function.
|
||||
MlirValue buildConstantList(MlirLocation loc,
|
||||
/// Builds a list with the given elements
|
||||
MlirValue buildList(MlirLocation loc,
|
||||
llvm::SmallVectorImpl<MlirValue> &elements);
|
||||
|
||||
private:
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_mlir
|
||||
|
||||
torch_mlir.debug_trace_to_stderr()
|
||||
|
||||
N = 3
|
||||
Cin = 16
|
||||
Cout = 4
|
||||
w = 10
|
||||
h = 10
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, Cin, Cout):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(Cin, Cout, (3,3))
|
||||
def forward(self, x):
|
||||
x0 = self.conv1(x)
|
||||
x1 = self.conv1(x)
|
||||
z = torch.cat([x0, x1])
|
||||
output = F.log_softmax(z, dim=1)
|
||||
return output
|
||||
|
||||
model = Net(Cin, Cout)
|
||||
inputs = torch.ones((N,Cin,h,w))
|
||||
loss = torch.nn.NLLLoss()
|
||||
target = torch.empty(2*N, 8, 8, dtype=torch.long).random_(0, Cout)
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("conv_cat", [inputs, target]) as f:
|
||||
result = loss(model(inputs), target)
|
||||
f.returns([result])
|
||||
|
||||
# CHECK: "aten::convolution"
|
||||
# CHECK: "aten::convolution"
|
||||
# CHECK: basicpy.build_list
|
||||
# CHECK: "aten::_cat"
|
||||
# CHECK: "aten::_log_softmax"
|
||||
# CHECK: "aten::nll_loss2d_forward"
|
||||
mb.module.operation.print(large_elements_limit=2)
|
|
@ -30,29 +30,28 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
|
|||
result = model(tensor)
|
||||
f.returns([result])
|
||||
|
||||
# Generated with mlir/utils/generate-test-checks.py
|
||||
# This is very deterministic and a change test is appropriate.
|
||||
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
# CHECK-LABEL: func @conv2d_fwd(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
|
||||
# CHECK: %[[VAL_1:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4x16x3x3xf32>
|
||||
# CHECK: %[[VAL_2:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4xf32>
|
||||
# CHECK: %[[VAL_3:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_4:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_5:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_5:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_6:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_7:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_8:.*]] = basicpy.build_list %[[VAL_6]], %[[VAL_7]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_9:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_10:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_11:.*]] = basicpy.build_list %[[VAL_9]], %[[VAL_10]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_12:.*]] = constant false
|
||||
# CHECK: %[[VAL_13:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_14:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_13]], %[[VAL_14]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_16:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_17:.*]] = numpy.create_array_from_tensor %[[VAL_1]] : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
|
||||
# CHECK: %[[VAL_18:.*]] = numpy.create_array_from_tensor %[[VAL_2]] : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
|
||||
# CHECK: %[[VAL_19:.*]] = torch.kernel_call "aten::convolution" %[[VAL_0]], %[[VAL_17]], %[[VAL_18]], %[[VAL_5]], %[[VAL_8]], %[[VAL_11]], %[[VAL_12]], %[[VAL_15]], %[[VAL_16]] : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32> {sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
# CHECK: %[[VAL_7:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_8:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_9:.*]] = constant false
|
||||
# CHECK: %[[VAL_10:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_11:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_12:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_13:.*]] = numpy.create_array_from_tensor %[[VAL_1]] : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
|
||||
# CHECK: %[[VAL_14:.*]] = numpy.create_array_from_tensor %[[VAL_2]] : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
|
||||
# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_16:.*]] = basicpy.build_list %[[VAL_5]], %[[VAL_6]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_17:.*]] = basicpy.build_list %[[VAL_7]], %[[VAL_8]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_18:.*]] = basicpy.build_list %[[VAL_10]], %[[VAL_11]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_19:.*]] = torch.kernel_call "aten::convolution" %[[VAL_0]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]], %[[VAL_17]], %[[VAL_9]], %[[VAL_18]], %[[VAL_12]] : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32> {sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
# CHECK: return %[[VAL_19]] : !numpy.ndarray<[3,4,8,8]:f32>
|
||||
# CHECK: }
|
||||
mb.module.operation.print(large_elements_limit=2)
|
||||
|
|
Loading…
Reference in New Issue