pull/394/head
George Petterson 2021-11-02 03:14:23 -04:00 committed by Yi Zhang
parent 53b4275ef5
commit 6dde5b347e
5 changed files with 67 additions and 5 deletions

View File

@ -310,3 +310,36 @@ class ElementwiseClampModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseClampModule())
def ElementwiseClampModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5, low=-10, high=10))
class RsubModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.rsub(x, 3.0, alpha=1.0)
@register_test_case(module_factory=lambda: RsubModule())
def RsubModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class RsubModule_noalpha(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.rsub(x, 2.0)
@register_test_case(module_factory=lambda: RsubModule_noalpha())
def RsubModule_noalpha_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

View File

@ -824,6 +824,22 @@ def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other `,` $alpha attr-dict `:` type($self) `,` type($other) `,` type($alpha) `->` type($result)";
}
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -1438,6 +1438,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
return result;
}
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
if (!rsub.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
rsub.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value self = payloadArgs[0];
Value other = promoteScalarToDtype(b, loc, operands[1], self.getType());
Value alpha = promoteScalarToDtype(b, loc, operands[2], self.getType());
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
return b.create<arith::SubFOp>(loc, other, mult);
}
op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp");
return nullptr;
@ -1647,7 +1661,7 @@ struct ConvertElementwiseOp : ConversionPattern {
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenClampOp>(op))
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -2783,7 +2797,7 @@ public:
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenMaximumOp, AtenClampOp>();
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenUnsqueezeOp>();
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);

View File

@ -230,7 +230,7 @@ public:
DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp,
AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp>(op)) {
AtenLayerNormOp, AtenClampOp, AtenRsubScalarOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}

View File

@ -469,8 +469,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
# variants.
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::gelu : (Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")