[TORCH-MLIR] Add ODS for aten.clamp.Tensor op. (#1894)

This commit adds the ODS definition for the aten.clamp.Tensor op.

Signed-off-by: Prateek Gupta <prateek.gupta2@cerebras.net>
pull/1900/head
Prateek Gupta 2023-02-24 22:48:24 +05:30 committed by GitHub
parent 0afb85d45f
commit 207229297e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 0 deletions

View File

@ -2014,6 +2014,55 @@ def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [
}];
}
def Torch_AtenClampTensorOp : Torch_Op<"aten.clamp.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorType:$min,
AnyTorchOptionalTensorType:$max
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenClampTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenClampTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenClamp_TensorOp : Torch_Op<"aten.clamp_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::clamp_.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorType:$min,
AnyTorchOptionalTensorType:$max
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenClamp_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenClamp_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenClampMinOp : Torch_Op<"aten.clamp_min", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -283,6 +283,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
"aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)",
"aten::clamp_min : (Tensor, Scalar) -> (Tensor)",
"aten::clamp_max : (Tensor, Scalar) -> (Tensor)",
"aten::log2 : (Tensor) -> (Tensor)",