pull/406/head snapshot-20211108.72
George Petterson 2021-11-08 03:03:22 -05:00 committed by Yi Zhang
parent f41958037a
commit e23cabf3a9
5 changed files with 50 additions and 3 deletions

View File

@ -429,3 +429,19 @@ class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseToDtypeF32ToI64Module())
def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class ElementwiseLog2Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.log2(a)
@register_test_case(module_factory=lambda: ElementwiseLog2Module())
def ElementwiseLog2Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

View File

@ -850,6 +850,34 @@ def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [
let assemblyFormat = "$self `,` $min `,` $max attr-dict `:` type($self) `,` type($min) `,` type($max) `->` type($result)";
}
def Torch_AtenLog2Op : Torch_Op<"aten.log2", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::log2 : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}
def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::log2_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}
def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -1284,6 +1284,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<math::LogOp>(loc, payloadArgs[0]);
if (isa<AtenSqrtOp>(op))
return b.create<math::SqrtOp>(loc, payloadArgs[0]);
if (isa<AtenLog2Op>(op))
return b.create<math::Log2Op>(loc, payloadArgs[0]);
if (isa<AtenSigmoidOp>(op)) {
Type elementType = payloadArgs[0].getType();
auto one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
@ -1700,7 +1702,7 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp>(op))
AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -2861,7 +2863,7 @@ public:
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
AtenPowTensorScalarOp>();
AtenPowTensorScalarOp, AtenLog2Op>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenUnsqueezeOp>();
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);

View File

@ -230,7 +230,7 @@ public:
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp,
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp>(op)) {
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}

View File

@ -465,6 +465,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
"aten::log2 : (Tensor) -> (Tensor)",
]:
emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating