[MLIR][TORCH] Add E2E support for aten.fill.Tensor op

This commit adds the decomposition for `aten.fill.Tensor` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/1536/head
Vivek Khandelwal 2022-10-26 17:48:49 +05:30
parent 87ab714ed6
commit c86177730d
8 changed files with 130 additions and 4 deletions

View File

@ -618,4 +618,7 @@ LTC_XFAIL_SET = {
"UpSampleNearest2dDynamicSize_basic", "UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticFactor_basic",
"UpSampleNearest2dStaticSize_basic", "UpSampleNearest2dStaticSize_basic",
"Fill_TensorFloat32WithFloat32_basic",
"Fill_TensorFloat32WithFloat64_basic",
"Fill_TensorFloat32WithInt64_basic",
} }

View File

@ -2471,6 +2471,53 @@ def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
}]; }];
} }
def Torch_AtenFillTensorOp : Torch_Op<"aten.fill.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenFillTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenFill_TensorOp : Torch_Op<"aten.fill_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::fill_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenFill_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [ def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -954,6 +954,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype); Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
return b.create<arith::SelectOp>(loc, mask, fillValue, input); return b.create<arith::SelectOp>(loc, mask, fillValue, input);
} }
if (auto fillTensor = dyn_cast<AtenFillTensorOp>(op)) {
AtenFillTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(fillTensor.getType())
.cast<RankedTensorType>()
.getElementType();
return convertScalarToDtype(b, loc, payloadArgs[1], dtype);
}
if (auto triu = dyn_cast<AtenTriuOp>(op)) { if (auto triu = dyn_cast<AtenTriuOp>(op)) {
// Check if the rank of the input tensor is valid. // Check if the rank of the input tensor is valid.
@ -1046,7 +1053,8 @@ public:
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp>(op)) AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(
op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1521,7 +1529,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp>(); AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>(); target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context); patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -700,7 +700,8 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp>(op)) { AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp,
AtenFillTensorOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
} }

View File

@ -6292,6 +6292,9 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.zero\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.zero\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fill.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fill.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.fill.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"

View File

@ -794,6 +794,9 @@ def atenmasked_fillTensor(self: List[int], mask: List[int], value: List[in
def atenzero(self: List[int]) -> List[int]: def atenzero(self: List[int]) -> List[int]:
return self return self
def atenfillTensor(self: List[int], value: List[int]) -> List[int]:
return self
def atenfillScalar(self: List[int], value: float) -> List[int]: def atenfillScalar(self: List[int], value: float) -> List[int]:
return self return self

View File

@ -292,7 +292,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::square : (Tensor) -> (Tensor)", "aten::square : (Tensor) -> (Tensor)",
"aten::unsqueeze : (Tensor, int) -> (Tensor)", "aten::unsqueeze : (Tensor, int) -> (Tensor)",
"aten::zero : (Tensor) -> (Tensor)", "aten::zero : (Tensor) -> (Tensor)",
"aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)" "aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)"
]: ]:
emit_with_mutating_variants(key) emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating # Elementwise tensor compute ops that don't have the standard mutating

View File

@ -2397,3 +2397,62 @@ class Fill_TensorFloat64WithInt64(torch.nn.Module):
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils): def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).to(torch.float64)) module.forward(torch.randn(3, 2, 4).to(torch.float64))
# ==============================================================================
class Fill_TensorFloat32WithFloat32(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([], torch.float32, True),
])
def forward(self, tensor, value):
return torch.ops.aten.fill_(tensor, value)
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithFloat32())
def Fill_TensorFloat32WithFloat32_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4), tu.rand())
class Fill_TensorFloat32WithFloat64(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([], torch.float64, True),
])
def forward(self, tensor, value):
return torch.ops.aten.fill_(tensor, value)
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithFloat64())
def Fill_TensorFloat32WithFloat64_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4), tu.rand().to(torch.float64))
class Fill_TensorFloat32WithInt64(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([], torch.int64, True),
])
def forward(self, tensor, value):
return torch.ops.aten.fill_(tensor, value)
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithInt64())
def Fill_TensorFloat32WithInt64_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4), tu.randint())