mirror of https://github.com/llvm/torch-mlir
Add rsub
parent
53b4275ef5
commit
6dde5b347e
|
@ -310,3 +310,36 @@ class ElementwiseClampModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseClampModule())
|
||||
def ElementwiseClampModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5, low=-10, high=10))
|
||||
|
||||
|
||||
class RsubModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.rsub(x, 3.0, alpha=1.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: RsubModule())
|
||||
def RsubModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
class RsubModule_noalpha(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.rsub(x, 2.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: RsubModule_noalpha())
|
||||
def RsubModule_noalpha_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
|
|
@ -824,6 +824,22 @@ def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [
|
|||
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$other,
|
||||
AnyTorchScalarType:$alpha
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $other `,` $alpha attr-dict `:` type($self) `,` type($other) `,` type($alpha) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -1438,6 +1438,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
return result;
|
||||
}
|
||||
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
|
||||
if (!rsub.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
rsub.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value self = payloadArgs[0];
|
||||
Value other = promoteScalarToDtype(b, loc, operands[1], self.getType());
|
||||
Value alpha = promoteScalarToDtype(b, loc, operands[2], self.getType());
|
||||
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
|
||||
return b.create<arith::SubFOp>(loc, other, mult);
|
||||
}
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgPayloadCalculationForElementwiseOp");
|
||||
return nullptr;
|
||||
|
@ -1647,7 +1661,7 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenClampOp>(op))
|
||||
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -2783,7 +2797,7 @@ public:
|
|||
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenClampOp>();
|
||||
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUnsqueezeOp>();
|
||||
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
|
||||
|
|
|
@ -230,7 +230,7 @@ public:
|
|||
DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
|
||||
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp,
|
||||
AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp,
|
||||
AtenLayerNormOp, AtenClampOp>(op)) {
|
||||
AtenLayerNormOp, AtenClampOp, AtenRsubScalarOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -469,8 +469,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
# variants.
|
||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||
|
||||
|
||||
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::gelu : (Tensor) -> (Tensor)")
|
||||
|
||||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue