mirror of https://github.com/llvm/torch-mlir
Add non-default dtype support for a few elementwise math ops. (#687)
* fix type inference * fix Torch2Linalg conversion * add test casespull/641/head
parent
fe8ac57e6d
commit
f7c7bb800c
|
@ -37,6 +37,25 @@ def ElementwiseUnaryModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseUnaryIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.tanh(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseUnaryIntModule())
|
||||
def ElementwiseUnaryIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseBinaryModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -282,6 +301,25 @@ def ElementwiseSigmoidModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseSigmoidIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.sigmoid(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSigmoidIntModule())
|
||||
def ElementwiseSigmoidIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 5), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseMinimumModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -545,6 +583,25 @@ def ElementwiseLogModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLogIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.log(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLogIntModule())
|
||||
def ElementwiseLogIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseErfModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -564,6 +621,25 @@ def ElementwiseErfModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseErfIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.erf(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseErfIntModule())
|
||||
def ElementwiseErfIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSqrtModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -585,6 +661,26 @@ def ElementwiseSqrtModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseSqrtIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.sqrt(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSqrtIntModule())
|
||||
def ElementwiseSqrtIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseFloorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -699,6 +795,25 @@ def ElementwiseLog2Module_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLog2IntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.log2(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLog2IntModule())
|
||||
def ElementwiseLog2IntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseRsqrtModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -719,6 +834,26 @@ def ElementwiseRsqrtModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseRsqrtIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.rsqrt(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRsqrtIntModule())
|
||||
def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseAbsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -757,6 +892,25 @@ def ElementwiseReciprocalModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseReciprocalIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.reciprocal(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseReciprocalIntModule())
|
||||
def ElementwiseReciprocalIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (4,), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseDivScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -949,3 +1103,123 @@ class ElementwiseCloneContiguousModule(torch.nn.Module):
|
|||
def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseExpModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.exp(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseExpModule())
|
||||
def ElementwiseExpModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseExpIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.exp(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseExpIntModule())
|
||||
def ElementwiseExpIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseSinModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.sin(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSinModule())
|
||||
def ElementwiseSinModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseSinIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.sin(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSinIntModule())
|
||||
def ElementwiseSinIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseCosModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.cos(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseCosModule())
|
||||
def ElementwiseCosModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseCosIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.cos(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseCosIntModule())
|
||||
def ElementwiseCosIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
|
|
@ -34,6 +34,7 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseUnaryModule_basic",
|
||||
"ElementwiseBinaryModule_basic",
|
||||
"ElementwiseSigmoidModule_basic",
|
||||
"ElementwiseExpModule_basic",
|
||||
"ElementwiseReluModule_basic",
|
||||
"ElementwiseFloorModule_basic",
|
||||
"ElementwiseLogModule_basic",
|
||||
|
|
|
@ -189,25 +189,60 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
|
|||
return buildNormalCdf(b, loc, x, zero, one);
|
||||
}
|
||||
|
||||
template <typename MathOpTy>
|
||||
static Value createCalculationForMathOpWithDtypeConversion(
|
||||
OpBuilder &b, TypeConverter *converter, Value payloadArg, Operation *op) {
|
||||
Type dtype = converter->convertType(op->getResult(0).getType())
|
||||
.template cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Location loc = op->getLoc();
|
||||
Value arg = convertScalarToDtype(b, loc, payloadArg, dtype);
|
||||
return b.create<MathOpTy>(loc, arg);
|
||||
}
|
||||
|
||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||
OpBuilder &b, Location loc, TypeConverter *converter,
|
||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||
if (isa<AtenTanhOp>(op))
|
||||
return b.create<math::TanhOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenExpOp>(op))
|
||||
return b.create<math::ExpOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenFloorOp>(op))
|
||||
return b.create<math::FloorOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenCeilOp>(op))
|
||||
return b.create<math::CeilOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenLogOp>(op))
|
||||
return b.create<math::LogOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenErfOp>(op))
|
||||
return b.create<math::ErfOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenSqrtOp>(op))
|
||||
return b.create<math::SqrtOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenRsqrtOp>(op))
|
||||
return b.create<math::RsqrtOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenTanhOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::TanhOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenExpOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::ExpOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenLogOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::LogOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenLog2Op>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenErfOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::ErfOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenSqrtOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::SqrtOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenRsqrtOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::RsqrtOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenSinOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::SinOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenCosOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::CosOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
|
||||
int64_t memoryFormat;
|
||||
if (!clone.memory_format().getType().isa<Torch::NoneType>() &&
|
||||
|
@ -235,14 +270,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
return b.create<arith::AndIOp>(loc, lhs, rhs);
|
||||
}
|
||||
if (isa<AtenLog2Op>(op))
|
||||
return b.create<math::Log2Op>(loc, payloadArgs[0]);
|
||||
if (isa<AtenAbsOp>(op))
|
||||
return b.create<math::AbsOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenSigmoidOp>(op)) {
|
||||
Type elementType = payloadArgs[0].getType();
|
||||
auto one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
auto negate = b.create<arith::NegFOp>(loc, payloadArgs[0]);
|
||||
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
auto one =
|
||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(negate.getType(), 1));
|
||||
auto exp = b.create<math::ExpOp>(loc, negate);
|
||||
auto added = b.create<arith::AddFOp>(loc, exp, one);
|
||||
return b.create<arith::DivFOp>(loc, one, added);
|
||||
|
@ -763,26 +797,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::DivFOp>(loc, self, other);
|
||||
}
|
||||
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
|
||||
if (!reciprocal.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
reciprocal.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type elementType = payloadArgs[0].getType();
|
||||
Type dtype = converter->convertType(reciprocal.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Type elementType = arg.getType();
|
||||
// assert(element != 0)
|
||||
auto zero =
|
||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ONE,
|
||||
payloadArgs[0], zero);
|
||||
auto pred =
|
||||
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ONE, arg, zero);
|
||||
b.create<cf::AssertOp>(
|
||||
loc, pred, b.getStringAttr("unimplemented: tensor with zero element"));
|
||||
|
||||
auto one =
|
||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
|
||||
return b.create<arith::DivFOp>(loc, one, payloadArgs[0]);
|
||||
return b.create<arith::DivFOp>(loc, one, arg);
|
||||
}
|
||||
if (auto thresholdOp = dyn_cast<AtenThresholdOp>(op)) {
|
||||
// The approach used here is as follows:
|
||||
|
@ -871,7 +901,7 @@ public:
|
|||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp>(op))
|
||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1545,7 +1575,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>();
|
||||
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
|
||||
AtenSinOp, AtenCosOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
||||
|
|
|
@ -62,19 +62,16 @@ ParseResult Torch::parseDefaultTorchOp(OpAsmParser &parser,
|
|||
|
||||
void Torch::printDefaultTorchOp(OpAsmPrinter &p, Operation *op, int numOperands,
|
||||
int numResults) {
|
||||
p << ' ';
|
||||
llvm::interleaveComma(op->getOperands(), p);
|
||||
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{});
|
||||
p << " : ";
|
||||
if (numOperands > 0) {
|
||||
p << ' ';
|
||||
llvm::interleaveComma(op->getOperands(), p);
|
||||
}
|
||||
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{});
|
||||
p << " : ";
|
||||
if (numOperands > 0)
|
||||
llvm::interleaveComma(op->getOperandTypes(), p);
|
||||
}
|
||||
if (numOperands > 0 && numResults > 0) {
|
||||
if (numOperands > 0 && numResults > 0)
|
||||
p << " -> ";
|
||||
}
|
||||
if (numResults > 0) {
|
||||
p << ' ';
|
||||
if (numResults > 0)
|
||||
llvm::interleaveComma(op->getResultTypes(), p);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -496,28 +496,25 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
}
|
||||
|
||||
// Take dtype from first operand.
|
||||
if (isa<CopyToValueTensorOp, CopyToNonValueTensorOp, AtenTanhOp,
|
||||
AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenCeilOp,
|
||||
AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp, AtenCosOp,
|
||||
AtenSigmoidOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
|
||||
AtenFill_ScalarOp, AtenDetachOp, AtenReciprocalOp,
|
||||
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp,
|
||||
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, ValsemVariantAtenUniformOp,
|
||||
if (isa<CopyToValueTensorOp, CopyToNonValueTensorOp, AtenBatchNormOp,
|
||||
AtenReluOp, AtenGeluOp, AtenCeilOp, AtenGeluBackwardOp,
|
||||
AtenBitwiseNotOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
|
||||
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op,
|
||||
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenNegOp, AtenFloorOp,
|
||||
Aten_SoftmaxBackwardDataOp, AtenDropoutOp, AtenTanhBackwardOp,
|
||||
Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp,
|
||||
AtenThresholdOp, AtenSquareOp, ValsemVariantAtenUniformOp,
|
||||
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
|
||||
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
|
||||
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
|
||||
AtenHardswishOp, AtenErfOp, AtenSiluOp, AtenHardtanhOp,
|
||||
AtenMaskedSelectOp, AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp,
|
||||
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp,
|
||||
AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp,
|
||||
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp,
|
||||
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
|
||||
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
|
||||
AtenConstantPadNdOp, AtenIndexTensorOp,
|
||||
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
|
||||
AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
|
||||
AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
|
||||
Aten_UnsafeViewOp, AtenReshapeOp, AtenResize_Op, AtenTransposeIntOp,
|
||||
AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
|
||||
AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
|
||||
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp,
|
||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||
ValsemVariantAtenCopyOp>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
|
@ -525,6 +522,21 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
// Dtype is always float32, except for float64 and nullptr.
|
||||
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
|
||||
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenRsqrtOp,
|
||||
AtenErfOp>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
Type dtype = operands[0]->getValue().dtype;
|
||||
if (dtype) {
|
||||
knowledge.dtype = Float32Type::get(op->getContext());
|
||||
if (dtype.isa<Float64Type>())
|
||||
knowledge.dtype = dtype;
|
||||
}
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
// Take dtype from second operand.
|
||||
if (isa<AtenNllLossBackwardOp>(op)) {
|
||||
auto self = operands[1]->getValue();
|
||||
|
|
|
@ -72,6 +72,18 @@ module {
|
|||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.exp"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.sin"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.cos"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.hardtanh"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
|
@ -1859,7 +1871,7 @@ module {
|
|||
}
|
||||
torch.prim.If.yield
|
||||
}
|
||||
%2 = torch.operator "aten.sub.float"(%arg1, %arg0) : (!torch.float, !torch.float) -> !torch.float
|
||||
%2 = torch.aten.sub.float %arg1, %arg0 : !torch.float, !torch.float -> !torch.float
|
||||
%3 = torch.operator "aten.div.float"(%2, %arg2) : (!torch.float, !torch.float) -> !torch.float
|
||||
%4 = torch.operator "aten.ceil.float"(%3) : (!torch.float) -> !torch.int
|
||||
%5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>
|
||||
|
@ -1891,7 +1903,7 @@ module {
|
|||
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
||||
torch.prim.If.yield
|
||||
}
|
||||
%2 = torch.operator "aten.sub.float"(%arg1, %arg0) : (!torch.float, !torch.float) -> !torch.float
|
||||
%2 = torch.aten.sub.float %arg1, %arg0 : !torch.float, !torch.float -> !torch.float
|
||||
%3 = torch.operator "aten.ceil.float"(%2) : (!torch.float) -> !torch.int
|
||||
%4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list<int>
|
||||
return %4 : !torch.list<int>
|
||||
|
|
|
@ -303,6 +303,15 @@ def aten〇hardswish(self: List[int]) -> List[int]:
|
|||
def aten〇silu(self: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.unary(self)
|
||||
|
||||
def aten〇exp(self: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.unary(self)
|
||||
|
||||
def aten〇sin(self: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.unary(self)
|
||||
|
||||
def aten〇cos(self: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.unary(self)
|
||||
|
||||
def aten〇hardtanh(self: List[int], min_val: float = -1, max_val: float = 1) -> List[int]:
|
||||
return upstream_shape_helpers.unary(self)
|
||||
|
||||
|
|
Loading…
Reference in New Issue