mirror of https://github.com/llvm/torch-mlir
Fix insertion point bug #102
The current code was inserting all build_list ops after the last constant op since it was assuming that all elements being passed in were constants. This patch replaces that patch with a new function that inserts the build_list ops before the terminator. Also modifies test_export_conv2d_fwd.py since its output no longer matches. TEST: Added test_export_cat.py which is the code in #102pull/108/head
parent
0c73c535d6
commit
c2d3820e48
|
@ -500,7 +500,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
||||||
for (IValue element : list) {
|
for (IValue element : list) {
|
||||||
elements.push_back(mapIValueToMlirValue(loc, element));
|
elements.push_back(mapIValueToMlirValue(loc, element));
|
||||||
}
|
}
|
||||||
return funcBuilder->buildConstantList(loc, elements);
|
return funcBuilder->buildList(loc, elements);
|
||||||
}
|
}
|
||||||
if (ival.isNone()) {
|
if (ival.isNone()) {
|
||||||
return funcBuilder->getNoneConstant(loc);
|
return funcBuilder->getNoneConstant(loc);
|
||||||
|
|
|
@ -190,12 +190,13 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirValue
|
MlirValue
|
||||||
FuncBuilder::buildConstantList(MlirLocation loc,
|
FuncBuilder::buildList(MlirLocation loc,
|
||||||
llvm::SmallVectorImpl<MlirValue> &elements) {
|
llvm::SmallVectorImpl<MlirValue> &elements) {
|
||||||
MlirType resultType = npcompListTypeGet(context);
|
MlirType resultType = npcompListTypeGet(context);
|
||||||
OperationStateHolder state{"basicpy.build_list", loc};
|
OperationStateHolder state{"basicpy.build_list", loc};
|
||||||
mlirOperationStateAddResults(state, 1, &resultType);
|
mlirOperationStateAddResults(state, 1, &resultType);
|
||||||
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
||||||
MlirOperation op = state.createOperation();
|
MlirOperation op = state.createOperation();
|
||||||
return insertConstantOp(op);
|
entryBlock.insertBeforeTerminator(op);
|
||||||
|
return mlirOperationGetResult(op, 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -131,11 +131,9 @@ public:
|
||||||
/// attribute.
|
/// attribute.
|
||||||
MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value);
|
MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value);
|
||||||
|
|
||||||
/// Builds a list with the given elements (derived from constants).
|
/// Builds a list with the given elements
|
||||||
/// The resulting list is inserted into the "constant section" of the
|
MlirValue buildList(MlirLocation loc,
|
||||||
/// function.
|
llvm::SmallVectorImpl<MlirValue> &elements);
|
||||||
MlirValue buildConstantList(MlirLocation loc,
|
|
||||||
llvm::SmallVectorImpl<MlirValue> &elements);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FuncBuilder(MlirContext context, MlirOperation funcOp,
|
FuncBuilder(MlirContext context, MlirOperation funcOp,
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
torch_mlir.debug_trace_to_stderr()
|
||||||
|
|
||||||
|
N = 3
|
||||||
|
Cin = 16
|
||||||
|
Cout = 4
|
||||||
|
w = 10
|
||||||
|
h = 10
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self, Cin, Cout):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(Cin, Cout, (3,3))
|
||||||
|
def forward(self, x):
|
||||||
|
x0 = self.conv1(x)
|
||||||
|
x1 = self.conv1(x)
|
||||||
|
z = torch.cat([x0, x1])
|
||||||
|
output = F.log_softmax(z, dim=1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
model = Net(Cin, Cout)
|
||||||
|
inputs = torch.ones((N,Cin,h,w))
|
||||||
|
loss = torch.nn.NLLLoss()
|
||||||
|
target = torch.empty(2*N, 8, 8, dtype=torch.long).random_(0, Cout)
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
with mb.capture_function("conv_cat", [inputs, target]) as f:
|
||||||
|
result = loss(model(inputs), target)
|
||||||
|
f.returns([result])
|
||||||
|
|
||||||
|
# CHECK: "aten::convolution"
|
||||||
|
# CHECK: "aten::convolution"
|
||||||
|
# CHECK: basicpy.build_list
|
||||||
|
# CHECK: "aten::_cat"
|
||||||
|
# CHECK: "aten::_log_softmax"
|
||||||
|
# CHECK: "aten::nll_loss2d_forward"
|
||||||
|
mb.module.operation.print(large_elements_limit=2)
|
|
@ -30,29 +30,28 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
|
||||||
result = model(tensor)
|
result = model(tensor)
|
||||||
f.returns([result])
|
f.returns([result])
|
||||||
|
|
||||||
# Generated with mlir/utils/generate-test-checks.py
|
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||||
# This is very deterministic and a change test is appropriate.
|
|
||||||
# CHECK-LABEL: func @conv2d_fwd(
|
# CHECK-LABEL: func @conv2d_fwd(
|
||||||
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
|
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
|
||||||
# CHECK: %[[VAL_1:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4x16x3x3xf32>
|
# CHECK: %[[VAL_1:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4x16x3x3xf32>
|
||||||
# CHECK: %[[VAL_2:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4xf32>
|
# CHECK: %[[VAL_2:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4xf32>
|
||||||
# CHECK: %[[VAL_3:.*]] = constant 1 : i64
|
# CHECK: %[[VAL_3:.*]] = constant 1 : i64
|
||||||
# CHECK: %[[VAL_4:.*]] = constant 1 : i64
|
# CHECK: %[[VAL_4:.*]] = constant 1 : i64
|
||||||
# CHECK: %[[VAL_5:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType
|
# CHECK: %[[VAL_5:.*]] = constant 0 : i64
|
||||||
# CHECK: %[[VAL_6:.*]] = constant 0 : i64
|
# CHECK: %[[VAL_6:.*]] = constant 0 : i64
|
||||||
# CHECK: %[[VAL_7:.*]] = constant 0 : i64
|
# CHECK: %[[VAL_7:.*]] = constant 1 : i64
|
||||||
# CHECK: %[[VAL_8:.*]] = basicpy.build_list %[[VAL_6]], %[[VAL_7]] : (i64, i64) -> !basicpy.ListType
|
# CHECK: %[[VAL_8:.*]] = constant 1 : i64
|
||||||
# CHECK: %[[VAL_9:.*]] = constant 1 : i64
|
# CHECK: %[[VAL_9:.*]] = constant false
|
||||||
# CHECK: %[[VAL_10:.*]] = constant 1 : i64
|
# CHECK: %[[VAL_10:.*]] = constant 0 : i64
|
||||||
# CHECK: %[[VAL_11:.*]] = basicpy.build_list %[[VAL_9]], %[[VAL_10]] : (i64, i64) -> !basicpy.ListType
|
# CHECK: %[[VAL_11:.*]] = constant 0 : i64
|
||||||
# CHECK: %[[VAL_12:.*]] = constant false
|
# CHECK: %[[VAL_12:.*]] = constant 1 : i64
|
||||||
# CHECK: %[[VAL_13:.*]] = constant 0 : i64
|
# CHECK: %[[VAL_13:.*]] = numpy.create_array_from_tensor %[[VAL_1]] : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
|
||||||
# CHECK: %[[VAL_14:.*]] = constant 0 : i64
|
# CHECK: %[[VAL_14:.*]] = numpy.create_array_from_tensor %[[VAL_2]] : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
|
||||||
# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_13]], %[[VAL_14]] : (i64, i64) -> !basicpy.ListType
|
# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType
|
||||||
# CHECK: %[[VAL_16:.*]] = constant 1 : i64
|
# CHECK: %[[VAL_16:.*]] = basicpy.build_list %[[VAL_5]], %[[VAL_6]] : (i64, i64) -> !basicpy.ListType
|
||||||
# CHECK: %[[VAL_17:.*]] = numpy.create_array_from_tensor %[[VAL_1]] : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
|
# CHECK: %[[VAL_17:.*]] = basicpy.build_list %[[VAL_7]], %[[VAL_8]] : (i64, i64) -> !basicpy.ListType
|
||||||
# CHECK: %[[VAL_18:.*]] = numpy.create_array_from_tensor %[[VAL_2]] : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
|
# CHECK: %[[VAL_18:.*]] = basicpy.build_list %[[VAL_10]], %[[VAL_11]] : (i64, i64) -> !basicpy.ListType
|
||||||
# CHECK: %[[VAL_19:.*]] = torch.kernel_call "aten::convolution" %[[VAL_0]], %[[VAL_17]], %[[VAL_18]], %[[VAL_5]], %[[VAL_8]], %[[VAL_11]], %[[VAL_12]], %[[VAL_15]], %[[VAL_16]] : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32> {sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
# CHECK: %[[VAL_19:.*]] = torch.kernel_call "aten::convolution" %[[VAL_0]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]], %[[VAL_17]], %[[VAL_9]], %[[VAL_18]], %[[VAL_12]] : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32> {sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||||
# CHECK: return %[[VAL_19]] : !numpy.ndarray<[3,4,8,8]:f32>
|
# CHECK: return %[[VAL_19]] : !numpy.ndarray<[3,4,8,8]:f32>
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
mb.module.operation.print(large_elements_limit=2)
|
mb.module.operation.print(large_elements_limit=2)
|
||||||
|
|
Loading…
Reference in New Issue