Add TorchList type and prim::ListConstruct #218

pull/219/head
Yi Zhang 2021-06-04 22:57:21 +00:00 committed by Sean Silva
parent 370e3270ab
commit e0ff5248fb
23 changed files with 291 additions and 122 deletions

View File

@ -468,7 +468,8 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
for (IValue element : list) {
elements.push_back(mapIValueToMlirValue(loc, element));
}
return funcBuilder->buildList(loc, elements);
return funcBuilder->buildList(loc,
typeMapper.mapFromTorchType(loc, list.elementType()), elements);
}
if (ival.isNone()) {
return funcBuilder->getNoneConstant(loc);
@ -511,7 +512,9 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
}
if (ival.isList()) {
return npcompListTypeGet(funcBuilder->getContext());
return npcompListTypeGet(
typeMapper.mapFromTorchType(
loc, ival.toList().elementType()));
}
if (ival.isNone()) {
return npcompNoneTypeGet(funcBuilder->getContext());

View File

@ -130,10 +130,10 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
return insertConstantOp(OpBuilder(context).createStdConstant(loc, value));
}
MlirValue FuncBuilder::buildList(MlirLocation loc,
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
std::vector<MlirValue> &elements) {
MlirType resultType = npcompListTypeGet(context);
OperationStateHolder state{"basicpy.build_list", loc};
MlirType resultType = npcompListTypeGet(elementType);
OperationStateHolder state{"torch.prim.ListConstruct", loc};
mlirOperationStateAddResults(state, 1, &resultType);
mlirOperationStateAddOperands(state, elements.size(), elements.data());
MlirOperation op = state.createOperation();

View File

@ -117,7 +117,8 @@ public:
MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value);
/// Builds a list with the given elements
MlirValue buildList(MlirLocation loc, std::vector<MlirValue> &elements);
MlirValue buildList(MlirLocation loc, MlirType elementType,
std::vector<MlirValue> &elements);
private:
FuncBuilder(MlirContext context, MlirOperation funcOp,

View File

@ -270,8 +270,10 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
elems.push_back(importIValue(elem));
}
MlirOperation operation =
createMlirOperationAtEnd(importBlock, "basicpy.build_list", loc,
npcompListTypeGet(context), elems);
createMlirOperationAtEnd(importBlock, "torch.prim.ListConstruct", loc,
npcompListTypeGet(
typeMapper.mapFromTorchType(
loc, list.elementType())), elems);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTuple()) {

View File

@ -82,6 +82,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
// Builtin interpreter ops with no operator/schema.
switch (kind) {
case c10::prim::ListUnpack:
case c10::prim::ListConstruct:
createAndMapTrivialNode(node,
"torch.prim." + std::string(kind.toUnqualString()));
return;
@ -96,10 +97,6 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
// Ops trivially lowered through `basicpy` dialect.
switch (kind) {
case c10::prim::ListConstruct: {
createAndMapTrivialNode(node, "basicpy.build_list");
return;
}
case c10::prim::TupleConstruct: {
createAndMapTrivialNode(node, "basicpy.build_tuple");
return;

View File

@ -181,8 +181,9 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
return npcompBoolTypeGet(context);
}
case TypeKind::ListType: {
// TODO: Don't lose the element type information.
return npcompListTypeGet(context);
return npcompListTypeGet(
mapFromTorchType(
loc, torchType->cast<c10::ListType>()->getElementType()));
}
case TypeKind::TupleType: {
// TODO: Don't lose the element type information.

View File

@ -224,7 +224,7 @@ TORCH_TYPE_TO_ODS_TYPE = {
"bool": "AnyTorchBoolType",
"bool[]": "AnyTorchBoolListType",
"float": "AnyFloat",
"t[]": "Basicpy_ListType",
"t[]": "AnyTorchListType",
"t": "AnyTorchType",
"t1": "AnyTorchType",
"t2": "AnyTorchType",
@ -455,6 +455,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::len.t : (t[]) -> (int)",
has_folder=True,
has_canonicalizer=True)
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):

View File

@ -40,7 +40,7 @@ with mb.capture_function("conv_cat", [inputs, target]) as f:
# CHECK: "aten.convolution"
# CHECK: "aten.convolution"
# CHECK: basicpy.build_list
# CHECK: torch.prim.ListConstruct
# CHECK: "aten._cat"
# CHECK: "aten._log_softmax"
# CHECK: "aten.nll_loss2d_forward"

View File

@ -45,11 +45,11 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
# CHECK: %[[VAL_10:.*]] = constant 1 : i64
# CHECK: %[[VAL_11:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>) : !torch.tensor<[4,16,3,3],f32>
# CHECK: %[[VAL_12:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4xf32>) : !torch.tensor<[4],f32>
# CHECK: %[[VAL_13:.*]] = basicpy.build_list %[[VAL_1]], %[[VAL_2]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_14:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_5]], %[[VAL_6]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_16:.*]] = basicpy.build_list %[[VAL_8]], %[[VAL_9]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_17:.*]] = torch.operator "aten.convolution"(%[[VAL_0]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_7]], %[[VAL_16]], %[[VAL_10]]) : (!torch.tensor<[3,16,10,10],f32>, !torch.tensor<[4,16,3,3],f32>, !torch.tensor<[4],f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType, !basicpy.ListType, i64) -> !torch.tensor<[3,4,8,8],f32>
# CHECK: %[[VAL_13:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_2]] : (i64, i64) -> !torch.list<i64>
# CHECK: %[[VAL_14:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !torch.list<i64>
# CHECK: %[[VAL_15:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_6]] : (i64, i64) -> !torch.list<i64>
# CHECK: %[[VAL_16:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_9]] : (i64, i64) -> !torch.list<i64>
# CHECK: %[[VAL_17:.*]] = torch.operator "aten.convolution"(%[[VAL_0]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_7]], %[[VAL_16]], %[[VAL_10]]) : (!torch.tensor<[3,16,10,10],f32>, !torch.tensor<[4,16,3,3],f32>, !torch.tensor<[4],f32>, !torch.list<i64>, !torch.list<i64>, !torch.list<i64>, !basicpy.BoolType, !torch.list<i64>, i64) -> !torch.tensor<[3,4,8,8],f32>
# CHECK: return %[[VAL_17]] : !torch.tensor<[3,4,8,8],f32>
# CHECK: }

View File

@ -16,14 +16,13 @@ class TestModule(torch.nn.Module):
super().__init__()
self.l = [1, 2]
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
# TODO: Don't lose element type.
# CHECK: torch.attr "l" : !basicpy.ListType
# CHECK: torch.attr "l" : !torch.list<i64>
# CHECK: }
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
# CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64
# CHECK: %[[LIST:.*]] = basicpy.build_list %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[N1]], %[[N2]] : (i64, i64) -> !torch.list<i64>
# CHECK: torch.nn_module {
# CHECK: torch.slot "l", %[[LIST]] : !basicpy.ListType
# CHECK: torch.slot "l", %[[LIST]] : !torch.list<i64>
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">

View File

@ -21,8 +21,8 @@ mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
# CHECK: %[[L2:.*]] = basicpy.build_list
# CHECK: %[[L1:.*]] = basicpy.build_list
# CHECK: %[[L2:.*]] = torch.prim.ListConstruct
# CHECK: %[[L1:.*]] = torch.prim.ListConstruct
# CHECK: torch.nn_module {
# CHECK: torch.slot "l2", %[[L2]]
# CHECK: torch.slot "l1", %[[L1]]

View File

@ -11,9 +11,9 @@ mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @__torch__.f(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !basicpy.ListType {
# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !basicpy.ListType
# CHECK: return %[[RET]] : !basicpy.ListType
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.list<!torch.tensor> {
# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !torch.list<!torch.tensor>
# CHECK: return %[[RET]] : !torch.list<!torch.tensor>
@mb.import_function
@torch.jit.script

View File

@ -84,8 +84,8 @@ def prim_TupleIndex(tup: typing.Tuple[int, int]):
return tup[0]
# CHECK-LABEL: func @__torch__.prim_ListUnpack(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 {
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64
# CHECK-SAME: %[[ARG:.*]]: !torch.list<i64>) -> i64 {
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !torch.list<i64> -> i64, i64
# CHECK: return %[[RET]]#1 : i64
@mb.import_function
@torch.jit.script
@ -122,11 +122,11 @@ def prim_device(x):
# CHECK-LABEL: func @__torch__.prim_min(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
# CHECK: %[[SINGLETON:.*]] = basicpy.build_list %[[ARG]] : (i64) -> !basicpy.ListType
# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !basicpy.ListType -> i64
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list<i64>
# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !torch.list<i64> -> i64
# CHECK: %[[MIN2:.*]] = torch.prim.min.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
# CHECK: %[[ARG_3_TIMES:.*]] = basicpy.build_list %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !basicpy.ListType
# CHECK: %[[MIN3:.*]] = torch.prim.min.self_int %[[ARG_3_TIMES]] : !basicpy.ListType -> i64
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list<i64>
# CHECK: %[[MIN3:.*]] = torch.prim.min.self_int %[[ARG_3_TIMES]] : !torch.list<i64> -> i64
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MIN1]], %[[MIN2]], %[[MIN3]] : (i64, i64, i64) -> !basicpy.TupleType
# CHECK: return %[[RET]] : !basicpy.TupleType
@mb.import_function
@ -136,11 +136,11 @@ def prim_min(x: int):
# CHECK-LABEL: func @__torch__.prim_max(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
# CHECK: %[[SINGLETON:.*]] = basicpy.build_list %[[ARG]] : (i64) -> !basicpy.ListType
# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !basicpy.ListType -> i64
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list<i64>
# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !torch.list<i64> -> i64
# CHECK: %[[MAX2:.*]] = torch.prim.max.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
# CHECK: %[[ARG_3_TIMES:.*]] = basicpy.build_list %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !basicpy.ListType
# CHECK: %[[MAX3:.*]] = torch.prim.max.self_int %[[ARG_3_TIMES]] : !basicpy.ListType -> i64
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list<i64>
# CHECK: %[[MAX3:.*]] = torch.prim.max.self_int %[[ARG_3_TIMES]] : !torch.list<i64> -> i64
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MAX1]], %[[MAX2]], %[[MAX3]] : (i64, i64, i64) -> !basicpy.TupleType
# CHECK: return %[[RET]] : !basicpy.TupleType
@mb.import_function

View File

@ -62,10 +62,10 @@ MlirType npcompDictTypeGet(MlirContext context);
/*============================================================================*/
/** Checks whether the given type is the Python "list" type. */
int npcompTypeIsAList(MlirType t);
int npcompTypeIsABasicpyList(MlirType t);
/** Gets the generic Python "list" type. */
MlirType npcompListTypeGet(MlirContext context);
MlirType npcompBasicpyListTypeGet(MlirContext context);
/*============================================================================*/
/* NDArray type. */
@ -137,6 +137,16 @@ int npcompTypeIsAOptional(MlirType t);
/** Gets the !torch.optional<T> type with subtype T. */
MlirType npcompOptionalTypeGet(MlirType containedType);
/*============================================================================*/
/* torch.list type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.list<T> type */
int npcompTypeIsAList(MlirType t);
/** Gets the !torch.list<T> type with contained T. */
MlirType npcompListTypeGet(MlirType containedType);
/*============================================================================*/
/* torch.Device type. */
/*============================================================================*/

View File

@ -382,7 +382,7 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [
]> {
let summary = "Generated op for `aten::len.t : (t[]) -> (int)`";
let arguments = (ins
Basicpy_ListType:$a
AnyTorchListType:$a
);
let results = (outs
AnyTorchIntType:$result
@ -392,3 +392,33 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [
let hasCanonicalizer = 1;
}
def Torch_Aten__Getitem__TOp : Torch_Op<"aten.__getitem__.t", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::__getitem__.t : (t[], int) -> (t)`";
let arguments = (ins
AnyTorchListType:$list,
AnyTorchIntType:$idx
);
let results = (outs
AnyTorchType:$result
);
let assemblyFormat = "$list `,` $idx attr-dict `:` type($list) `,` type($idx) `->` type($result)";
let hasCanonicalizer = 1;
}
def Torch_Aten_SetItemTOp : Torch_Op<"aten._set_item.t", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::_set_item.t : (t[], int, t) -> (t[])`";
let arguments = (ins
AnyTorchListType:$l,
AnyTorchIntType:$idx,
AnyTorchType:$el
);
let results = (outs
AnyTorchListType:$result
);
let assemblyFormat = "$l `,` $idx `,` $el attr-dict `:` type($l) `,` type($idx) `,` type($el) `->` type($result)";
}

View File

@ -322,6 +322,26 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
}];
}
def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
NoSideEffect,
AllowsTypeRefinement,
]> {
let summary = "TorchScript prim::ListConstruct op";
let arguments = (ins
Variadic<AnyTorchType>:$elements
);
let results = (outs
AnyTorchListType:$result
);
let verifier = "return ::verify(*this);";
let assemblyFormat = [{
$elements attr-dict `:` functional-type(operands, results)
}];
}
def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> {
let summary = "TorchScript prim::GetAttr op";

View File

@ -21,6 +21,31 @@ class Torch_Type<string name, string typeMnemonic,
let mnemonic = typeMnemonic;
}
class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type<name, typeMnemonic> {
let parameters = (ins "::mlir::Type":$containedType);
let printer = [{
$_printer << getMnemonic() << "<" << getImpl()->containedType << ">";
}];
let parser = [{
if (parser.parseLess())
return Type();
Type containedType;
if ($_parser.parseType(containedType))
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, containedType);
}];
let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
return Base::get(containedType.getContext(), containedType);
}]>
];
}
def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
let summary = "torch.nn.Module";
let description = [{
@ -195,32 +220,16 @@ def AnyTorchTensorType : Type<
// For now, we only need it as a stand-in type to allow importing
// the `_is_full_backward_hook` optional bool type that Torch puts on
// all classes.
def Torch_OptionalType : Torch_Type<"Optional", "optional"> {
def Torch_OptionalType : Torch_TypeWithContainedType<"Optional", "optional"> {
let summary = "!torch.optional<T>";
let description = [{
}];
let parameters = (ins "::mlir::Type":$containedType);
}
let printer = [{
$_printer << "optional<" << getImpl()->containedType << ">";
def Torch_ListType : Torch_TypeWithContainedType<"List", "list"> {
let summary = "!torch.list<T>";
let description = [{
}];
let parser = [{
if (parser.parseLess())
return Type();
Type containedType;
if ($_parser.parseType(containedType))
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, containedType);
}];
let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
return Base::get(containedType.getContext(), containedType);
}]>
];
}
def Torch_DeviceType : Torch_Type<"Device", "Device"> {
@ -279,6 +288,13 @@ def AnyTorchScalarType : AnyTypeOf<[
AnySignlessInteger,
], "Any primitive type suitable to be passed as a Torch Scalar">;
def IsListTypePred : CPred<"$_self.isa<::mlir::NPCOMP::Torch::ListType>()">;
class ListOf<list<Type> allowedTypes, string descr> :
ContainerType<AnyTypeOf<allowedTypes>, IsListTypePred,
"$_self.cast<::mlir::NPCOMP::Torch::ListType>().getContainedType()",
descr, "::mlir::NPCOMP::Torch::ListType">;
def AnyTorchNumberType : AnyTypeOf<[
AnySignedInteger,
AnyFloat,
@ -293,33 +309,29 @@ def AnyTorchBoolType : AnyTypeOf<[
Basicpy_BoolType,
], "Any permissible bool type">;
def AnyTorchBoolListType : AnyTypeOf<[
Basicpy_ListType,
// TODO: Support typed list when available.
], "Any bool list type (bool[])">;
def AnyTorchBoolListType : ListOf<[AnyTorchBoolType], "Any bool list type (bool[])">;
def AnyTorchIntType : AnyTypeOf<[
AnySignedInteger,
AnySignlessInteger,
], "Any primitive integer type suitable to be passed as a Torch 'int'">;
def AnyTorchIntListType : AnyTypeOf<[
Basicpy_ListType,
// TODO: Support typed list when available.
], "Any int list type (int[])">;
def AnyTorchIntListType : ListOf<[AnyTorchIntType], "Any int list type (int[])">;
def AnyTorchType : AnyTypeOf<[
AnyTorchBoolType,
AnyTorchScalarType,
AnyTorchTensorType,
Basicpy_ListType,
Basicpy_TupleType,
Basicpy_NoneType,
Basicpy_BytesType,
Torch_NnModuleType,
Torch_OptionalType,
Torch_ListType,
Torch_DeviceType,
Torch_LinearParamsType,
], "Any type that is legal to pass to a Torch kernel">;
def AnyTorchListType : ListOf<[AnyType], "Any Torch list Type">;
#endif // TORCH_TYPES

View File

@ -69,10 +69,10 @@ MlirType npcompDictTypeGet(MlirContext context) {
/*============================================================================*/
/** Checks whether the given type is the Python "list" type. */
int npcompTypeIsAList(MlirType t) { return unwrap(t).isa<Basicpy::ListType>(); }
int npcompTypeIsABasicpyList(MlirType t) { return unwrap(t).isa<Basicpy::ListType>(); }
/** Gets the generic Python "dict" type. */
MlirType npcompListTypeGet(MlirContext context) {
MlirType npcompBasicpyListTypeGet(MlirContext context) {
return wrap(Basicpy::ListType::get(unwrap(context)));
}
@ -175,6 +175,19 @@ MlirType npcompOptionalTypeGet(MlirType containedType) {
return wrap(Torch::OptionalType::get(unwrap(containedType)));
}
/*============================================================================*/
/* torch.list type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.list<T> type */
int npcompTypeIsAList(MlirType t) {
return unwrap(t).isa<Torch::ListType>();
}
/** Gets the !torch.List<T> type with contained T. */
MlirType npcompListTypeGet(MlirType containedType) {
return wrap(Torch::ListType::get(unwrap(containedType)));
}
/*============================================================================*/
/* torch.Device type. */
/*============================================================================*/

View File

@ -11,6 +11,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
@ -131,6 +132,21 @@ LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
//===----------------------------------------------------------------------===//
// PrimListConstructOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(PrimListConstructOp op) {
auto resultType = op.getResult().getType();
auto resultElementType = resultType.dyn_cast<ListType>().getContainedType();
auto matchResultElementType = [&](Type type) {
return type.getTypeID() == resultElementType.getTypeID();
};
if (llvm::all_of(op->getOperandTypes(), matchResultElementType))
return success();
else return failure();
}
//===----------------------------------------------------------------------===//
// ClassTypeOp
//===----------------------------------------------------------------------===//
@ -246,9 +262,9 @@ OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenLenTOp::fold(ArrayRef<Attribute> operands) {
// `len([1,1,1])` -> `3`
if (auto buildList = getOperand().getDefiningOp<Basicpy::BuildListOp>()) {
if (auto listConstruct = getOperand().getDefiningOp<Torch::PrimListConstructOp>()) {
return IntegerAttr::get(IntegerType::get(getContext(), 64),
buildList.getNumOperands());
listConstruct.getNumOperands());
}
return nullptr;
}
@ -288,8 +304,8 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
listElements.push_back(rewriter.create<::mlir::ConstantOp>(
op->getLoc(), rewriter.getI64IntegerAttr(size)));
}
rewriter.replaceOpWithNewOp<Basicpy::BuildListOp>(
op, Basicpy::ListType::get(rewriter.getContext()), listElements);
rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(
op, Torch::ListType::get(rewriter.getI64Type()), listElements);
return success();
});
// One-off pattern to erase if dead.
@ -425,5 +441,30 @@ LogicalResult FromBuiltinTensorOp::inferReturnTypes(
return success();
}
//===----------------------------------------------------------------------===//
// Aten__Getitem__TOp
//===----------------------------------------------------------------------===//
void Aten__Getitem__TOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
auto torchList = op.getOperand(0);
if(!torchList.hasOneUse())
return failure();
auto listConstruct = torchList.getDefiningOp<Torch::PrimListConstructOp>();
if (!listConstruct)
return failure();
APInt indexAP;
if (!matchPattern(op.getOperand(1), m_ConstantInt(&indexAP)))
return failure();
auto index = indexAP.getSExtValue();
rewriter.replaceOp(op, {listConstruct.getOperand(index)});
return success();
});
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"

