Add a number of kernels and new patterns.

* convolution, convolution_backward, _log_softmax, _log_softmax_backward_data, nll_loss_forward, nll_loss_backward, nll_loss2d_forward, nll_loss2d_backward, copy_
* Extends the recognition logic and metadata for handling inplace transformations, optional tensors, ints, lists and dropped args.
* The kernel_calls generated by test_conv_nllloss_grads.py now convert to ATen.
* The result *almost* comes out as a pure tensor program with the exception of the copy_ op, which I will do some followup work to deal with.
* More progress on #97
pull/108/head
Stella Laurenzo 2020-11-03 19:24:28 -08:00
parent 3dab9056f0
commit 6c702b149f
17 changed files with 920 additions and 454 deletions

View File

@ -11,4 +11,5 @@ export PYTHONPATH="${build_dir}/python"
python -m torch_mlir_utils.codegen.torch_signature_ods_gen \ python -m torch_mlir_utils.codegen.torch_signature_ods_gen \
--ods_td_file="${aten_dir}/GeneratedATenOps.td" \ --ods_td_file="${aten_dir}/GeneratedATenOps.td" \
--ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc" --ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc" \
--debug_op_reg_file="${aten_dir}/ATenOpRegistrations.txt"

View File

@ -189,8 +189,7 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
return constValue; return constValue;
} }
MlirValue MlirValue FuncBuilder::buildList(MlirLocation loc,
FuncBuilder::buildList(MlirLocation loc,
llvm::SmallVectorImpl<MlirValue> &elements) { llvm::SmallVectorImpl<MlirValue> &elements) {
MlirType resultType = npcompListTypeGet(context); MlirType resultType = npcompListTypeGet(context);
OperationStateHolder state{"basicpy.build_list", loc}; OperationStateHolder state{"basicpy.build_list", loc};

View File

@ -13,6 +13,7 @@ import logging
import re import re
import sys import sys
import textwrap import textwrap
import traceback
# Note that this utility exists only in the c-extension. # Note that this utility exists only in the c-extension.
from _torch_mlir import get_registered_ops from _torch_mlir import get_registered_ops
@ -75,6 +76,55 @@ def generate_ops(g: "OpGenerator"):
g.ordinary_unary_op(f"aten::{uname}(Tensor)", g.ordinary_unary_op(f"aten::{uname}(Tensor)",
f"{snakecase_to_camelcase(uname)}Op", uname) f"{snakecase_to_camelcase(uname)}Op", uname)
# Convolution ops. Note that these are special in PyTorch and the importer,
# and we model them after the signatures of the convolution_overrideable
# ops (generic for non-CPU/GPU backends) but set the names according to
# how they come in.
g.print_banner("NN ops")
g.ordinary_immutable_op(
"aten::convolution_overrideable(Tensor,Tensor,Tensor?,int[],int[],int[],bool,int[],int)",
"ConvolutionOp",
"convolution",
alias_kernel_names=["aten::convolution"])
g.ordinary_immutable_op(
"aten::convolution_backward_overrideable(Tensor,Tensor,Tensor,int[],int[],int[],bool,int[],int,bool[])",
"ConvolutionBackwardOp",
"convolution_backward",
alias_kernel_names=["aten::convolution_backward"],
# These do return as None but are not coded optional in the registry :(
override_return_types=["Tensor?", "Tensor?", "Tensor?"])
g.ordinary_immutable_op("aten::_log_softmax(Tensor,int,bool)",
"LogSoftmaxOp", "log_softmax")
g.ordinary_immutable_op(
"aten::_log_softmax_backward_data(Tensor,Tensor,int,Tensor)",
"LogSoftmaxBackwardDataOp", "log_softmax_backward_data")
# Loss functions.
g.print_banner("Loss function ops")
g.ordinary_immutable_op(
"aten::nll_loss_forward(Tensor,Tensor,Tensor?,int,int)",
"NllLossForwardOp", "nll_loss_forward")
# Note also a grad_input 8-arg variant.
g.ordinary_immutable_op(
"aten::nll_loss_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
"NllLossBackwardOp", "nll_loss_backward")
g.ordinary_immutable_op(
"aten::nll_loss2d_forward(Tensor,Tensor,Tensor?,int,int)",
"NllLoss2dForwardOp", "nll_loss2d_forward")
# Note also a grad_input 8-arg variant.
g.ordinary_immutable_op(
"aten::nll_loss2d_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
"NllLoss2dBackwardOp", "nll_loss2d_backward")
# One-off in-place ops (note that many in-place arithmetic ops are handled
# as a transformation from their immutable forms).
g.ordinary_inplace_op("aten::copy_(Tensor,Tensor,bool)",
"CopyInplaceOp",
"copy.inplace",
drop_arg_indices=[2])
def dump_registered_ops(outfile, reg_ops_dict): def dump_registered_ops(outfile, reg_ops_dict):
for k in sorted(reg_ops_dict.keys()): for k in sorted(reg_ops_dict.keys()):
@ -104,7 +154,20 @@ class OpGenerator:
) )
em.print("") em.print("")
def ordinary_binary_op(self, kernel_sig, ods_name, op_name): def define_op(self, kernel_sig, ods_name, op_name, **kwargs):
return InflightOpDef(self,
kernel_sig=kernel_sig,
ods_name=ods_name,
op_name=op_name,
**kwargs)
def ordinary_binary_op(self,
kernel_sig,
ods_name,
op_name,
promote_trailing_out_tensor=True,
traits=(),
**kwargs):
""""Binary"-ops. These ops typically have: """"Binary"-ops. These ops typically have:
- '.Tensor' variant where the second arg is a Tensor - '.Tensor' variant where the second arg is a Tensor
- '.Scalar' variant where the second arg is a Scalar - '.Scalar' variant where the second arg is a Scalar
@ -124,10 +187,14 @@ class OpGenerator:
- Setting all arguments and returns to kImmutableTensor - Setting all arguments and returns to kImmutableTensor
- Enabling kPromoteScalarToTensor on the second argument. - Enabling kPromoteScalarToTensor on the second argument.
""" """
reg_record = self._get_reg_record(kernel_sig) opdef = self.define_op(
ods_ins, arg_type_flags = self._map_sigtypes( kernel_sig=kernel_sig,
reg_record["arguments"], ods_name=ods_name,
type_transforms={ op_name=op_name,
promote_trailing_out_tensor=promote_trailing_out_tensor,
traits=list(traits) + ["NoSideEffect"],
**kwargs)
opdef.arg_transforms(type_transforms={
"Tensor:0": "AnyTorchImmutableTensor", "Tensor:0": "AnyTorchImmutableTensor",
"Tensor:1": "AnyTorchImmutableTensor", "Tensor:1": "AnyTorchImmutableTensor",
"Scalar:1": "AnyTorchImmutableTensor", "Scalar:1": "AnyTorchImmutableTensor",
@ -137,62 +204,113 @@ class OpGenerator:
":0": ["kImmutableTensor"], ":0": ["kImmutableTensor"],
":1": ["kImmutableTensor", "kPromoteScalar"], ":1": ["kImmutableTensor", "kPromoteScalar"],
}) })
ods_outs, return_type_flags = self._map_sigtypes( opdef.return_transforms(type_transforms={
reg_record["returns"],
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor", "Tensor:0": "AnyTorchImmutableTensor",
}, },
flag_transforms={ flag_transforms={
":0": ["kImmutableTensor"], ":0": ["kImmutableTensor"],
}) })
self.ods_emitter.emit_opdef(ods_name, opdef.emit()
op_name,
reg_record,
ods_ins=ods_ins,
ods_outs=ods_outs,
traits=["NoSideEffect"])
self.impl_emitter.emit_kernel_methods(ods_name,
reg_record,
arg_type_flags=arg_type_flags,
return_type_flags=return_type_flags,
promote_trailing_out_tensor=True)
def ordinary_unary_op(self, kernel_sig, ods_name, op_name): def ordinary_immutable_op(self,
kernel_sig,
ods_name,
op_name,
promote_trailing_out_tensor=True,
traits=(),
**kwargs):
""""An ordinary immutable-tensor based op."""
opdef = self.define_op(
kernel_sig=kernel_sig,
ods_name=ods_name,
op_name=op_name,
promote_trailing_out_tensor=promote_trailing_out_tensor,
traits=list(traits) + ["NoSideEffect"],
**kwargs)
opdef.transforms(type_transforms={
"Tensor": "AnyTorchImmutableTensor",
"Tensor?": "AnyTorchOptionalImmutableTensor",
"int": "AnyTorchIntType",
"int[]": "AnyTorchIntListType",
"bool": "AnyTorchBoolType",
"bool[]": "AnyTorchBoolListType",
},
flag_transforms={
"Tensor": ["kImmutableTensor"],
"Tensor?": ["kImmutableTensor"],
})
opdef.emit()
def ordinary_inplace_op(self, kernel_sig, ods_name, op_name, **kwargs):
"""In-place ops (ending in '_').
These ops take a mutable first argument and then standard immutable
conversions for subsequent. When emitting into MLIR, the return value is
dropped.
"""
opdef = self.define_op(kernel_sig=kernel_sig,
ods_name=ods_name,
op_name=op_name,
**kwargs)
opdef.arg_transforms(type_transforms={
":0": "AnyTorchMutableTensor",
"Tensor": "AnyTorchImmutableTensor",
"Tensor?": "AnyTorchOptionalImmutableTensor",
"int": "AnyTorchIntType",
"int[]": "AnyTorchIntListType",
"bool": "AnyTorchBoolType",
"bool[]": "AnyTorchBoolListType",
},
flag_transforms={
":0": [],
"Tensor": ["kImmutableTensor"],
"Tensor?": ["kImmutableTensor"],
})
opdef.return_transforms(
type_transforms={
":0": "DROP_UNUSED", # Ignored because we clear the outs below.
},
flag_transforms={
":0": ["kDropReturnAndAliasArg0"],
})
opdef.map_signatures()
opdef.ods_outs = [] # Clear the computed outs.
opdef.emit()
def ordinary_unary_op(self,
kernel_sig,
ods_name,
op_name,
promote_trailing_out_tensor=True,
traits=(),
**kwargs):
"""Unary ops. """Unary ops.
These take and return a tensor and typically have an out and inplace These take and return a tensor and typically have an out and inplace
variant (they may not but we generate patterns to match anyway). variant (they may not but we generate patterns to match anyway).
""" """
reg_record = self._get_reg_record(kernel_sig) opdef = self.define_op(
ods_ins, arg_type_flags = self._map_sigtypes( kernel_sig=kernel_sig,
reg_record["arguments"], ods_name=ods_name,
type_transforms={ op_name=op_name,
promote_trailing_out_tensor=promote_trailing_out_tensor,
traits=list(traits) + ["NoSideEffect"],
**kwargs)
opdef.arg_transforms(type_transforms={
"Tensor:0": "AnyTorchImmutableTensor", "Tensor:0": "AnyTorchImmutableTensor",
}, },
flag_transforms={ flag_transforms={
":0": ["kImmutableTensor"], ":0": ["kImmutableTensor"],
}) })
ods_outs, return_type_flags = self._map_sigtypes( opdef.return_transforms(type_transforms={
reg_record["returns"],
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor", "Tensor:0": "AnyTorchImmutableTensor",
}, },
flag_transforms={ flag_transforms={
":0": ["kImmutableTensor"], ":0": ["kImmutableTensor"],
}) })
self.ods_emitter.emit_opdef(ods_name, opdef.emit()
op_name,
reg_record,
ods_ins=ods_ins,
ods_outs=ods_outs,
traits=["NoSideEffect"])
self.impl_emitter.emit_kernel_methods(ods_name,
reg_record,
arg_type_flags=arg_type_flags,
return_type_flags=return_type_flags,
promote_trailing_out_tensor=True)
def _get_reg_record(self, kernel_sig): def get_reg_record(self, kernel_sig):
"""Gets the op-dict for a given registered op name. """Gets the op-dict for a given registered op name.
Args: Args:
@ -212,8 +330,14 @@ class OpGenerator:
raise ValueError(f"Could not find registry op matching '{kernel_sig}'. " raise ValueError(f"Could not find registry op matching '{kernel_sig}'. "
f"Possible matches:\n {dym_message}") f"Possible matches:\n {dym_message}")
def _map_sigtypes(self, siglist: List[Dict], type_transforms: Dict[str, str], def _map_sigtypes(
flag_transforms: Dict[str, List[str]]) -> List[Tuple[str]]: self,
siglist: List[Dict],
type_transforms: Dict[str, str],
flag_transforms: Dict[str, List[str]],
drop_indices: Sequence[int] = (),
override_types: Optional[Sequence[str]] = None,
) -> List[Tuple[str]]:
"""Maps a list of signature entries to ods dags and flag lists. """Maps a list of signature entries to ods dags and flag lists.
The torch signature list contains dicts that minimally have keys 'name' and The torch signature list contains dicts that minimally have keys 'name' and
@ -233,14 +357,22 @@ class OpGenerator:
- An ods dag list of (ods_name, ods_type) tuples - An ods dag list of (ods_name, ods_type) tuples
- List of (torch_type, [conversion_flag]) for specifying conversions. - List of (torch_type, [conversion_flag]) for specifying conversions.
""" """
# Make sure any override types are sane.
if override_types:
assert len(override_types) == len(siglist), (
"Mismatch override and actual types")
# Generate to ods dag list. # Generate to ods dag list.
ods_dag_list = [] ods_dag_list = []
for i, sigitem in enumerate(siglist): for i, sigitem in enumerate(siglist):
if i in drop_indices:
# Do not emit in ODS.
continue
torch_name = sigitem["name"] torch_name = sigitem["name"]
torch_type = sigitem["type"] torch_type = (sigitem["type"]
if override_types is None else override_types[i])
# Look up the type transform. # Look up the type transform.
ods_type = (type_transforms.get(f"{torch_type}:{i}") or ods_type = _first_non_none(type_transforms.get(f"{torch_type}:{i}"),
type_transforms.get(f":{i}") or type_transforms.get(f":{i}"),
type_transforms.get(torch_type)) type_transforms.get(torch_type))
if not ods_type: if not ods_type:
raise ValueError(f"Signature item {i}, type {torch_type} did not match " raise ValueError(f"Signature item {i}, type {torch_type} did not match "
@ -250,16 +382,130 @@ class OpGenerator:
# Generate the type conversion flags. # Generate the type conversion flags.
type_flag_list = [] type_flag_list = []
for i, sigitem in enumerate(siglist): for i, sigitem in enumerate(siglist):
torch_type = sigitem["type"] torch_type = (sigitem["type"]
if override_types is None else override_types[i])
# Look up the type transform. # Look up the type transform.
flags = (flag_transforms.get(f"{torch_type}:{i}") or if i in drop_indices:
flag_transforms.get(f":{i}") or flag_transforms.get(torch_type)) flags = ["kDrop"]
if not flags: else:
flags = _first_non_none(flag_transforms.get(f"{torch_type}:{i}"),
flag_transforms.get(f":{i}"),
flag_transforms.get(torch_type))
if flags is None:
flags = [] flags = []
type_flag_list.append((torch_type, flags)) type_flag_list.append((torch_type, flags))
return ods_dag_list, type_flag_list return ods_dag_list, type_flag_list
class InflightOpDef:
def __init__(self,
g: OpGenerator,
kernel_sig,
ods_name,
op_name,
traits=(),
alias_kernel_names=(),
promote_trailing_out_tensor=False,
override_arg_types=None,
override_return_types=None,
drop_arg_indices=(),
drop_return_indices=()):
super().__init__()
self.g = g
self.kernel_sig = kernel_sig
self.ods_name = ods_name
self.op_name = op_name
self.traits = list(traits)
self.alias_kernel_names = list(alias_kernel_names)
self.promote_trailing_out_tensor = promote_trailing_out_tensor
self.override_arg_types = override_arg_types
self.override_return_types = override_return_types
self.drop_arg_indices = drop_arg_indices
self.drop_return_indices = drop_return_indices
self.reg_record = g.get_reg_record(self.kernel_sig)
self._emitted = False
self._traceback = traceback.extract_stack()[0:-2]
# Arg and flag transform dicts.
self.arg_type_transforms = dict()
self.arg_flag_transforms = dict()
self.return_type_transforms = dict()
self.return_flag_trasforms = dict()
# Signature mapping.
self._sigs_mapped = False
self.ods_ins = None
self.ods_outs = None
self.arg_type_flags = None
self.return_type_flags = None
def __del__(self):
if not self._emitted:
print("WARNING: Op defined but not emitted. Defined at:", file=sys.stderr)
for line in traceback.format_list(self._traceback):
sys.stderr.write(line)
def transforms(self, type_transforms=None, flag_transforms=None):
self.arg_transforms(type_transforms=type_transforms,
flag_transforms=flag_transforms)
self.return_transforms(type_transforms=type_transforms,
flag_transforms=flag_transforms)
return self
def arg_transforms(self, type_transforms=None, flag_transforms=None):
"""Adds arg type and flag transforms dicts."""
if type_transforms:
self.arg_type_transforms.update(type_transforms)
if flag_transforms:
self.arg_flag_transforms.update(flag_transforms)
return self
def return_transforms(self, type_transforms=None, flag_transforms=None):
"""Adds return type and flag transform dicts."""
if type_transforms:
self.return_type_transforms.update(type_transforms)
if flag_transforms:
self.return_flag_trasforms.update(flag_transforms)
return self
def map_signatures(self):
assert not self._sigs_mapped, "Signatures already mapped"
self._sigs_mapped = True
self.ods_ins, self.arg_type_flags = self.g._map_sigtypes(
self.reg_record["arguments"],
type_transforms=self.arg_type_transforms,
flag_transforms=self.arg_flag_transforms,
override_types=self.override_arg_types,
drop_indices=self.drop_arg_indices)
self.ods_outs, self.return_type_flags = self.g._map_sigtypes(
self.reg_record["returns"],
type_transforms=self.return_type_transforms,
flag_transforms=self.return_flag_trasforms,
override_types=self.override_return_types,
drop_indices=self.drop_return_indices)
return self
def emit(self):
assert not self._emitted, "Op already emitted"
self._emitted = True
if not self._sigs_mapped:
self.map_signatures()
self.g.ods_emitter.emit_opdef(self.ods_name,
self.op_name,
self.reg_record,
ods_ins=self.ods_ins,
ods_outs=self.ods_outs,
traits=self.traits)
self.g.impl_emitter.emit_kernel_methods(
self.ods_name,
self.reg_record,
arg_type_flags=self.arg_type_flags,
return_type_flags=self.return_type_flags,
promote_trailing_out_tensor=self.promote_trailing_out_tensor,
alias_kernel_names=self.alias_kernel_names)
class EmitterBase: class EmitterBase:
_INDENT = " " _INDENT = " "
@ -355,7 +601,8 @@ class CCImplEmitter(EmitterBase):
reg_record, reg_record,
arg_type_flags: List[Tuple[str, List[Tuple[str]]]], arg_type_flags: List[Tuple[str, List[Tuple[str]]]],
return_type_flags: List[Tuple[str, List[Tuple[str]]]], return_type_flags: List[Tuple[str, List[Tuple[str]]]],
promote_trailing_out_tensor=False): promote_trailing_out_tensor=False,
alias_kernel_names: Sequence[str] = ()):
# getTorchKernelMetadata() method. # getTorchKernelMetadata() method.
self.print( self.print(
f"Torch::KernelMetadata {ods_def_name}::getTorchKernelMetadata() {{") f"Torch::KernelMetadata {ods_def_name}::getTorchKernelMetadata() {{")
@ -374,6 +621,9 @@ class CCImplEmitter(EmitterBase):
with self.indent(): with self.indent():
self.print("Torch::BuildKernelMetadata m;") self.print("Torch::BuildKernelMetadata m;")
self.print(f"m.kernelName = {self.quote(kernel_name)};") self.print(f"m.kernelName = {self.quote(kernel_name)};")
for alias_kernel_name in alias_kernel_names:
self.print(
f"m.aliasKernelNames.push_back({self.quote(alias_kernel_name)});")
if promote_trailing_out_tensor: if promote_trailing_out_tensor:
self.print("m.promoteTrailingOutTensor = true;") self.print("m.promoteTrailingOutTensor = true;")
# Arg types/flags. # Arg types/flags.
@ -393,7 +643,7 @@ class CCImplEmitter(EmitterBase):
self.print("return m;") self.print("return m;")
self.print("})();") self.print("})();")
self.print("return metadata;") self.print("return metadata;")
self.print("}") self.print("}\n")
def _format_cpp_str_initlist(self, strings): def _format_cpp_str_initlist(self, strings):
quoted = [self.quote(s) for s in strings] quoted = [self.quote(s) for s in strings]
@ -416,6 +666,13 @@ def snakecase_to_camelcase(ident: str):
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident)) return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
def _first_non_none(*args):
for arg in args:
if arg is not None:
return arg
return None
def _load_ops_as_dict(): def _load_ops_as_dict():
# Returns a list of dicts, each with a name that is a tuple of the form: # Returns a list of dicts, each with a name that is a tuple of the form:
# (kernel_signature, variant) # (kernel_signature, variant)

View File

@ -0,0 +1 @@
ATenOpRegistrations.txt

View File

@ -61,55 +61,6 @@ def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>,
} }
// Our jit library only supports 6 argument convolutions, rather than 9
// arguments supported by pytorch. This operation allows us to represent this
// limitation temporarily.
def aten_ConvolutionOp: aten_Op<"_convolution", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$input,
AnyTensor:$weight,
AnyTensor:$bias,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation
);
let summary = "Convolution operator";
let description = [{
Convolution operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
uint64_t getOperandTransferVolume(unsigned int idx, bool read);
uint64_t getResultTransferVolume(unsigned int idx, bool read);
}];
}
// Our jit library only supports 6 argument convolutions, rather than 9
// arguments supported by pytorch. This operation allows us to represent this
// limitation temporarily.
def aten_ConvolutionBackwardOp: aten_Op<"_convolution_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor:$dx, AnyTensor:$dw, AnyTensor:$db)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$input,
AnyTensor:$weight,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation
);
let summary = "ConvolutionBackward operator";
let description = [{
ConvolutionBackward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>, def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> { Results<(outs AnyTensor)> {
let arguments = ( let arguments = (

View File

@ -37,6 +37,7 @@ const Torch::BuildKernelMetadata &AddOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata Atan2Op::getTorchKernelMetadata() { Torch::KernelMetadata Atan2Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -55,6 +56,7 @@ const Torch::BuildKernelMetadata &Atan2Op::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata DivOp::getTorchKernelMetadata() { Torch::KernelMetadata DivOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -73,6 +75,7 @@ const Torch::BuildKernelMetadata &DivOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata FloorDivideOp::getTorchKernelMetadata() { Torch::KernelMetadata FloorDivideOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -91,6 +94,7 @@ const Torch::BuildKernelMetadata &FloorDivideOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata MulOp::getTorchKernelMetadata() { Torch::KernelMetadata MulOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -109,6 +113,7 @@ const Torch::BuildKernelMetadata &MulOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata RemainderOp::getTorchKernelMetadata() { Torch::KernelMetadata RemainderOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -127,6 +132,7 @@ const Torch::BuildKernelMetadata &RemainderOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata TrueDivideOp::getTorchKernelMetadata() { Torch::KernelMetadata TrueDivideOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -145,6 +151,7 @@ const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Unary arithmetic ops // Unary arithmetic ops
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -167,6 +174,7 @@ const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata AcosOp::getTorchKernelMetadata() { Torch::KernelMetadata AcosOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -185,6 +193,7 @@ const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata AngleOp::getTorchKernelMetadata() { Torch::KernelMetadata AngleOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -203,6 +212,7 @@ const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata AsinOp::getTorchKernelMetadata() { Torch::KernelMetadata AsinOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -221,6 +231,7 @@ const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata AtanOp::getTorchKernelMetadata() { Torch::KernelMetadata AtanOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -239,6 +250,7 @@ const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata CeilOp::getTorchKernelMetadata() { Torch::KernelMetadata CeilOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -257,6 +269,7 @@ const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata ConjOp::getTorchKernelMetadata() { Torch::KernelMetadata ConjOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -275,6 +288,7 @@ const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata CosOp::getTorchKernelMetadata() { Torch::KernelMetadata CosOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -293,6 +307,7 @@ const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata CoshOp::getTorchKernelMetadata() { Torch::KernelMetadata CoshOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -311,6 +326,7 @@ const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata DigammaOp::getTorchKernelMetadata() { Torch::KernelMetadata DigammaOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -329,6 +345,7 @@ const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata ErfOp::getTorchKernelMetadata() { Torch::KernelMetadata ErfOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -347,6 +364,7 @@ const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata ErfcOp::getTorchKernelMetadata() { Torch::KernelMetadata ErfcOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -365,6 +383,7 @@ const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata ErfinvOp::getTorchKernelMetadata() { Torch::KernelMetadata ErfinvOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -383,6 +402,7 @@ const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata ExpOp::getTorchKernelMetadata() { Torch::KernelMetadata ExpOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -401,6 +421,7 @@ const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata Expm1Op::getTorchKernelMetadata() { Torch::KernelMetadata Expm1Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -419,6 +440,7 @@ const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata FloorOp::getTorchKernelMetadata() { Torch::KernelMetadata FloorOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -437,6 +459,7 @@ const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata FracOp::getTorchKernelMetadata() { Torch::KernelMetadata FracOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -455,6 +478,7 @@ const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata LgammaOp::getTorchKernelMetadata() { Torch::KernelMetadata LgammaOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -473,6 +497,7 @@ const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata LogOp::getTorchKernelMetadata() { Torch::KernelMetadata LogOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -491,6 +516,7 @@ const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata Log10Op::getTorchKernelMetadata() { Torch::KernelMetadata Log10Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -509,6 +535,7 @@ const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata Log1pOp::getTorchKernelMetadata() { Torch::KernelMetadata Log1pOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -527,6 +554,7 @@ const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata Log2Op::getTorchKernelMetadata() { Torch::KernelMetadata Log2Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -545,6 +573,7 @@ const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata NegOp::getTorchKernelMetadata() { Torch::KernelMetadata NegOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -563,6 +592,7 @@ const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata ReluOp::getTorchKernelMetadata() { Torch::KernelMetadata ReluOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -581,6 +611,7 @@ const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata ReciprocalOp::getTorchKernelMetadata() { Torch::KernelMetadata ReciprocalOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -599,6 +630,7 @@ const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata RoundOp::getTorchKernelMetadata() { Torch::KernelMetadata RoundOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -617,6 +649,7 @@ const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata RsqrtOp::getTorchKernelMetadata() { Torch::KernelMetadata RsqrtOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -635,6 +668,7 @@ const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata SigmoidOp::getTorchKernelMetadata() { Torch::KernelMetadata SigmoidOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -653,6 +687,7 @@ const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata SignOp::getTorchKernelMetadata() { Torch::KernelMetadata SignOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -671,6 +706,7 @@ const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata SinOp::getTorchKernelMetadata() { Torch::KernelMetadata SinOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -689,6 +725,7 @@ const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata SinhOp::getTorchKernelMetadata() { Torch::KernelMetadata SinhOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -707,6 +744,7 @@ const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata SqrtOp::getTorchKernelMetadata() { Torch::KernelMetadata SqrtOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -725,6 +763,7 @@ const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata TanOp::getTorchKernelMetadata() { Torch::KernelMetadata TanOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -743,6 +782,7 @@ const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata TanhOp::getTorchKernelMetadata() { Torch::KernelMetadata TanhOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -761,6 +801,7 @@ const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
Torch::KernelMetadata TruncOp::getTorchKernelMetadata() { Torch::KernelMetadata TruncOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }
@ -779,3 +820,184 @@ const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() {
})(); })();
return metadata; return metadata;
} }
// -----------------------------------------------------------------------------
// NN ops
// -----------------------------------------------------------------------------
Torch::KernelMetadata ConvolutionOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &ConvolutionOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::convolution_overrideable";
m.aliasKernelNames.push_back("aten::convolution");
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata ConvolutionBackwardOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &ConvolutionBackwardOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::convolution_backward_overrideable";
m.aliasKernelNames.push_back("aten::convolution_backward");
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor", "int[]", "int[]", "int[]", "bool", "int[]", "int", "bool[]"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor?", "Tensor?", "Tensor?"});
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata LogSoftmaxOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &LogSoftmaxOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::_log_softmax";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "int", "bool"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata LogSoftmaxBackwardDataOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &LogSoftmaxBackwardDataOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::_log_softmax_backward_data";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "int", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
// -----------------------------------------------------------------------------
// Loss function ops
// -----------------------------------------------------------------------------
Torch::KernelMetadata NllLossForwardOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &NllLossForwardOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::nll_loss_forward";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor", "Tensor"});
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata NllLossBackwardOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &NllLossBackwardOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::nll_loss_backward";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata NllLoss2dForwardOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &NllLoss2dForwardOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::nll_loss2d_forward";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor", "Tensor"});
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata NllLoss2dBackwardOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &NllLoss2dBackwardOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::nll_loss2d_backward";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata CopyInplaceOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &CopyInplaceOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::copy_";
m.addArgTypes({"Tensor", "Tensor", "bool"});
m.addArgConversions({KVC::kNone, KVC::kImmutableTensor, KVC::kDrop});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kDropReturnAndAliasArg0});
return m;
})();
return metadata;
}

View File

@ -450,3 +450,147 @@ def aten_TruncOp: aten_Op<"trunc", [NoSideEffect, DeclareOpInterfaceMethods<Torc
); );
} }
// -----------------------------------------------------------------------------
// NN ops
// -----------------------------------------------------------------------------
def aten_ConvolutionOp: aten_Op<"convolution", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::convolution_overrideable";
let arguments = (ins
AnyTorchImmutableTensor:$input,
AnyTorchImmutableTensor:$weight,
AnyTorchOptionalImmutableTensor:$bias,
AnyTorchIntListType:$stride,
AnyTorchIntListType:$padding,
AnyTorchIntListType:$dilation,
AnyTorchBoolType:$transposed,
AnyTorchIntListType:$output_padding,
AnyTorchIntType:$groups
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_ConvolutionBackwardOp: aten_Op<"convolution_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::convolution_backward_overrideable";
let arguments = (ins
AnyTorchImmutableTensor:$grad_output,
AnyTorchImmutableTensor:$input,
AnyTorchImmutableTensor:$weight,
AnyTorchIntListType:$stride,
AnyTorchIntListType:$padding,
AnyTorchIntListType:$dilation,
AnyTorchBoolType:$transposed,
AnyTorchIntListType:$output_padding,
AnyTorchIntType:$groups,
AnyTorchBoolListType:$output_mask
);
let results = (outs
AnyTorchOptionalImmutableTensor:$grad_input,
AnyTorchOptionalImmutableTensor:$grad_weight,
AnyTorchOptionalImmutableTensor:$grad_bias
);
}
def aten_LogSoftmaxOp: aten_Op<"log_softmax", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::_log_softmax";
let arguments = (ins
AnyTorchImmutableTensor:$self,
AnyTorchIntType:$dim,
AnyTorchBoolType:$half_to_float
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_LogSoftmaxBackwardDataOp: aten_Op<"log_softmax_backward_data", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::_log_softmax_backward_data";
let arguments = (ins
AnyTorchImmutableTensor:$grad_output,
AnyTorchImmutableTensor:$output,
AnyTorchIntType:$dim,
AnyTorchImmutableTensor:$self
);
let results = (outs
AnyTorchImmutableTensor
);
}
// -----------------------------------------------------------------------------
// Loss function ops
// -----------------------------------------------------------------------------
def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::nll_loss_forward";
let arguments = (ins
AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$target,
AnyTorchOptionalImmutableTensor:$weight,
AnyTorchIntType:$reduction,
AnyTorchIntType:$ignore_index
);
let results = (outs
AnyTorchImmutableTensor:$output,
AnyTorchImmutableTensor:$total_weight
);
}
def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::nll_loss_backward";
let arguments = (ins
AnyTorchImmutableTensor:$grad_output,
AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$target,
AnyTorchOptionalImmutableTensor:$weight,
AnyTorchIntType:$reduction,
AnyTorchIntType:$ignore_index,
AnyTorchImmutableTensor:$total_weight
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::nll_loss2d_forward";
let arguments = (ins
AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$target,
AnyTorchOptionalImmutableTensor:$weight,
AnyTorchIntType:$reduction,
AnyTorchIntType:$ignore_index
);
let results = (outs
AnyTorchImmutableTensor:$output,
AnyTorchImmutableTensor:$total_weight
);
}
def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::nll_loss2d_backward";
let arguments = (ins
AnyTorchImmutableTensor:$grad_output,
AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$target,
AnyTorchOptionalImmutableTensor:$weight,
AnyTorchIntType:$reduction,
AnyTorchIntType:$ignore_index,
AnyTorchImmutableTensor:$total_weight
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::copy_";
let arguments = (ins
AnyTorchMutableTensor:$self,
AnyTorchImmutableTensor:$src
);
let results = (outs
);
}

View File

@ -44,52 +44,6 @@ def aten_AsStridedOp: aten_Op<"as_strided", [NoSideEffect, StatisticsOpInterface
}]; }];
} }
def aten_ConvolutionOverrideableOp: aten_Op<"convolution_overrideable", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$input,
AnyTensor:$weight,
AnyTensor:$bias,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation,
AnyScalar:$transposed,
AnyType:$output_padding,
AnyScalar:$groups
);
let summary = "aten convolution_overrideable operator";
let description = [{
ConvolutionOverrideableOp
aten convolution_overrideable operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_ConvolutionBackwardOverrideableOp: aten_Op<"convolution_backward_overrideable", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor, AnyTensor, AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$input,
AnyTensor:$weight,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation,
AnyScalar:$transposed,
AnyType:$output_padding,
AnyScalar:$groups
);
let summary = "aten convolution_backward_overrideable operator";
let description = [{
ConvolutionBackwardOverrideableOp
aten convolution_backward_overrideable operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_DivUnderOp: aten_Op<"div_", [NoSideEffect, StatisticsOpInterface]>, def aten_DivUnderOp: aten_Op<"div_", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> { Results<(outs AnyTensor)> {
let arguments = ( let arguments = (
@ -123,35 +77,6 @@ def aten_ExpandOp: aten_Op<"expand", [NoSideEffect, StatisticsOpInterface]>,
}]; }];
} }
def aten_LogSoftmaxOp: aten_Op<"_log_softmax", [NoSideEffect]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyScalar:$dim,
AnyScalar:$half_to_float
);
let summary = "aten _log_softmax operator";
let description = [{
LogSoftmaxOp
aten _log_softmax operator
}];
}
def aten_LogSoftmaxBackwardDataOp: aten_Op<"_log_softmax_backward_data", [NoSideEffect]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$output,
AnyScalar:$dim,
AnyTensor:$self
);
let summary = "aten _log_softmax_backward_data operator";
let description = [{
LogSoftmaxBackwardDataOp
aten _log_softmax_backward_data operator
}];
}
def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>, def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> { Results<(outs AnyTensor)> {
let arguments = ( let arguments = (
@ -445,86 +370,6 @@ def aten_GatherOp: aten_Op<"gather", [NoSideEffect, StatisticsOpInterface]>,
}]; }];
} }
def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor, AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index
);
let summary = "aten nll_loss_forward operator";
let description = [{
NllLossForwardOp
aten nll_loss_forward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index,
AnyTensor:$total_weight
);
let summary = "aten nll_loss_backward operator";
let description = [{
NllLossBackwardOp
aten nll_loss_backward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor, AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index
);
let summary = "aten nll_loss2d_forward operator";
let description = [{
NllLoss2dForwardOp
aten nll_loss2d_forward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index,
AnyTensor:$total_weight
);
let summary = "aten nll_loss2d_backward operator";
let description = [{
NllLoss2dBackwardOp
aten nll_loss2d_backward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_HardtanhOp: aten_Op<"hardtanh", [NoSideEffect, StatisticsOpInterface]>, def aten_HardtanhOp: aten_Op<"hardtanh", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> { Results<(outs AnyTensor)> {
let arguments = ( let arguments = (

View File

@ -12,15 +12,4 @@
include "mlir/Dialect/StandardOps/IR/Ops.td" include "mlir/Dialect/StandardOps/IR/Ops.td"
include "npcomp/Dialect/ATen/IR/ATenOps.td" include "npcomp/Dialect/ATen/IR/ATenOps.td"
// The pytorch convolution operator has 9 arguments, but we only have a jit
// library that supports the first six at the moment.
def : Pat<(aten_ConvolutionOverrideableOp $a1, $a2, $a3, $a4, $a5, $a6,
$a7, $a8, $a9),
(aten_ConvolutionOp $a1, $a2, $a3, $a4, $a5, $a6)>;
def : Pat<(aten_ConvolutionBackwardOverrideableOp $a1, $a2, $a3, $a4, $a5, $a6,
$a7, $a8, $a9),
(aten_ConvolutionBackwardOp $a1, $a2, $a3, $a4, $a5, $a6)>;
#endif #endif

View File

@ -18,7 +18,7 @@ namespace Torch {
/// Conversion rule to apply to a value (argument or return). /// Conversion rule to apply to a value (argument or return).
namespace KernelValueConversion { namespace KernelValueConversion {
enum BitMask { enum BitMask : uint32_t {
// No coercion. // No coercion.
kNone = 0, kNone = 0,
@ -32,7 +32,16 @@ enum BitMask {
// to a 0d tensor. // to a 0d tensor.
kPromoteScalar = 8, kPromoteScalar = 8,
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kPromoteScalar) // Drops the return value and aliases to argument 0.
// TODO: Remove this in favor of general alias metadata processing (note that
// the vast majority are this one case so it isn't so bad to have a special
// case for it if necessary).
kDropReturnAndAliasArg0 = 16,
// Drops the argument/return.
kDrop = 32,
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kDrop)
}; };
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
} // namespace KernelValueConversion } // namespace KernelValueConversion
@ -74,6 +83,9 @@ struct BuildKernelMetadata : public KernelMetadata {
SmallVector<KernelValueConversion::BitMask, 4> argConversions; SmallVector<KernelValueConversion::BitMask, 4> argConversions;
SmallVector<KernelValueConversion::BitMask, 4> returnConversions; SmallVector<KernelValueConversion::BitMask, 4> returnConversions;
/// Additional alias kernel names to match.
SmallVector<StringRef, 1> aliasKernelNames;
void addArgConversions( void addArgConversions(
std::initializer_list<KernelValueConversion::BitMask> ilist) { std::initializer_list<KernelValueConversion::BitMask> ilist) {
argConversions.insert(argConversions.end(), ilist); argConversions.insert(argConversions.end(), ilist);

View File

@ -72,6 +72,11 @@ def AnyTorchImmutableTensor : AnyTypeOf<[
AnyTensor, AnyTensor,
], "allowable torch immutable tensor">; ], "allowable torch immutable tensor">;
def AnyTorchOptionalImmutableTensor : AnyTypeOf<[
AnyTorchImmutableTensor,
Basicpy_NoneType,
], "allowable torch immutable tensor (or None)">;
def AnyTorchMutableTensor : AnyTypeOf<[ def AnyTorchMutableTensor : AnyTypeOf<[
// "Numpy-style" mutable NDArray. While not offering the full generality // "Numpy-style" mutable NDArray. While not offering the full generality
// of a Torch tensor, it models the same access patterns and implies the // of a Torch tensor, it models the same access patterns and implies the
@ -95,7 +100,28 @@ def AnyTorchScalarType : AnyTypeOf<[
AnySignlessInteger, AnySignlessInteger,
], "Any primitive type suitable to be passed as a Torch Scalar">; ], "Any primitive type suitable to be passed as a Torch Scalar">;
def AnyTorchBoolType : AnyTypeOf<[
I1,
Basicpy_BoolType,
], "Any permissible bool type">;
def AnyTorchBoolListType : AnyTypeOf<[
Basicpy_ListType,
// TODO: Support typed list when available.
], "Any bool list type (bool[])">;
def AnyTorchIntType : AnyTypeOf<[
AnySignedInteger,
AnySignlessInteger,
], "Any primitive integer type suitable to be passed as a Torch 'int'">;
def AnyTorchIntListType : AnyTypeOf<[
Basicpy_ListType,
// TODO: Support typed list when available.
], "Any int list type (int[])">;
def AnyTorchType : AnyTypeOf<[ def AnyTorchType : AnyTypeOf<[
AnyTorchBoolType,
AnyTorchScalarType, AnyTorchScalarType,
AnyTorchTensorType, AnyTorchTensorType,
Basicpy_ListType, Basicpy_ListType,

View File

@ -170,40 +170,6 @@ std::map<std::string, uint64_t> BatchNormOp::getStatistics() {
return toReturn; return toReturn;
} }
// _convolution
std::map<std::string, uint64_t> ConvolutionOp::getStatistics() {
return getConv2dStatistics(this, /*groups*/ 1);
}
std::map<std::string, uint64_t> ConvolutionOverrideableOp::getStatistics() {
// FIXME
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
auto ia = co.template getAttrOfType<IntegerAttr>("value");
uint64_t groups = ia.getValue().getZExtValue();
return getConv2dStatistics(this, groups);
}
uint64_t ConvolutionOp::getOperandTransferVolume(unsigned int idx, bool read) {
return getConv2dOperandTransferVolume<ConvolutionOp>(this, idx, read);
}
uint64_t ConvolutionOp::getResultTransferVolume(unsigned int idx, bool write) {
return getConv2dResultTransferVolume<ConvolutionOp>(this, idx, write);
}
// _convolution_backward
std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() {
return getConv2dBackwardStatistics(*this, 1);
}
std::map<std::string, uint64_t>
ConvolutionBackwardOverrideableOp::getStatistics() {
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
auto ia = co.template getAttrOfType<IntegerAttr>("value");
uint64_t groups = ia.getValue().getZExtValue();
return getConv2dBackwardStatistics(*this, groups);
}
// div_ // div_
std::map<std::string, uint64_t> DivUnderOp::getStatistics() { std::map<std::string, uint64_t> DivUnderOp::getStatistics() {
@ -559,35 +525,6 @@ std::map<std::string, uint64_t> NativeBatchNormBackwardOp::getStatistics() {
return toReturn; return toReturn;
} }
std::map<std::string, uint64_t> NllLossForwardOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
// FIXME: unimplemented
toReturn["reads"] = -1;
toReturn["writes"] = -1;
return toReturn;
}
std::map<std::string, uint64_t> NllLossBackwardOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
// FIXME: unimplemented
toReturn["reads"] = -1;
toReturn["writes"] = -1;
return toReturn;
}
std::map<std::string, uint64_t> NllLoss2dForwardOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
// FIXME: unimplemented
toReturn["reads"] = -1;
toReturn["writes"] = -1;
return toReturn;
}
std::map<std::string, uint64_t> NllLoss2dBackwardOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
// FIXME: unimplemented
toReturn["reads"] = -1;
toReturn["writes"] = -1;
return toReturn;
}
// std::map<std::string, uint64_t> ReLUUnderOp::getStatistics() { // std::map<std::string, uint64_t> ReLUUnderOp::getStatistics() {
// return getReLUOpStatistics(*this); // return getReLUOpStatistics(*this);
// } // }

View File

@ -12,6 +12,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/ATen/IR/ATenDialect.h" #include "npcomp/Dialect/ATen/IR/ATenDialect.h"
#include "npcomp/Dialect/ATen/Transforms/Passes.h" #include "npcomp/Dialect/ATen/Transforms/Passes.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h" #include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h" #include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
@ -28,6 +29,14 @@ using namespace mlir::NPCOMP::Torch;
namespace { namespace {
bool isTorchTensorType(StringRef torchType) {
return torchType == "Tensor" || torchType == "Tensor?";
}
bool isTorchOptionalType(StringRef torchType) {
return torchType.endswith("?");
}
struct TypeConversion { struct TypeConversion {
Type targetType; Type targetType;
std::function<Value(Location loc, Value originalValue, std::function<Value(Location loc, Value originalValue,
@ -49,9 +58,15 @@ convertTorchArgType(StringRef sourceTorchType, StringRef targetTorchType,
// Immutable tensor conversion. // Immutable tensor conversion.
if (flag & KVC::kImmutableTensor) { if (flag & KVC::kImmutableTensor) {
// TODO: Support the kPromoteScalar flag. // TODO: Support the kPromoteScalar flag.
if (sourceTorchType != "Tensor" || targetTorchType != "Tensor") if (!isTorchTensorType(sourceTorchType) ||
!isTorchTensorType(targetTorchType))
return None; return None;
// If the target is optional and the type is NoneType, passthrough.
if (isTorchOptionalType(targetTorchType) &&
sourceMlirType.isa<Basicpy::NoneType>())
return TypeConversion{sourceMlirType, nullptr};
// Already immutable. // Already immutable.
if (sourceMlirType.isa<TensorType>()) if (sourceMlirType.isa<TensorType>())
return TypeConversion{sourceMlirType, nullptr}; return TypeConversion{sourceMlirType, nullptr};
@ -86,30 +101,51 @@ convertTorchReturnType(StringRef sourceTorchType, StringRef targetTorchType,
Type sourceMlirType) { Type sourceMlirType) {
using KVC = KernelValueConversion::BitMask; using KVC = KernelValueConversion::BitMask;
// Default trivial case. // Default trivial case.
if (sourceTorchType == targetTorchType && flag == 0) if (sourceTorchType == targetTorchType && flag == 0) {
LLVM_DEBUG(llvm::dbgs() << " * Return types already match\n");
return TypeConversion{sourceMlirType, nullptr}; return TypeConversion{sourceMlirType, nullptr};
}
// Immutable tensor conversion. // Immutable tensor conversion.
if (flag & KVC::kImmutableTensor) { if (flag & KVC::kImmutableTensor) {
if (sourceTorchType != "Tensor" || targetTorchType != "Tensor") LLVM_DEBUG(llvm::dbgs()
<< " * Return conversion flag kImmutableTensor\n");
if (!isTorchTensorType(sourceTorchType) ||
!isTorchTensorType(targetTorchType)) {
LLVM_DEBUG(llvm::dbgs()
<< " * Source or target not a Tensor type\n");
return None; return None;
}
// Already immutable. // Already immutable.
if (sourceMlirType.isa<TensorType>()) if (sourceMlirType.isa<TensorType>()) {
LLVM_DEBUG(llvm::dbgs() << " * Source is already immutable\n");
return TypeConversion{sourceMlirType, nullptr}; return TypeConversion{sourceMlirType, nullptr};
}
// Convert NdArray type. // Convert NdArray type.
if (auto ndArrayType = sourceMlirType.dyn_cast<Numpy::NdArrayType>()) { if (sourceMlirType.isa<Basicpy::NoneType>() &&
isTorchOptionalType(targetTorchType)) {
LLVM_DEBUG(llvm::dbgs() << " * None Tensor type passthrough\n");
return TypeConversion{sourceMlirType, nullptr};
} else if (auto ndArrayType =
sourceMlirType.dyn_cast<Numpy::NdArrayType>()) {
auto tensorType = ndArrayType.toTensorType(); auto tensorType = ndArrayType.toTensorType();
auto callback = [=](Location loc, Value newOpResultValue, auto callback = [=](Location loc, Value newOpResultValue,
PatternRewriter &rewriter) -> Value { PatternRewriter &rewriter) -> Value {
return rewriter.create<Numpy::CreateArrayFromTensorOp>( return rewriter.create<Numpy::CreateArrayFromTensorOp>(
loc, ndArrayType, newOpResultValue); loc, ndArrayType, newOpResultValue);
}; };
LLVM_DEBUG(llvm::dbgs() << " * Convert return type\n");
return TypeConversion{tensorType, callback}; return TypeConversion{tensorType, callback};
} else {
LLVM_DEBUG(llvm::dbgs()
<< " * Return type is not a supported tensor type\n");
return None;
} }
} }
LLVM_DEBUG(llvm::dbgs() << " * Return type conversion fallthrough\n");
return None; return None;
} }
@ -142,11 +178,18 @@ public:
const BuildKernelMetadata &buildMetadata) { const BuildKernelMetadata &buildMetadata) {
LLVM_DEBUG(llvm::dbgs() LLVM_DEBUG(llvm::dbgs()
<< "Register kernel call translation for: " << opName << "\n"); << "Register kernel call translation for: " << opName << "\n");
{
CandidateTransformList &candidates = CandidateTransformList &candidates =
kernelTransforms[buildMetadata.kernelName]; kernelTransforms[buildMetadata.kernelName];
candidates.emplace_back(opName, buildMetadata); candidates.emplace_back(opName, buildMetadata);
} }
for (StringRef aliasKernelName : buildMetadata.aliasKernelNames) {
CandidateTransformList &candidates = kernelTransforms[aliasKernelName];
candidates.emplace_back(opName, buildMetadata);
}
}
LogicalResult transformKernelCall(KernelCallOp kernelCall, LogicalResult transformKernelCall(KernelCallOp kernelCall,
PatternRewriter &rewriter) const { PatternRewriter &rewriter) const {
StringRef kernelName = kernelCall.kernelName(); StringRef kernelName = kernelCall.kernelName();
@ -229,6 +272,8 @@ public:
"arg arity mismatch"); "arg arity mismatch");
// Convert fixed return types. // Convert fixed return types.
using PostConversionCallback = std::function<void()>;
SmallVector<PostConversionCallback, 4> postConversionCallbacks;
struct ConversionInfo { struct ConversionInfo {
Value originalValue; Value originalValue;
TypeConversion conversion; TypeConversion conversion;
@ -241,25 +286,49 @@ public:
KVC flag = candidate.buildMetadata.getReturnConversion(i); KVC flag = candidate.buildMetadata.getReturnConversion(i);
Value sourceValue = kernelCall.getResult(i); Value sourceValue = kernelCall.getResult(i);
Type sourceMlirType = kernelCall.getResultTypes()[i]; Type sourceMlirType = kernelCall.getResultTypes()[i];
auto conversion = convertTorchReturnType(sourceTorchType, targetTorchType, if (flag & KVC::kDropReturnAndAliasArg0) {
flag, sourceMlirType); // Reduce result arity and alias any uses to arg0.
if (kernelCall.args().empty()) {
LLVM_DEBUG(llvm::dbgs()
<< " - Cannot alias arg0 (no arguments)\n");
return failure();
}
Value arg0 = kernelCall.args()[0];
postConversionCallbacks.push_back(
[sourceValue, arg0]() { sourceValue.replaceAllUsesWith(arg0); });
} else {
// General, arity-preserving type conversion.
auto conversion = convertTorchReturnType(
sourceTorchType, targetTorchType, flag, sourceMlirType);
if (!conversion) { if (!conversion) {
LLVM_DEBUG(llvm::dbgs() << " - Return type[" << i LLVM_DEBUG(llvm::dbgs()
<< "] incompatible: source=" << sourceTorchType << " - Return type[" << i << "] incompatible: source="
<< ", target=" << targetTorchType << sourceTorchType << ", target=" << targetTorchType
<< ", flag=" << flag << "\n"); << ", flag=" << flag << "\n");
return failure(); return failure();
} }
resultTypes.push_back(conversion->targetType); resultTypes.push_back(conversion->targetType);
resultConversions.push_back({sourceValue, std::move(*conversion)}); resultConversions.push_back({sourceValue, std::move(*conversion)});
} }
}
// Convert fixed arg types. // Convert fixed arg types.
SmallVector<ConversionInfo, 4> operandInfos; SmallVector<ConversionInfo, 4> operandInfos;
for (size_t i = 0; i < fixedArgArity; ++i) { for (size_t i = 0, operandIndex = 0; i < fixedArgArity; ++i) {
// Drop this arg?
if (candidate.buildMetadata.argConversions[i] & KVC::kDrop)
continue;
if (kernelCall.getNumOperands() <= operandIndex) {
LLVM_DEBUG(llvm::dbgs()
<< " - Arg operand " << i
<< " does not exist in kernel call (missing default?)\n");
return failure();
}
// Normal type conversion of the operand.
operandInfos.emplace_back(); operandInfos.emplace_back();
ConversionInfo &info = operandInfos.back(); ConversionInfo &info = operandInfos.back();
info.originalValue = kernelCall.getOperand(i); info.originalValue = kernelCall.getOperand(operandIndex++);
Type sourceMlirType = info.originalValue.getType(); Type sourceMlirType = info.originalValue.getType();
auto conversion = convertTorchArgType( auto conversion = convertTorchArgType(
/*sourceTorchType=*/sourceMetadata.argTypes[i], /*sourceTorchType=*/sourceMetadata.argTypes[i],
@ -312,6 +381,10 @@ public:
origOpResultValue.replaceAllUsesWith(convertedValue); origOpResultValue.replaceAllUsesWith(convertedValue);
} }
// Post conversion callbacks.
for (auto &callback : postConversionCallbacks)
callback();
// Done. // Done.
rewriter.eraseOp(kernelCall); rewriter.eraseOp(kernelCall);
return success(); return success();

View File

@ -1,25 +0,0 @@
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
// CHECK-LABEL: "L0-_convolution-0": {
// CHECK-NEXT: "activation_in": 32768,
// CHECK-NEXT: "activation_out": 65536,
// CHECK-NEXT: "ops:+": 65536,
// CHECK-NEXT: "ops:MAC": 6422528,
// CHECK-NEXT: "parameters_in": 1584,
// CHECK-NEXT: "reads": 34352,
// CHECK-NEXT: "writes": 65536
module {
func @graph(%arg0: tensor<1x2x128x128xf32>, %arg1: tensor<16x2x7x7xf32>, %arg2: tensor<16xf32>) -> tensor<1x16x64x64xf32> {
%0 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi64>} : () -> !aten.list<i32>
%1 = "aten.constant"() {type = "List[i32]", value = dense<3> : vector<2xi64>} : () -> !aten.list<i32>
%2 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi64>} : () -> !aten.list<i32>
%3 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%4 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi64>} : () -> !aten.list<i32>
%5 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
%6 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%7 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%8 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
%9 = "aten._convolution"(%arg0, %arg1, %arg2, %0, %1, %2) : (tensor<1x2x128x128xf32>, tensor<16x2x7x7xf32>, tensor<16xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<1x16x64x64xf32>
"std.return"(%9) : (tensor<1x16x64x64xf32>) -> ()
}
}

View File

@ -1,23 +0,0 @@
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
// CHECK-LABEL: "L0-convolution_backward_overrideable-0": {
// CHECK-NEXT: "activation_in": 5568,
// CHECK-NEXT: "grad": 5380,
// CHECK-NEXT: "ops:+": 768,
// CHECK-NEXT: "ops:MAC": 345600,
// CHECK-NEXT: "parameters_in": 576,
// CHECK-NEXT: "reads": 6144,
// CHECK-NEXT: "writes": 5380
// CHECK-NEXT: }
// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION
// CHECK-CONVERSION-LABEL: @graph
module {
func @graph(%arg0: tensor<3x4x8x8xf32>, %arg1: tensor<3x16x10x10xf32>, %arg2: tensor<4x16x3x3xf32>) -> tensor<4x16x3x3xf32> {
%0 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
%1 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
%2 = "aten.constant"() {type = "bool", value = false} : () -> i1
%3 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
%10:3 = "aten.convolution_backward_overrideable"(%arg0, %arg1, %arg2, %0, %1, %0, %2, %1, %3) {layer_name = "L5-convolution_backward_overrideable-0"} : (tensor<3x4x8x8xf32>, tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, i1, !aten.list<i32>, i32) -> (tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, tensor<4xf32>)
return %10#1 : tensor<4x16x3x3xf32>
}
}

View File

@ -1,26 +0,0 @@
// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION
// CHECK-CONVERSION-LABEL: @graph
module {
func @graph(%arg0: tensor<10xf32>, %arg1: tensor<128xf32>, %arg2: tensor<4x1x28x28xf32>, %arg3: tensor<32x1x3x3xf32>, %arg4: tensor<32xf32>, %arg5: tensor<64x32x3x3xf32>, %arg6: tensor<64xf32>, %arg7: tensor<128x9216xf32>, %arg8: tensor<10x128xf32>) -> tensor<4x10xf32> {
%0 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
%1 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
%2 = "aten.constant"() {type = "bool", value = false} : () -> i1
%3 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
%4 = "aten.constant"() {type = "bool", value = true} : () -> i1
%5 = "aten._convolution"(%arg2, %arg3, %arg4, %0, %1, %0) {layer_name = "L0-_convolution-0"} : (tensor<4x1x28x28xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<4x32x26x26xf32>
%6 = "aten.relu"(%5) {layer_name = "L1-relu-0"} : (tensor<4x32x26x26xf32>) -> tensor<4x32x26x26xf32>
%7 = "aten._convolution"(%6, %arg5, %arg6, %0, %1, %0) {layer_name = "L2-_convolution-1"} : (tensor<4x32x26x26xf32>, tensor<64x32x3x3xf32>, tensor<64xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<4x64x24x24xf32>
%8 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi32>} : () -> !aten.list<i32>
%9:2 = "aten.max_pool2d_with_indices"(%7, %8, %8, %1, %0, %2) {layer_name = "L3-max_pool2d_with_indices-0"} : (tensor<4x64x24x24xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, i1) -> (tensor<4x64x12x12xf32>, tensor<4x64x12x12xi64>)
%10 = "aten.constant"() {type = "List[i32]", value = dense<[4, 9216]> : vector<2xi32>} : () -> !aten.list<i32>
%11 = "aten.view"(%9#0, %10) {layer_name = "L4-view-0"} : (tensor<4x64x12x12xf32>, !aten.list<i32>) -> tensor<4x9216xf32>
%12 = "aten.t"(%arg7) {layer_name = "L5-t-0"} : (tensor<128x9216xf32>) -> tensor<9216x128xf32>
%13 = "aten.addmm"(%arg1, %11, %12, %3, %3) {layer_name = "L6-addmm-0"} : (tensor<128xf32>, tensor<4x9216xf32>, tensor<9216x128xf32>, i32, i32) -> tensor<4x128xf32>
%14 = "aten.relu"(%13) {layer_name = "L7-relu-1"} : (tensor<4x128xf32>) -> tensor<4x128xf32>
%15 = "aten.t"(%arg8) {layer_name = "L8-t-1"} : (tensor<10x128xf32>) -> tensor<128x10xf32>
%16 = "aten.addmm"(%arg0, %14, %15, %3, %3) {layer_name = "L9-addmm-1"} : (tensor<10xf32>, tensor<4x128xf32>, tensor<128x10xf32>, i32, i32) -> tensor<4x10xf32>
%17 = "aten._log_softmax"(%16, %3, %2) {layer_name = "L10-_log_softmax-0"} : (tensor<4x10xf32>, i32, i1) -> tensor<4x10xf32>
return %17 : tensor<4x10xf32>
}
}

View File

@ -13,3 +13,86 @@ func @graph(%arg0: !numpy.ndarray<*:?>, %arg1 : !numpy.ndarray<*:?>, %arg2 : si6
// CHECK: return %[[RESULT_MUT]] // CHECK: return %[[RESULT_MUT]]
return %0 : !numpy.ndarray<*:?> return %0 : !numpy.ndarray<*:?>
} }
// -----
// CHECK-LABEL: func @nll_loss2d_forward
// Contains a Tensor? type mapped to None.
func @nll_loss2d_forward(
%arg0: !numpy.ndarray<[3,4,8,8]:f32>,
%arg1: !numpy.ndarray<[3,8,8]:i64>,
%arg2: !basicpy.NoneType,
%arg3: i64,
%arg4: i64) -> (!numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>) {
// CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
// CHECK: %[[TOUTPUT:.*]], %[[TTOTAL_WEIGHT:.*]] = "aten.nll_loss2d_forward"(%[[TARG0]], %[[TARG1]], %arg2, %arg3, %arg4) : (tensor<3x4x8x8xf32>, tensor<3x8x8xi64>, !basicpy.NoneType, i64, i64) -> (tensor<f32>, tensor<f32>)
// CHECK: %[[AOUTPUT:.*]] = numpy.create_array_from_tensor %[[TOUTPUT]]
// CHECK: %[[ATOTAL_WEIGHT:.*]] = numpy.create_array_from_tensor %[[TTOTAL_WEIGHT]]
%0:2 = torch.kernel_call "aten::nll_loss2d_forward"
%arg0, %arg1, %arg2, %arg3, %arg4 :
(!numpy.ndarray<[3,4,8,8]:f32>, !numpy.ndarray<[3,8,8]:i64>, !basicpy.NoneType, i64, i64) ->
(!numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>)
{sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor", "Tensor"]}
// CHECK: return %[[AOUTPUT]], %[[ATOTAL_WEIGHT]]
return %0#0, %0#1 : !numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>
}
// -----
// CHECK-LABEL: func @convolution
// Contains a Tensor?, bool, int and list types.
func @convolution(
%arg0: !numpy.ndarray<[3,16,10,10]:f32>, %arg1: !numpy.ndarray<[4,16,3,3]:f32>,
%arg2: !numpy.ndarray<[4]:f32>, %arg3: !basicpy.ListType, %arg4: !basicpy.ListType,
%arg5: !basicpy.ListType, %arg6: i1, %arg7: !basicpy.ListType, %arg8: i64) -> !numpy.ndarray<[3,4,8,8]:f32> {
// CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
// CHECK: %[[TARG2:.*]] = numpy.copy_to_tensor %arg2
// CHECK: %[[TRESULT:.*]] = "aten.convolution"(%[[TARG0]], %[[TARG1]], %[[TARG2]], %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, tensor<4xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> tensor<3x4x8x8xf32>
%0 = torch.kernel_call "aten::convolution"
%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8 :
(!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>,
!numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType,
!basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32>
{sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
return %0 : !numpy.ndarray<[3,4,8,8]:f32>
}
// -----
// CHECK-LABEL: func @convolution_backward
// Interesting because it has optional tensor returns.
func @convolution_backward(
%arg0: !numpy.ndarray<[3,4,8,8]:f32>,
%arg1: !numpy.ndarray<[3,16,10,10]:f32>,
%arg2: !numpy.ndarray<[4,16,3,3]:f32>,
%arg3: !basicpy.ListType,
%arg4: !basicpy.ListType,
%arg5: !basicpy.ListType,
%arg6: i1,
%arg7: !basicpy.ListType,
%arg8: i64,
%arg9: !basicpy.ListType) -> (!basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>) {
// CHECK: %[[GRAD_INPUT:.*]], %[[GRAD_WEIGHT:.*]], %[[GRAD_BIAS:.*]] = "aten.convolution_backward"
// Note that this kernel call masks out the input gradients, which will return as NoneType.
%0:3 = torch.kernel_call "aten::convolution_backward"
%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9 :
(!numpy.ndarray<[3,4,8,8]:f32>, !numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64, !basicpy.ListType) ->
(!basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>) {sigArgTypes = ["Tensor", "Tensor", "Tensor", "int[]", "int[]", "int[]", "bool", "int[]", "int", "bool[]"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor", "Tensor", "Tensor"]}
// CHECK: %[[AGRAD_WEIGHT:.*]] = numpy.create_array_from_tensor %[[GRAD_WEIGHT]]
// CHECK: %[[AGRAD_BIAS:.*]] = numpy.create_array_from_tensor %[[GRAD_BIAS]]
// Key thing: The return returns the raw NoneType from the masked input gradient
// and it does not get converted to an array.
// CHECK: return %[[GRAD_INPUT]], %[[AGRAD_WEIGHT]], %[[AGRAD_BIAS]]
return %0#0, %0#1, %0#2 : !basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>
}
// -----
// CHECK-LABEL: func @copy_inplace
// Mutable/in-place op conversion, dropping result.
func @copy_inplace(%arg0: !numpy.ndarray<[4]:f32>, %arg1: !numpy.ndarray<[4]:f32>) -> !numpy.ndarray<[4]:f32> {
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
// CHECK: "aten.copy.inplace"(%arg0, %[[TARG1]]) : (!numpy.ndarray<[4]:f32>, tensor<4xf32>) -> ()
%0 = torch.kernel_call "aten::copy_" %arg0, %arg1 : (!numpy.ndarray<[4]:f32>, !numpy.ndarray<[4]:f32>) -> !numpy.ndarray<[4]:f32> {sigArgTypes = ["Tensor", "Tensor", "bool"], sigIsMutable = true, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
// CHECK: return %arg0
return %0 : !numpy.ndarray<[4]:f32>
}