forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
python_sugared_value.cpp
1265 lines (1140 loc) · 44.9 KB
/
python_sugared_value.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <torch/csrc/jit/python/python_sugared_value.h>
#include <pybind11/pytypes.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/Layout.h>
#include <torch/csrc/MemoryFormat.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/python/module_python.h>
#include <climits>
#include <memory>
#include <sstream>
#include <string>
#include <tuple>
#include <vector>
#include <Python.h>
namespace torch {
namespace jit {
std::string typeString(py::handle h) {
return py::str(h.get_type().attr("__name__"));
}
c10::optional<StrongFunctionPtr> as_function(const py::object& obj) {
if (py::isinstance<StrongFunctionPtr>(obj)) {
return py::cast<StrongFunctionPtr>(obj);
}
return c10::nullopt;
}
FunctionSchema PythonValue::getSchema(
const size_t n_args,
const size_t n_binders,
const SourceRange& loc) {
auto annotations = py::module::import("torch.jit.annotations");
const auto callable = moduleSelf_ ? py::getattr(self, "original_fn") : self;
// Make sure the function is not a class instantiation (e.g. `Exception()`)
annotations.attr("check_fn")(callable, loc);
auto is_vararg = py::cast<bool>(annotations.attr("is_vararg")(callable));
auto signature = annotations.attr("get_signature")(
callable, rcb ? *rcb : py::none(), loc, bool(moduleSelf_));
std::vector<Argument> args, rets;
auto py_param_names = annotations.attr("get_param_names")(callable, n_args);
auto param_names = py::cast<std::vector<std::string>>(py_param_names);
auto names_it = param_names.begin();
if (moduleSelf_) {
if (param_names.size() == 0) {
throw ErrorReport(loc)
<< "Non-static method does not have a self argument";
}
// If there is a `self` parameter on the callable, skip it on the names list
args.emplace_back(Argument(*names_it, moduleSelf_->type(), {}, {}, false));
++names_it;
}
if (signature.is_none()) {
// No type signature was provided on the callable, so make a default
// signature where each argument is typed as a Tensor
for (; names_it != param_names.end(); ++names_it) {
args.emplace_back(Argument(
/*name=*/*names_it,
/*type=*/TensorType::get(),
/*N=*/c10::nullopt,
/*default_value=*/c10::nullopt,
/*kwarg_only=*/false));
}
// Use as many outputs as are requested to make the return type
TypePtr ret_type = TensorType::get();
if (n_binders == 0) {
ret_type = NoneType::get();
} else if (n_binders > 1) {
std::vector<TypePtr> tuple_values(n_binders, ret_type);
ret_type = TupleType::create(std::move(tuple_values));
}
rets.emplace_back(Argument("0", ret_type, {}, {}, false));
} else {
// Use the provided type signature
std::vector<TypePtr> arg_types;
TypePtr ret_type;
std::tie(arg_types, ret_type) =
py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
// arg_types does not include self but param_names does, so adjust for that
// if needed
TORCH_INTERNAL_ASSERT(
arg_types.size() == param_names.size() - (moduleSelf_ ? 1 : 0));
auto types_it = arg_types.begin();
for (; types_it != arg_types.end(); ++types_it, ++names_it) {
args.emplace_back(
/*name=*/*names_it,
/*type=*/std::move(*types_it),
/*N=*/c10::nullopt,
/*default_value=*/c10::nullopt,
/*kwarg_only=*/false);
}
rets.push_back(Argument("0", std::move(ret_type), {}, {}, false));
}
std::string name;
if (py::hasattr(self, "__qualname__")) {
// Use the qualified name if possible
name = py::str(py::getattr(self, "__qualname__"));
} else if (py::hasattr(self, "__name__")) {
name = py::str(py::getattr(self, "__name__"));
}
return FunctionSchema(name, "", std::move(args), std::move(rets), is_vararg);
}
std::shared_ptr<SugaredValue> PythonValue::call(
const SourceRange& loc,
Function& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
std::vector<NamedValue> argsWithSelf;
if (moduleSelf_) {
argsWithSelf.emplace_back(NamedValue("self", moduleSelf_));
}
argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end());
auto schema = getSchema(argsWithSelf.size(), n_binders, loc);
auto inputs = toValues(*m.graph(), argsWithSelf);
MatchedSchema matched_schema =
matchSchema(schema, loc, *m.graph(), argsWithSelf, kwargs);
// If if a function is marked as dropped,
// we throw an exception if it is invoked.
if (py::cast<bool>(py::module::import("torch._jit_internal")
.attr("should_drop")(self))) {
auto g = m.graph();
auto err_msg = insertConstant(
*g,
IValue(
"This Python function is annotated to be ignored and cannot be run"));
g->insert(prim::RaiseException, {err_msg}, {}, loc);
return std::make_shared<SimpleValue>(
g->insertNode(g->createUninitialized(matched_schema.return_types.at(0)))
->output());
}
// Release the function object so we can wrap it in a PythonOp
py::object func = self;
std::string cconv(inputs.size(), 'd');
Node* new_node = m.graph()->insertNode(
m.graph()->createPythonOp(THPObjectPtr(func.release().ptr()), cconv, {}));
new_node->setSourceRange(loc);
for (auto& i : matched_schema.inputs)
new_node->addInput(i);
Value* output =
new_node->addOutput()->setType(matched_schema.return_types.at(0));
return std::make_shared<SimpleValue>(output);
}
std::string PythonValue::kind() const {
std::stringstream ss;
ss << "python value of type '" << typeString(self) << "'";
return ss.str();
}
std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
const SourceRange& loc,
Function& m,
const c10::optional<size_t>& size_hint) {
const std::string type_str = typeString(self);
std::stringstream ss;
ss << kind() << " cannot be used as a tuple";
checkForAddToConstantsError(ss);
throw ErrorReport(loc) << ss.str();
}
std::shared_ptr<SugaredValue> PythonValue::attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
const std::string type_str = typeString(self);
std::stringstream ss;
ss << "attribute lookup is not defined on " << kind();
checkForAddToConstantsError(ss);
throw ErrorReport(loc) << ss.str();
}
py::object PythonValue::getattr(
const SourceRange& loc,
const std::string& name) {
try {
return py::getattr(self, name.c_str());
} catch (py::error_already_set& e) {
throw ErrorReport(loc) << "object has no attribute " << name;
}
}
void PythonValue::checkForAddToConstantsError(std::stringstream& ss) {
auto nn = py::module::import("torch.nn");
if (py::isinstance(self, nn.attr("ModuleList")) ||
py::isinstance(self, nn.attr("Sequential"))) {
ss << ". Did you forget to add it to __constants__? ";
}
}
std::shared_ptr<SugaredValue> PythonModuleValue::attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
py::object member = getattr(loc, field);
// note: is_constant = true because we consider that global properties
// on modules like math.pi or torch.float to be constants
// even though it is possible, though rare, for someone to mutate them
return toSugaredValue(member, m, loc, /*is_constant=*/true);
}
#ifndef __HIP_PLATFORM_HCC__
std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
// List of all the cuda operators which are supported in JIT
const std::unordered_set<std::string> cuda_ops = {
"current_stream",
"default_stream",
"current_device",
"set_device",
"device_index",
"device_count",
"set_stream",
"synchronize"};
if (cuda_ops.find(field) != cuda_ops.end()) {
// Both current_device and set_device API's are a part of c10::cuda
// namespace. Hence, to resolve the conflict for jit, we append _ to both
// these APIs.
if (field == "current_device" || field == "set_device") {
return std::make_shared<BuiltinFunction>(
Symbol::cuda("_" + field), c10::nullopt);
} else {
return std::make_shared<BuiltinFunction>(
Symbol::cuda(field), c10::nullopt);
}
}
if (field == "Stream" || field == "Event") {
auto class_type = getCustomClass("__torch__.torch.classes.cuda." + field);
return std::make_shared<ClassValue>(class_type);
}
py::object member = getattr(loc, field);
// note: is_constant = true because we consider that global properties
// on modules like math.pi or torch.float to be constants
// even though it is possible, though rare, for someone to mutate them
return toSugaredValue(member, m, loc, /*is_constant=*/true);
}
#endif
Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
return self_;
}
SugaredValuePtr ModuleValue::asTupleValue(const SourceRange& loc, Function& m) {
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
auto dict = getSugaredDict(loc, m);
auto mods = dict->getModules();
return mods;
}
throw ErrorReport(loc)
<< "Only ModuleList or Sequential modules can be used as tuple";
}
bool ModuleValue::areAllSubmodulesSubtypeOf(
const TypePtr& ty,
std::ostream* why_not) const {
const auto& self_type = concreteType_->getJitType()->expect<ClassType>();
for (size_t i = 0; i < self_type->numAttributes(); ++i) {
const auto& attr_type = self_type->getAttribute(i);
if (attr_type->is_module()) {
std::stringstream ss;
if (!attr_type->isSubtypeOfExt(ty, &ss)) {
if (why_not) {
*why_not << "Attribute " << self_type->getAttributeName(i)
<< " is not of annotated type " << ty->annotation_str()
<< ": " << ss.str();
}
return false;
}
}
}
return true;
}
SugaredValuePtr ModuleValue::getitem(
const SourceRange& loc,
Function& m,
Value* idx,
TypePtr type_hint) {
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
if (type_hint) {
// Check that all submodules comply with the type hint.
std::stringstream ss;
if (!areAllSubmodulesSubtypeOf(type_hint, &ss)) {
throw ErrorReport(loc) << ss.str();
}
// Emit a prim::ModuleContainerIndex operator. This is needed because
// it's difficult to construct a list in the graph representing the
// ModuleList and use aten::__getitem__ ops to index into it because
// any call to ModuleList.setitem would invalidate that emitted list.
auto graph = m.graph();
auto* getitem_node = graph->insertNode(
graph->create(prim::ModuleContainerIndex, {self_, idx}));
getitem_node->output(0)->setType(type_hint);
return std::make_shared<SimpleValue>(getitem_node->output(0));
} else {
return getSugaredDict(loc, m)->getModules()->getitem(
loc, m, idx, type_hint);
}
} else if (
concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
if (auto ivalue = toIValue(idx)) {
auto sd = getSugaredDict(loc, m);
auto idx_str = ivalue->toStringRef();
auto keys_iter = sd->keys_;
auto module_values_iter = sd->modules_;
for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
auto key = keys_iter->tup_.at(i);
auto key_str = toIValue(key->asValue(loc, m))->toStringRef();
if (key_str == idx_str) {
return module_values_iter->tup_.at(i);
}
}
throw ErrorReport(loc) << "Key Error, " << idx_str;
} else if (type_hint) {
// Check that all submodules comply with the type hint.
std::stringstream ss;
if (!areAllSubmodulesSubtypeOf(type_hint, &ss)) {
throw ErrorReport(loc) << ss.str();
}
// Emit a prim::ModuleContainerIndex operator. This is needed because
// it's difficult to construct a dict in the graph representing the
// ModuleDict and use aten::__getitem__ ops to index into it because
// any call to ModuleDict.setAttr would invalidate that emitted dict.
auto graph = m.graph();
auto* getitem_node = graph->insertNode(
graph->create(prim::ModuleContainerIndex, {self_, idx}));
getitem_node->output(0)->setType(type_hint);
return std::make_shared<SimpleValue>(getitem_node->output(0));
}
throw ErrorReport(loc)
<< "Unable to extract string literal index. "
<< "ModuleDict indexing is only supported with string literals. "
<< "Enumeration of ModuleDict is supported, e.g. 'for k, v in self.items(): ...'";
}
throw ErrorReport(loc)
<< "Only ModuleList, Sequential, and ModuleDict modules are subscriptable";
}
void checkInterface(
const SourceRange& loc,
Function& m,
const std::shared_ptr<ModuleValue>& self,
const std::string& field) {
if (self->asValue(loc, m)->type()->cast<InterfaceType>()) {
throw ErrorReport(loc)
<< "Could not compile " << field
<< "() because module is an interface type. Please file issue.";
}
}
void recurseThroughNestedModules(
const SourceRange& loc,
Function& m,
std::vector<SugaredValuePtr>& keys,
std::vector<SugaredValuePtr>& values,
std::shared_ptr<ModuleValue>& self,
const std::string& prefix,
const std::string& field) {
auto prefix_value =
std::make_shared<SimpleValue>(insertConstant(*m.graph(), prefix));
keys.push_back(prefix_value);
values.push_back(self);
checkInterface(loc, m, self, field);
auto module_dict = self->getSugaredDict(loc, m);
auto keys_iter = module_dict->keys_;
auto module_values_iter = module_dict->modules_;
for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
std::shared_ptr<SugaredValue> module_sugared_value =
module_values_iter->tup_.at(i);
auto module_value =
std::dynamic_pointer_cast<ModuleValue>(module_sugared_value);
auto keys_value = keys_iter->tup_.at(i);
auto key_string = toIValue(keys_value->asValue(loc, m))->toStringRef();
std::string submodule_prefix = prefix;
if (prefix != "") {
submodule_prefix = prefix + ".";
}
submodule_prefix += key_string;
recurseThroughNestedModules(
loc, m, keys, values, module_value, submodule_prefix, field);
};
}
std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
const SourceRange& loc,
Function& m) {
std::vector<std::string> paramNames;
std::vector<SugaredValuePtr> values;
const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
for (size_t i = 0; i < selfType->numAttributes(); ++i) {
if (selfType->is_buffer(i)) {
paramNames.push_back(selfType->getAttributeName(i));
}
}
std::vector<SugaredValuePtr> keys;
for (const auto& name : paramNames) {
auto name_v =
std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
m.graph()->insertGetAttr(self_, name);
values.push_back(tryGetAttr(loc, m, name));
keys.push_back(name_v);
}
return std::make_shared<SugaredDict>(
std::make_shared<ModuleValue>(self_, concreteType_),
std::make_shared<SugaredTupleValue>(keys),
std::make_shared<SugaredTupleValue>(values));
}
std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
const SourceRange& loc,
Function& m) {
std::vector<std::string> submoduleNames;
const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
for (size_t i = 0; i < selfType->numAttributes(); ++i) {
const auto& attrType = selfType->getAttribute(i);
if (attrType->is_module()) {
submoduleNames.push_back(selfType->getAttributeName(i));
}
}
std::vector<SugaredValuePtr> keys;
std::vector<SugaredValuePtr> values;
for (const auto& name : submoduleNames) {
auto name_v =
std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
Value* module_v = m.graph()->insertGetAttr(self_, name);
auto mod_v = std::make_shared<ModuleValue>(
module_v, concreteType_->findSubmoduleConcreteType(name));
keys.push_back(name_v);
values.push_back(mod_v);
}
return std::make_shared<SugaredDict>(
std::make_shared<ModuleValue>(self_, concreteType_),
std::make_shared<SugaredTupleValue>(keys),
std::make_shared<SugaredTupleValue>(values));
}
std::shared_ptr<SugaredValue> SugaredDict::attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
// Recursive compilation does not maintain module aliasing,
// so we do not add uniqueness checks on
// "children"/"named_children"/"modules"/"named_modules"
checkInterface(loc, m, self_, field);
if (field == "keys") {
return std::make_shared<ModuleDictMethod>(keys_, "keys");
} else if (field == "values" || field == "children") {
return std::make_shared<ModuleDictMethod>(modules_, field);
} else if (
field == "items" || field == "named_children" ||
field == "named_buffers") {
auto iterator = std::make_shared<IterableTree>();
iterator->addChild(loc, m, keys_);
iterator->addChild(loc, m, modules_);
return std::make_shared<ModuleDictMethod>(iterator, field);
} else if (field == "named_modules" || field == "modules") {
std::vector<SugaredValuePtr> keys;
std::vector<SugaredValuePtr> values;
recurseThroughNestedModules(loc, m, keys, values, self_, "", field);
if (field == "modules") {
return std::make_shared<ModuleDictMethod>(
std::make_shared<SugaredTupleValue>(values), field);
} else {
auto iterator = std::make_shared<IterableTree>();
iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(keys));
iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(values));
return std::make_shared<ModuleDictMethod>(iterator, field);
}
}
TORCH_INTERNAL_ASSERT(false);
}
std::shared_ptr<SugaredEnumClass> createSugaredEnumClassFromObj(
const py::object& obj,
Function& m,
const SourceRange& loc) {
auto annotation_type = py::module::import("torch.jit.annotations")
.attr("try_ann_to_type")(obj, loc);
TORCH_INTERNAL_ASSERT(!annotation_type.is_none());
auto type = py::cast<TypePtr>(annotation_type);
auto enum_type = type->expect<EnumType>();
return std::make_shared<SugaredEnumClass>(enum_type);
}
// helper function for instantiating a SugaredValue from an IValue
std::shared_ptr<SugaredValue> toSugaredValue(
const IValue& v,
Function& m,
const SourceRange& loc) {
if (v.isTuple()) {
auto tp = v.toTuple();
std::vector<Value*> values;
values.reserve(tp->elements().size());
for (const auto& e : tp->elements()) {
values.push_back(toSugaredValue(e, m, loc)->asValue(loc, m));
}
return toSimple(
m.graph()->insertNode(m.graph()->createTuple(values))->output());
} else {
return toSimple(m.graph()->insertConstant(v, loc));
}
}
// This method controls how we desugar attribute lookups on ScriptModules
std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
const SourceRange& loc,
Function& m,
const std::string& field) {
// 1. Look inside Module object for the field.
const auto& selfType_ = concreteType_->getJitType();
if (selfType_->cast<InterfaceType>()) {
return std::make_shared<SimpleValue>(self_)->attr(loc, m, field);
}
const auto& selfType = selfType_->expect<ClassType>();
if (selfType->hasAttribute(field) &&
selfType->getAttribute(field)->is_module()) {
// ...if it's a submodule, return it as a new ModuleValue.
if (const auto submoduleConcreteType =
concreteType_->findSubmoduleConcreteType(field)) {
return std::make_shared<ModuleValue>(
m.graph()->insertGetAttr(self_, field), submoduleConcreteType);
}
return std::make_shared<ModuleValue>(
m.graph()->insertGetAttr(self_, field),
ConcreteModuleType::fromJitType(selfType->getAttribute(field)));
} else if (selfType->hasAttribute(field) || selfType->findMethod(field)) {
// ...otherwise, methods, parameters, attributes, and buffers are all
// first class so they get returned as SimpleValues
return std::make_shared<SimpleValue>(self_)->attr(loc, m, field);
} else if (selfType->hasConstant(field)) {
auto v = selfType->getConstant(field);
return toSugaredValue(v, m, loc);
}
// 2. Special case: for module dicts we manually desugar items(), keys(),
// values() calls into the appropriate method.
if (concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
if (field == "items" || field == "keys" || field == "values") {
return getSugaredDict(loc, m)->attr(loc, m, field);
}
}
if (field == "named_modules" || field == "modules" || field == "children" ||
field == "named_children") {
return getSugaredDict(loc, m)->attr(loc, m, field);
}
if (field == "named_buffers") {
return getSugaredNamedBufferDict(loc, m)->attr(loc, m, field);
}
// 3. Check if this is the name of an overloaded method.
// This can also be a call to a non-script module, or a plain
// python method. If so return this as a python value.
if (const auto overloads = concreteType_->findOverloads(field)) {
return std::make_shared<MethodValue>(self_, *overloads);
}
// 4. Check if it's a function attribute.
if (const auto fnAttr = concreteType_->findFunctionAttribute(field)) {
return std::make_shared<FunctionValue>(*fnAttr);
} else if (const auto builtin = concreteType_->findBuiltinFunction(field)) {
return std::make_shared<BuiltinFunction>(*builtin, /*self=*/c10::nullopt);
}
// 5. Check if it's an attribute of the original Python class that this
// ScriptModule was derived from. The only class attributes we handle are
// methods.
const auto maybePyClass = concreteType_->getPyClass();
if (!maybePyClass) {
// ConcreteType doesn't always have an originating Python class, e.g. if it
// was derived from a serialized ScriptModule. In this case, we've exhausted
// our options for attr lookup.
return nullptr;
}
py::object unboundMethod = py::getattr(
*maybePyClass, field.c_str(), pybind11::cast<pybind11::none>(Py_None));
if (py::isinstance<py::function>(unboundMethod)) {
bool isStaticFn =
py::cast<bool>(py::module::import("torch._jit_internal")
.attr("is_static_fn")(*maybePyClass, field.c_str()));
if (isStaticFn) {
// Functions within the module annotated with @staticmethod do not need
// binding.
py::object staticFn =
py::module::import("torch._jit_internal")
.attr("get_static_fn")(*maybePyClass, field.c_str());
return toSugaredValue(staticFn, m, loc);
}
// For Python methods that we're trying to call directly, we need to bind
// the method to a self. (see the documentation for lazy_bind in Python for
// more info).
bool isIgnoredFn =
py::cast<bool>(py::module::import("torch._jit_internal")
.attr("is_ignored_fn")(unboundMethod));
if (isIgnoredFn) {
// Create a generated ScriptModule type with module_ set as cpp_module
auto boundMethod = py::module::import("torch.jit._recursive")
.attr("lazy_bind")(concreteType_, unboundMethod);
TORCH_CHECK(py::isinstance<py::function>(boundMethod));
auto rcb =
py::module::import("torch._jit_internal")
.attr("createResolutionCallbackFromClosure")(unboundMethod);
return std::make_shared<PythonValue>(boundMethod, rcb, self_);
}
// If we reach here, it's because this is a "normal" method that just hasn't
// been compiled yet (directly exported methods would have been returned by
// step 1). Just compile it.
auto stub =
py::module::import("torch.jit._recursive")
.attr("compile_unbound_method")(concreteType_, unboundMethod);
TORCH_INTERNAL_ASSERT(!stub.is_none());
// Look up the attribute again, it will be available as a compiled method.
return attr(loc, m, field);
}
return nullptr;
}
bool ModuleValue::hasAttr(
const SourceRange& loc,
Function& m,
const std::string& field) {
return tryGetAttr(loc, m, field) != nullptr;
}
std::shared_ptr<SugaredValue> ModuleValue::call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
c10::ClassTypePtr class_type = concreteType_->getJitType()->cast<ClassType>();
bool have_pre_hooks =
class_type && class_type->getForwardPreHooks().size() != 0;
bool have_hooks = class_type && class_type->getForwardHooks().size() != 0;
std::vector<Value*> arg_values;
std::vector<NamedValue> pre_hook_result;
Value* forward_input = nullptr;
std::shared_ptr<Graph> calling_graph = caller.graph();
if (have_pre_hooks || have_hooks) {
// convert forward args into tuple for forward hooks
// (the input of eager hooks are always tuples)
for (const auto& sv : args) {
arg_values.push_back(sv.value(*calling_graph));
}
forward_input =
calling_graph->insertNode(calling_graph->createTuple(arg_values))
->output();
}
// call pre_hooks
if (have_pre_hooks) {
for (const auto& hook : class_type->getForwardPreHooks()) {
TORCH_INTERNAL_ASSERT(forward_input != nullptr);
Value* pre_hook_output =
FunctionValue(hook)
.call(
loc,
caller,
{NamedValue(self_), NamedValue(forward_input)},
kwargs,
n_binders)
->asValue(loc, caller);
if (pre_hook_output->type() != NoneType::get()) {
if (pre_hook_output->type()->kind() != TypeKind::TupleType) {
pre_hook_output =
calling_graph
->insertNode(calling_graph->createTuple({pre_hook_output}))
->output();
}
forward_input = pre_hook_output;
}
}
// de-tuple pre_hook output for forward
at::ArrayRef<Value*> output_nodes =
calling_graph
->insertNode(calling_graph->createTupleUnpack(forward_input))
->outputs();
for (auto& output_node : output_nodes) {
pre_hook_result.emplace_back(NamedValue(output_node));
}
if (args.size() != 0) { // only replace input if it existed
args = pre_hook_result;
}
}
// call forward
std::shared_ptr<SugaredValue> forwardSV =
attr(loc, caller, "forward")->call(loc, caller, args, kwargs, n_binders);
Value* forward_output = forwardSV->asValue(loc, caller);
// call hooks
if (have_hooks) {
for (const auto& hook : class_type->getForwardHooks()) {
Value* forward_hook_output = FunctionValue(hook)
.call(
loc,
caller,
{NamedValue(self_),
NamedValue(forward_input),
NamedValue(forward_output)},
kwargs,
n_binders)
->asValue(loc, caller);
if (forward_hook_output->type() != NoneType::get()) {
forward_output = forward_hook_output;
}
}
}
return std::make_shared<SimpleValue>(forward_output);
}
// This method controls how we desugar attribute lookups on ScriptModules.
std::shared_ptr<SugaredValue> ModuleValue::attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
if (auto attr = tryGetAttr(loc, m, field)) {
return attr;
}
// Check if it's a property.
auto prop =
concreteType_->getJitType()->expectRef<ClassType>().getProperty(field);
if (prop) {
return MethodValue(self_, prop->getter->name())
.call(loc, m, {}, {}, /*n_binders=*/1);
}
// We don't define this attr. Bailout with a hint to the user.
std::string hint;
if (auto failureReason = concreteType_->findFailedAttribute(field)) {
hint = *failureReason;
} else if (concreteType_->isIgnoredAttribute(field)) {
hint = "attribute was ignored during compilation";
}
throw ErrorReport(loc)
<< "Module '"
<< concreteType_->getJitType()->expectRef<ClassType>().name()->name()
<< "'"
<< " has no attribute '" << field << "' " << hint;
}
SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
const auto iterableModuleKind = concreteType_->getIterableModuleKind();
if (iterableModuleKind == IterableModuleKind::NONE) {
throw ErrorReport(loc)
<< "Only constant Sequential, ModueList, or ModuleDict can be used as an iterable";
}
auto module_dict = getSugaredDict(loc, m);
if (iterableModuleKind == IterableModuleKind::DICT) {
return module_dict->keys_;
} else if (iterableModuleKind == IterableModuleKind::LIST) {
return module_dict->modules_;
} else {
TORCH_INTERNAL_ASSERT(false);
}
}
std::shared_ptr<SugaredValue> PythonClassValue::attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
// Resolve values from the Python object first (e.g. for static methods on
// this type, resolve them as functions)
if (auto* fn = type_->findStaticMethod(field)) {
return std::make_shared<FunctionValue>(fn);
}
auto py_attr = py::getattr(py_type_, field.c_str(), py::none());
if (!py_attr.is_none()) {
return toSugaredValue(py_attr, m, loc);
}
return ClassValue::attr(loc, m, field);
}
bool PythonClassValue::hasAttr(
const SourceRange& loc,
Function& m,
const std::string& field) {
try {
py::getattr(py_type_, field.c_str());
return true;
} catch (py::error_already_set& e) {
return false;
}
}
void ModuleValue::setAttr(
const SourceRange& loc,
Function& m,
const std::string& field,
Value* newValue) {
// Forward to SimpleValue::setAttr
SimpleValue simple(self_);
simple.setAttr(loc, m, field, newValue);
}
std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
c10::optional<bool> result;
Graph& graph = *(caller.graph());
auto index = py::cast<size_t>(dispatched_fn_["index"]);
auto arg_name = py::str(dispatched_fn_["arg_name"]);
ErrorReport error(loc);
if (index < args.size()) {
// Dispatch flag is in arg list
result = constant_as<bool>(args.at(index).value(graph));
error << "Argument for boolean dispatch at position " << index
<< " was not constant";
} else if (auto i = findInputWithName(arg_name, kwargs)) {
// Dispatch flag is in kwargs
result = constant_as<bool>(kwargs[*i].value(graph));
error << "Keyword argument '" << arg_name
<< "' for boolean dispatch at position was not constant";
} else {
// Didn't find dispatch flag, so use default value
result = py::cast<bool>(dispatched_fn_["default"]);
TORCH_INTERNAL_ASSERT(result);
}
if (!result.has_value()) {
throw error;
}
std::shared_ptr<SugaredValue> value;
if (*result) {
value = toSugaredValue(dispatched_fn_["if_true"], caller, loc);
} else {
value = toSugaredValue(dispatched_fn_["if_false"], caller, loc);
}
return value->call(loc, caller, args, kwargs, n_binders);
}
std::shared_ptr<SugaredValue> PythonExceptionValue::call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t /*n_binders*/) {
Value* error_message = nullptr;
if (args.size() == 0) {
error_message = insertConstant(*caller.graph(), "", loc);
} else if (args.size() == 1) {
error_message = args.at(0).value(*caller.graph());
} else {
std::vector<Value*> message_values;
message_values.reserve(args.size() + kwargs.size());
for (const auto& inp : args) {
message_values.push_back(inp.value(*caller.graph()));
}
for (const auto& kwarg_inp : kwargs) {
message_values.push_back(kwarg_inp.value(*caller.graph()));
}
error_message =
caller.graph()
->insertNode(caller.graph()->createTuple(message_values))
->output();
}
return std::make_shared<ExceptionMessageValue>(error_message);
}
bool isNamedTupleClass(const py::object& obj) {
auto tuple_type = reinterpret_cast<PyObject*>(&PyTuple_Type);
int is_tuple_class = PyObject_IsSubclass(obj.ptr(), tuple_type);
if (is_tuple_class == -1) {
PyErr_Clear();
return false;
}
return is_tuple_class == 1 && py::hasattr(obj, "_fields");
}
TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) {
TORCH_INTERNAL_ASSERT(isNamedTupleClass(obj));
auto qualifiedName = c10::QualifiedName(py::cast<std::string>(
py::module::import("torch._jit_internal").attr("_qualified_name")(obj)));
// Currently don't support default values
if (py::hasattr(obj, "_field_defaults")) {
auto default_dict = py::cast<std::map<std::string, py::object>>(
py::getattr(obj, "_field_defaults"));
if (default_dict.size()) {
std::string error_msg =
"Default values are currently not supported"
" on NamedTuple fields in TorchScript. Fields "
"with default values: [";
bool first = true;
for (const auto& kv : default_dict) {
if (!first) {
error_msg += ", ";
}
error_msg += kv.first;
}
error_msg += "]";
throw ErrorReport(loc) << error_msg;
}
}
py::object props = py::module::import("torch._jit_internal")
.attr("_get_named_tuple_properties")(obj);
std::string unqualName;
std::vector<std::string> fields;
std::vector<TypePtr> annotations;
std::tie(unqualName, fields, annotations) = py::cast<
std::tuple<std::string, decltype(fields), decltype(annotations)>>(props);
auto tt = TupleType::createNamed(qualifiedName, fields, annotations);
if (auto type = get_python_cu()->get_type(qualifiedName)) {
TORCH_CHECK(
type->isSubtypeOf(tt),
"Can't to redefine NamedTuple: ",
tt->repr_str());
return type;
}
get_python_cu()->register_type(tt);
return tt;
}
bool isEnumClass(py::object obj) {
auto enum_type_obj =
py::cast<py::object>(py::module::import("enum").attr("Enum"));
int ret = PyObject_IsSubclass(obj.ptr(), enum_type_obj.ptr());
if (ret == -1) {
PyErr_Clear();
return false;
}
return ret == 1;
}
std::shared_ptr<SugaredValue> createSimpleEnumValue(
const py::object& obj,
Function& m,
const SourceRange& loc) {
auto enum_class = obj.attr("__class__");
auto enum_type =
py::cast<TypePtr>(py::module::import("torch.jit.annotations")
.attr("try_ann_to_type")(enum_class, loc));
auto enum_ivalue = toIValue(obj, enum_type);
return toSimple(m.graph()->insertConstant(enum_ivalue, loc));
}
std::shared_ptr<SugaredValue> PythonSliceClass::call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> args,