View File

@ -2,20 +2,20 @@
// CHECK that multiple nested initialization ops are properly handled.
// CHECK-LABEL: torch.global_slot @l : !basicpy.ListType {
// CHECK: %[[L0:.*]] = basicpy.build_list : () -> !basicpy.ListType
// CHECK: %[[L1:.*]] = basicpy.build_list %[[L0]], %[[L0]] : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType
// CHECK: %[[L2:.*]] = basicpy.build_list %[[L1]], %[[L1]] : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType
// CHECK: torch.global_slot.init %[[L2]] : !basicpy.ListType
// CHECK-LABEL: torch.global_slot @l : !torch.list<!torch.list<!torch.list<!torch.tensor>>> {
// CHECK: %[[L0:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.tensor>
// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[L0]], %[[L0]] : (!torch.list<!torch.tensor>, !torch.list<!torch.tensor>) -> !torch.list<!torch.list<!torch.tensor>>
// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[L1]], %[[L1]] : (!torch.list<!torch.list<!torch.tensor>>, !torch.list<!torch.list<!torch.tensor>>) -> !torch.list<!torch.list<!torch.list<!torch.tensor>>>
// CHECK: torch.global_slot.init %[[L2]] : !torch.list<!torch.list<!torch.list<!torch.tensor>>>
// CHECK: }
torch.class_type @c {
torch.attr "l" : !basicpy.ListType
torch.attr "l" : !torch.list<!torch.list<!torch.list<!torch.tensor>>>
}
%l0 = basicpy.build_list : () -> !basicpy.ListType
%l1 = basicpy.build_list %l0, %l0 : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType
%l2 = basicpy.build_list %l1, %l1 : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType
%l0 = torch.prim.ListConstruct : () -> !torch.list<!torch.tensor>
%l1 = torch.prim.ListConstruct %l0, %l0 : (!torch.list<!torch.tensor>, !torch.list<!torch.tensor>) -> !torch.list<!torch.list<!torch.tensor>>
%l2 = torch.prim.ListConstruct %l1, %l1 : (!torch.list<!torch.list<!torch.tensor>>, !torch.list<!torch.list<!torch.tensor>>) -> !torch.list<!torch.list<!torch.list<!torch.tensor>>>
torch.nn_module {
torch.slot "l", %l2 : !basicpy.ListType
torch.slot "l", %l2 : !torch.list<!torch.list<!torch.list<!torch.tensor>>>
} : !torch.nn.Module<"c">

View File

@ -6,7 +6,7 @@ torch.class_type @parent {
func private @module_type_return(%arg0: !torch.nn.Module<"parent">) {
// expected-error @+1 {{unsupported use of a torch.nn.Module. Expected only method calls or attribute get/set}}
basicpy.build_list %arg0 : (!torch.nn.Module<"parent">) -> !basicpy.ListType
torch.prim.ListConstruct %arg0 : (!torch.nn.Module<"parent">) -> !torch.list<!torch.nn.Module<"parent">>
return
}

View File

@ -3,31 +3,31 @@
// CHECK-LABEL: func @torch.aten.__is__
// CHECK: %[[FALSE:.*]] = basicpy.bool_constant false
// CHECK: return %[[FALSE]] : !basicpy.BoolType
func @torch.aten.__is__(%arg0: !basicpy.ListType, %arg1: !basicpy.NoneType) -> !basicpy.BoolType{
%0 = torch.aten.__is__ %arg0, %arg1 : !basicpy.ListType, !basicpy.NoneType -> !basicpy.BoolType
func @torch.aten.__is__(%arg0: !torch.list<i64>, %arg1: !basicpy.NoneType) -> !basicpy.BoolType{
%0 = torch.aten.__is__ %arg0, %arg1 : !torch.list<i64>, !basicpy.NoneType -> !basicpy.BoolType
return %0 : !basicpy.BoolType
}
// CHECK-LABEL: func @torch.aten.size$canonicalize_to_list(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !basicpy.ListType {
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.list<i64> {
// CHECK: %[[C2:.*]] = constant 2 : i64
// CHECK: %[[C3:.*]] = constant 3 : i64
// CHECK: %[[LIST:.*]] = basicpy.build_list %[[C2]], %[[C3]] : (i64, i64) -> !basicpy.ListType
// CHECK: return %[[LIST]] : !basicpy.ListType
func @torch.aten.size$canonicalize_to_list(%arg0: !torch.vtensor<[2,3],f32>) -> !basicpy.ListType {
%0 = torch.aten.size %arg0 : !torch.vtensor<[2,3],f32> -> !basicpy.ListType
return %0 : !basicpy.ListType
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C2]], %[[C3]] : (i64, i64) -> !torch.list<i64>
// CHECK: return %[[LIST]] : !torch.list<i64>
func @torch.aten.size$canonicalize_to_list(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.list<i64> {
%0 = torch.aten.size %arg0 : !torch.vtensor<[2,3],f32> -> !torch.list<i64>
return %0 : !torch.list<i64>
}
// One size unknown, so cannot canonicalize.
// TODO: For unknown sizes, insert the equivalent of a "dim" op.
// Then this will only require static rank.
// CHECK-LABEL: func @torch.aten.size$unknown_size(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,3],f32>) -> !basicpy.ListType {
// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor<[?,3],f32> -> !basicpy.ListType
func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !basicpy.ListType {
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !basicpy.ListType
return %0 : !basicpy.ListType
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<i64> {
// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor<[?,3],f32> -> !torch.list<i64>
func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<i64> {
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<i64>
return %0 : !torch.list<i64>
}
// CHECK-LABEL: func @torch.aten.len.t$of_size(
@ -35,8 +35,8 @@ func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !basicpy
// CHECK: %[[DIM:.*]] = torch.aten.dim %[[ARG]] : !torch.vtensor<*,f32> -> i64
// CHECK: return %[[DIM]] : i64
func @torch.aten.len.t$of_size(%arg0: !torch.vtensor<*,f32>) -> i64 {
%0 = torch.aten.size %arg0 : !torch.vtensor<*,f32> -> !basicpy.ListType
%1 = torch.aten.len.t %0 : !basicpy.ListType -> i64
%0 = torch.aten.size %arg0 : !torch.vtensor<*,f32> -> !torch.list<i64>
%1 = torch.aten.len.t %0 : !torch.list<i64> -> i64
return %1 : i64
}
@ -54,8 +54,8 @@ func @torch.aten.dim$with_shape(%arg0: !torch.vtensor<[?,?,?],f32>) -> i64 {
// CHECK: %[[LEN:.*]] = constant 4 : i64
// CHECK: return %[[LEN]] : i64
func @torch.aten.len.t$of_build_list(%arg0: i64) -> i64 {
%0 = basicpy.build_list %arg0, %arg0, %arg0, %arg0 : (i64, i64, i64, i64) -> !basicpy.ListType
%1 = torch.aten.len.t %0 : !basicpy.ListType -> i64
%0 = torch.prim.ListConstruct %arg0, %arg0, %arg0, %arg0 : (i64, i64, i64, i64) -> !torch.list<i64>
%1 = torch.aten.len.t %0 : !torch.list<i64> -> i64
return %1 : i64
}
@ -75,3 +75,41 @@ func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(%arg0: !torch.vte
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
return %1 : !torch.vtensor
}
// CHECK-LABEL: func @torch.aten.__getitem__.t(
// CHECK: %[[C5:.*]] = constant 5 : i64
// CHECK: return %[[C5]] : i64
func @torch.aten.__getitem__.t() -> i64 {
%c4_i64 = constant 4 : i64
%c5_i64 = constant 5 : i64
%c1_i64 = constant 1 : i64
%0 = torch.prim.ListConstruct %c4_i64, %c5_i64 : (i64, i64) -> !torch.list<i64>
%1 = torch.aten.__getitem__.t %0, %c1_i64 : !torch.list<i64>, i64 -> i64
return %1 : i64
}
// Not canonicalized because of passed in index
// CHECK-LABEL: func @torch.aten.__getitem__.t$no_change_test0(
// CHECK: %[[C4:.*]] = constant 4 : i64
// CHECK: %[[C5:.*]] = constant 5 : i64
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C4]], %[[C5]] : (i64, i64) -> !torch.list<i64>
// CHECK: %[[ITEM:.*]] = torch.aten.__getitem__.t %[[LIST]], %arg0 : !torch.list<i64>, i64 -> i64
// CHECK: return %[[ITEM]] : i64
func @torch.aten.__getitem__.t$no_change_test0(%arg0: i64) -> i64 {
%c4_i64 = constant 4 : i64
%c5_i64 = constant 5 : i64
%0 = torch.prim.ListConstruct %c4_i64, %c5_i64 : (i64, i64) -> !torch.list<i64>
%1 = torch.aten.__getitem__.t %0, %arg0 : !torch.list<i64>, i64 -> i64
return %1 : i64
}
// Not canonicalized because of passed in list
// CHECK-LABEL: func @torch.aten.__getitem__.t$no_change_test1(
// CHECK: %[[C5:.*]] = constant 5 : i64
// CHECK: %[[ITEM:.*]] = torch.aten.__getitem__.t %arg0, %[[C5]] : !torch.list<i64>, i64 -> i64
// CHECK: return %[[ITEM]] : i64
func @torch.aten.__getitem__.t$no_change_test1(%arg0: !torch.list<i64>) -> i64 {
%c5_i64 = constant 5 : i64
%0 = torch.aten.__getitem__.t %arg0, %c5_i64 : !torch.list<i64>, i64 -> i64
return %0 : i64
}

View File

@ -72,10 +72,10 @@ func @f(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg
func @f(%arg0:!torch.vtensor, %arg1:!torch.vtensor, %arg2:!torch.vtensor) ->!torch.vtensor {
%c0_i64 = constant 0 : i64
%c1_i64 = constant 1 : i64
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64 ->!torch.vtensor
%0 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
%1 = torch.prim.ListConstruct %c0_i64, %c0_i64 : (i64, i64) -> !torch.list<i64>
%2 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.list<i64>, !torch.list<i64>, !torch.list<i64>, i64 ->!torch.vtensor
return %3 :!torch.vtensor
}
@ -86,10 +86,10 @@ func @f(%arg0:!torch.vtensor, %arg1:!torch.vtensor, %arg2:!torch.vtensor) ->!tor
func @g(%arg0:!torch.vtensor<*,f32>, %arg1:!torch.vtensor<*,f32>, %arg2:!torch.vtensor<*,f32>) ->!torch.vtensor {
%c0_i64 = constant 0 : i64
%c1_i64 = constant 1 : i64
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64 ->!torch.vtensor
%0 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
%1 = torch.prim.ListConstruct %c0_i64, %c0_i64 : (i64, i64) -> !torch.list<i64>
%2 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !torch.list<i64>, !torch.list<i64>, !torch.list<i64>, i64 ->!torch.vtensor
return %3 :!torch.vtensor
}
@ -101,12 +101,12 @@ func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
%c3_i64 = constant 3 : i64
%c2_i64 = constant 2 : i64
%bool_false = basicpy.bool_constant false
%21 = basicpy.build_list %c3_i64, %c3_i64 : (i64, i64) -> !basicpy.ListType
%22 = basicpy.build_list %c2_i64, %c2_i64 : (i64, i64) -> !basicpy.ListType
%23 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%24 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%21 = torch.prim.ListConstruct %c3_i64, %c3_i64 : (i64, i64) -> !torch.list<i64>
%22 = torch.prim.ListConstruct %c2_i64, %c2_i64 : (i64, i64) -> !torch.list<i64>
%23 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
%24 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[?,?,?,?],f32>
%27 = torch.aten.max_pool2d %arg0, %21, %22, %23, %24, %bool_false : !torch.vtensor<[?,?,?,?],f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType -> !torch.vtensor
%27 = torch.aten.max_pool2d %arg0, %21, %22, %23, %24, %bool_false : !torch.vtensor<[?,?,?,?],f32>, !torch.list<i64>, !torch.list<i64>, !torch.list<i64>, !torch.list<i64>, !basicpy.BoolType -> !torch.vtensor
return %27 : !torch.vtensor
}
@ -115,9 +115,9 @@ func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
// CHECK-LABEL: func @f
func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
%c1_i64 = constant 1 : i64
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%0 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
// CHECK: torch.aten.adaptive_avg_pool2d{{.*}} -> !torch.vtensor<[?,?,?,?],f32>
%1 = torch.aten.adaptive_avg_pool2d %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !basicpy.ListType -> !torch.vtensor
%1 = torch.aten.adaptive_avg_pool2d %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<i64> -> !torch.vtensor
return %1 : !torch.vtensor
}