mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.fill.Tensor op
This commit adds the decomposition for `aten.fill.Tensor` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1536/head
parent
87ab714ed6
commit
c86177730d
|
@ -618,4 +618,7 @@ LTC_XFAIL_SET = {
|
||||||
"UpSampleNearest2dDynamicSize_basic",
|
"UpSampleNearest2dDynamicSize_basic",
|
||||||
"UpSampleNearest2dStaticFactor_basic",
|
"UpSampleNearest2dStaticFactor_basic",
|
||||||
"UpSampleNearest2dStaticSize_basic",
|
"UpSampleNearest2dStaticSize_basic",
|
||||||
|
"Fill_TensorFloat32WithFloat32_basic",
|
||||||
|
"Fill_TensorFloat32WithFloat64_basic",
|
||||||
|
"Fill_TensorFloat32WithInt64_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -2471,6 +2471,53 @@ def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenFillTensorOp : Torch_Op<"aten.fill.Tensor", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenFillTensorOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenFill_TensorOp : Torch_Op<"aten.fill_.Tensor", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::fill_.Tensor : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenFill_TensorOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
|
def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -954,6 +954,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
|
Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
|
||||||
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
|
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
|
||||||
}
|
}
|
||||||
|
if (auto fillTensor = dyn_cast<AtenFillTensorOp>(op)) {
|
||||||
|
AtenFillTensorOp::Adaptor adaptor(operands);
|
||||||
|
Type dtype = converter->convertType(fillTensor.getType())
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
return convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
}
|
||||||
|
|
||||||
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
|
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
|
||||||
// Check if the rank of the input tensor is valid.
|
// Check if the rank of the input tensor is valid.
|
||||||
|
@ -1046,7 +1053,8 @@ public:
|
||||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||||
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
||||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp>(op))
|
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(
|
||||||
|
op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -1521,7 +1529,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
||||||
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp>();
|
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||||
|
AtenFillTensorOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||||
|
|
|
@ -700,7 +700,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
|
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
|
||||||
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
|
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
|
||||||
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
|
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
|
||||||
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp>(op)) {
|
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp,
|
||||||
|
AtenFillTensorOp>(op)) {
|
||||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6292,6 +6292,9 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.zero\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.zero\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" return %arg0 : !torch.list<int>\n"
|
" return %arg0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.fill.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
|
" return %arg0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.fill.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.fill.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
|
||||||
" return %arg0 : !torch.list<int>\n"
|
" return %arg0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
|
|
@ -794,6 +794,9 @@ def aten〇masked_fill〇Tensor(self: List[int], mask: List[int], value: List[in
|
||||||
def aten〇zero(self: List[int]) -> List[int]:
|
def aten〇zero(self: List[int]) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def aten〇fill〇Tensor(self: List[int], value: List[int]) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
def aten〇fill〇Scalar(self: List[int], value: float) -> List[int]:
|
def aten〇fill〇Scalar(self: List[int], value: float) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -292,7 +292,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::square : (Tensor) -> (Tensor)",
|
"aten::square : (Tensor) -> (Tensor)",
|
||||||
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
|
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
|
||||||
"aten::zero : (Tensor) -> (Tensor)",
|
"aten::zero : (Tensor) -> (Tensor)",
|
||||||
"aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)"
|
"aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||||
|
"aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)"
|
||||||
]:
|
]:
|
||||||
emit_with_mutating_variants(key)
|
emit_with_mutating_variants(key)
|
||||||
# Elementwise tensor compute ops that don't have the standard mutating
|
# Elementwise tensor compute ops that don't have the standard mutating
|
||||||
|
|
|
@ -2397,3 +2397,62 @@ class Fill_TensorFloat64WithInt64(torch.nn.Module):
|
||||||
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
|
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class Fill_TensorFloat32WithFloat32(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor, value):
|
||||||
|
return torch.ops.aten.fill_(tensor, value)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithFloat32())
|
||||||
|
def Fill_TensorFloat32WithFloat32_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 2, 4), tu.rand())
|
||||||
|
|
||||||
|
|
||||||
|
class Fill_TensorFloat32WithFloat64(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor, value):
|
||||||
|
return torch.ops.aten.fill_(tensor, value)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithFloat64())
|
||||||
|
def Fill_TensorFloat32WithFloat64_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 2, 4), tu.rand().to(torch.float64))
|
||||||
|
|
||||||
|
|
||||||
|
class Fill_TensorFloat32WithInt64(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor, value):
|
||||||
|
return torch.ops.aten.fill_(tensor, value)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithInt64())
|
||||||
|
def Fill_TensorFloat32WithInt64_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 2, 4), tu.randint())
|
||||||
|
|
Loading…
Reference in New Issue