diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 981a19658..4c86af55a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8302,6 +8302,54 @@ def Torch_Aten__Or__TensorOp : Torch_Op<"aten.__or__.Tensor", [ let hasCanonicalizer = 1; } +def Torch_Aten__Lshift__ScalarOp : Torch_Op<"aten.__lshift__.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__lshift__.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__Lshift__ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__Lshift__ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_Aten__Rshift__ScalarOp : Torch_Op<"aten.__rshift__.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__rshift__.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__Rshift__ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__Rshift__ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [ AllowsTypeRefinement, HasValueSemantics, @@ -12014,6 +12062,29 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [ let hasFolder = 1; } +def Torch_AtenViewDtypeOp : Torch_Op<"aten.view.dtype", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::view.dtype : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dtype + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenViewDtypeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenViewDtypeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 7823138c9..1c3806ea3 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -845,6 +845,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, lhs, scaled); } } + if (auto lshiftScalar = dyn_cast(op)) { + Type dtype = + cast(converter->convertType(lshiftScalar.getType())) + .getElementType(); + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value other = convertScalarToDtype(b, loc, operands[1], dtype); + return b.create(loc, self, other); + } + if (auto rshiftScalar = dyn_cast(op)) { + Type dtype = + cast(converter->convertType(rshiftScalar.getType())) + .getElementType(); + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value other = convertScalarToDtype(b, loc, operands[1], dtype); + return b.create(loc, self, other); + } if (auto subScalar = dyn_cast(op)) { Type dtype = cast(converter->convertType(subScalar.getType())) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index cd23e1fa3..710fd5eae 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -691,6 +691,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::__and__.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) emit("aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) + emit("aten::__lshift__.Scalar : (Tensor, Scalar) -> (Tensor)") + emit("aten::__rshift__.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)") emit("aten::std : (Tensor, bool) -> (Tensor)") @@ -886,6 +888,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::_cast_Long : (Tensor, bool) -> (Tensor)", has_canonicalizer=True) emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) + emit("aten::view.dtype : (Tensor, int) -> (Tensor)") emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True) emit(