mirror of https://github.com/llvm/torch-mlir
Add a number of kernels and new patterns.
* convolution, convolution_backward, _log_softmax, _log_softmax_backward_data, nll_loss_forward, nll_loss_backward, nll_loss2d_forward, nll_loss2d_backward, copy_ * Extends the recognition logic and metadata for handling inplace transformations, optional tensors, ints, lists and dropped args. * The kernel_calls generated by test_conv_nllloss_grads.py now convert to ATen. * The result *almost* comes out as a pure tensor program with the exception of the copy_ op, which I will do some followup work to deal with. * More progress on #97pull/108/head
parent
3dab9056f0
commit
6c702b149f
|
@ -11,4 +11,5 @@ export PYTHONPATH="${build_dir}/python"
|
|||
|
||||
python -m torch_mlir_utils.codegen.torch_signature_ods_gen \
|
||||
--ods_td_file="${aten_dir}/GeneratedATenOps.td" \
|
||||
--ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc"
|
||||
--ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc" \
|
||||
--debug_op_reg_file="${aten_dir}/ATenOpRegistrations.txt"
|
||||
|
|
|
@ -189,9 +189,8 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
|||
return constValue;
|
||||
}
|
||||
|
||||
MlirValue
|
||||
FuncBuilder::buildList(MlirLocation loc,
|
||||
llvm::SmallVectorImpl<MlirValue> &elements) {
|
||||
MlirValue FuncBuilder::buildList(MlirLocation loc,
|
||||
llvm::SmallVectorImpl<MlirValue> &elements) {
|
||||
MlirType resultType = npcompListTypeGet(context);
|
||||
OperationStateHolder state{"basicpy.build_list", loc};
|
||||
mlirOperationStateAddResults(state, 1, &resultType);
|
||||
|
|
|
@ -13,6 +13,7 @@ import logging
|
|||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
|
||||
# Note that this utility exists only in the c-extension.
|
||||
from _torch_mlir import get_registered_ops
|
||||
|
@ -75,6 +76,55 @@ def generate_ops(g: "OpGenerator"):
|
|||
g.ordinary_unary_op(f"aten::{uname}(Tensor)",
|
||||
f"{snakecase_to_camelcase(uname)}Op", uname)
|
||||
|
||||
# Convolution ops. Note that these are special in PyTorch and the importer,
|
||||
# and we model them after the signatures of the convolution_overrideable
|
||||
# ops (generic for non-CPU/GPU backends) but set the names according to
|
||||
# how they come in.
|
||||
g.print_banner("NN ops")
|
||||
g.ordinary_immutable_op(
|
||||
"aten::convolution_overrideable(Tensor,Tensor,Tensor?,int[],int[],int[],bool,int[],int)",
|
||||
"ConvolutionOp",
|
||||
"convolution",
|
||||
alias_kernel_names=["aten::convolution"])
|
||||
g.ordinary_immutable_op(
|
||||
"aten::convolution_backward_overrideable(Tensor,Tensor,Tensor,int[],int[],int[],bool,int[],int,bool[])",
|
||||
"ConvolutionBackwardOp",
|
||||
"convolution_backward",
|
||||
alias_kernel_names=["aten::convolution_backward"],
|
||||
# These do return as None but are not coded optional in the registry :(
|
||||
override_return_types=["Tensor?", "Tensor?", "Tensor?"])
|
||||
|
||||
g.ordinary_immutable_op("aten::_log_softmax(Tensor,int,bool)",
|
||||
"LogSoftmaxOp", "log_softmax")
|
||||
g.ordinary_immutable_op(
|
||||
"aten::_log_softmax_backward_data(Tensor,Tensor,int,Tensor)",
|
||||
"LogSoftmaxBackwardDataOp", "log_softmax_backward_data")
|
||||
|
||||
# Loss functions.
|
||||
g.print_banner("Loss function ops")
|
||||
g.ordinary_immutable_op(
|
||||
"aten::nll_loss_forward(Tensor,Tensor,Tensor?,int,int)",
|
||||
"NllLossForwardOp", "nll_loss_forward")
|
||||
# Note also a grad_input 8-arg variant.
|
||||
g.ordinary_immutable_op(
|
||||
"aten::nll_loss_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
|
||||
"NllLossBackwardOp", "nll_loss_backward")
|
||||
|
||||
g.ordinary_immutable_op(
|
||||
"aten::nll_loss2d_forward(Tensor,Tensor,Tensor?,int,int)",
|
||||
"NllLoss2dForwardOp", "nll_loss2d_forward")
|
||||
# Note also a grad_input 8-arg variant.
|
||||
g.ordinary_immutable_op(
|
||||
"aten::nll_loss2d_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
|
||||
"NllLoss2dBackwardOp", "nll_loss2d_backward")
|
||||
|
||||
# One-off in-place ops (note that many in-place arithmetic ops are handled
|
||||
# as a transformation from their immutable forms).
|
||||
g.ordinary_inplace_op("aten::copy_(Tensor,Tensor,bool)",
|
||||
"CopyInplaceOp",
|
||||
"copy.inplace",
|
||||
drop_arg_indices=[2])
|
||||
|
||||
|
||||
def dump_registered_ops(outfile, reg_ops_dict):
|
||||
for k in sorted(reg_ops_dict.keys()):
|
||||
|
@ -104,7 +154,20 @@ class OpGenerator:
|
|||
)
|
||||
em.print("")
|
||||
|
||||
def ordinary_binary_op(self, kernel_sig, ods_name, op_name):
|
||||
def define_op(self, kernel_sig, ods_name, op_name, **kwargs):
|
||||
return InflightOpDef(self,
|
||||
kernel_sig=kernel_sig,
|
||||
ods_name=ods_name,
|
||||
op_name=op_name,
|
||||
**kwargs)
|
||||
|
||||
def ordinary_binary_op(self,
|
||||
kernel_sig,
|
||||
ods_name,
|
||||
op_name,
|
||||
promote_trailing_out_tensor=True,
|
||||
traits=(),
|
||||
**kwargs):
|
||||
""""Binary"-ops. These ops typically have:
|
||||
- '.Tensor' variant where the second arg is a Tensor
|
||||
- '.Scalar' variant where the second arg is a Scalar
|
||||
|
@ -124,75 +187,130 @@ class OpGenerator:
|
|||
- Setting all arguments and returns to kImmutableTensor
|
||||
- Enabling kPromoteScalarToTensor on the second argument.
|
||||
"""
|
||||
reg_record = self._get_reg_record(kernel_sig)
|
||||
ods_ins, arg_type_flags = self._map_sigtypes(
|
||||
reg_record["arguments"],
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
"Tensor:1": "AnyTorchImmutableTensor",
|
||||
"Scalar:1": "AnyTorchImmutableTensor",
|
||||
"Scalar": "AnyTorchScalarType",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
":1": ["kImmutableTensor", "kPromoteScalar"],
|
||||
})
|
||||
ods_outs, return_type_flags = self._map_sigtypes(
|
||||
reg_record["returns"],
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
self.ods_emitter.emit_opdef(ods_name,
|
||||
op_name,
|
||||
reg_record,
|
||||
ods_ins=ods_ins,
|
||||
ods_outs=ods_outs,
|
||||
traits=["NoSideEffect"])
|
||||
self.impl_emitter.emit_kernel_methods(ods_name,
|
||||
reg_record,
|
||||
arg_type_flags=arg_type_flags,
|
||||
return_type_flags=return_type_flags,
|
||||
promote_trailing_out_tensor=True)
|
||||
opdef = self.define_op(
|
||||
kernel_sig=kernel_sig,
|
||||
ods_name=ods_name,
|
||||
op_name=op_name,
|
||||
promote_trailing_out_tensor=promote_trailing_out_tensor,
|
||||
traits=list(traits) + ["NoSideEffect"],
|
||||
**kwargs)
|
||||
opdef.arg_transforms(type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
"Tensor:1": "AnyTorchImmutableTensor",
|
||||
"Scalar:1": "AnyTorchImmutableTensor",
|
||||
"Scalar": "AnyTorchScalarType",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
":1": ["kImmutableTensor", "kPromoteScalar"],
|
||||
})
|
||||
opdef.return_transforms(type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
opdef.emit()
|
||||
|
||||
def ordinary_unary_op(self, kernel_sig, ods_name, op_name):
|
||||
def ordinary_immutable_op(self,
|
||||
kernel_sig,
|
||||
ods_name,
|
||||
op_name,
|
||||
promote_trailing_out_tensor=True,
|
||||
traits=(),
|
||||
**kwargs):
|
||||
""""An ordinary immutable-tensor based op."""
|
||||
opdef = self.define_op(
|
||||
kernel_sig=kernel_sig,
|
||||
ods_name=ods_name,
|
||||
op_name=op_name,
|
||||
promote_trailing_out_tensor=promote_trailing_out_tensor,
|
||||
traits=list(traits) + ["NoSideEffect"],
|
||||
**kwargs)
|
||||
opdef.transforms(type_transforms={
|
||||
"Tensor": "AnyTorchImmutableTensor",
|
||||
"Tensor?": "AnyTorchOptionalImmutableTensor",
|
||||
"int": "AnyTorchIntType",
|
||||
"int[]": "AnyTorchIntListType",
|
||||
"bool": "AnyTorchBoolType",
|
||||
"bool[]": "AnyTorchBoolListType",
|
||||
},
|
||||
flag_transforms={
|
||||
"Tensor": ["kImmutableTensor"],
|
||||
"Tensor?": ["kImmutableTensor"],
|
||||
})
|
||||
opdef.emit()
|
||||
|
||||
def ordinary_inplace_op(self, kernel_sig, ods_name, op_name, **kwargs):
|
||||
"""In-place ops (ending in '_').
|
||||
|
||||
These ops take a mutable first argument and then standard immutable
|
||||
conversions for subsequent. When emitting into MLIR, the return value is
|
||||
dropped.
|
||||
"""
|
||||
opdef = self.define_op(kernel_sig=kernel_sig,
|
||||
ods_name=ods_name,
|
||||
op_name=op_name,
|
||||
**kwargs)
|
||||
opdef.arg_transforms(type_transforms={
|
||||
":0": "AnyTorchMutableTensor",
|
||||
"Tensor": "AnyTorchImmutableTensor",
|
||||
"Tensor?": "AnyTorchOptionalImmutableTensor",
|
||||
"int": "AnyTorchIntType",
|
||||
"int[]": "AnyTorchIntListType",
|
||||
"bool": "AnyTorchBoolType",
|
||||
"bool[]": "AnyTorchBoolListType",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": [],
|
||||
"Tensor": ["kImmutableTensor"],
|
||||
"Tensor?": ["kImmutableTensor"],
|
||||
})
|
||||
opdef.return_transforms(
|
||||
type_transforms={
|
||||
":0": "DROP_UNUSED", # Ignored because we clear the outs below.
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kDropReturnAndAliasArg0"],
|
||||
})
|
||||
opdef.map_signatures()
|
||||
opdef.ods_outs = [] # Clear the computed outs.
|
||||
opdef.emit()
|
||||
|
||||
def ordinary_unary_op(self,
|
||||
kernel_sig,
|
||||
ods_name,
|
||||
op_name,
|
||||
promote_trailing_out_tensor=True,
|
||||
traits=(),
|
||||
**kwargs):
|
||||
"""Unary ops.
|
||||
|
||||
These take and return a tensor and typically have an out and inplace
|
||||
variant (they may not but we generate patterns to match anyway).
|
||||
"""
|
||||
reg_record = self._get_reg_record(kernel_sig)
|
||||
ods_ins, arg_type_flags = self._map_sigtypes(
|
||||
reg_record["arguments"],
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
ods_outs, return_type_flags = self._map_sigtypes(
|
||||
reg_record["returns"],
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
self.ods_emitter.emit_opdef(ods_name,
|
||||
op_name,
|
||||
reg_record,
|
||||
ods_ins=ods_ins,
|
||||
ods_outs=ods_outs,
|
||||
traits=["NoSideEffect"])
|
||||
self.impl_emitter.emit_kernel_methods(ods_name,
|
||||
reg_record,
|
||||
arg_type_flags=arg_type_flags,
|
||||
return_type_flags=return_type_flags,
|
||||
promote_trailing_out_tensor=True)
|
||||
opdef = self.define_op(
|
||||
kernel_sig=kernel_sig,
|
||||
ods_name=ods_name,
|
||||
op_name=op_name,
|
||||
promote_trailing_out_tensor=promote_trailing_out_tensor,
|
||||
traits=list(traits) + ["NoSideEffect"],
|
||||
**kwargs)
|
||||
opdef.arg_transforms(type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
opdef.return_transforms(type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
opdef.emit()
|
||||
|
||||
def _get_reg_record(self, kernel_sig):
|
||||
def get_reg_record(self, kernel_sig):
|
||||
"""Gets the op-dict for a given registered op name.
|
||||
|
||||
Args:
|
||||
|
@ -212,8 +330,14 @@ class OpGenerator:
|
|||
raise ValueError(f"Could not find registry op matching '{kernel_sig}'. "
|
||||
f"Possible matches:\n {dym_message}")
|
||||
|
||||
def _map_sigtypes(self, siglist: List[Dict], type_transforms: Dict[str, str],
|
||||
flag_transforms: Dict[str, List[str]]) -> List[Tuple[str]]:
|
||||
def _map_sigtypes(
|
||||
self,
|
||||
siglist: List[Dict],
|
||||
type_transforms: Dict[str, str],
|
||||
flag_transforms: Dict[str, List[str]],
|
||||
drop_indices: Sequence[int] = (),
|
||||
override_types: Optional[Sequence[str]] = None,
|
||||
) -> List[Tuple[str]]:
|
||||
"""Maps a list of signature entries to ods dags and flag lists.
|
||||
|
||||
The torch signature list contains dicts that minimally have keys 'name' and
|
||||
|
@ -233,15 +357,23 @@ class OpGenerator:
|
|||
- An ods dag list of (ods_name, ods_type) tuples
|
||||
- List of (torch_type, [conversion_flag]) for specifying conversions.
|
||||
"""
|
||||
# Make sure any override types are sane.
|
||||
if override_types:
|
||||
assert len(override_types) == len(siglist), (
|
||||
"Mismatch override and actual types")
|
||||
# Generate to ods dag list.
|
||||
ods_dag_list = []
|
||||
for i, sigitem in enumerate(siglist):
|
||||
if i in drop_indices:
|
||||
# Do not emit in ODS.
|
||||
continue
|
||||
torch_name = sigitem["name"]
|
||||
torch_type = sigitem["type"]
|
||||
torch_type = (sigitem["type"]
|
||||
if override_types is None else override_types[i])
|
||||
# Look up the type transform.
|
||||
ods_type = (type_transforms.get(f"{torch_type}:{i}") or
|
||||
type_transforms.get(f":{i}") or
|
||||
type_transforms.get(torch_type))
|
||||
ods_type = _first_non_none(type_transforms.get(f"{torch_type}:{i}"),
|
||||
type_transforms.get(f":{i}"),
|
||||
type_transforms.get(torch_type))
|
||||
if not ods_type:
|
||||
raise ValueError(f"Signature item {i}, type {torch_type} did not match "
|
||||
f"a type transform {type_transforms}")
|
||||
|
@ -250,16 +382,130 @@ class OpGenerator:
|
|||
# Generate the type conversion flags.
|
||||
type_flag_list = []
|
||||
for i, sigitem in enumerate(siglist):
|
||||
torch_type = sigitem["type"]
|
||||
torch_type = (sigitem["type"]
|
||||
if override_types is None else override_types[i])
|
||||
# Look up the type transform.
|
||||
flags = (flag_transforms.get(f"{torch_type}:{i}") or
|
||||
flag_transforms.get(f":{i}") or flag_transforms.get(torch_type))
|
||||
if not flags:
|
||||
flags = []
|
||||
if i in drop_indices:
|
||||
flags = ["kDrop"]
|
||||
else:
|
||||
flags = _first_non_none(flag_transforms.get(f"{torch_type}:{i}"),
|
||||
flag_transforms.get(f":{i}"),
|
||||
flag_transforms.get(torch_type))
|
||||
if flags is None:
|
||||
flags = []
|
||||
type_flag_list.append((torch_type, flags))
|
||||
return ods_dag_list, type_flag_list
|
||||
|
||||
|
||||
class InflightOpDef:
|
||||
|
||||
def __init__(self,
|
||||
g: OpGenerator,
|
||||
kernel_sig,
|
||||
ods_name,
|
||||
op_name,
|
||||
traits=(),
|
||||
alias_kernel_names=(),
|
||||
promote_trailing_out_tensor=False,
|
||||
override_arg_types=None,
|
||||
override_return_types=None,
|
||||
drop_arg_indices=(),
|
||||
drop_return_indices=()):
|
||||
super().__init__()
|
||||
self.g = g
|
||||
self.kernel_sig = kernel_sig
|
||||
self.ods_name = ods_name
|
||||
self.op_name = op_name
|
||||
self.traits = list(traits)
|
||||
self.alias_kernel_names = list(alias_kernel_names)
|
||||
self.promote_trailing_out_tensor = promote_trailing_out_tensor
|
||||
self.override_arg_types = override_arg_types
|
||||
self.override_return_types = override_return_types
|
||||
self.drop_arg_indices = drop_arg_indices
|
||||
self.drop_return_indices = drop_return_indices
|
||||
self.reg_record = g.get_reg_record(self.kernel_sig)
|
||||
self._emitted = False
|
||||
self._traceback = traceback.extract_stack()[0:-2]
|
||||
|
||||
# Arg and flag transform dicts.
|
||||
self.arg_type_transforms = dict()
|
||||
self.arg_flag_transforms = dict()
|
||||
self.return_type_transforms = dict()
|
||||
self.return_flag_trasforms = dict()
|
||||
|
||||
# Signature mapping.
|
||||
self._sigs_mapped = False
|
||||
self.ods_ins = None
|
||||
self.ods_outs = None
|
||||
self.arg_type_flags = None
|
||||
self.return_type_flags = None
|
||||
|
||||
def __del__(self):
|
||||
if not self._emitted:
|
||||
print("WARNING: Op defined but not emitted. Defined at:", file=sys.stderr)
|
||||
for line in traceback.format_list(self._traceback):
|
||||
sys.stderr.write(line)
|
||||
|
||||
def transforms(self, type_transforms=None, flag_transforms=None):
|
||||
self.arg_transforms(type_transforms=type_transforms,
|
||||
flag_transforms=flag_transforms)
|
||||
self.return_transforms(type_transforms=type_transforms,
|
||||
flag_transforms=flag_transforms)
|
||||
return self
|
||||
|
||||
def arg_transforms(self, type_transforms=None, flag_transforms=None):
|
||||
"""Adds arg type and flag transforms dicts."""
|
||||
if type_transforms:
|
||||
self.arg_type_transforms.update(type_transforms)
|
||||
if flag_transforms:
|
||||
self.arg_flag_transforms.update(flag_transforms)
|
||||
return self
|
||||
|
||||
def return_transforms(self, type_transforms=None, flag_transforms=None):
|
||||
"""Adds return type and flag transform dicts."""
|
||||
if type_transforms:
|
||||
self.return_type_transforms.update(type_transforms)
|
||||
if flag_transforms:
|
||||
self.return_flag_trasforms.update(flag_transforms)
|
||||
return self
|
||||
|
||||
def map_signatures(self):
|
||||
assert not self._sigs_mapped, "Signatures already mapped"
|
||||
self._sigs_mapped = True
|
||||
self.ods_ins, self.arg_type_flags = self.g._map_sigtypes(
|
||||
self.reg_record["arguments"],
|
||||
type_transforms=self.arg_type_transforms,
|
||||
flag_transforms=self.arg_flag_transforms,
|
||||
override_types=self.override_arg_types,
|
||||
drop_indices=self.drop_arg_indices)
|
||||
self.ods_outs, self.return_type_flags = self.g._map_sigtypes(
|
||||
self.reg_record["returns"],
|
||||
type_transforms=self.return_type_transforms,
|
||||
flag_transforms=self.return_flag_trasforms,
|
||||
override_types=self.override_return_types,
|
||||
drop_indices=self.drop_return_indices)
|
||||
return self
|
||||
|
||||
def emit(self):
|
||||
assert not self._emitted, "Op already emitted"
|
||||
self._emitted = True
|
||||
if not self._sigs_mapped:
|
||||
self.map_signatures()
|
||||
self.g.ods_emitter.emit_opdef(self.ods_name,
|
||||
self.op_name,
|
||||
self.reg_record,
|
||||
ods_ins=self.ods_ins,
|
||||
ods_outs=self.ods_outs,
|
||||
traits=self.traits)
|
||||
self.g.impl_emitter.emit_kernel_methods(
|
||||
self.ods_name,
|
||||
self.reg_record,
|
||||
arg_type_flags=self.arg_type_flags,
|
||||
return_type_flags=self.return_type_flags,
|
||||
promote_trailing_out_tensor=self.promote_trailing_out_tensor,
|
||||
alias_kernel_names=self.alias_kernel_names)
|
||||
|
||||
|
||||
class EmitterBase:
|
||||
_INDENT = " "
|
||||
|
||||
|
@ -355,7 +601,8 @@ class CCImplEmitter(EmitterBase):
|
|||
reg_record,
|
||||
arg_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
||||
return_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
||||
promote_trailing_out_tensor=False):
|
||||
promote_trailing_out_tensor=False,
|
||||
alias_kernel_names: Sequence[str] = ()):
|
||||
# getTorchKernelMetadata() method.
|
||||
self.print(
|
||||
f"Torch::KernelMetadata {ods_def_name}::getTorchKernelMetadata() {{")
|
||||
|
@ -374,6 +621,9 @@ class CCImplEmitter(EmitterBase):
|
|||
with self.indent():
|
||||
self.print("Torch::BuildKernelMetadata m;")
|
||||
self.print(f"m.kernelName = {self.quote(kernel_name)};")
|
||||
for alias_kernel_name in alias_kernel_names:
|
||||
self.print(
|
||||
f"m.aliasKernelNames.push_back({self.quote(alias_kernel_name)});")
|
||||
if promote_trailing_out_tensor:
|
||||
self.print("m.promoteTrailingOutTensor = true;")
|
||||
# Arg types/flags.
|
||||
|
@ -393,7 +643,7 @@ class CCImplEmitter(EmitterBase):
|
|||
self.print("return m;")
|
||||
self.print("})();")
|
||||
self.print("return metadata;")
|
||||
self.print("}")
|
||||
self.print("}\n")
|
||||
|
||||
def _format_cpp_str_initlist(self, strings):
|
||||
quoted = [self.quote(s) for s in strings]
|
||||
|
@ -416,6 +666,13 @@ def snakecase_to_camelcase(ident: str):
|
|||
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
|
||||
|
||||
|
||||
def _first_non_none(*args):
|
||||
for arg in args:
|
||||
if arg is not None:
|
||||
return arg
|
||||
return None
|
||||
|
||||
|
||||
def _load_ops_as_dict():
|
||||
# Returns a list of dicts, each with a name that is a tuple of the form:
|
||||
# (kernel_signature, variant)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
ATenOpRegistrations.txt
|
|
@ -61,55 +61,6 @@ def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>,
|
|||
|
||||
}
|
||||
|
||||
// Our jit library only supports 6 argument convolutions, rather than 9
|
||||
// arguments supported by pytorch. This operation allows us to represent this
|
||||
// limitation temporarily.
|
||||
def aten_ConvolutionOp: aten_Op<"_convolution", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
AnyTensor:$weight,
|
||||
AnyTensor:$bias,
|
||||
AnyType:$stride,
|
||||
AnyType:$padding,
|
||||
AnyType:$dilation
|
||||
);
|
||||
|
||||
let summary = "Convolution operator";
|
||||
let description = [{
|
||||
Convolution operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
uint64_t getOperandTransferVolume(unsigned int idx, bool read);
|
||||
uint64_t getResultTransferVolume(unsigned int idx, bool read);
|
||||
}];
|
||||
}
|
||||
|
||||
// Our jit library only supports 6 argument convolutions, rather than 9
|
||||
// arguments supported by pytorch. This operation allows us to represent this
|
||||
// limitation temporarily.
|
||||
def aten_ConvolutionBackwardOp: aten_Op<"_convolution_backward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor:$dx, AnyTensor:$dw, AnyTensor:$db)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$input,
|
||||
AnyTensor:$weight,
|
||||
AnyType:$stride,
|
||||
AnyType:$padding,
|
||||
AnyType:$dilation
|
||||
);
|
||||
|
||||
let summary = "ConvolutionBackward operator";
|
||||
let description = [{
|
||||
ConvolutionBackward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
|
|
|
@ -37,6 +37,7 @@ const Torch::BuildKernelMetadata &AddOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata Atan2Op::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -55,6 +56,7 @@ const Torch::BuildKernelMetadata &Atan2Op::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata DivOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -73,6 +75,7 @@ const Torch::BuildKernelMetadata &DivOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata FloorDivideOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -91,6 +94,7 @@ const Torch::BuildKernelMetadata &FloorDivideOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata MulOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -109,6 +113,7 @@ const Torch::BuildKernelMetadata &MulOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata RemainderOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -127,6 +132,7 @@ const Torch::BuildKernelMetadata &RemainderOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata TrueDivideOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -145,6 +151,7 @@ const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Unary arithmetic ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -167,6 +174,7 @@ const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata AcosOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -185,6 +193,7 @@ const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata AngleOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -203,6 +212,7 @@ const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata AsinOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -221,6 +231,7 @@ const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata AtanOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -239,6 +250,7 @@ const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata CeilOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -257,6 +269,7 @@ const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata ConjOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -275,6 +288,7 @@ const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata CosOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -293,6 +307,7 @@ const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata CoshOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -311,6 +326,7 @@ const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata DigammaOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -329,6 +345,7 @@ const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata ErfOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -347,6 +364,7 @@ const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata ErfcOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -365,6 +383,7 @@ const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata ErfinvOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -383,6 +402,7 @@ const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata ExpOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -401,6 +421,7 @@ const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata Expm1Op::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -419,6 +440,7 @@ const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata FloorOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -437,6 +459,7 @@ const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata FracOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -455,6 +478,7 @@ const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata LgammaOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -473,6 +497,7 @@ const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata LogOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -491,6 +516,7 @@ const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata Log10Op::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -509,6 +535,7 @@ const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata Log1pOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -527,6 +554,7 @@ const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata Log2Op::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -545,6 +573,7 @@ const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata NegOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -563,6 +592,7 @@ const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata ReluOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -581,6 +611,7 @@ const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata ReciprocalOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -599,6 +630,7 @@ const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata RoundOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -617,6 +649,7 @@ const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata RsqrtOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -635,6 +668,7 @@ const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata SigmoidOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -653,6 +687,7 @@ const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata SignOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -671,6 +706,7 @@ const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata SinOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -689,6 +725,7 @@ const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata SinhOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -707,6 +744,7 @@ const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata SqrtOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -725,6 +763,7 @@ const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata TanOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -743,6 +782,7 @@ const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata TanhOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -761,6 +801,7 @@ const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata TruncOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
@ -779,3 +820,184 @@ const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() {
|
|||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// NN ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
Torch::KernelMetadata ConvolutionOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &ConvolutionOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::convolution_overrideable";
|
||||
m.aliasKernelNames.push_back("aten::convolution");
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata ConvolutionBackwardOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &ConvolutionBackwardOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::convolution_backward_overrideable";
|
||||
m.aliasKernelNames.push_back("aten::convolution_backward");
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor", "Tensor", "int[]", "int[]", "int[]", "bool", "int[]", "int", "bool[]"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor?", "Tensor?", "Tensor?"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata LogSoftmaxOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &LogSoftmaxOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::_log_softmax";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "int", "bool"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata LogSoftmaxBackwardDataOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &LogSoftmaxBackwardDataOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::_log_softmax_backward_data";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor", "int", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Loss function ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
Torch::KernelMetadata NllLossForwardOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &NllLossForwardOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::nll_loss_forward";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor", "Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata NllLossBackwardOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &NllLossBackwardOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::nll_loss_backward";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata NllLoss2dForwardOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &NllLoss2dForwardOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::nll_loss2d_forward";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor", "Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata NllLoss2dBackwardOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &NllLoss2dBackwardOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::nll_loss2d_backward";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata CopyInplaceOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &CopyInplaceOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::copy_";
|
||||
m.addArgTypes({"Tensor", "Tensor", "bool"});
|
||||
m.addArgConversions({KVC::kNone, KVC::kImmutableTensor, KVC::kDrop});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kDropReturnAndAliasArg0});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
|
|
|
@ -450,3 +450,147 @@ def aten_TruncOp: aten_Op<"trunc", [NoSideEffect, DeclareOpInterfaceMethods<Torc
|
|||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// NN ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
def aten_ConvolutionOp: aten_Op<"convolution", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::convolution_overrideable";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$input,
|
||||
AnyTorchImmutableTensor:$weight,
|
||||
AnyTorchOptionalImmutableTensor:$bias,
|
||||
AnyTorchIntListType:$stride,
|
||||
AnyTorchIntListType:$padding,
|
||||
AnyTorchIntListType:$dilation,
|
||||
AnyTorchBoolType:$transposed,
|
||||
AnyTorchIntListType:$output_padding,
|
||||
AnyTorchIntType:$groups
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
def aten_ConvolutionBackwardOp: aten_Op<"convolution_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::convolution_backward_overrideable";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$grad_output,
|
||||
AnyTorchImmutableTensor:$input,
|
||||
AnyTorchImmutableTensor:$weight,
|
||||
AnyTorchIntListType:$stride,
|
||||
AnyTorchIntListType:$padding,
|
||||
AnyTorchIntListType:$dilation,
|
||||
AnyTorchBoolType:$transposed,
|
||||
AnyTorchIntListType:$output_padding,
|
||||
AnyTorchIntType:$groups,
|
||||
AnyTorchBoolListType:$output_mask
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalImmutableTensor:$grad_input,
|
||||
AnyTorchOptionalImmutableTensor:$grad_weight,
|
||||
AnyTorchOptionalImmutableTensor:$grad_bias
|
||||
);
|
||||
}
|
||||
|
||||
def aten_LogSoftmaxOp: aten_Op<"log_softmax", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::_log_softmax";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$self,
|
||||
AnyTorchIntType:$dim,
|
||||
AnyTorchBoolType:$half_to_float
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
def aten_LogSoftmaxBackwardDataOp: aten_Op<"log_softmax_backward_data", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::_log_softmax_backward_data";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$grad_output,
|
||||
AnyTorchImmutableTensor:$output,
|
||||
AnyTorchIntType:$dim,
|
||||
AnyTorchImmutableTensor:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Loss function ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::nll_loss_forward";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$self,
|
||||
AnyTorchImmutableTensor:$target,
|
||||
AnyTorchOptionalImmutableTensor:$weight,
|
||||
AnyTorchIntType:$reduction,
|
||||
AnyTorchIntType:$ignore_index
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor:$output,
|
||||
AnyTorchImmutableTensor:$total_weight
|
||||
);
|
||||
}
|
||||
|
||||
def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::nll_loss_backward";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$grad_output,
|
||||
AnyTorchImmutableTensor:$self,
|
||||
AnyTorchImmutableTensor:$target,
|
||||
AnyTorchOptionalImmutableTensor:$weight,
|
||||
AnyTorchIntType:$reduction,
|
||||
AnyTorchIntType:$ignore_index,
|
||||
AnyTorchImmutableTensor:$total_weight
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::nll_loss2d_forward";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$self,
|
||||
AnyTorchImmutableTensor:$target,
|
||||
AnyTorchOptionalImmutableTensor:$weight,
|
||||
AnyTorchIntType:$reduction,
|
||||
AnyTorchIntType:$ignore_index
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor:$output,
|
||||
AnyTorchImmutableTensor:$total_weight
|
||||
);
|
||||
}
|
||||
|
||||
def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::nll_loss2d_backward";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$grad_output,
|
||||
AnyTorchImmutableTensor:$self,
|
||||
AnyTorchImmutableTensor:$target,
|
||||
AnyTorchOptionalImmutableTensor:$weight,
|
||||
AnyTorchIntType:$reduction,
|
||||
AnyTorchIntType:$ignore_index,
|
||||
AnyTorchImmutableTensor:$total_weight
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::copy_";
|
||||
let arguments = (ins
|
||||
AnyTorchMutableTensor:$self,
|
||||
AnyTorchImmutableTensor:$src
|
||||
);
|
||||
let results = (outs
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -44,52 +44,6 @@ def aten_AsStridedOp: aten_Op<"as_strided", [NoSideEffect, StatisticsOpInterface
|
|||
}];
|
||||
}
|
||||
|
||||
def aten_ConvolutionOverrideableOp: aten_Op<"convolution_overrideable", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
AnyTensor:$weight,
|
||||
AnyTensor:$bias,
|
||||
AnyType:$stride,
|
||||
AnyType:$padding,
|
||||
AnyType:$dilation,
|
||||
AnyScalar:$transposed,
|
||||
AnyType:$output_padding,
|
||||
AnyScalar:$groups
|
||||
);
|
||||
let summary = "aten convolution_overrideable operator";
|
||||
let description = [{
|
||||
ConvolutionOverrideableOp
|
||||
aten convolution_overrideable operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_ConvolutionBackwardOverrideableOp: aten_Op<"convolution_backward_overrideable", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor, AnyTensor, AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$input,
|
||||
AnyTensor:$weight,
|
||||
AnyType:$stride,
|
||||
AnyType:$padding,
|
||||
AnyType:$dilation,
|
||||
AnyScalar:$transposed,
|
||||
AnyType:$output_padding,
|
||||
AnyScalar:$groups
|
||||
);
|
||||
let summary = "aten convolution_backward_overrideable operator";
|
||||
let description = [{
|
||||
ConvolutionBackwardOverrideableOp
|
||||
aten convolution_backward_overrideable operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_DivUnderOp: aten_Op<"div_", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
|
@ -123,35 +77,6 @@ def aten_ExpandOp: aten_Op<"expand", [NoSideEffect, StatisticsOpInterface]>,
|
|||
}];
|
||||
}
|
||||
|
||||
def aten_LogSoftmaxOp: aten_Op<"_log_softmax", [NoSideEffect]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyScalar:$dim,
|
||||
AnyScalar:$half_to_float
|
||||
);
|
||||
let summary = "aten _log_softmax operator";
|
||||
let description = [{
|
||||
LogSoftmaxOp
|
||||
aten _log_softmax operator
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_LogSoftmaxBackwardDataOp: aten_Op<"_log_softmax_backward_data", [NoSideEffect]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$output,
|
||||
AnyScalar:$dim,
|
||||
AnyTensor:$self
|
||||
);
|
||||
let summary = "aten _log_softmax_backward_data operator";
|
||||
let description = [{
|
||||
LogSoftmaxBackwardDataOp
|
||||
aten _log_softmax_backward_data operator
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
|
@ -445,86 +370,6 @@ def aten_GatherOp: aten_Op<"gather", [NoSideEffect, StatisticsOpInterface]>,
|
|||
}];
|
||||
}
|
||||
|
||||
def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor, AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$target,
|
||||
AnyTensor:$weight,
|
||||
AnyScalar:$reduction,
|
||||
AnyScalar:$ignore_index
|
||||
);
|
||||
let summary = "aten nll_loss_forward operator";
|
||||
let description = [{
|
||||
NllLossForwardOp
|
||||
aten nll_loss_forward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$self,
|
||||
AnyTensor:$target,
|
||||
AnyTensor:$weight,
|
||||
AnyScalar:$reduction,
|
||||
AnyScalar:$ignore_index,
|
||||
AnyTensor:$total_weight
|
||||
);
|
||||
let summary = "aten nll_loss_backward operator";
|
||||
let description = [{
|
||||
NllLossBackwardOp
|
||||
aten nll_loss_backward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor, AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$target,
|
||||
AnyTensor:$weight,
|
||||
AnyScalar:$reduction,
|
||||
AnyScalar:$ignore_index
|
||||
);
|
||||
let summary = "aten nll_loss2d_forward operator";
|
||||
let description = [{
|
||||
NllLoss2dForwardOp
|
||||
aten nll_loss2d_forward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$self,
|
||||
AnyTensor:$target,
|
||||
AnyTensor:$weight,
|
||||
AnyScalar:$reduction,
|
||||
AnyScalar:$ignore_index,
|
||||
AnyTensor:$total_weight
|
||||
);
|
||||
let summary = "aten nll_loss2d_backward operator";
|
||||
let description = [{
|
||||
NllLoss2dBackwardOp
|
||||
aten nll_loss2d_backward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_HardtanhOp: aten_Op<"hardtanh", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
|
|
|
@ -12,15 +12,4 @@
|
|||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "npcomp/Dialect/ATen/IR/ATenOps.td"
|
||||
|
||||
// The pytorch convolution operator has 9 arguments, but we only have a jit
|
||||
// library that supports the first six at the moment.
|
||||
def : Pat<(aten_ConvolutionOverrideableOp $a1, $a2, $a3, $a4, $a5, $a6,
|
||||
$a7, $a8, $a9),
|
||||
(aten_ConvolutionOp $a1, $a2, $a3, $a4, $a5, $a6)>;
|
||||
|
||||
def : Pat<(aten_ConvolutionBackwardOverrideableOp $a1, $a2, $a3, $a4, $a5, $a6,
|
||||
$a7, $a8, $a9),
|
||||
(aten_ConvolutionBackwardOp $a1, $a2, $a3, $a4, $a5, $a6)>;
|
||||
|
||||
|
||||
#endif
|
||||
|
|
|
@ -18,7 +18,7 @@ namespace Torch {
|
|||
|
||||
/// Conversion rule to apply to a value (argument or return).
|
||||
namespace KernelValueConversion {
|
||||
enum BitMask {
|
||||
enum BitMask : uint32_t {
|
||||
// No coercion.
|
||||
kNone = 0,
|
||||
|
||||
|
@ -32,7 +32,16 @@ enum BitMask {
|
|||
// to a 0d tensor.
|
||||
kPromoteScalar = 8,
|
||||
|
||||
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kPromoteScalar)
|
||||
// Drops the return value and aliases to argument 0.
|
||||
// TODO: Remove this in favor of general alias metadata processing (note that
|
||||
// the vast majority are this one case so it isn't so bad to have a special
|
||||
// case for it if necessary).
|
||||
kDropReturnAndAliasArg0 = 16,
|
||||
|
||||
// Drops the argument/return.
|
||||
kDrop = 32,
|
||||
|
||||
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kDrop)
|
||||
};
|
||||
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
|
||||
} // namespace KernelValueConversion
|
||||
|
@ -74,6 +83,9 @@ struct BuildKernelMetadata : public KernelMetadata {
|
|||
SmallVector<KernelValueConversion::BitMask, 4> argConversions;
|
||||
SmallVector<KernelValueConversion::BitMask, 4> returnConversions;
|
||||
|
||||
/// Additional alias kernel names to match.
|
||||
SmallVector<StringRef, 1> aliasKernelNames;
|
||||
|
||||
void addArgConversions(
|
||||
std::initializer_list<KernelValueConversion::BitMask> ilist) {
|
||||
argConversions.insert(argConversions.end(), ilist);
|
||||
|
|
|
@ -72,6 +72,11 @@ def AnyTorchImmutableTensor : AnyTypeOf<[
|
|||
AnyTensor,
|
||||
], "allowable torch immutable tensor">;
|
||||
|
||||
def AnyTorchOptionalImmutableTensor : AnyTypeOf<[
|
||||
AnyTorchImmutableTensor,
|
||||
Basicpy_NoneType,
|
||||
], "allowable torch immutable tensor (or None)">;
|
||||
|
||||
def AnyTorchMutableTensor : AnyTypeOf<[
|
||||
// "Numpy-style" mutable NDArray. While not offering the full generality
|
||||
// of a Torch tensor, it models the same access patterns and implies the
|
||||
|
@ -95,7 +100,28 @@ def AnyTorchScalarType : AnyTypeOf<[
|
|||
AnySignlessInteger,
|
||||
], "Any primitive type suitable to be passed as a Torch Scalar">;
|
||||
|
||||
def AnyTorchBoolType : AnyTypeOf<[
|
||||
I1,
|
||||
Basicpy_BoolType,
|
||||
], "Any permissible bool type">;
|
||||
|
||||
def AnyTorchBoolListType : AnyTypeOf<[
|
||||
Basicpy_ListType,
|
||||
// TODO: Support typed list when available.
|
||||
], "Any bool list type (bool[])">;
|
||||
|
||||
def AnyTorchIntType : AnyTypeOf<[
|
||||
AnySignedInteger,
|
||||
AnySignlessInteger,
|
||||
], "Any primitive integer type suitable to be passed as a Torch 'int'">;
|
||||
|
||||
def AnyTorchIntListType : AnyTypeOf<[
|
||||
Basicpy_ListType,
|
||||
// TODO: Support typed list when available.
|
||||
], "Any int list type (int[])">;
|
||||
|
||||
def AnyTorchType : AnyTypeOf<[
|
||||
AnyTorchBoolType,
|
||||
AnyTorchScalarType,
|
||||
AnyTorchTensorType,
|
||||
Basicpy_ListType,
|
||||
|
|
|
@ -170,40 +170,6 @@ std::map<std::string, uint64_t> BatchNormOp::getStatistics() {
|
|||
return toReturn;
|
||||
}
|
||||
|
||||
// _convolution
|
||||
std::map<std::string, uint64_t> ConvolutionOp::getStatistics() {
|
||||
return getConv2dStatistics(this, /*groups*/ 1);
|
||||
}
|
||||
std::map<std::string, uint64_t> ConvolutionOverrideableOp::getStatistics() {
|
||||
// FIXME
|
||||
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
|
||||
auto ia = co.template getAttrOfType<IntegerAttr>("value");
|
||||
uint64_t groups = ia.getValue().getZExtValue();
|
||||
|
||||
return getConv2dStatistics(this, groups);
|
||||
}
|
||||
|
||||
uint64_t ConvolutionOp::getOperandTransferVolume(unsigned int idx, bool read) {
|
||||
return getConv2dOperandTransferVolume<ConvolutionOp>(this, idx, read);
|
||||
}
|
||||
|
||||
uint64_t ConvolutionOp::getResultTransferVolume(unsigned int idx, bool write) {
|
||||
return getConv2dResultTransferVolume<ConvolutionOp>(this, idx, write);
|
||||
}
|
||||
|
||||
// _convolution_backward
|
||||
std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() {
|
||||
return getConv2dBackwardStatistics(*this, 1);
|
||||
}
|
||||
std::map<std::string, uint64_t>
|
||||
ConvolutionBackwardOverrideableOp::getStatistics() {
|
||||
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
|
||||
auto ia = co.template getAttrOfType<IntegerAttr>("value");
|
||||
uint64_t groups = ia.getValue().getZExtValue();
|
||||
|
||||
return getConv2dBackwardStatistics(*this, groups);
|
||||
}
|
||||
|
||||
// div_
|
||||
std::map<std::string, uint64_t> DivUnderOp::getStatistics() {
|
||||
|
||||
|
@ -559,35 +525,6 @@ std::map<std::string, uint64_t> NativeBatchNormBackwardOp::getStatistics() {
|
|||
return toReturn;
|
||||
}
|
||||
|
||||
std::map<std::string, uint64_t> NllLossForwardOp::getStatistics() {
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
// FIXME: unimplemented
|
||||
toReturn["reads"] = -1;
|
||||
toReturn["writes"] = -1;
|
||||
return toReturn;
|
||||
}
|
||||
std::map<std::string, uint64_t> NllLossBackwardOp::getStatistics() {
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
// FIXME: unimplemented
|
||||
toReturn["reads"] = -1;
|
||||
toReturn["writes"] = -1;
|
||||
return toReturn;
|
||||
}
|
||||
std::map<std::string, uint64_t> NllLoss2dForwardOp::getStatistics() {
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
// FIXME: unimplemented
|
||||
toReturn["reads"] = -1;
|
||||
toReturn["writes"] = -1;
|
||||
return toReturn;
|
||||
}
|
||||
std::map<std::string, uint64_t> NllLoss2dBackwardOp::getStatistics() {
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
// FIXME: unimplemented
|
||||
toReturn["reads"] = -1;
|
||||
toReturn["writes"] = -1;
|
||||
return toReturn;
|
||||
}
|
||||
|
||||
// std::map<std::string, uint64_t> ReLUUnderOp::getStatistics() {
|
||||
// return getReLUOpStatistics(*this);
|
||||
// }
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||
#include "npcomp/Dialect/ATen/Transforms/Passes.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
||||
|
@ -28,6 +29,14 @@ using namespace mlir::NPCOMP::Torch;
|
|||
|
||||
namespace {
|
||||
|
||||
bool isTorchTensorType(StringRef torchType) {
|
||||
return torchType == "Tensor" || torchType == "Tensor?";
|
||||
}
|
||||
|
||||
bool isTorchOptionalType(StringRef torchType) {
|
||||
return torchType.endswith("?");
|
||||
}
|
||||
|
||||
struct TypeConversion {
|
||||
Type targetType;
|
||||
std::function<Value(Location loc, Value originalValue,
|
||||
|
@ -49,9 +58,15 @@ convertTorchArgType(StringRef sourceTorchType, StringRef targetTorchType,
|
|||
// Immutable tensor conversion.
|
||||
if (flag & KVC::kImmutableTensor) {
|
||||
// TODO: Support the kPromoteScalar flag.
|
||||
if (sourceTorchType != "Tensor" || targetTorchType != "Tensor")
|
||||
if (!isTorchTensorType(sourceTorchType) ||
|
||||
!isTorchTensorType(targetTorchType))
|
||||
return None;
|
||||
|
||||
// If the target is optional and the type is NoneType, passthrough.
|
||||
if (isTorchOptionalType(targetTorchType) &&
|
||||
sourceMlirType.isa<Basicpy::NoneType>())
|
||||
return TypeConversion{sourceMlirType, nullptr};
|
||||
|
||||
// Already immutable.
|
||||
if (sourceMlirType.isa<TensorType>())
|
||||
return TypeConversion{sourceMlirType, nullptr};
|
||||
|
@ -86,30 +101,51 @@ convertTorchReturnType(StringRef sourceTorchType, StringRef targetTorchType,
|
|||
Type sourceMlirType) {
|
||||
using KVC = KernelValueConversion::BitMask;
|
||||
// Default trivial case.
|
||||
if (sourceTorchType == targetTorchType && flag == 0)
|
||||
if (sourceTorchType == targetTorchType && flag == 0) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Return types already match\n");
|
||||
return TypeConversion{sourceMlirType, nullptr};
|
||||
}
|
||||
|
||||
// Immutable tensor conversion.
|
||||
if (flag & KVC::kImmutableTensor) {
|
||||
if (sourceTorchType != "Tensor" || targetTorchType != "Tensor")
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " * Return conversion flag kImmutableTensor\n");
|
||||
if (!isTorchTensorType(sourceTorchType) ||
|
||||
!isTorchTensorType(targetTorchType)) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " * Source or target not a Tensor type\n");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Already immutable.
|
||||
if (sourceMlirType.isa<TensorType>())
|
||||
if (sourceMlirType.isa<TensorType>()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Source is already immutable\n");
|
||||
return TypeConversion{sourceMlirType, nullptr};
|
||||
}
|
||||
|
||||
// Convert NdArray type.
|
||||
if (auto ndArrayType = sourceMlirType.dyn_cast<Numpy::NdArrayType>()) {
|
||||
if (sourceMlirType.isa<Basicpy::NoneType>() &&
|
||||
isTorchOptionalType(targetTorchType)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " * None Tensor type passthrough\n");
|
||||
return TypeConversion{sourceMlirType, nullptr};
|
||||
} else if (auto ndArrayType =
|
||||
sourceMlirType.dyn_cast<Numpy::NdArrayType>()) {
|
||||
auto tensorType = ndArrayType.toTensorType();
|
||||
auto callback = [=](Location loc, Value newOpResultValue,
|
||||
PatternRewriter &rewriter) -> Value {
|
||||
return rewriter.create<Numpy::CreateArrayFromTensorOp>(
|
||||
loc, ndArrayType, newOpResultValue);
|
||||
};
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Convert return type\n");
|
||||
return TypeConversion{tensorType, callback};
|
||||
} else {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " * Return type is not a supported tensor type\n");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Return type conversion fallthrough\n");
|
||||
return None;
|
||||
}
|
||||
|
||||
|
@ -142,9 +178,16 @@ public:
|
|||
const BuildKernelMetadata &buildMetadata) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Register kernel call translation for: " << opName << "\n");
|
||||
CandidateTransformList &candidates =
|
||||
kernelTransforms[buildMetadata.kernelName];
|
||||
candidates.emplace_back(opName, buildMetadata);
|
||||
{
|
||||
CandidateTransformList &candidates =
|
||||
kernelTransforms[buildMetadata.kernelName];
|
||||
candidates.emplace_back(opName, buildMetadata);
|
||||
}
|
||||
|
||||
for (StringRef aliasKernelName : buildMetadata.aliasKernelNames) {
|
||||
CandidateTransformList &candidates = kernelTransforms[aliasKernelName];
|
||||
candidates.emplace_back(opName, buildMetadata);
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult transformKernelCall(KernelCallOp kernelCall,
|
||||
|
@ -229,6 +272,8 @@ public:
|
|||
"arg arity mismatch");
|
||||
|
||||
// Convert fixed return types.
|
||||
using PostConversionCallback = std::function<void()>;
|
||||
SmallVector<PostConversionCallback, 4> postConversionCallbacks;
|
||||
struct ConversionInfo {
|
||||
Value originalValue;
|
||||
TypeConversion conversion;
|
||||
|
@ -241,25 +286,49 @@ public:
|
|||
KVC flag = candidate.buildMetadata.getReturnConversion(i);
|
||||
Value sourceValue = kernelCall.getResult(i);
|
||||
Type sourceMlirType = kernelCall.getResultTypes()[i];
|
||||
auto conversion = convertTorchReturnType(sourceTorchType, targetTorchType,
|
||||
flag, sourceMlirType);
|
||||
if (!conversion) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " - Return type[" << i
|
||||
<< "] incompatible: source=" << sourceTorchType
|
||||
<< ", target=" << targetTorchType
|
||||
<< ", flag=" << flag << "\n");
|
||||
return failure();
|
||||
if (flag & KVC::kDropReturnAndAliasArg0) {
|
||||
// Reduce result arity and alias any uses to arg0.
|
||||
if (kernelCall.args().empty()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " - Cannot alias arg0 (no arguments)\n");
|
||||
return failure();
|
||||
}
|
||||
Value arg0 = kernelCall.args()[0];
|
||||
postConversionCallbacks.push_back(
|
||||
[sourceValue, arg0]() { sourceValue.replaceAllUsesWith(arg0); });
|
||||
} else {
|
||||
// General, arity-preserving type conversion.
|
||||
auto conversion = convertTorchReturnType(
|
||||
sourceTorchType, targetTorchType, flag, sourceMlirType);
|
||||
if (!conversion) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " - Return type[" << i << "] incompatible: source="
|
||||
<< sourceTorchType << ", target=" << targetTorchType
|
||||
<< ", flag=" << flag << "\n");
|
||||
return failure();
|
||||
}
|
||||
resultTypes.push_back(conversion->targetType);
|
||||
resultConversions.push_back({sourceValue, std::move(*conversion)});
|
||||
}
|
||||
resultTypes.push_back(conversion->targetType);
|
||||
resultConversions.push_back({sourceValue, std::move(*conversion)});
|
||||
}
|
||||
|
||||
// Convert fixed arg types.
|
||||
SmallVector<ConversionInfo, 4> operandInfos;
|
||||
for (size_t i = 0; i < fixedArgArity; ++i) {
|
||||
for (size_t i = 0, operandIndex = 0; i < fixedArgArity; ++i) {
|
||||
// Drop this arg?
|
||||
if (candidate.buildMetadata.argConversions[i] & KVC::kDrop)
|
||||
continue;
|
||||
if (kernelCall.getNumOperands() <= operandIndex) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " - Arg operand " << i
|
||||
<< " does not exist in kernel call (missing default?)\n");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Normal type conversion of the operand.
|
||||
operandInfos.emplace_back();
|
||||
ConversionInfo &info = operandInfos.back();
|
||||
info.originalValue = kernelCall.getOperand(i);
|
||||
info.originalValue = kernelCall.getOperand(operandIndex++);
|
||||
Type sourceMlirType = info.originalValue.getType();
|
||||
auto conversion = convertTorchArgType(
|
||||
/*sourceTorchType=*/sourceMetadata.argTypes[i],
|
||||
|
@ -312,6 +381,10 @@ public:
|
|||
origOpResultValue.replaceAllUsesWith(convertedValue);
|
||||
}
|
||||
|
||||
// Post conversion callbacks.
|
||||
for (auto &callback : postConversionCallbacks)
|
||||
callback();
|
||||
|
||||
// Done.
|
||||
rewriter.eraseOp(kernelCall);
|
||||
return success();
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
|
||||
// CHECK-LABEL: "L0-_convolution-0": {
|
||||
// CHECK-NEXT: "activation_in": 32768,
|
||||
// CHECK-NEXT: "activation_out": 65536,
|
||||
// CHECK-NEXT: "ops:+": 65536,
|
||||
// CHECK-NEXT: "ops:MAC": 6422528,
|
||||
// CHECK-NEXT: "parameters_in": 1584,
|
||||
// CHECK-NEXT: "reads": 34352,
|
||||
// CHECK-NEXT: "writes": 65536
|
||||
|
||||
module {
|
||||
func @graph(%arg0: tensor<1x2x128x128xf32>, %arg1: tensor<16x2x7x7xf32>, %arg2: tensor<16xf32>) -> tensor<1x16x64x64xf32> {
|
||||
%0 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi64>} : () -> !aten.list<i32>
|
||||
%1 = "aten.constant"() {type = "List[i32]", value = dense<3> : vector<2xi64>} : () -> !aten.list<i32>
|
||||
%2 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi64>} : () -> !aten.list<i32>
|
||||
%3 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
||||
%4 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi64>} : () -> !aten.list<i32>
|
||||
%5 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
|
||||
%6 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
||||
%7 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
||||
%8 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
|
||||
%9 = "aten._convolution"(%arg0, %arg1, %arg2, %0, %1, %2) : (tensor<1x2x128x128xf32>, tensor<16x2x7x7xf32>, tensor<16xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<1x16x64x64xf32>
|
||||
"std.return"(%9) : (tensor<1x16x64x64xf32>) -> ()
|
||||
}
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
|
||||
// CHECK-LABEL: "L0-convolution_backward_overrideable-0": {
|
||||
// CHECK-NEXT: "activation_in": 5568,
|
||||
// CHECK-NEXT: "grad": 5380,
|
||||
// CHECK-NEXT: "ops:+": 768,
|
||||
// CHECK-NEXT: "ops:MAC": 345600,
|
||||
// CHECK-NEXT: "parameters_in": 576,
|
||||
// CHECK-NEXT: "reads": 6144,
|
||||
// CHECK-NEXT: "writes": 5380
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION
|
||||
// CHECK-CONVERSION-LABEL: @graph
|
||||
module {
|
||||
func @graph(%arg0: tensor<3x4x8x8xf32>, %arg1: tensor<3x16x10x10xf32>, %arg2: tensor<4x16x3x3xf32>) -> tensor<4x16x3x3xf32> {
|
||||
%0 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%1 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%2 = "aten.constant"() {type = "bool", value = false} : () -> i1
|
||||
%3 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
|
||||
%10:3 = "aten.convolution_backward_overrideable"(%arg0, %arg1, %arg2, %0, %1, %0, %2, %1, %3) {layer_name = "L5-convolution_backward_overrideable-0"} : (tensor<3x4x8x8xf32>, tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, i1, !aten.list<i32>, i32) -> (tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, tensor<4xf32>)
|
||||
return %10#1 : tensor<4x16x3x3xf32>
|
||||
}
|
||||
}
|
|
@ -1,26 +0,0 @@
|
|||
// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION
|
||||
// CHECK-CONVERSION-LABEL: @graph
|
||||
|
||||
module {
|
||||
func @graph(%arg0: tensor<10xf32>, %arg1: tensor<128xf32>, %arg2: tensor<4x1x28x28xf32>, %arg3: tensor<32x1x3x3xf32>, %arg4: tensor<32xf32>, %arg5: tensor<64x32x3x3xf32>, %arg6: tensor<64xf32>, %arg7: tensor<128x9216xf32>, %arg8: tensor<10x128xf32>) -> tensor<4x10xf32> {
|
||||
%0 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%1 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%2 = "aten.constant"() {type = "bool", value = false} : () -> i1
|
||||
%3 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
|
||||
%4 = "aten.constant"() {type = "bool", value = true} : () -> i1
|
||||
%5 = "aten._convolution"(%arg2, %arg3, %arg4, %0, %1, %0) {layer_name = "L0-_convolution-0"} : (tensor<4x1x28x28xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<4x32x26x26xf32>
|
||||
%6 = "aten.relu"(%5) {layer_name = "L1-relu-0"} : (tensor<4x32x26x26xf32>) -> tensor<4x32x26x26xf32>
|
||||
%7 = "aten._convolution"(%6, %arg5, %arg6, %0, %1, %0) {layer_name = "L2-_convolution-1"} : (tensor<4x32x26x26xf32>, tensor<64x32x3x3xf32>, tensor<64xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<4x64x24x24xf32>
|
||||
%8 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%9:2 = "aten.max_pool2d_with_indices"(%7, %8, %8, %1, %0, %2) {layer_name = "L3-max_pool2d_with_indices-0"} : (tensor<4x64x24x24xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, i1) -> (tensor<4x64x12x12xf32>, tensor<4x64x12x12xi64>)
|
||||
%10 = "aten.constant"() {type = "List[i32]", value = dense<[4, 9216]> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%11 = "aten.view"(%9#0, %10) {layer_name = "L4-view-0"} : (tensor<4x64x12x12xf32>, !aten.list<i32>) -> tensor<4x9216xf32>
|
||||
%12 = "aten.t"(%arg7) {layer_name = "L5-t-0"} : (tensor<128x9216xf32>) -> tensor<9216x128xf32>
|
||||
%13 = "aten.addmm"(%arg1, %11, %12, %3, %3) {layer_name = "L6-addmm-0"} : (tensor<128xf32>, tensor<4x9216xf32>, tensor<9216x128xf32>, i32, i32) -> tensor<4x128xf32>
|
||||
%14 = "aten.relu"(%13) {layer_name = "L7-relu-1"} : (tensor<4x128xf32>) -> tensor<4x128xf32>
|
||||
%15 = "aten.t"(%arg8) {layer_name = "L8-t-1"} : (tensor<10x128xf32>) -> tensor<128x10xf32>
|
||||
%16 = "aten.addmm"(%arg0, %14, %15, %3, %3) {layer_name = "L9-addmm-1"} : (tensor<10xf32>, tensor<4x128xf32>, tensor<128x10xf32>, i32, i32) -> tensor<4x10xf32>
|
||||
%17 = "aten._log_softmax"(%16, %3, %2) {layer_name = "L10-_log_softmax-0"} : (tensor<4x10xf32>, i32, i1) -> tensor<4x10xf32>
|
||||
return %17 : tensor<4x10xf32>
|
||||
}
|
||||
}
|
|
@ -13,3 +13,86 @@ func @graph(%arg0: !numpy.ndarray<*:?>, %arg1 : !numpy.ndarray<*:?>, %arg2 : si6
|
|||
// CHECK: return %[[RESULT_MUT]]
|
||||
return %0 : !numpy.ndarray<*:?>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @nll_loss2d_forward
|
||||
// Contains a Tensor? type mapped to None.
|
||||
func @nll_loss2d_forward(
|
||||
%arg0: !numpy.ndarray<[3,4,8,8]:f32>,
|
||||
%arg1: !numpy.ndarray<[3,8,8]:i64>,
|
||||
%arg2: !basicpy.NoneType,
|
||||
%arg3: i64,
|
||||
%arg4: i64) -> (!numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>) {
|
||||
// CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0
|
||||
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
|
||||
// CHECK: %[[TOUTPUT:.*]], %[[TTOTAL_WEIGHT:.*]] = "aten.nll_loss2d_forward"(%[[TARG0]], %[[TARG1]], %arg2, %arg3, %arg4) : (tensor<3x4x8x8xf32>, tensor<3x8x8xi64>, !basicpy.NoneType, i64, i64) -> (tensor<f32>, tensor<f32>)
|
||||
// CHECK: %[[AOUTPUT:.*]] = numpy.create_array_from_tensor %[[TOUTPUT]]
|
||||
// CHECK: %[[ATOTAL_WEIGHT:.*]] = numpy.create_array_from_tensor %[[TTOTAL_WEIGHT]]
|
||||
%0:2 = torch.kernel_call "aten::nll_loss2d_forward"
|
||||
%arg0, %arg1, %arg2, %arg3, %arg4 :
|
||||
(!numpy.ndarray<[3,4,8,8]:f32>, !numpy.ndarray<[3,8,8]:i64>, !basicpy.NoneType, i64, i64) ->
|
||||
(!numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>)
|
||||
{sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor", "Tensor"]}
|
||||
// CHECK: return %[[AOUTPUT]], %[[ATOTAL_WEIGHT]]
|
||||
return %0#0, %0#1 : !numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @convolution
|
||||
// Contains a Tensor?, bool, int and list types.
|
||||
func @convolution(
|
||||
%arg0: !numpy.ndarray<[3,16,10,10]:f32>, %arg1: !numpy.ndarray<[4,16,3,3]:f32>,
|
||||
%arg2: !numpy.ndarray<[4]:f32>, %arg3: !basicpy.ListType, %arg4: !basicpy.ListType,
|
||||
%arg5: !basicpy.ListType, %arg6: i1, %arg7: !basicpy.ListType, %arg8: i64) -> !numpy.ndarray<[3,4,8,8]:f32> {
|
||||
// CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0
|
||||
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
|
||||
// CHECK: %[[TARG2:.*]] = numpy.copy_to_tensor %arg2
|
||||
// CHECK: %[[TRESULT:.*]] = "aten.convolution"(%[[TARG0]], %[[TARG1]], %[[TARG2]], %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, tensor<4xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> tensor<3x4x8x8xf32>
|
||||
%0 = torch.kernel_call "aten::convolution"
|
||||
%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8 :
|
||||
(!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>,
|
||||
!numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType,
|
||||
!basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32>
|
||||
{sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
return %0 : !numpy.ndarray<[3,4,8,8]:f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @convolution_backward
|
||||
// Interesting because it has optional tensor returns.
|
||||
func @convolution_backward(
|
||||
%arg0: !numpy.ndarray<[3,4,8,8]:f32>,
|
||||
%arg1: !numpy.ndarray<[3,16,10,10]:f32>,
|
||||
%arg2: !numpy.ndarray<[4,16,3,3]:f32>,
|
||||
%arg3: !basicpy.ListType,
|
||||
%arg4: !basicpy.ListType,
|
||||
%arg5: !basicpy.ListType,
|
||||
%arg6: i1,
|
||||
%arg7: !basicpy.ListType,
|
||||
%arg8: i64,
|
||||
%arg9: !basicpy.ListType) -> (!basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>) {
|
||||
// CHECK: %[[GRAD_INPUT:.*]], %[[GRAD_WEIGHT:.*]], %[[GRAD_BIAS:.*]] = "aten.convolution_backward"
|
||||
// Note that this kernel call masks out the input gradients, which will return as NoneType.
|
||||
%0:3 = torch.kernel_call "aten::convolution_backward"
|
||||
%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9 :
|
||||
(!numpy.ndarray<[3,4,8,8]:f32>, !numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64, !basicpy.ListType) ->
|
||||
(!basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>) {sigArgTypes = ["Tensor", "Tensor", "Tensor", "int[]", "int[]", "int[]", "bool", "int[]", "int", "bool[]"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor", "Tensor", "Tensor"]}
|
||||
// CHECK: %[[AGRAD_WEIGHT:.*]] = numpy.create_array_from_tensor %[[GRAD_WEIGHT]]
|
||||
// CHECK: %[[AGRAD_BIAS:.*]] = numpy.create_array_from_tensor %[[GRAD_BIAS]]
|
||||
// Key thing: The return returns the raw NoneType from the masked input gradient
|
||||
// and it does not get converted to an array.
|
||||
// CHECK: return %[[GRAD_INPUT]], %[[AGRAD_WEIGHT]], %[[AGRAD_BIAS]]
|
||||
return %0#0, %0#1, %0#2 : !basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @copy_inplace
|
||||
// Mutable/in-place op conversion, dropping result.
|
||||
func @copy_inplace(%arg0: !numpy.ndarray<[4]:f32>, %arg1: !numpy.ndarray<[4]:f32>) -> !numpy.ndarray<[4]:f32> {
|
||||
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
|
||||
// CHECK: "aten.copy.inplace"(%arg0, %[[TARG1]]) : (!numpy.ndarray<[4]:f32>, tensor<4xf32>) -> ()
|
||||
%0 = torch.kernel_call "aten::copy_" %arg0, %arg1 : (!numpy.ndarray<[4]:f32>, !numpy.ndarray<[4]:f32>) -> !numpy.ndarray<[4]:f32> {sigArgTypes = ["Tensor", "Tensor", "bool"], sigIsMutable = true, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
// CHECK: return %arg0
|
||||
return %0 : !numpy.ndarray<[4]:f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue