Add non-default dtype support for a few elementwise math ops. (#687)

* fix type inference
* fix Torch2Linalg conversion
* add test cases
pull/641/head
Qiang Fu 2022-03-23 16:35:43 -04:00 committed by GitHub
parent fe8ac57e6d
commit f7c7bb800c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 398 additions and 62 deletions

View File

@ -37,6 +37,25 @@ def ElementwiseUnaryModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseUnaryIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.tanh(a)
@register_test_case(module_factory=lambda: ElementwiseUnaryIntModule())
def ElementwiseUnaryIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseBinaryModule(torch.nn.Module): class ElementwiseBinaryModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -282,6 +301,25 @@ def ElementwiseSigmoidModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseSigmoidIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.sigmoid(x)
@register_test_case(module_factory=lambda: ElementwiseSigmoidIntModule())
def ElementwiseSigmoidIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 5), dtype=torch.int32))
# ==============================================================================
class ElementwiseMinimumModule(torch.nn.Module): class ElementwiseMinimumModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -545,6 +583,25 @@ def ElementwiseLogModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseLogIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.log(a)
@register_test_case(module_factory=lambda: ElementwiseLogIntModule())
def ElementwiseLogIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseErfModule(torch.nn.Module): class ElementwiseErfModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -564,6 +621,25 @@ def ElementwiseErfModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseErfIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.erf(a)
@register_test_case(module_factory=lambda: ElementwiseErfIntModule())
def ElementwiseErfIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseSqrtModule(torch.nn.Module): class ElementwiseSqrtModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -585,6 +661,26 @@ def ElementwiseSqrtModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseSqrtIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.sqrt(a)
@register_test_case(module_factory=lambda: ElementwiseSqrtIntModule())
def ElementwiseSqrtIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseFloorModule(torch.nn.Module): class ElementwiseFloorModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -699,6 +795,25 @@ def ElementwiseLog2Module_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseLog2IntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.log2(a)
@register_test_case(module_factory=lambda: ElementwiseLog2IntModule())
def ElementwiseLog2IntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseRsqrtModule(torch.nn.Module): class ElementwiseRsqrtModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -719,6 +834,26 @@ def ElementwiseRsqrtModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseRsqrtIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.rsqrt(a)
@register_test_case(module_factory=lambda: ElementwiseRsqrtIntModule())
def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseAbsModule(torch.nn.Module): class ElementwiseAbsModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -757,6 +892,25 @@ def ElementwiseReciprocalModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseReciprocalIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int32, True),
])
def forward(self, a):
return torch.reciprocal(a)
@register_test_case(module_factory=lambda: ElementwiseReciprocalIntModule())
def ElementwiseReciprocalIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (4,), dtype=torch.int32))
# ==============================================================================
class ElementwiseDivScalarModule(torch.nn.Module): class ElementwiseDivScalarModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -949,3 +1103,123 @@ class ElementwiseCloneContiguousModule(torch.nn.Module):
def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils): def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4)) module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class ElementwiseExpModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.exp(a)
@register_test_case(module_factory=lambda: ElementwiseExpModule())
def ElementwiseExpModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseExpIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.exp(a)
@register_test_case(module_factory=lambda: ElementwiseExpIntModule())
def ElementwiseExpIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseSinModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.sin(a)
@register_test_case(module_factory=lambda: ElementwiseSinModule())
def ElementwiseSinModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseSinIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.sin(a)
@register_test_case(module_factory=lambda: ElementwiseSinIntModule())
def ElementwiseSinIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseCosModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.cos(a)
@register_test_case(module_factory=lambda: ElementwiseCosModule())
def ElementwiseCosModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseCosIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.cos(a)
@register_test_case(module_factory=lambda: ElementwiseCosIntModule())
def ElementwiseCosIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))

View File

@ -34,6 +34,7 @@ TOSA_PASS_SET = {
"ElementwiseUnaryModule_basic", "ElementwiseUnaryModule_basic",
"ElementwiseBinaryModule_basic", "ElementwiseBinaryModule_basic",
"ElementwiseSigmoidModule_basic", "ElementwiseSigmoidModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseReluModule_basic", "ElementwiseReluModule_basic",
"ElementwiseFloorModule_basic", "ElementwiseFloorModule_basic",
"ElementwiseLogModule_basic", "ElementwiseLogModule_basic",

View File

@ -189,25 +189,60 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
return buildNormalCdf(b, loc, x, zero, one); return buildNormalCdf(b, loc, x, zero, one);
} }
template <typename MathOpTy>
static Value createCalculationForMathOpWithDtypeConversion(
OpBuilder &b, TypeConverter *converter, Value payloadArg, Operation *op) {
Type dtype = converter->convertType(op->getResult(0).getType())
.template cast<RankedTensorType>()
.getElementType();
Location loc = op->getLoc();
Value arg = convertScalarToDtype(b, loc, payloadArg, dtype);
return b.create<MathOpTy>(loc, arg);
}
static Value createLinalgPayloadCalculationForElementwiseOp( static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, TypeConverter *converter, OpBuilder &b, Location loc, TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) { ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
if (isa<AtenTanhOp>(op))
return b.create<math::TanhOp>(loc, payloadArgs[0]);
if (isa<AtenExpOp>(op))
return b.create<math::ExpOp>(loc, payloadArgs[0]);
if (isa<AtenFloorOp>(op)) if (isa<AtenFloorOp>(op))
return b.create<math::FloorOp>(loc, payloadArgs[0]); return b.create<math::FloorOp>(loc, payloadArgs[0]);
if (isa<AtenCeilOp>(op)) if (isa<AtenCeilOp>(op))
return b.create<math::CeilOp>(loc, payloadArgs[0]); return b.create<math::CeilOp>(loc, payloadArgs[0]);
if (isa<AtenLogOp>(op)) if (isa<AtenTanhOp>(op)) {
return b.create<math::LogOp>(loc, payloadArgs[0]); return createCalculationForMathOpWithDtypeConversion<math::TanhOp>(
if (isa<AtenErfOp>(op)) b, converter, payloadArgs[0], op);
return b.create<math::ErfOp>(loc, payloadArgs[0]); }
if (isa<AtenSqrtOp>(op)) if (isa<AtenExpOp>(op)) {
return b.create<math::SqrtOp>(loc, payloadArgs[0]); return createCalculationForMathOpWithDtypeConversion<math::ExpOp>(
if (isa<AtenRsqrtOp>(op)) b, converter, payloadArgs[0], op);
return b.create<math::RsqrtOp>(loc, payloadArgs[0]); }
if (isa<AtenLogOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::LogOp>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenLog2Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenErfOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ErfOp>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenSqrtOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SqrtOp>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenRsqrtOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::RsqrtOp>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenSinOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SinOp>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenCosOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::CosOp>(
b, converter, payloadArgs[0], op);
}
if (auto clone = dyn_cast<AtenCloneOp>(op)) { if (auto clone = dyn_cast<AtenCloneOp>(op)) {
int64_t memoryFormat; int64_t memoryFormat;
if (!clone.memory_format().getType().isa<Torch::NoneType>() && if (!clone.memory_format().getType().isa<Torch::NoneType>() &&
@ -235,14 +270,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::AndIOp>(loc, lhs, rhs); return b.create<arith::AndIOp>(loc, lhs, rhs);
} }
if (isa<AtenLog2Op>(op))
return b.create<math::Log2Op>(loc, payloadArgs[0]);
if (isa<AtenAbsOp>(op)) if (isa<AtenAbsOp>(op))
return b.create<math::AbsOp>(loc, payloadArgs[0]); return b.create<math::AbsOp>(loc, payloadArgs[0]);
if (isa<AtenSigmoidOp>(op)) { if (isa<AtenSigmoidOp>(op)) {
Type elementType = payloadArgs[0].getType(); auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
auto one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1)); b, converter, payloadArgs[0], op);
auto negate = b.create<arith::NegFOp>(loc, payloadArgs[0]); auto one =
b.create<arith::ConstantOp>(loc, FloatAttr::get(negate.getType(), 1));
auto exp = b.create<math::ExpOp>(loc, negate); auto exp = b.create<math::ExpOp>(loc, negate);
auto added = b.create<arith::AddFOp>(loc, exp, one); auto added = b.create<arith::AddFOp>(loc, exp, one);
return b.create<arith::DivFOp>(loc, one, added); return b.create<arith::DivFOp>(loc, one, added);
@ -763,26 +797,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::DivFOp>(loc, self, other); return b.create<arith::DivFOp>(loc, self, other);
} }
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) { if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
if (!reciprocal.getType() Type dtype = converter->convertType(reciprocal.getType())
.cast<ValueTensorType>() .cast<RankedTensorType>()
.getDtype() .getElementType();
.isa<mlir::FloatType>()) { Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
reciprocal.emitError("unimplemented: non-floating point dtype"); Type elementType = arg.getType();
return nullptr;
}
Type elementType = payloadArgs[0].getType();
// assert(element != 0) // assert(element != 0)
auto zero = auto zero =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0)); b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ONE, auto pred =
payloadArgs[0], zero); b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ONE, arg, zero);
b.create<cf::AssertOp>( b.create<cf::AssertOp>(
loc, pred, b.getStringAttr("unimplemented: tensor with zero element")); loc, pred, b.getStringAttr("unimplemented: tensor with zero element"));
auto one = auto one =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0)); b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
return b.create<arith::DivFOp>(loc, one, payloadArgs[0]); return b.create<arith::DivFOp>(loc, one, arg);
} }
if (auto thresholdOp = dyn_cast<AtenThresholdOp>(op)) { if (auto thresholdOp = dyn_cast<AtenThresholdOp>(op)) {
// The approach used here is as follows: // The approach used here is as follows:
@ -871,7 +901,7 @@ public:
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp>(op)) AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1545,7 +1575,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>(); AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>(); target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context); patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);

View File

@ -62,19 +62,16 @@ ParseResult Torch::parseDefaultTorchOp(OpAsmParser &parser,
void Torch::printDefaultTorchOp(OpAsmPrinter &p, Operation *op, int numOperands, void Torch::printDefaultTorchOp(OpAsmPrinter &p, Operation *op, int numOperands,
int numResults) { int numResults) {
p << ' ';
llvm::interleaveComma(op->getOperands(), p);
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{});
p << " : ";
if (numOperands > 0) { if (numOperands > 0) {
p << ' '; p << ' ';
llvm::interleaveComma(op->getOperands(), p);
}
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{});
p << " : ";
if (numOperands > 0)
llvm::interleaveComma(op->getOperandTypes(), p); llvm::interleaveComma(op->getOperandTypes(), p);
} if (numOperands > 0 && numResults > 0)
if (numOperands > 0 && numResults > 0) {
p << " -> "; p << " -> ";
} if (numResults > 0)
if (numResults > 0) {
p << ' ';
llvm::interleaveComma(op->getResultTypes(), p); llvm::interleaveComma(op->getResultTypes(), p);
}
} }

View File

@ -496,28 +496,25 @@ ChangeResult TypeAnalyzer::visitOperation(
} }
// Take dtype from first operand. // Take dtype from first operand.
if (isa<CopyToValueTensorOp, CopyToNonValueTensorOp, AtenTanhOp, if (isa<CopyToValueTensorOp, CopyToNonValueTensorOp, AtenBatchNormOp,
AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenCeilOp, AtenReluOp, AtenGeluOp, AtenCeilOp, AtenGeluBackwardOp,
AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenBitwiseNotOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
AtenSigmoidOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op,
AtenFill_ScalarOp, AtenDetachOp, AtenReciprocalOp, AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenNegOp, AtenFloorOp,
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, Aten_SoftmaxBackwardDataOp, AtenDropoutOp, AtenTanhBackwardOp,
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp,
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenThresholdOp, AtenSquareOp, ValsemVariantAtenUniformOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenThresholdOp, AtenSquareOp, ValsemVariantAtenUniformOp,
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp, AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp, ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
AtenHardswishOp, AtenErfOp, AtenSiluOp, AtenHardtanhOp, AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
AtenMaskedSelectOp, AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp, Aten_UnsafeViewOp, AtenReshapeOp, AtenResize_Op, AtenTransposeIntOp,
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp, AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp,
AtenConstantPadNdOp, AtenIndexTensorOp, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
ValsemVariantAtenCopyOp>(op)) { ValsemVariantAtenCopyOp>(op)) {
ValueKnowledge knowledge = ValueKnowledge knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
@ -525,6 +522,21 @@ ChangeResult TypeAnalyzer::visitOperation(
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
} }
// Dtype is always float32, except for float64 and nullptr.
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenRsqrtOp,
AtenErfOp>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
Type dtype = operands[0]->getValue().dtype;
if (dtype) {
knowledge.dtype = Float32Type::get(op->getContext());
if (dtype.isa<Float64Type>())
knowledge.dtype = dtype;
}
return incorporateKnowledge(op->getResult(0), knowledge);
}
// Take dtype from second operand. // Take dtype from second operand.
if (isa<AtenNllLossBackwardOp>(op)) { if (isa<AtenNllLossBackwardOp>(op)) {
auto self = operands[1]->getValue(); auto self = operands[1]->getValue();

View File

@ -72,6 +72,18 @@ module {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int> %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>
} }
func @"__torch_mlir_shape_fn.aten.exp"(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.sin"(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.cos"(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.hardtanh"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> { func @"__torch_mlir_shape_fn.aten.hardtanh"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int> %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>
@ -1859,7 +1871,7 @@ module {
} }
torch.prim.If.yield torch.prim.If.yield
} }
%2 = torch.operator "aten.sub.float"(%arg1, %arg0) : (!torch.float, !torch.float) -> !torch.float %2 = torch.aten.sub.float %arg1, %arg0 : !torch.float, !torch.float -> !torch.float
%3 = torch.operator "aten.div.float"(%2, %arg2) : (!torch.float, !torch.float) -> !torch.float %3 = torch.operator "aten.div.float"(%2, %arg2) : (!torch.float, !torch.float) -> !torch.float
%4 = torch.operator "aten.ceil.float"(%3) : (!torch.float) -> !torch.int %4 = torch.operator "aten.ceil.float"(%3) : (!torch.float) -> !torch.int
%5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int> %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>
@ -1891,7 +1903,7 @@ module {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield torch.prim.If.yield
} }
%2 = torch.operator "aten.sub.float"(%arg1, %arg0) : (!torch.float, !torch.float) -> !torch.float %2 = torch.aten.sub.float %arg1, %arg0 : !torch.float, !torch.float -> !torch.float
%3 = torch.operator "aten.ceil.float"(%2) : (!torch.float) -> !torch.int %3 = torch.operator "aten.ceil.float"(%2) : (!torch.float) -> !torch.int
%4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list<int> %4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list<int>
return %4 : !torch.list<int> return %4 : !torch.list<int>

View File

@ -303,6 +303,15 @@ def atenhardswish(self: List[int]) -> List[int]:
def atensilu(self: List[int]) -> List[int]: def atensilu(self: List[int]) -> List[int]:
return upstream_shape_helpers.unary(self) return upstream_shape_helpers.unary(self)
def atenexp(self: List[int]) -> List[int]:
return upstream_shape_helpers.unary(self)
def atensin(self: List[int]) -> List[int]:
return upstream_shape_helpers.unary(self)
def atencos(self: List[int]) -> List[int]:
return upstream_shape_helpers.unary(self)
def atenhardtanh(self: List[int], min_val: float = -1, max_val: float = 1) -> List[int]: def atenhardtanh(self: List[int], min_val: float = -1, max_val: float = 1) -> List[int]:
return upstream_shape_helpers.unary(self) return upstream_shape_helpers.unary(self)