mirror of https://github.com/llvm/torch-mlir
parent
a2e694df40
commit
365655ca29
|
@ -957,6 +957,7 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseEluModule_basic",
|
"ElementwiseEluModule_basic",
|
||||||
"ElementwiseEluNonDefaultModule_basic",
|
"ElementwiseEluNonDefaultModule_basic",
|
||||||
"ElementwiseFloorModule_basic",
|
"ElementwiseFloorModule_basic",
|
||||||
|
"ElementwiseFloorIntModule_basic",
|
||||||
"ElementwiseLogModule_basic",
|
"ElementwiseLogModule_basic",
|
||||||
"ElementwiseBinaryStaticShapeModule_basic",
|
"ElementwiseBinaryStaticShapeModule_basic",
|
||||||
"ElementwiseMinimumModule_basic",
|
"ElementwiseMinimumModule_basic",
|
||||||
|
|
|
@ -1023,51 +1023,6 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenFloorOp : Torch_Op<"aten.floor", [
|
|
||||||
AllowsTypeRefinement,
|
|
||||||
HasValueSemantics,
|
|
||||||
ReadOnly
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`";
|
|
||||||
let arguments = (ins
|
|
||||||
AnyTorchTensorType:$self
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
AnyTorchTensorType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
|
||||||
}
|
|
||||||
void AtenFloorOp::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 1, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
|
|
||||||
IsTrailingUnderscoreInplaceVariant,
|
|
||||||
AllowsTypeRefinement
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`";
|
|
||||||
let arguments = (ins
|
|
||||||
Torch_NonValueTensorType:$self
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
Torch_NonValueTensorType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
|
||||||
}
|
|
||||||
void AtenFloor_Op::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 1, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [
|
def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -3657,6 +3612,52 @@ def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenFloorOp : Torch_Op<"aten.floor", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenFloorOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_NonValueTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_NonValueTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenFloor_Op::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -1117,6 +1117,22 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenFloorOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
void AtenFloorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenFloorOp op, PatternRewriter &rewriter) {
|
||||||
|
auto outputTy = op.getType().dyn_cast<ValueTensorType>();
|
||||||
|
if (outputTy && outputTy.hasDtype() &&
|
||||||
|
outputTy.getDtype().isa<mlir::IntegerType>()) {
|
||||||
|
rewriter.replaceOp(op, op.getSelf());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenMulScalarOp
|
// AtenMulScalarOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -273,7 +273,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::atan : (Tensor) -> (Tensor)",
|
"aten::atan : (Tensor) -> (Tensor)",
|
||||||
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
|
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::neg : (Tensor) -> (Tensor)",
|
"aten::neg : (Tensor) -> (Tensor)",
|
||||||
"aten::floor : (Tensor) -> (Tensor)",
|
|
||||||
"aten::ceil : (Tensor) -> (Tensor)",
|
"aten::ceil : (Tensor) -> (Tensor)",
|
||||||
"aten::bitwise_not : (Tensor) -> (Tensor)",
|
"aten::bitwise_not : (Tensor) -> (Tensor)",
|
||||||
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
|
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
|
@ -333,6 +332,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||||
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||||
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
|
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||||
|
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||||
|
|
||||||
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||||
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||||
|
|
|
@ -1420,6 +1420,24 @@ class ElementwiseFloorModule(torch.nn.Module):
|
||||||
def ElementwiseFloorModule_basic(module, tu: TestUtils):
|
def ElementwiseFloorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
class ElementwiseFloorIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.floor(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseFloorIntModule())
|
||||||
|
def ElementwiseFloorIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-10, high=10).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue